diff --git a/aide/backend/backend_openai.py b/aide/backend/backend_openai.py index e751e5e..db76938 100644 --- a/aide/backend/backend_openai.py +++ b/aide/backend/backend_openai.py @@ -19,6 +19,23 @@ openai.InternalServerError, ) +# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models +SUPPORTED_FUNCTION_CALL_MODELS = { + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-turbo-preview", + "gpt-4-0125-preview", + "gpt-4-1106-preview", + "gpt-3.5-turbo", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-1106", +} + @once def _setup_openai_client(): @@ -26,21 +43,41 @@ def _setup_openai_client(): _client = openai.OpenAI(max_retries=0) +def is_function_call_supported(model_name: str) -> bool: + """Return True if the model supports function calling.""" + return model_name in SUPPORTED_FUNCTION_CALL_MODELS + + def query( system_message: str | None, user_message: str | None, func_spec: FunctionSpec | None = None, **model_kwargs, ) -> tuple[OutputType, float, int, int, dict]: + """ + Query the OpenAI API, optionally with function calling. + Function calling support is only checked for feedback/review operations. + """ _setup_openai_client() - filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore + filtered_kwargs: dict = select_values(notnone, model_kwargs) + model_name = filtered_kwargs.get("model", "") + logger.debug(f"OpenAI query called with model='{model_name}'") messages = opt_messages_to_list(system_message, user_message) if func_spec is not None: - filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict] - # force the model the use the function - filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict + # Only check function call support for feedback/search operations + if func_spec.name == "submit_review": + if not is_function_call_supported(model_name): + logger.warning( + f"Review function calling was requested, but model '{model_name}' " + "does not support function calling. Falling back to plain text generation." + ) + filtered_kwargs.pop("tools", None) + filtered_kwargs.pop("tool_choice", None) + else: + filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict] + filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict t0 = time.time() completion = backoff_create( @@ -53,22 +90,30 @@ def query( choice = completion.choices[0] - if func_spec is None: + if func_spec is None or "tools" not in filtered_kwargs: output = choice.message.content else: - assert ( - choice.message.tool_calls - ), f"function_call is empty, it is not a function call: {choice.message}" - assert ( - choice.message.tool_calls[0].function.name == func_spec.name - ), "Function name mismatch" - try: - output = json.loads(choice.message.tool_calls[0].function.arguments) - except json.JSONDecodeError as e: - logger.error( - f"Error decoding the function arguments: {choice.message.tool_calls[0].function.arguments}" + tool_calls = getattr(choice.message, "tool_calls", None) + + if not tool_calls: + logger.warning( + f"No function call used despite function spec. Fallback to text. " + f"Message content: {choice.message.content}" + ) + output = choice.message.content + else: + first_call = tool_calls[0] + assert first_call.function.name == func_spec.name, ( + f"Function name mismatch: expected {func_spec.name}, " + f"got {first_call.function.name}" ) - raise e + try: + output = json.loads(first_call.function.arguments) + except json.JSONDecodeError as e: + logger.error( + f"Error decoding function arguments:\n{first_call.function.arguments}" + ) + raise e in_tokens = completion.usage.prompt_tokens out_tokens = completion.usage.completion_tokens