From c4b224373c9a7c267e11de9a79b1890c11f08812 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 22 Jan 2025 10:00:45 -0800 Subject: [PATCH] Fix deprecation warning, add test --- griptape/drivers/__init__.py | 45 ++++++++++++++++--- .../drivers/prompt/test_base_prompt_driver.py | 15 +++++++ 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 70b247a35..0f6aafaf4 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -1,4 +1,7 @@ -from griptape.utils.deprecation import deprecation_warn +from types import ModuleType +import warnings +import sys +from typing import Any from .prompt import BasePromptDriver from .prompt.openai import OpenAiChatPromptDriver @@ -123,11 +126,6 @@ from .assistant.griptape_cloud import GriptapeCloudAssistantDriver from .assistant.openai import OpenAiAssistantDriver -deprecation_warn( - "Importing from `griptape.drivers` is deprecated and will be removed in a future release. " - "Please import from the provider-specific package instead. " - "e.g., `from griptape.drivers import OpenAiChatPromptDriver` -> `from griptape.drivers.prompt.openai import OpenAiChatPromptDriver`" -) __all__ = [ "BasePromptDriver", @@ -236,3 +234,38 @@ "GriptapeCloudAssistantDriver", "OpenAiAssistantDriver", ] + + +class _DeprecationWarningModuleWrapper(ModuleType): + """Module wrapper that issues a deprecation warning when importing.""" + + __ignore_attrs__ = { + "__file__", + "__package__", + "__path__", + "__doc__", + "__all__", + "__name__", + "__loader__", + "__spec__", + } + + def __init__(self, real_module: Any) -> None: + self._real_module = real_module + + def __getattr__(self, name: str) -> Any: + if name in self.__ignore_attrs__: + return getattr(self._real_module, name) + + warnings.warn( + "Importing from `griptape.drivers` is deprecated and will be removed in a future release. " + "Please import from the provider-specific package instead.\n" + "e.g., `from griptape.drivers import OpenAiChatPromptDriver` -> `from griptape.drivers.prompt.openai import OpenAiChatPromptDriver`", + DeprecationWarning, + stacklevel=2, + ) + + return getattr(self._real_module, name) + + +sys.modules[__name__] = _DeprecationWarningModuleWrapper(sys.modules[__name__]) diff --git a/tests/unit/drivers/prompt/test_base_prompt_driver.py b/tests/unit/drivers/prompt/test_base_prompt_driver.py index 3ffcebce4..cb6741d23 100644 --- a/tests/unit/drivers/prompt/test_base_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_base_prompt_driver.py @@ -1,4 +1,7 @@ import json +import warnings + +import pytest from griptape.artifacts import ActionArtifact, ErrorArtifact, TextArtifact from griptape.common import Message, PromptStack @@ -139,3 +142,15 @@ def test_rule_structured_output_strategy_populated(self): assert "baz" in prompt_stack.messages[0].content[2].to_text() assert isinstance(output, TextArtifact) assert output.value == json.dumps({"baz": "foo"}) + + def test_deprecated_import(self): + with pytest.warns(DeprecationWarning): + from griptape.drivers import BasePromptDriver + + assert BasePromptDriver + + with warnings.catch_warnings(): + warnings.simplefilter("error") + from griptape.drivers.prompt.base_prompt_driver import BasePromptDriver + + assert BasePromptDriver