Skip to content

Commit

Permalink
Fix deprecation warning, add test
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Jan 22, 2025
1 parent 8c01377 commit c4b2243
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
45 changes: 39 additions & 6 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from griptape.utils.deprecation import deprecation_warn
from types import ModuleType
import warnings
import sys
from typing import Any

from .prompt import BasePromptDriver
from .prompt.openai import OpenAiChatPromptDriver
Expand Down Expand Up @@ -123,11 +126,6 @@
from .assistant.griptape_cloud import GriptapeCloudAssistantDriver
from .assistant.openai import OpenAiAssistantDriver

deprecation_warn(
"Importing from `griptape.drivers` is deprecated and will be removed in a future release. "
"Please import from the provider-specific package instead. "
"e.g., `from griptape.drivers import OpenAiChatPromptDriver` -> `from griptape.drivers.prompt.openai import OpenAiChatPromptDriver`"
)

__all__ = [
"BasePromptDriver",
Expand Down Expand Up @@ -236,3 +234,38 @@
"GriptapeCloudAssistantDriver",
"OpenAiAssistantDriver",
]


class _DeprecationWarningModuleWrapper(ModuleType):
"""Module wrapper that issues a deprecation warning when importing."""

__ignore_attrs__ = {
"__file__",
"__package__",
"__path__",
"__doc__",
"__all__",
"__name__",
"__loader__",
"__spec__",
}

def __init__(self, real_module: Any) -> None:
self._real_module = real_module

def __getattr__(self, name: str) -> Any:
if name in self.__ignore_attrs__:
return getattr(self._real_module, name)

warnings.warn(
"Importing from `griptape.drivers` is deprecated and will be removed in a future release. "
"Please import from the provider-specific package instead.\n"
"e.g., `from griptape.drivers import OpenAiChatPromptDriver` -> `from griptape.drivers.prompt.openai import OpenAiChatPromptDriver`",
DeprecationWarning,
stacklevel=2,
)

return getattr(self._real_module, name)


sys.modules[__name__] = _DeprecationWarningModuleWrapper(sys.modules[__name__])
15 changes: 15 additions & 0 deletions tests/unit/drivers/prompt/test_base_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import json
import warnings

import pytest

from griptape.artifacts import ActionArtifact, ErrorArtifact, TextArtifact
from griptape.common import Message, PromptStack
Expand Down Expand Up @@ -139,3 +142,15 @@ def test_rule_structured_output_strategy_populated(self):
assert "baz" in prompt_stack.messages[0].content[2].to_text()
assert isinstance(output, TextArtifact)
assert output.value == json.dumps({"baz": "foo"})

def test_deprecated_import(self):
with pytest.warns(DeprecationWarning):
from griptape.drivers import BasePromptDriver

assert BasePromptDriver

with warnings.catch_warnings():
warnings.simplefilter("error")
from griptape.drivers.prompt.base_prompt_driver import BasePromptDriver

assert BasePromptDriver

0 comments on commit c4b2243

Please sign in to comment.