Skip to content

Commit

Permalink
Add requires_optional_import decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaranvpl committed Jan 17, 2025
1 parent c20b56f commit c2bc4c5
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 3 deletions.
24 changes: 23 additions & 1 deletion autogen/exception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
from typing import Any
from typing import Any, Iterable, Union

__all__ = [
"AgentNameConflict",
"AutogenImportError",
"InvalidCarryOverType",
"ModelToolNotSupportedError",
"NoEligibleSpeaker",
"SenderRequired",
"UndefinedNextAgent",
Expand Down Expand Up @@ -65,3 +67,23 @@ def __init__(
):
self.message = f"Tools are not supported with {model} models. Refer to the documentation at https://platform.openai.com/docs/guides/reasoning#limitations"
super().__init__(self.message)


class AutogenImportError(ImportError):
"""
Exception raised when a required module is not found.
"""

def __init__(
self,
missing_modules: Union[str, Iterable[str]],
dep_target: str,
):
if isinstance(missing_modules, str):
missing_modules = [missing_modules]

missing_modules_str = ", ".join(f"'{x}'" for x in missing_modules)
self.message = (
f"Missing imported module {missing_modules_str}, please install it using 'pip install ag2[{dep_target}]'"
)
super().__init__(self.message)
55 changes: 54 additions & 1 deletion autogen/handle_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
#
# SPDX-License-Identifier: Apache-2.0

import importlib
import sys
from contextlib import contextmanager
from typing import Generator
from functools import wraps
from typing import Any, Generator, Iterable, Union

from autogen.exception_utils import AutogenImportError


@contextmanager
Expand All @@ -16,3 +21,51 @@ def check_for_missing_imports() -> Generator[None, None, None]:
yield
except ImportError:
pass # Ignore ImportErrors during this context


class DummyModule:
"""A dummy module that raises ImportError when any attribute is accessed"""

def __init__(self, name: str, dep_target: str):
self._name = name
self._dep_target = dep_target

def __getattr__(self, attr: str) -> Any:
raise AutogenImportError(missing_modules=self._name, dep_target=self._dep_target)


def requires_optional_import(modules: Union[str, Iterable[str]], dep_target: str):
"""Decorator to handle optional module dependencies
Args:
modules: Module name or list of module names required
dep_target: Target name for pip installation (e.g. 'test' in pip install ag2[test])
"""
if isinstance(modules, str):
modules = [modules]

def decorator(cls):
# Check if all required modules are available
missing_modules = []
dummy_modules = {}

for module_name in modules:
try:
importlib.import_module(module_name)
except ImportError:
missing_modules.append(module_name)
# Create dummy module
dummy_module = DummyModule(module_name, dep_target)
dummy_modules[module_name] = dummy_module
sys.modules[module_name] = dummy_module

if missing_modules:
# Replace real class with dummy that raises ImportError
@wraps(cls)
def dummy_class(*args, **kwargs):
raise AutogenImportError(missing_modules=missing_modules, dep_target=dep_target)

return dummy_class
return cls

return decorator
53 changes: 52 additions & 1 deletion test/test_handle_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,61 @@
#
# SPDX-License-Identifier: Apache-2.0

from autogen.handle_imports import check_for_missing_imports
import pytest

from autogen.exception_utils import AutogenImportError
from autogen.handle_imports import check_for_missing_imports, requires_optional_import


def test_check_for_missing_imports():
with check_for_missing_imports():
pass # Safe to attempt, even if it fails
assert True


class TestRequiresOptionalImport:
def test_with_class(
self,
):
@requires_optional_import("some_optional_module", "test")
class DummyClass:
def __init__(self):
o = some_optional_module.SomeClass() # Raises MissingImportError if module is missing

assert DummyClass is not None
with pytest.raises(AutogenImportError) as e:
DummyClass()

assert (
str(e.value)
== "Missing imported module 'some_optional_module', please install it using 'pip install ag2[test]'"
)

def test_with_function(self):
@requires_optional_import("some_other_optional_module", "test")
def dummy_function():
o = some_other_optional_module.SomeOtherClass()

assert dummy_function is not None
with pytest.raises(AutogenImportError) as e:
dummy_function()

assert (
str(e.value)
== "Missing imported module 'some_other_optional_module', please install it using 'pip install ag2[test]'"
)

def test_with_multiple_modules(self):
@requires_optional_import(["module1", "module2"], "test")
def dummy_function():
o = module1.SomeClass()
o2 = module2.SomeOtherClass()

assert dummy_function is not None
with pytest.raises(AutogenImportError) as e:
dummy_function()

assert (
str(e.value)
== "Missing imported module 'module1', 'module2', please install it using 'pip install ag2[test]'"
)

0 comments on commit c2bc4c5

Please sign in to comment.