Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLAML moved to optional dependancies #598

Merged
merged 2 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions autogen/oai/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@
null_handler = logging.NullHandler()
flaml_logger.addHandler(null_handler)

from flaml import BlendSearch, tune
from flaml.tune.space import is_constant
from ..import_utils import optional_import_block, require_optional_import

with optional_import_block() as result:
from flaml import BlendSearch, tune
from flaml.tune.space import is_constant

FLAML_INSTALLED = result.is_successful

# Restore logging by removing the NullHandler
flaml_logger.removeHandler(null_handler)
Expand Down Expand Up @@ -111,26 +116,30 @@ class Completion(OpenAICompletion):
"gpt-4-32k-0613": (0.06, 0.12),
}

default_search_space = {
"model": tune.choice(
[
"text-ada-001",
"text-babbage-001",
"text-davinci-003",
"gpt-3.5-turbo",
"gpt-4",
]
),
"temperature_or_top_p": tune.choice(
[
{"temperature": tune.uniform(0, 2)},
{"top_p": tune.uniform(0, 1)},
]
),
"max_tokens": tune.lograndint(50, 1000),
"n": tune.randint(1, 100),
"prompt": "{prompt}",
}
default_search_space = (
{
"model": tune.choice(
[
"text-ada-001",
"text-babbage-001",
"text-davinci-003",
"gpt-3.5-turbo",
"gpt-4",
]
),
"temperature_or_top_p": tune.choice(
[
{"temperature": tune.uniform(0, 2)},
{"top_p": tune.uniform(0, 1)},
]
),
"max_tokens": tune.lograndint(50, 1000),
"n": tune.randint(1, 100),
"prompt": "{prompt}",
}
if FLAML_INSTALLED
else {}
)

cache_seed = 41
cache_path = f".cache/{cache_seed}"
Expand Down Expand Up @@ -525,6 +534,7 @@ def _eval(cls, config: dict, prune=True, eval_only=False):
return result

@classmethod
@require_optional_import("flaml", "flaml")
def tune(
cls,
data: list[dict],
Expand Down Expand Up @@ -1213,5 +1223,5 @@ class ChatCompletion(Completion):
"""`(openai<1)` A class for OpenAI API ChatCompletion. Share the same API as Completion."""

default_search_space = Completion.default_search_space.copy()
default_search_space["model"] = tune.choice(["gpt-3.5-turbo", "gpt-4"])
default_search_space["model"] = tune.choice(["gpt-3.5-turbo", "gpt-4"]) if FLAML_INSTALLED else {}
openai_completion_class = not ERROR and openai.ChatCompletion
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,9 @@ dependencies = [
"openai>=1.58",
"diskcache",
"termcolor",
"flaml",
# numpy is installed by flaml, but we want to pin the version to below 2.x (see https://github.com/microsoft/autogen/issues/1960)
"numpy>=2.1; python_version>='3.13'", # numpy 2.1+ required for Python 3.13
"numpy>=1.24.0,<2.0.0; python_version<'3.13'", # numpy 1.24+ for older Python versions
"python-dotenv",
"tiktoken",
"numpy",
# Disallowing 2.6.0 can be removed when this is fixed https://github.com/pydantic/pydantic/issues/8705
"pydantic>=2.6.1,<3",
"docker",
Expand All @@ -72,6 +69,13 @@ dependencies = [

[project.optional-dependencies]

flaml = [
"flaml",
# numpy is installed by flaml, but we want to pin the version to below 2.x (see https://github.com/microsoft/autogen/issues/1960)
"numpy>=2.1; python_version>='3.13'", # numpy 2.1+ required for Python 3.13
"numpy>=1.24.0,<2.0.0; python_version<'3.13'", # numpy 1.24+ for older Python versions
]

# public distributions
jupyter-executor = [
"jupyter-kernel-gateway",
Expand Down
24 changes: 6 additions & 18 deletions test/interop/langchain/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@
#
# SPDX-License-Identifier: Apache-2.0

import sys
from unittest.mock import MagicMock

import pytest
from langchain.tools import tool as langchain_tool
from pydantic import BaseModel, Field

from autogen import AssistantAgent, UserProxyAgent
from autogen.import_utils import optional_import_block, skip_on_missing_imports
from autogen.interop import Interoperable
from autogen.interop.langchain import LangChainInteroperability

from ...conftest import Credentials

with optional_import_block():
from langchain.tools import tool as langchain_tool


# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@skip_on_missing_imports("langchain", "interop-langchain")
class TestLangChainInteroperability:
@pytest.fixture(autouse=True)
def setup(self) -> None:
Expand Down Expand Up @@ -76,10 +76,7 @@ def test_get_unsupported_reason(self) -> None:
assert LangChainInteroperability.get_unsupported_reason() is None


# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@skip_on_missing_imports("langchain", "interop-langchain")
class TestLangChainInteroperabilityWithoutPydanticInput:
@pytest.fixture(autouse=True)
def setup(self) -> None:
Expand Down Expand Up @@ -129,12 +126,3 @@ def test_with_llm(self, credentials_gpt_4o: Credentials) -> None:
user_proxy.initiate_chat(recipient=chatbot, message="search for LangChain, Use max 100 characters", max_turns=5)

self.mock.assert_called()


@pytest.mark.skipif(sys.version_info >= (3, 9), reason="LangChain Interoperability is supported")
class TestLangChainInteroperabilityIfNotSupported:
def test_get_unsupported_reason(self) -> None:
assert (
LangChainInteroperability.get_unsupported_reason()
== "This submodule is only supported for Python versions 3.9 and above"
)
30 changes: 8 additions & 22 deletions test/interop/pydantic_ai/test_pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,25 @@
# SPDX-License-Identifier: Apache-2.0

import random
import sys
from inspect import signature
from typing import Any, Optional

import pytest
from pydantic import BaseModel
from pydantic_ai import RunContext
from pydantic_ai.tools import Tool as PydanticAITool

from autogen import AssistantAgent, UserProxyAgent
from autogen.import_utils import optional_import_block, skip_on_missing_imports
from autogen.interop import Interoperable
from autogen.interop.pydantic_ai import PydanticAIInteroperability

from ...conftest import Credentials

with optional_import_block():
from pydantic_ai import RunContext
from pydantic_ai.tools import Tool as PydanticAITool

# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)

@skip_on_missing_imports("pydantic_ai", "interop-pydantic-ai")
class TestPydanticAIInteroperabilityWithotContext:
@pytest.fixture(autouse=True)
def setup(self) -> None:
Expand Down Expand Up @@ -66,9 +65,7 @@ def test_with_llm(self, credentials_gpt_4o: Credentials) -> None:
assert False, "No tool response found in chat messages"


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@skip_on_missing_imports("pydantic_ai", "interop-pydantic-ai")
class TestPydanticAIInteroperabilityDependencyInjection:
def test_dependency_injection(self) -> None:
def f(
Expand Down Expand Up @@ -127,9 +124,7 @@ def f(
assert pydantic_ai_tool.current_retry == 3


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@skip_on_missing_imports("pydantic_ai", "interop-pydantic-ai")
class TestPydanticAIInteroperabilityWithContext:
@pytest.fixture(autouse=True)
def setup(self) -> None:
Expand Down Expand Up @@ -210,12 +205,3 @@ def test_with_llm(self, credentials_gpt_4o: Credentials) -> None:
return

assert False, "No tool response found in chat messages"


@pytest.mark.skipif(sys.version_info >= (3, 9), reason="LangChain Interoperability is supported")
class TestPydanticAIInteroperabilityIfNotSupported:
def test_get_unsupported_reason(self) -> None:
assert (
PydanticAIInteroperability.get_unsupported_reason()
== "This submodule is only supported for Python versions 3.9 and above"
)
14 changes: 5 additions & 9 deletions test/interop/pydantic_ai/test_pydantic_ai_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0

import sys

import pytest
from pydantic_ai.tools import Tool as PydanticAITool

from autogen import AssistantAgent
from autogen.import_utils import optional_import_block, skip_on_missing_imports
from autogen.interop.pydantic_ai.pydantic_ai_tool import PydanticAITool as AG2PydanticAITool

with optional_import_block():
from pydantic_ai.tools import Tool as PydanticAITool


# skip if python version is not >= 3.9
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@skip_on_missing_imports("pydantic_ai", "interop-pydantic-ai")
class TestPydanticAITool:
def test_register_for_llm(self) -> None:
def foobar(a: int, b: str, c: dict[str, list[float]]) -> str: # type: ignore[misc]
Expand Down
19 changes: 11 additions & 8 deletions test/interop/test_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@

import pytest

from autogen.import_utils import optional_import_block, skip_on_missing_imports
from autogen.interop import Interoperability

from ..conftest import MOCK_OPEN_AI_API_KEY

with optional_import_block():
from crewai_tools import FileReadTool

with optional_import_block():
pass # type: ignore[import]


class TestInteroperability:
@skip_on_missing_imports(["crewai_tools", "langchain", "pydantic_ai"], "interop")
def test_supported_types(self) -> None:
actual = Interoperability.get_supported_types()

Expand All @@ -28,12 +36,9 @@ def test_supported_types(self) -> None:
if sys.version_info >= (3, 13):
assert actual == ["langchain", "pydanticai"]

@pytest.mark.skipif(
sys.version_info < (3, 10) or sys.version_info >= (3, 13), reason="Only Python 3.10, 3.11, 3.12 are supported"
)
@skip_on_missing_imports("crewai_tools", "interop-crewai")
def test_crewai(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
from crewai_tools import FileReadTool

crewai_tool = FileReadTool()

Expand All @@ -56,9 +61,7 @@ def test_crewai(self, monkeypatch: pytest.MonkeyPatch) -> None:

assert tool.func(args=args) == "Hello, World!"

@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Only Python 3.9 and above are supported for LangchainInteroperability"
)
@pytest.mark.skip(reason="This test is not yet implemented")
@pytest.mark.skip("This test is not yet implemented")
@skip_on_missing_imports("langchain", "interop-langchain")
def test_langchain(self) -> None:
raise NotImplementedError("This test is not yet implemented")
Loading