diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 229cf2ccd9..7319f13baa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,6 +27,17 @@ repos: - id: trailing-whitespace - id: end-of-file-fixer - id: no-commit-to-branch + - repo: local + hooks: + - id: build-setup-scripts + name: build setup scripts + entry: "scripts/pre-commit-build-setup-files.sh" + language: python + # language_version: python3.9 + types: [python] + require_serial: true + verbose: true + additional_dependencies: ['jinja2', 'toml', 'ruff'] - repo: local hooks: - id: lint diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index f92e27eca1..2c99f9a1a7 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -164,8 +164,6 @@ def __init__( code_execution_config.copy() if hasattr(code_execution_config, "copy") else code_execution_config ) - self._validate_name(name) - self._name = name # a dictionary of conversations, default value is list if chat_messages is None: self._oai_messages = defaultdict(list) @@ -191,6 +189,8 @@ def __init__( ) from e self._validate_llm_config(llm_config) + self._validate_name(name) + self._name = name if logging_enabled(): log_new_agent(self, locals()) @@ -286,6 +286,15 @@ def __init__( } def _validate_name(self, name: str) -> None: + if not self.llm_config or "config_list" not in self.llm_config or len(self.llm_config["config_list"]) == 0: + return + + config_list = self.llm_config.get("config_list") + # The validation is currently done only for openai endpoints + # (other ones do not have the issue with whitespace in the name) + if "api_type" in config_list[0] and config_list[0]["api_type"] != "openai": + return + # Validation for name using regex to detect any whitespace if re.search(r"\s", name): raise ValueError(f"The name of the agent cannot contain any whitespace. The name provided is: '{name}'") diff --git a/autogen/coding/func_with_reqs.py b/autogen/coding/func_with_reqs.py index 5fc373cb90..8d2c27c8ba 100644 --- a/autogen/coding/func_with_reqs.py +++ b/autogen/coding/func_with_reqs.py @@ -4,7 +4,6 @@ # # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT -from __future__ import annotations import functools import importlib @@ -20,7 +19,7 @@ P = ParamSpec("P") -def _to_code(func: FunctionWithRequirements[T, P] | Callable[P, T] | FunctionWithRequirementsStr) -> str: +def _to_code(func: Union["FunctionWithRequirements[T, P]", Callable[P, T], "FunctionWithRequirementsStr"]) -> str: if isinstance(func, FunctionWithRequirementsStr): return func.func @@ -40,7 +39,7 @@ class Alias: @dataclass class ImportFromModule: module: str - imports: list[str | Alias] + imports: list[Union[str, Alias]] Import = Union[str, ImportFromModule, Alias] @@ -53,7 +52,7 @@ def _import_to_str(im: Import) -> str: return f"import {im.name} as {im.alias}" else: - def to_str(i: str | Alias) -> str: + def to_str(i: Union[str, Alias]) -> str: if isinstance(i, str): return i else: @@ -123,7 +122,7 @@ class FunctionWithRequirements(Generic[T, P]): @classmethod def from_callable( cls, func: Callable[P, T], python_packages: list[str] = [], global_imports: list[Import] = [] - ) -> FunctionWithRequirements[T, P]: + ) -> "FunctionWithRequirements[T, P]": return cls(python_packages=python_packages, global_imports=global_imports, func=func) @staticmethod @@ -162,7 +161,7 @@ def wrapper(func: Callable[P, T]) -> FunctionWithRequirements[T, P]: def _build_python_functions_file( - funcs: list[FunctionWithRequirements[Any, P] | Callable[..., Any] | FunctionWithRequirementsStr], + funcs: list[Union[FunctionWithRequirements[Any, P], Callable[..., Any], FunctionWithRequirementsStr]], ) -> str: # First collect all global imports global_imports: set[str] = set() @@ -178,7 +177,7 @@ def _build_python_functions_file( return content -def to_stub(func: Callable[..., Any] | FunctionWithRequirementsStr) -> str: +def to_stub(func: Union[Callable[..., Any], FunctionWithRequirementsStr]) -> str: """Generate a stub for a function as a string Args: diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 75db1e5f94..e8ae0602e0 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -8,6 +8,7 @@ import inspect import logging +import re import sys import uuid import warnings @@ -289,6 +290,33 @@ def _format_content(content: str) -> str: for choice in choices ] + @staticmethod + def _is_agent_name_error_message(message: str) -> bool: + pattern = re.compile(r"Invalid 'messages\[\d+\]\.name': string does not match pattern.") + return True if pattern.match(message) else False + + @staticmethod + def _handle_openai_bad_request_error(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any): + try: + return func(*args, **kwargs) + except openai.BadRequestError as e: + response_json = e.response.json() + # Check if the error message is related to the agent name. If so, raise a ValueError with a more informative message. + if "error" in response_json and "message" in response_json["error"]: + if OpenAIClient._is_agent_name_error_message(response_json["error"]["message"]): + error_message = ( + f"This error typically occurs when the agent name contains invalid characters, such as spaces or special symbols.\n" + "Please ensure that your agent name follows the correct format and doesn't include any unsupported characters.\n" + "Check the agent name and try again.\n" + f"Here is the full BadRequestError from openai:\n{e.message}." + ) + raise ValueError(error_message) + + raise e + + return wrapper + def create(self, params: dict[str, Any]) -> ChatCompletion: """Create a completion for a given config using openai's client. @@ -313,6 +341,8 @@ def _create_or_parse(*args, **kwargs): else: completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined] create_or_parse = completions.create + # Wrap _create_or_parse with exception handling + create_or_parse = OpenAIClient._handle_openai_bad_request_error(create_or_parse) # needs to be updated when the o3 is released to generalize is_o1 = "model" in params and params["model"].startswith("o1") diff --git a/pyproject.toml b/pyproject.toml index 6466445b80..ee5472aae1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -180,15 +180,14 @@ bedrock = ["boto3>=1.34.149"] # test dependencies test = [ - "ipykernel", - "nbconvert", - "nbformat", - "pre-commit", - "pytest-cov>=5", - "pytest-asyncio", - "pytest>=8,<9", - "pandas", - "fastapi>=0.115.0,<1", + "ipykernel==6.29.5", + "nbconvert==7.16.5", + "nbformat==5.10.4", + "pytest-cov==6.0.0", + "pytest-asyncio==0.25.2", + "pytest==8.3.4", + "pandas==2.2.3", + "fastapi==0.115.6", ] # docs dependencies @@ -212,8 +211,9 @@ lint = [ ] dev = [ + "toml==0.10.2", "pyautogen[lint,test,types,docs]", - "pre-commit==4.0.1", + "pre-commit==4.1.0", "detect-secrets==1.5.0", "uv==0.5.21", ] @@ -271,7 +271,7 @@ fix = true line-length = 120 target-version = 'py39' #include = ["autogen", "test", "docs"] -#exclude = [] +exclude = ["setup_*.py"] [tool.ruff.lint] # Enable Pyflakes `E` and `F` codes by default. diff --git a/scripts/build-setup-files.py b/scripts/build-setup-files.py new file mode 100755 index 0000000000..f153681530 --- /dev/null +++ b/scripts/build-setup-files.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + + +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import toml +from jinja2 import Template + + +def get_optional_dependencies(pyproject_path: str) -> dict: + with open(pyproject_path, "r") as f: + pyproject_data = toml.load(f) + + optional_dependencies = pyproject_data.get("project", {}).get("optional-dependencies", {}) + return optional_dependencies + + +# Example usage +pyproject_path = Path(__file__).parent.joinpath("../pyproject.toml") +optional_dependencies = get_optional_dependencies(pyproject_path) +optional_groups = [group for group in optional_dependencies.keys()] + +# for group, dependencies in optional_dependencies.items(): +# print(f"Group: {group}") +# for dependency in dependencies: +# print(f" - {dependency}") + +template_path = Path(__file__).parents[1].joinpath("setup.jinja") +assert template_path.exists() + +with template_path.open("r") as f: + template_str = f.read() + +if len(template_str) < 100: + raise ValueError("Template string is too short") + +# Create a Jinja2 template object +template = Template(template_str) + +for name in ["ag2", "autogen"]: + file_name = f"setup_{name}.py" + file_path = Path(__file__).parents[1].joinpath(file_name) + # Render the template with the optional dependencies + rendered_setup_py = template.render(optional_dependencies=optional_dependencies, name=name) + + # Write the rendered setup.py to a file + with file_path.open("w") as setup_file: + setup_file.write(rendered_setup_py) diff --git a/scripts/pre-commit-build-setup-files.sh b/scripts/pre-commit-build-setup-files.sh new file mode 100755 index 0000000000..ebb6349205 --- /dev/null +++ b/scripts/pre-commit-build-setup-files.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +# taken from: https://jaredkhan.com/blog/mypy-pre-commit + +# A script for running mypy, +# with all its dependencies installed. + +set -o errexit + +# Change directory to the project root directory. +cd "$(dirname "$0")"/.. + +./scripts/build-setup-files.py +ruff check -s setup_*.py diff --git a/setup.jinja b/setup.jinja new file mode 100644 index 0000000000..5a2f479b66 --- /dev/null +++ b/setup.jinja @@ -0,0 +1,45 @@ +# Copyright (c) 2023 - 2024, Owners of https://github.com/ag2ai +# +# SPDX-License-Identifier: Apache-2.0 + +# this file is autogenerated, please do not edit it directly +# instead, edit the corresponding setup.jinja file and run the ./scripts/build-setup-files.py script + +import os + +import setuptools + +here = os.path.abspath(os.path.dirname(__file__)) + +with open("README.md", "r", encoding="UTF-8") as fh: + long_description = fh.read() + +# Get the code version +version = {} +with open(os.path.join(here, "autogen/version.py")) as fp: + exec(fp.read(), version) +__version__ = version["__version__"] + +setuptools.setup( + name="{{ name }}", + version=__version__, + description="Alias package for pyautogen", + long_description=long_description, + long_description_content_type="text/markdown", + install_requires=["pyautogen==" + __version__], + extras_require={ + {% for group, packages in optional_dependencies.items() -%} + "{{ group }}": ["pyautogen[{{ group }}]==" + __version__], + {% endfor %} + }, + url="https://github.com/ag2ai/ag2", + author="Chi Wang & Qingyun Wu", + author_email="support@ag2.ai", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + ], + license="Apache Software License 2.0", + python_requires=">=3.9,<3.14", +) diff --git a/setup_ag2.py b/setup_ag2.py index a1336a05b9..4860c9d497 100644 --- a/setup_ag2.py +++ b/setup_ag2.py @@ -2,6 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 +# this file is autogenerated, please do not edit it directly +# instead, edit the corresponding setup.jinja file and run the ./scripts/build-setup-files.py script + import os import setuptools @@ -25,15 +28,22 @@ long_description_content_type="text/markdown", install_requires=["pyautogen==" + __version__], extras_require={ - "test": ["pyautogen[test]==" + __version__], - "blendsearch": ["pyautogen[blendsearch]==" + __version__], - "mathchat": ["pyautogen[mathchat]==" + __version__], + "flaml": ["pyautogen[flaml]==" + __version__], + "jupyter-executor": ["pyautogen[jupyter-executor]==" + __version__], "retrievechat": ["pyautogen[retrievechat]==" + __version__], "retrievechat-pgvector": ["pyautogen[retrievechat-pgvector]==" + __version__], "retrievechat-mongodb": ["pyautogen[retrievechat-mongodb]==" + __version__], "retrievechat-qdrant": ["pyautogen[retrievechat-qdrant]==" + __version__], "graph-rag-falkor-db": ["pyautogen[graph-rag-falkor-db]==" + __version__], + "neo4j": ["pyautogen[neo4j]==" + __version__], + "twilio": ["pyautogen[twilio]==" + __version__], + "interop-crewai": ["pyautogen[interop-crewai]==" + __version__], + "interop-langchain": ["pyautogen[interop-langchain]==" + __version__], + "interop-pydantic-ai": ["pyautogen[interop-pydantic-ai]==" + __version__], + "interop": ["pyautogen[interop]==" + __version__], "autobuild": ["pyautogen[autobuild]==" + __version__], + "blendsearch": ["pyautogen[blendsearch]==" + __version__], + "mathchat": ["pyautogen[mathchat]==" + __version__], "captainagent": ["pyautogen[captainagent]==" + __version__], "teachable": ["pyautogen[teachable]==" + __version__], "lmm": ["pyautogen[lmm]==" + __version__], @@ -44,8 +54,6 @@ "redis": ["pyautogen[redis]==" + __version__], "cosmosdb": ["pyautogen[cosmosdb]==" + __version__], "websockets": ["pyautogen[websockets]==" + __version__], - "jupyter-executor": ["pyautogen[jupyter-executor]==" + __version__], - "types": ["pyautogen[types]==" + __version__], "long-context": ["pyautogen[long-context]==" + __version__], "anthropic": ["pyautogen[anthropic]==" + __version__], "cerebras": ["pyautogen[cerebras]==" + __version__], @@ -54,13 +62,12 @@ "cohere": ["pyautogen[cohere]==" + __version__], "ollama": ["pyautogen[ollama]==" + __version__], "bedrock": ["pyautogen[bedrock]==" + __version__], - "twilio": ["pyautogen[twilio]==" + __version__], - "interop-crewai": ["pyautogen[interop-crewai]==" + __version__], - "interop-langchain": ["pyautogen[interop-langchain]==" + __version__], - "interop-pydantic-ai": ["pyautogen[interop-pydantic-ai]==" + __version__], - "interop": ["pyautogen[interop]==" + __version__], - "neo4j": ["pyautogen[neo4j]==" + __version__], + "test": ["pyautogen[test]==" + __version__], "docs": ["pyautogen[docs]==" + __version__], + "types": ["pyautogen[types]==" + __version__], + "lint": ["pyautogen[lint]==" + __version__], + "dev": ["pyautogen[dev]==" + __version__], + }, url="https://github.com/ag2ai/ag2", author="Chi Wang & Qingyun Wu", diff --git a/setup_autogen.py b/setup_autogen.py index 3bad453eba..44f989f3fc 100644 --- a/setup_autogen.py +++ b/setup_autogen.py @@ -2,6 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 +# this file is autogenerated, please do not edit it directly +# instead, edit the corresponding setup.jinja file and run the ./scripts/build-setup-files.py script + import os import setuptools @@ -25,15 +28,22 @@ long_description_content_type="text/markdown", install_requires=["pyautogen==" + __version__], extras_require={ - "test": ["pyautogen[test]==" + __version__], - "blendsearch": ["pyautogen[blendsearch]==" + __version__], - "mathchat": ["pyautogen[mathchat]==" + __version__], + "flaml": ["pyautogen[flaml]==" + __version__], + "jupyter-executor": ["pyautogen[jupyter-executor]==" + __version__], "retrievechat": ["pyautogen[retrievechat]==" + __version__], "retrievechat-pgvector": ["pyautogen[retrievechat-pgvector]==" + __version__], "retrievechat-mongodb": ["pyautogen[retrievechat-mongodb]==" + __version__], "retrievechat-qdrant": ["pyautogen[retrievechat-qdrant]==" + __version__], "graph-rag-falkor-db": ["pyautogen[graph-rag-falkor-db]==" + __version__], + "neo4j": ["pyautogen[neo4j]==" + __version__], + "twilio": ["pyautogen[twilio]==" + __version__], + "interop-crewai": ["pyautogen[interop-crewai]==" + __version__], + "interop-langchain": ["pyautogen[interop-langchain]==" + __version__], + "interop-pydantic-ai": ["pyautogen[interop-pydantic-ai]==" + __version__], + "interop": ["pyautogen[interop]==" + __version__], "autobuild": ["pyautogen[autobuild]==" + __version__], + "blendsearch": ["pyautogen[blendsearch]==" + __version__], + "mathchat": ["pyautogen[mathchat]==" + __version__], "captainagent": ["pyautogen[captainagent]==" + __version__], "teachable": ["pyautogen[teachable]==" + __version__], "lmm": ["pyautogen[lmm]==" + __version__], @@ -44,8 +54,6 @@ "redis": ["pyautogen[redis]==" + __version__], "cosmosdb": ["pyautogen[cosmosdb]==" + __version__], "websockets": ["pyautogen[websockets]==" + __version__], - "jupyter-executor": ["pyautogen[jupyter-executor]==" + __version__], - "types": ["pyautogen[types]==" + __version__], "long-context": ["pyautogen[long-context]==" + __version__], "anthropic": ["pyautogen[anthropic]==" + __version__], "cerebras": ["pyautogen[cerebras]==" + __version__], @@ -54,13 +62,12 @@ "cohere": ["pyautogen[cohere]==" + __version__], "ollama": ["pyautogen[ollama]==" + __version__], "bedrock": ["pyautogen[bedrock]==" + __version__], - "twilio": ["pyautogen[twilio]==" + __version__], - "interop-crewai": ["pyautogen[interop-crewai]==" + __version__], - "interop-langchain": ["pyautogen[interop-langchain]==" + __version__], - "interop-pydantic-ai": ["pyautogen[interop-pydantic-ai]==" + __version__], - "interop": ["pyautogen[interop]==" + __version__], - "neo4j": ["pyautogen[neo4j]==" + __version__], + "test": ["pyautogen[test]==" + __version__], "docs": ["pyautogen[docs]==" + __version__], + "types": ["pyautogen[types]==" + __version__], + "lint": ["pyautogen[lint]==" + __version__], + "dev": ["pyautogen[dev]==" + __version__], + }, url="https://github.com/ag2ai/ag2", author="Chi Wang & Qingyun Wu", diff --git a/test/agentchat/test_async.py b/test/agentchat/test_async.py index af2f2f25c2..6d971096d1 100755 --- a/test/agentchat/test_async.py +++ b/test/agentchat/test_async.py @@ -89,10 +89,11 @@ async def _test_async_groupchat(credentials: Credentials): @pytest.mark.parametrize("credentials_from_test_param", credentials_all_llms, indirect=True) -def test_async_groupchat( +@pytest.mark.asyncio +async def test_async_groupchat( credentials_from_test_param: Credentials, ) -> None: - _test_async_groupchat(credentials_from_test_param) + await _test_async_groupchat(credentials_from_test_param) async def _test_stream(credentials: Credentials): @@ -162,7 +163,8 @@ async def add_data_reply(recipient, messages, sender, config): @pytest.mark.parametrize("credentials_from_test_param", credentials_all_llms, indirect=True) -def test_stream( +@pytest.mark.asyncio +async def test_stream( credentials_from_test_param: Credentials, ) -> None: - _test_stream(credentials_from_test_param) + await _test_stream(credentials_from_test_param) diff --git a/test/agentchat/test_async_get_human_input.py b/test/agentchat/test_async_get_human_input.py index 1eb80b6681..520d49c946 100755 --- a/test/agentchat/test_async_get_human_input.py +++ b/test/agentchat/test_async_get_human_input.py @@ -41,10 +41,11 @@ async def _test_async_get_human_input(credentials: Credentials) -> None: @pytest.mark.parametrize("credentials_from_test_param", credentials_all_llms, indirect=True) -def test_async_get_human_input( +@pytest.mark.asyncio +async def test_async_get_human_input( credentials_from_test_param: Credentials, ) -> None: - _test_async_get_human_input(credentials_from_test_param) + await _test_async_get_human_input(credentials_from_test_param) async def _test_async_max_turn(credentials: Credentials): @@ -76,7 +77,8 @@ async def _test_async_max_turn(credentials: Credentials): @pytest.mark.parametrize("credentials_from_test_param", credentials_all_llms, indirect=True) -def test_async_max_turn( +@pytest.mark.asyncio +async def test_async_max_turn( credentials_from_test_param: Credentials, ) -> None: - _test_async_max_turn(credentials_from_test_param) + await _test_async_max_turn(credentials_from_test_param) diff --git a/test/agentchat/test_conversable_agent.py b/test/agentchat/test_conversable_agent.py index 2430df22d8..4be2fa0a0e 100755 --- a/test/agentchat/test_conversable_agent.py +++ b/test/agentchat/test_conversable_agent.py @@ -12,7 +12,7 @@ import os import time import unittest -from typing import Annotated, Any, Callable, Literal +from typing import Annotated, Any, Callable, Literal, Optional from unittest.mock import MagicMock import pytest @@ -40,12 +40,23 @@ def conversable_agent(): @pytest.mark.parametrize("name", ["agent name", "agent_name ", " agent\nname", " agent\tname"]) -def test_conversable_agent_name_with_white_space_raises_error(name: str) -> None: +def test_conversable_agent_name_with_white_space( + name: str, + mock_credentials: Credentials, +) -> None: + agent = ConversableAgent(name=name) + assert agent.name == name + + llm_config = mock_credentials.llm_config with pytest.raises( ValueError, match=f"The name of the agent cannot contain any whitespace. The name provided is: '{name}'", ): - ConversableAgent(name=name) + ConversableAgent(name=name, llm_config=llm_config) + + llm_config["config_list"][0]["api_type"] = "something-else" + agent = ConversableAgent(name=name, llm_config=llm_config) + assert agent.name == name def test_sync_trigger(): @@ -1045,10 +1056,11 @@ def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."] @pytest.mark.parametrize("credentials_from_test_param", credentials_all_llms, indirect=True) -def test_function_registration_e2e_async( +@pytest.mark.asyncio +async def test_function_registration_e2e_async( credentials_from_test_param: Credentials, ) -> None: - _test_function_registration_e2e_async(credentials_from_test_param) + await _test_function_registration_e2e_async(credentials_from_test_param) @pytest.mark.openai @@ -1481,6 +1493,34 @@ def test_handle_carryover(): assert proc_content_empty_carryover == content, "Incorrect carryover processing" +@pytest.mark.parametrize("credentials_from_test_param", credentials_all_llms, indirect=True) +def test_conversable_agent_with_whitespaces_in_name_end2end( + credentials_from_test_param: Credentials, + request: pytest.FixtureRequest, +) -> None: + agent = ConversableAgent( + name="first_agent", + llm_config=credentials_from_test_param.llm_config, + ) + + user_proxy = UserProxyAgent( + name="user proxy", + human_input_mode="NEVER", + ) + + # Get the parameter name request node + current_llm = request.node.callspec.id + if "gpt_4" in current_llm: + with pytest.raises( + ValueError, + match="This error typically occurs when the agent name contains invalid characters, such as spaces or special symbols.", + ): + user_proxy.initiate_chat(agent, message="Hello, how are you?", max_turns=2) + # anthropic and gemini will not raise an error if agent name contains whitespaces + else: + user_proxy.initiate_chat(agent, message="Hello, how are you?", max_turns=2) + + @pytest.mark.openai def test_context_variables(): # Test initialization with context_variables @@ -1542,6 +1582,32 @@ def test_context_variables(): assert agent._context_variables == expected_final_context +@pytest.mark.skip(reason="This test is failing. We need to investigate the issue.") +@pytest.mark.gemini +def test_gemini_with_tools_parameters_set_to_is_annotated_with_none_as_default_value( + credentials_gemini_pro: Credentials, +) -> None: + agent = ConversableAgent(name="agent", llm_config=credentials_gemini_pro.llm_config) + + user_proxy = UserProxyAgent( + name="user_proxy_1", + human_input_mode="NEVER", + ) + + mock = MagicMock() + + @user_proxy.register_for_execution() + @agent.register_for_llm(description="Login function") + def login( + additional_notes: Annotated[Optional[str], "Additional notes"] = None, + ) -> str: + return "Login successful." + + user_proxy.initiate_chat(agent, message="Please login", max_turns=2) + + mock.assert_called_once() + + if __name__ == "__main__": # test_trigger() # test_context() diff --git a/test/agentchat/test_dependancy_injection.py b/test/agentchat/test_dependancy_injection.py index 9ffd49e5a0..00fdc7084b 100644 --- a/test/agentchat/test_dependancy_injection.py +++ b/test/agentchat/test_dependancy_injection.py @@ -200,7 +200,7 @@ async def test_register_tools( assert actual == expected - async def _test_end2end(self, credentials, is_async: bool) -> None: + async def _test_end2end(self, credentials: Credentials, is_async: bool) -> None: class UserContext(BaseContext, BaseModel): username: str password: str @@ -243,7 +243,6 @@ async def login( @agent.register_for_llm(description="Login function") def login( user: Annotated[UserContext, Depends(user)], - additional_notes: Annotated[Optional[str], "Additional notes"] = None, ) -> str: return _login(user) @@ -256,9 +255,10 @@ def login( @pytest.mark.parametrize("credentials_from_test_param", credentials_all_llms, indirect=True) @pytest.mark.parametrize("is_async", [False, True]) - def test_end2end( + @pytest.mark.asyncio + async def test_end2end( self, credentials_from_test_param: Credentials, is_async: bool, ) -> None: - self._test_end2end(credentials_from_test_param, is_async) + await self._test_end2end(credentials_from_test_param, is_async) diff --git a/test/oai/test_client.py b/test/oai/test_client.py index 5f8217c18f..774be3c335 100755 --- a/test/oai/test_client.py +++ b/test/oai/test_client.py @@ -10,6 +10,7 @@ import shutil import time from collections.abc import Generator +from unittest.mock import MagicMock import pytest @@ -290,6 +291,57 @@ def test_cache(credentials_gpt_4o_mini: Credentials): assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED))) +class TestOpenAIClientBadRequestsError: + def test_is_agent_name_error_message(self) -> None: + assert OpenAIClient._is_agent_name_error_message("Invalid 'messages[0].something") is False + for i in range(5): + error_message = f"Invalid 'messages[{i}].name': string does not match pattern. Expected a string that matches the pattern ..." + assert OpenAIClient._is_agent_name_error_message(error_message) is True + + @pytest.mark.parametrize( + "error_message, raise_new_error", + [ + ( + "Invalid 'messages[0].name': string does not match pattern. Expected a string that matches the pattern ...", + True, + ), + ( + "Invalid 'messages[1].name': string does not match pattern. Expected a string that matches the pattern ...", + True, + ), + ( + "Invalid 'messages[0].something': string does not match pattern. Expected a string that matches the pattern ...", + False, + ), + ], + ) + def test_handle_openai_bad_request_error(self, error_message: str, raise_new_error: bool) -> None: + def raise_bad_request_error(error_message: str) -> None: + mock_response = MagicMock() + mock_response.json.return_value = { + "error": { + "message": error_message, + } + } + body = {"error": {"message": "Bad Request error occurred"}} + raise openai.BadRequestError("Bad Request", response=mock_response, body=body) + + # Function raises BadRequestError + with pytest.raises(openai.BadRequestError): + raise_bad_request_error(error_message=error_message) + + wrapped_raise_bad_request_error = OpenAIClient._handle_openai_bad_request_error(raise_bad_request_error) + if raise_new_error: + with pytest.raises( + ValueError, + match="This error typically occurs when the agent name contains invalid characters, such as spaces or special symbols.", + ): + wrapped_raise_bad_request_error(error_message=error_message) + else: + with pytest.raises(openai.BadRequestError): + wrapped_raise_bad_request_error(error_message=error_message) + + class TestO1: @pytest.fixture def mock_oai_client(self, mock_credentials: Credentials) -> OpenAIClient: