Skip to content

Commit

Permalink
🐛 handle missing function calls for openai (#35) (#38)
Browse files Browse the repository at this point in the history
* 🐛 handle missing function calls for openai (#35)

* fix: black format
  • Loading branch information
dexhunter authored Jan 9, 2025
1 parent c92f194 commit 9c55a42
Showing 1 changed file with 62 additions and 17 deletions.
79 changes: 62 additions & 17 deletions aide/backend/backend_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,65 @@
openai.InternalServerError,
)

# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
SUPPORTED_FUNCTION_CALL_MODELS = {
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
}


@once
def _setup_openai_client():
global _client
_client = openai.OpenAI(max_retries=0)


def is_function_call_supported(model_name: str) -> bool:
"""Return True if the model supports function calling."""
return model_name in SUPPORTED_FUNCTION_CALL_MODELS


def query(
system_message: str | None,
user_message: str | None,
func_spec: FunctionSpec | None = None,
**model_kwargs,
) -> tuple[OutputType, float, int, int, dict]:
"""
Query the OpenAI API, optionally with function calling.
Function calling support is only checked for feedback/review operations.
"""
_setup_openai_client()
filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore
filtered_kwargs: dict = select_values(notnone, model_kwargs)
model_name = filtered_kwargs.get("model", "")
logger.debug(f"OpenAI query called with model='{model_name}'")

messages = opt_messages_to_list(system_message, user_message)

if func_spec is not None:
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
# force the model the use the function
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
# Only check function call support for feedback/search operations
if func_spec.name == "submit_review":
if not is_function_call_supported(model_name):
logger.warning(
f"Review function calling was requested, but model '{model_name}' "
"does not support function calling. Falling back to plain text generation."
)
filtered_kwargs.pop("tools", None)
filtered_kwargs.pop("tool_choice", None)
else:
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict

t0 = time.time()
completion = backoff_create(
Expand All @@ -53,22 +90,30 @@ def query(

choice = completion.choices[0]

if func_spec is None:
if func_spec is None or "tools" not in filtered_kwargs:
output = choice.message.content
else:
assert (
choice.message.tool_calls
), f"function_call is empty, it is not a function call: {choice.message}"
assert (
choice.message.tool_calls[0].function.name == func_spec.name
), "Function name mismatch"
try:
output = json.loads(choice.message.tool_calls[0].function.arguments)
except json.JSONDecodeError as e:
logger.error(
f"Error decoding the function arguments: {choice.message.tool_calls[0].function.arguments}"
tool_calls = getattr(choice.message, "tool_calls", None)

if not tool_calls:
logger.warning(
f"No function call used despite function spec. Fallback to text. "
f"Message content: {choice.message.content}"
)
output = choice.message.content
else:
first_call = tool_calls[0]
assert first_call.function.name == func_spec.name, (
f"Function name mismatch: expected {func_spec.name}, "
f"got {first_call.function.name}"
)
raise e
try:
output = json.loads(first_call.function.arguments)
except json.JSONDecodeError as e:
logger.error(
f"Error decoding function arguments:\n{first_call.function.arguments}"
)
raise e

in_tokens = completion.usage.prompt_tokens
out_tokens = completion.usage.completion_tokens
Expand Down

0 comments on commit 9c55a42

Please sign in to comment.