From 8bd65c672f2bb75c3991054d2a1c1b49fbd092de Mon Sep 17 00:00:00 2001 From: Sachin Joglekar Date: Thu, 16 Jan 2025 15:47:38 -0800 Subject: [PATCH 1/2] Add ChatCompletionCache along with AbstractStore for caching completions (#4924) * Add ChatCompletionCache along with AbstractStore for caching completions * Addressing comments * Improve interface for cachestore * Improve documentation & revert protocol * Make cache store typed, and improve docs * remove unnecessary casts --- .../autogen-core/docs/src/reference/index.md | 3 + .../autogen_ext.cache_store.diskcache.rst | 8 + .../python/autogen_ext.cache_store.redis.rst | 8 + .../python/autogen_ext.models.cache.rst | 8 + .../python/autogen_ext.models.replay.rst | 16 +- .../tutorial/models.ipynb | 6 +- .../framework/model-clients.ipynb | 88 +++++++- python/packages/autogen-core/pyproject.toml | 2 + .../autogen-core/src/autogen_core/__init__.py | 3 + .../src/autogen_core/_cache_store.py | 46 ++++ .../autogen-core/tests/test_cache_store.py | 48 ++++ python/packages/autogen-ext/pyproject.toml | 6 + .../src/autogen_ext/cache_store/__init__.py | 0 .../src/autogen_ext/cache_store/diskcache.py | 26 +++ .../src/autogen_ext/cache_store/redis.py | 29 +++ .../src/autogen_ext/models/cache/__init__.py | 6 + .../models/cache/_chat_completion_cache.py | 210 ++++++++++++++++++ .../replay/_replay_chat_completion_client.py | 14 +- .../tests/cache_store/test_diskcache_store.py | 48 ++++ .../tests/cache_store/test_redis_store.py | 53 +++++ .../models/test_chat_completion_cache.py | 133 +++++++++++ python/uv.lock | 59 ++++- 22 files changed, 802 insertions(+), 18 deletions(-) create mode 100644 python/packages/autogen-core/docs/src/reference/python/autogen_ext.cache_store.diskcache.rst create mode 100644 python/packages/autogen-core/docs/src/reference/python/autogen_ext.cache_store.redis.rst create mode 100644 python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.cache.rst create mode 100644 python/packages/autogen-core/src/autogen_core/_cache_store.py create mode 100644 python/packages/autogen-core/tests/test_cache_store.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/cache_store/__init__.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/cache_store/diskcache.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/cache_store/redis.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/models/cache/__init__.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/models/cache/_chat_completion_cache.py create mode 100644 python/packages/autogen-ext/tests/cache_store/test_diskcache_store.py create mode 100644 python/packages/autogen-ext/tests/cache_store/test_redis_store.py create mode 100644 python/packages/autogen-ext/tests/models/test_chat_completion_cache.py diff --git a/python/packages/autogen-core/docs/src/reference/index.md b/python/packages/autogen-core/docs/src/reference/index.md index 04c0e5e5227a..608866ce1a14 100644 --- a/python/packages/autogen-core/docs/src/reference/index.md +++ b/python/packages/autogen-core/docs/src/reference/index.md @@ -48,6 +48,7 @@ python/autogen_ext.agents.video_surfer python/autogen_ext.agents.video_surfer.tools python/autogen_ext.auth.azure python/autogen_ext.teams.magentic_one +python/autogen_ext.models.cache python/autogen_ext.models.openai python/autogen_ext.models.replay python/autogen_ext.tools.langchain @@ -56,5 +57,7 @@ python/autogen_ext.tools.code_execution python/autogen_ext.code_executors.local python/autogen_ext.code_executors.docker python/autogen_ext.code_executors.azure +python/autogen_ext.cache_store.diskcache +python/autogen_ext.cache_store.redis python/autogen_ext.runtimes.grpc ``` diff --git a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.cache_store.diskcache.rst b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.cache_store.diskcache.rst new file mode 100644 index 000000000000..5fbc4c8b35ac --- /dev/null +++ b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.cache_store.diskcache.rst @@ -0,0 +1,8 @@ +autogen\_ext.cache_store.diskcache +================================== + + +.. automodule:: autogen_ext.cache_store.diskcache + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.cache_store.redis.rst b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.cache_store.redis.rst new file mode 100644 index 000000000000..fab1b46d520a --- /dev/null +++ b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.cache_store.redis.rst @@ -0,0 +1,8 @@ +autogen\_ext.cache_store.redis +============================== + + +.. automodule:: autogen_ext.cache_store.redis + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.cache.rst b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.cache.rst new file mode 100644 index 000000000000..48956ace16e2 --- /dev/null +++ b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.cache.rst @@ -0,0 +1,8 @@ +autogen\_ext.models.cache +========================= + + +.. automodule:: autogen_ext.models.cache + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.replay.rst b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.replay.rst index a3630970fa2b..4fc9aefbfb3d 100644 --- a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.replay.rst +++ b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.replay.rst @@ -1,8 +1,8 @@ -autogen\_ext.models.replay -========================== - - -.. automodule:: autogen_ext.models.replay - :members: - :undoc-members: - :show-inheritance: +autogen\_ext.models.replay +========================== + + +.. automodule:: autogen_ext.models.replay + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb index d7aed4fc5cda..a3a423553160 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/models.ipynb @@ -6,7 +6,11 @@ "source": [ "# Models\n", "\n", - "In many cases, agents need access to LLM model services such as OpenAI, Azure OpenAI, or local models. Since there are many different providers with different APIs, `autogen-core` implements a protocol for [model clients](../../core-user-guide/framework/model-clients.ipynb) and `autogen-ext` implements a set of model clients for popular model services. AgentChat can use these model clients to interact with model services. " + "In many cases, agents need access to LLM model services such as OpenAI, Azure OpenAI, or local models. Since there are many different providers with different APIs, `autogen-core` implements a protocol for [model clients](../../core-user-guide/framework/model-clients.ipynb) and `autogen-ext` implements a set of model clients for popular model services. AgentChat can use these model clients to interact with model services. \n", + "\n", + "```{note}\n", + "See {py:class}`~autogen_ext.models.cache.ChatCompletionCache` for a caching wrapper to use with the following clients.\n", + "```" ] }, { diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb index cadd0466ab0c..38fd13195c6f 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/model-clients.ipynb @@ -96,7 +96,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Default [Model Capabilities](../faqs.md#what-are-model-capabilities-and-how-do-i-specify-them) may be overridden should the need arise.\n", + "Default [Model Capabilities](../faqs.md#what-are-model-capabilities-and-how-do-i-specify-them) may be overridden should the need arise.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "\n", "\n", "### Streaming Response\n", @@ -315,6 +321,84 @@ "```" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Caching Wrapper\n", + "\n", + "`autogen_ext` implements {py:class}`~autogen_ext.models.cache.ChatCompletionCache` that can wrap any {py:class}`~autogen_core.models.ChatCompletionClient`. Using this wrapper avoids incurring token usage when querying the underlying client with the same prompt multiple times.\n", + "\n", + "{py:class}`~autogen_core.models.ChatCompletionCache` uses a {py:class}`~autogen_core.CacheStore` protocol. We have implemented some useful variants of {py:class}`~autogen_core.CacheStore` including {py:class}`~autogen_ext.cache_store.diskcache.DiskCacheStore` and {py:class}`~autogen_ext.cache_store.redis.RedisStore`.\n", + "\n", + "Here's an example of using `diskcache` for local caching:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# pip install -U \"autogen-ext[openai, diskcache]\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + "import asyncio\n", + "import tempfile\n", + "\n", + "from autogen_core.models import UserMessage\n", + "from autogen_ext.cache_store.diskcache import DiskCacheStore\n", + "from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache\n", + "from autogen_ext.models.openai import OpenAIChatCompletionClient\n", + "from diskcache import Cache\n", + "\n", + "\n", + "async def main() -> None:\n", + " with tempfile.TemporaryDirectory() as tmpdirname:\n", + " # Initialize the original client\n", + " openai_model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n", + "\n", + " # Then initialize the CacheStore, in this case with diskcache.Cache.\n", + " # You can also use redis like:\n", + " # from autogen_ext.cache_store.redis import RedisStore\n", + " # import redis\n", + " # redis_instance = redis.Redis()\n", + " # cache_store = RedisCacheStore[CHAT_CACHE_VALUE_TYPE](redis_instance)\n", + " cache_store = DiskCacheStore[CHAT_CACHE_VALUE_TYPE](Cache(tmpdirname))\n", + " cache_client = ChatCompletionCache(openai_model_client, cache_store)\n", + "\n", + " response = await cache_client.create([UserMessage(content=\"Hello, how are you?\", source=\"user\")])\n", + " print(response) # Should print response from OpenAI\n", + " response = await cache_client.create([UserMessage(content=\"Hello, how are you?\", source=\"user\")])\n", + " print(response) # Should print cached response\n", + "\n", + "\n", + "asyncio.run(main())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Inspecting `cached_client.total_usage()` (or `model_client.total_usage()`) before and after a cached response should yield idential counts.\n", + "\n", + "Note that the caching is sensitive to the exact arguments provided to `cached_client.create` or `cached_client.create_stream`, so changing `tools` or `json_output` arguments might lead to a cache miss." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -615,7 +699,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.12.1" } }, "nbformat": 4, diff --git a/python/packages/autogen-core/pyproject.toml b/python/packages/autogen-core/pyproject.toml index f0ff7d9920f6..78eb2bef3b4f 100644 --- a/python/packages/autogen-core/pyproject.toml +++ b/python/packages/autogen-core/pyproject.toml @@ -72,6 +72,8 @@ dev = [ "autogen_ext==0.4.3", # Documentation tooling + "diskcache", + "redis", "sphinx-autobuild", ] diff --git a/python/packages/autogen-core/src/autogen_core/__init__.py b/python/packages/autogen-core/src/autogen_core/__init__.py index 478ecc422e03..0198544ca61e 100644 --- a/python/packages/autogen-core/src/autogen_core/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/__init__.py @@ -10,6 +10,7 @@ from ._agent_runtime import AgentRuntime from ._agent_type import AgentType from ._base_agent import BaseAgent +from ._cache_store import CacheStore, InMemoryStore from ._cancellation_token import CancellationToken from ._closure_agent import ClosureAgent, ClosureContext from ._component_config import ( @@ -85,6 +86,8 @@ "AgentMetadata", "AgentRuntime", "BaseAgent", + "CacheStore", + "InMemoryStore", "CancellationToken", "AgentInstantiationContext", "TopicId", diff --git a/python/packages/autogen-core/src/autogen_core/_cache_store.py b/python/packages/autogen-core/src/autogen_core/_cache_store.py new file mode 100644 index 000000000000..339048fdc8f8 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/_cache_store.py @@ -0,0 +1,46 @@ +from typing import Dict, Generic, Optional, Protocol, TypeVar + +T = TypeVar("T") + + +class CacheStore(Protocol, Generic[T]): + """ + This protocol defines the basic interface for store/cache operations. + + Sub-classes should handle the lifecycle of underlying storage. + """ + + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: + """ + Retrieve an item from the store. + + Args: + key: The key identifying the item in the store. + default (optional): The default value to return if the key is not found. + Defaults to None. + + Returns: + The value associated with the key if found, else the default value. + """ + ... + + def set(self, key: str, value: T) -> None: + """ + Set an item in the store. + + Args: + key: The key under which the item is to be stored. + value: The value to be stored in the store. + """ + ... + + +class InMemoryStore(CacheStore[T]): + def __init__(self) -> None: + self.store: Dict[str, T] = {} + + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: + return self.store.get(key, default) + + def set(self, key: str, value: T) -> None: + self.store[key] = value diff --git a/python/packages/autogen-core/tests/test_cache_store.py b/python/packages/autogen-core/tests/test_cache_store.py new file mode 100644 index 000000000000..3caf058af053 --- /dev/null +++ b/python/packages/autogen-core/tests/test_cache_store.py @@ -0,0 +1,48 @@ +from unittest.mock import Mock + +from autogen_core import CacheStore, InMemoryStore + + +def test_set_and_get_object_key_value() -> None: + mock_store = Mock(spec=CacheStore) + test_key = "test_key" + test_value = object() + mock_store.set(test_key, test_value) + mock_store.get.return_value = test_value + mock_store.set.assert_called_with(test_key, test_value) + assert mock_store.get(test_key) == test_value + + +def test_get_non_existent_key() -> None: + mock_store = Mock(spec=CacheStore) + key = "non_existent_key" + mock_store.get.return_value = None + assert mock_store.get(key) is None + + +def test_set_overwrite_existing_key() -> None: + mock_store = Mock(spec=CacheStore) + key = "test_key" + initial_value = "initial_value" + new_value = "new_value" + mock_store.set(key, initial_value) + mock_store.set(key, new_value) + mock_store.get.return_value = new_value + mock_store.set.assert_called_with(key, new_value) + assert mock_store.get(key) == new_value + + +def test_inmemory_store() -> None: + store = InMemoryStore[int]() + test_key = "test_key" + test_value = 42 + store.set(test_key, test_value) + assert store.get(test_key) == test_value + + new_value = 2 + store.set(test_key, new_value) + assert store.get(test_key) == new_value + + key = "non_existent_key" + default_value = 99 + assert store.get(key, default_value) == default_value diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index 5fa2ce54379c..e0f05425b487 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -46,6 +46,12 @@ video-surfer = [ "ffmpeg-python", "openai-whisper", ] +diskcache = [ + "diskcache>=5.6.3" +] +redis = [ + "redis>=5.2.1" +] grpc = [ "grpcio~=1.62.0", # TODO: update this once we have a stable version. diff --git a/python/packages/autogen-ext/src/autogen_ext/cache_store/__init__.py b/python/packages/autogen-ext/src/autogen_ext/cache_store/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/packages/autogen-ext/src/autogen_ext/cache_store/diskcache.py b/python/packages/autogen-ext/src/autogen_ext/cache_store/diskcache.py new file mode 100644 index 000000000000..afb1db224253 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/cache_store/diskcache.py @@ -0,0 +1,26 @@ +from typing import Any, Optional, TypeVar, cast + +import diskcache +from autogen_core import CacheStore + +T = TypeVar("T") + + +class DiskCacheStore(CacheStore[T]): + """ + A typed CacheStore implementation that uses diskcache as the underlying storage. + See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage. + + Args: + cache_instance: An instance of diskcache.Cache. + The user is responsible for managing the DiskCache instance's lifetime. + """ + + def __init__(self, cache_instance: diskcache.Cache): # type: ignore[no-any-unimported] + self.cache = cache_instance + + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: + return cast(Optional[T], self.cache.get(key, default)) # type: ignore[reportUnknownMemberType] + + def set(self, key: str, value: T) -> None: + self.cache.set(key, cast(Any, value)) # type: ignore[reportUnknownMemberType] diff --git a/python/packages/autogen-ext/src/autogen_ext/cache_store/redis.py b/python/packages/autogen-ext/src/autogen_ext/cache_store/redis.py new file mode 100644 index 000000000000..e751f418082c --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/cache_store/redis.py @@ -0,0 +1,29 @@ +from typing import Any, Optional, TypeVar, cast + +import redis +from autogen_core import CacheStore + +T = TypeVar("T") + + +class RedisStore(CacheStore[T]): + """ + A typed CacheStore implementation that uses redis as the underlying storage. + See :class:`~autogen_ext.models.cache.ChatCompletionCache` for an example of usage. + + Args: + cache_instance: An instance of `redis.Redis`. + The user is responsible for managing the Redis instance's lifetime. + """ + + def __init__(self, redis_instance: redis.Redis): + self.cache = redis_instance + + def get(self, key: str, default: Optional[T] = None) -> Optional[T]: + value = cast(Optional[T], self.cache.get(key)) + if value is None: + return default + return value + + def set(self, key: str, value: T) -> None: + self.cache.set(key, cast(Any, value)) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/cache/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/cache/__init__.py new file mode 100644 index 000000000000..333d2b737a53 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/cache/__init__.py @@ -0,0 +1,6 @@ +from ._chat_completion_cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache + +__all__ = [ + "CHAT_CACHE_VALUE_TYPE", + "ChatCompletionCache", +] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/cache/_chat_completion_cache.py b/python/packages/autogen-ext/src/autogen_ext/models/cache/_chat_completion_cache.py new file mode 100644 index 000000000000..79ed5f1660a0 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/cache/_chat_completion_cache.py @@ -0,0 +1,210 @@ +import hashlib +import json +import warnings +from typing import Any, AsyncGenerator, List, Mapping, Optional, Sequence, Union, cast + +from autogen_core import CacheStore, CancellationToken +from autogen_core.models import ( + ChatCompletionClient, + CreateResult, + LLMMessage, + ModelCapabilities, # type: ignore + ModelInfo, + RequestUsage, +) +from autogen_core.tools import Tool, ToolSchema + +CHAT_CACHE_VALUE_TYPE = Union[CreateResult, List[Union[str, CreateResult]]] + + +class ChatCompletionCache(ChatCompletionClient): + """ + A wrapper around a :class:`~autogen_ext.models.cache.ChatCompletionClient` that caches + creation results from an underlying client. + Cache hits do not contribute to token usage of the original client. + + Typical Usage: + + Lets use caching on disk with `openai` client as an example. + First install `autogen-ext` with the required packages: + + .. code-block:: bash + + pip install -U "autogen-ext[openai, diskcache]" + + And use it as: + + .. code-block:: python + + import asyncio + import tempfile + + from autogen_core.models import UserMessage + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_ext.models.cache import ChatCompletionCache, CHAT_CACHE_VALUE_TYPE + from autogen_ext.cache_store.diskcache import DiskCacheStore + from diskcache import Cache + + + async def main(): + with tempfile.TemporaryDirectory() as tmpdirname: + # Initialize the original client + openai_model_client = OpenAIChatCompletionClient(model="gpt-4o") + + # Then initialize the CacheStore, in this case with diskcache.Cache. + # You can also use redis like: + # from autogen_ext.cache_store.redis import RedisStore + # import redis + # redis_instance = redis.Redis() + # cache_store = RedisCacheStore[CHAT_CACHE_VALUE_TYPE](redis_instance) + cache_store = DiskCacheStore[CHAT_CACHE_VALUE_TYPE](Cache(tmpdirname)) + cache_client = ChatCompletionCache(openai_model_client, cache_store) + + response = await cache_client.create([UserMessage(content="Hello, how are you?", source="user")]) + print(response) # Should print response from OpenAI + response = await cache_client.create([UserMessage(content="Hello, how are you?", source="user")]) + print(response) # Should print cached response + + + asyncio.run(main()) + + You can now use the `cached_client` as you would the original client, but with caching enabled. + + Args: + client (ChatCompletionClient): The original ChatCompletionClient to wrap. + store (CacheStore): A store object that implements get and set methods. + The user is responsible for managing the store's lifecycle & clearing it (if needed). + """ + + def __init__(self, client: ChatCompletionClient, store: CacheStore[CHAT_CACHE_VALUE_TYPE]): + self.client = client + self.store = store + + def _check_cache( + self, + messages: Sequence[LLMMessage], + tools: Sequence[Tool | ToolSchema], + json_output: Optional[bool], + extra_create_args: Mapping[str, Any], + ) -> tuple[Optional[Union[CreateResult, List[Union[str, CreateResult]]]], str]: + """ + Helper function to check the cache for a result. + Returns a tuple of (cached_result, cache_key). + """ + + data = { + "messages": [message.model_dump() for message in messages], + "tools": [(tool.schema if isinstance(tool, Tool) else tool) for tool in tools], + "json_output": json_output, + "extra_create_args": extra_create_args, + } + serialized_data = json.dumps(data, sort_keys=True) + cache_key = hashlib.sha256(serialized_data.encode()).hexdigest() + + cached_result = cast(Optional[CreateResult], self.store.get(cache_key)) + if cached_result is not None: + return cached_result, cache_key + + return None, cache_key + + async def create( + self, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + ) -> CreateResult: + """ + Cached version of ChatCompletionClient.create. + If the result of a call to create has been cached, it will be returned immediately + without invoking the underlying client. + + NOTE: cancellation_token is ignored for cached results. + """ + cached_result, cache_key = self._check_cache(messages, tools, json_output, extra_create_args) + if cached_result: + assert isinstance(cached_result, CreateResult) + cached_result.cached = True + return cached_result + + result = await self.client.create( + messages, + tools=tools, + json_output=json_output, + extra_create_args=extra_create_args, + cancellation_token=cancellation_token, + ) + self.store.set(cache_key, result) + return result + + def create_stream( + self, + messages: Sequence[LLMMessage], + *, + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + ) -> AsyncGenerator[Union[str, CreateResult], None]: + """ + Cached version of ChatCompletionClient.create_stream. + If the result of a call to create_stream has been cached, it will be returned + without streaming from the underlying client. + + NOTE: cancellation_token is ignored for cached results. + """ + + async def _generator() -> AsyncGenerator[Union[str, CreateResult], None]: + cached_result, cache_key = self._check_cache( + messages, + tools, + json_output, + extra_create_args, + ) + if cached_result: + assert isinstance(cached_result, list) + for result in cached_result: + if isinstance(result, CreateResult): + result.cached = True + yield result + return + + result_stream = self.client.create_stream( + messages, + tools=tools, + json_output=json_output, + extra_create_args=extra_create_args, + cancellation_token=cancellation_token, + ) + + output_results: List[Union[str, CreateResult]] = [] + self.store.set(cache_key, output_results) + + async for result in result_stream: + output_results.append(result) + yield result + + return _generator() + + def actual_usage(self) -> RequestUsage: + return self.client.actual_usage() + + def count_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: + return self.client.count_tokens(messages, tools=tools) + + @property + def capabilities(self) -> ModelCapabilities: # type: ignore + warnings.warn("capabilities is deprecated, use model_info instead", DeprecationWarning, stacklevel=2) + return self.client.capabilities + + @property + def model_info(self) -> ModelInfo: + return self.client.model_info + + def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[Tool | ToolSchema] = []) -> int: + return self.client.remaining_tokens(messages, tools=tools) + + def total_usage(self) -> RequestUsage: + return self.client.total_usage() diff --git a/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py b/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py index b62084b646b1..5ae7b6b665eb 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py @@ -40,8 +40,8 @@ class ReplayChatCompletionClient(ChatCompletionClient): .. code-block:: python - from autogen_ext.models.replay import ReplayChatCompletionClient from autogen_core.models import UserMessage + from autogen_ext.models.replay import ReplayChatCompletionClient async def example(): @@ -60,8 +60,8 @@ async def example(): .. code-block:: python import asyncio - from autogen_ext.models.replay import ReplayChatCompletionClient from autogen_core.models import UserMessage + from autogen_ext.models.replay import ReplayChatCompletionClient async def example(): @@ -86,8 +86,8 @@ async def example(): .. code-block:: python import asyncio - from autogen_ext.models.replay import ReplayChatCompletionClient from autogen_core.models import UserMessage + from autogen_ext.models.replay import ReplayChatCompletionClient async def example(): @@ -129,6 +129,7 @@ def __init__( self._cur_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._current_index = 0 + self._cached_bool_value = True async def create( self, @@ -148,7 +149,9 @@ async def create( if isinstance(response, str): _, output_token_count = self._tokenize(response) self._cur_usage = RequestUsage(prompt_tokens=prompt_token_count, completion_tokens=output_token_count) - response = CreateResult(finish_reason="stop", content=response, usage=self._cur_usage, cached=True) + response = CreateResult( + finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value + ) else: self._cur_usage = RequestUsage( prompt_tokens=prompt_token_count, completion_tokens=response.usage.completion_tokens @@ -207,6 +210,9 @@ def remaining_tokens(self, messages: Sequence[LLMMessage], *, tools: Sequence[To 0, self._total_available_tokens - self._total_usage.prompt_tokens - self._total_usage.completion_tokens ) + def set_cached_bool_value(self, value: bool) -> None: + self._cached_bool_value = value + def _tokenize(self, messages: Union[str, LLMMessage, Sequence[LLMMessage]]) -> tuple[list[str], int]: total_tokens = 0 all_tokens: List[str] = [] diff --git a/python/packages/autogen-ext/tests/cache_store/test_diskcache_store.py b/python/packages/autogen-ext/tests/cache_store/test_diskcache_store.py new file mode 100644 index 000000000000..ddca0b82cdcc --- /dev/null +++ b/python/packages/autogen-ext/tests/cache_store/test_diskcache_store.py @@ -0,0 +1,48 @@ +import tempfile + +import pytest + +diskcache = pytest.importorskip("diskcache") + + +def test_diskcache_store_basic() -> None: + from autogen_ext.cache_store.diskcache import DiskCacheStore + from diskcache import Cache + + with tempfile.TemporaryDirectory() as temp_dir: + cache = Cache(temp_dir) + store = DiskCacheStore[int](cache) + test_key = "test_key" + test_value = 42 + store.set(test_key, test_value) + assert store.get(test_key) == test_value + + new_value = 2 + store.set(test_key, new_value) + assert store.get(test_key) == new_value + + key = "non_existent_key" + default_value = 99 + assert store.get(key, default_value) == default_value + + +def test_diskcache_with_different_instances() -> None: + from autogen_ext.cache_store.diskcache import DiskCacheStore + from diskcache import Cache + + with tempfile.TemporaryDirectory() as temp_dir_1, tempfile.TemporaryDirectory() as temp_dir_2: + cache_1 = Cache(temp_dir_1) + cache_2 = Cache(temp_dir_2) + + store_1 = DiskCacheStore[int](cache_1) + store_2 = DiskCacheStore[int](cache_2) + + test_key = "test_key" + test_value_1 = 5 + test_value_2 = 6 + + store_1.set(test_key, test_value_1) + assert store_1.get(test_key) == test_value_1 + + store_2.set(test_key, test_value_2) + assert store_2.get(test_key) == test_value_2 diff --git a/python/packages/autogen-ext/tests/cache_store/test_redis_store.py b/python/packages/autogen-ext/tests/cache_store/test_redis_store.py new file mode 100644 index 000000000000..111f38a4fffd --- /dev/null +++ b/python/packages/autogen-ext/tests/cache_store/test_redis_store.py @@ -0,0 +1,53 @@ +from unittest.mock import MagicMock + +import pytest + +redis = pytest.importorskip("redis") + + +def test_redis_store_basic() -> None: + from autogen_ext.cache_store.redis import RedisStore + + redis_instance = MagicMock() + store = RedisStore[int](redis_instance) + test_key = "test_key" + test_value = 42 + store.set(test_key, test_value) + redis_instance.set.assert_called_with(test_key, test_value) + redis_instance.get.return_value = test_value + assert store.get(test_key) == test_value + + new_value = 2 + store.set(test_key, new_value) + redis_instance.set.assert_called_with(test_key, new_value) + redis_instance.get.return_value = new_value + assert store.get(test_key) == new_value + + key = "non_existent_key" + default_value = 99 + redis_instance.get.return_value = None + assert store.get(key, default_value) == default_value + + +def test_redis_with_different_instances() -> None: + from autogen_ext.cache_store.redis import RedisStore + + redis_instance_1 = MagicMock() + redis_instance_2 = MagicMock() + + store_1 = RedisStore[int](redis_instance_1) + store_2 = RedisStore[int](redis_instance_2) + + test_key = "test_key" + test_value_1 = 5 + test_value_2 = 6 + + store_1.set(test_key, test_value_1) + redis_instance_1.set.assert_called_with(test_key, test_value_1) + redis_instance_1.get.return_value = test_value_1 + assert store_1.get(test_key) == test_value_1 + + store_2.set(test_key, test_value_2) + redis_instance_2.set.assert_called_with(test_key, test_value_2) + redis_instance_2.get.return_value = test_value_2 + assert store_2.get(test_key) == test_value_2 diff --git a/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py b/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py new file mode 100644 index 000000000000..ceb4d9a9f72a --- /dev/null +++ b/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py @@ -0,0 +1,133 @@ +import copy +from typing import List, Tuple, Union + +import pytest +from autogen_core import InMemoryStore +from autogen_core.models import ( + ChatCompletionClient, + CreateResult, + LLMMessage, + SystemMessage, + UserMessage, +) +from autogen_ext.models.cache import CHAT_CACHE_VALUE_TYPE, ChatCompletionCache +from autogen_ext.models.replay import ReplayChatCompletionClient + + +def get_test_data() -> Tuple[list[str], list[str], SystemMessage, ChatCompletionClient, ChatCompletionCache]: + num_messages = 3 + responses = [f"This is dummy message number {i}" for i in range(num_messages)] + prompts = [f"This is dummy prompt number {i}" for i in range(num_messages)] + system_prompt = SystemMessage(content="This is a system prompt") + replay_client = ReplayChatCompletionClient(responses) + replay_client.set_cached_bool_value(False) + store = InMemoryStore[CHAT_CACHE_VALUE_TYPE]() + cached_client = ChatCompletionCache(replay_client, store) + + return responses, prompts, system_prompt, replay_client, cached_client + + +@pytest.mark.asyncio +async def test_cache_basic_with_args() -> None: + responses, prompts, system_prompt, _, cached_client = get_test_data() + + response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")]) + assert isinstance(response0, CreateResult) + assert not response0.cached + assert response0.content == responses[0] + + response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")]) + assert not response1.cached + assert response1.content == responses[1] + + # Cached output. + response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")]) + assert isinstance(response0, CreateResult) + assert response0_cached.cached + assert response0_cached.content == responses[0] + + # Cache miss if args change. + response2 = await cached_client.create( + [system_prompt, UserMessage(content=prompts[0], source="user")], json_output=True + ) + assert isinstance(response2, CreateResult) + assert not response2.cached + assert response2.content == responses[2] + + +@pytest.mark.asyncio +async def test_cache_model_and_count_api() -> None: + _, prompts, system_prompt, replay_client, cached_client = get_test_data() + + assert replay_client.model_info == cached_client.model_info + assert replay_client.capabilities == cached_client.capabilities + + messages: List[LLMMessage] = [system_prompt, UserMessage(content=prompts[0], source="user")] + assert replay_client.count_tokens(messages) == cached_client.count_tokens(messages) + assert replay_client.remaining_tokens(messages) == cached_client.remaining_tokens(messages) + + +@pytest.mark.asyncio +async def test_cache_token_usage() -> None: + responses, prompts, system_prompt, replay_client, cached_client = get_test_data() + + response0 = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")]) + assert isinstance(response0, CreateResult) + assert not response0.cached + assert response0.content == responses[0] + actual_usage0 = copy.copy(cached_client.actual_usage()) + total_usage0 = copy.copy(cached_client.total_usage()) + + response1 = await cached_client.create([system_prompt, UserMessage(content=prompts[1], source="user")]) + assert not response1.cached + assert response1.content == responses[1] + actual_usage1 = copy.copy(cached_client.actual_usage()) + total_usage1 = copy.copy(cached_client.total_usage()) + assert total_usage1.prompt_tokens > total_usage0.prompt_tokens + assert total_usage1.completion_tokens > total_usage0.completion_tokens + assert actual_usage1.prompt_tokens == actual_usage0.prompt_tokens + assert actual_usage1.completion_tokens == actual_usage0.completion_tokens + + # Cached output. + response0_cached = await cached_client.create([system_prompt, UserMessage(content=prompts[0], source="user")]) + assert isinstance(response0, CreateResult) + assert response0_cached.cached + assert response0_cached.content == responses[0] + total_usage2 = copy.copy(cached_client.total_usage()) + assert total_usage2.prompt_tokens == total_usage1.prompt_tokens + assert total_usage2.completion_tokens == total_usage1.completion_tokens + + assert cached_client.actual_usage() == replay_client.actual_usage() + assert cached_client.total_usage() == replay_client.total_usage() + + +@pytest.mark.asyncio +async def test_cache_create_stream() -> None: + _, prompts, system_prompt, _, cached_client = get_test_data() + + original_streamed_results: List[Union[str, CreateResult]] = [] + async for completion in cached_client.create_stream( + [system_prompt, UserMessage(content=prompts[0], source="user")] + ): + original_streamed_results.append(completion) + total_usage0 = copy.copy(cached_client.total_usage()) + + cached_completion_results: List[Union[str, CreateResult]] = [] + async for completion in cached_client.create_stream( + [system_prompt, UserMessage(content=prompts[0], source="user")] + ): + cached_completion_results.append(completion) + total_usage1 = copy.copy(cached_client.total_usage()) + + assert total_usage1.prompt_tokens == total_usage0.prompt_tokens + assert total_usage1.completion_tokens == total_usage0.completion_tokens + + for original, cached in zip(original_streamed_results, cached_completion_results, strict=False): + if isinstance(original, str): + assert original == cached + elif isinstance(original, CreateResult) and isinstance(cached, CreateResult): + assert original.content == cached.content + assert cached.cached + assert not original.cached + else: + raise ValueError(f"Unexpected types : {type(original)} and {type(cached)}") diff --git a/python/uv.lock b/python/uv.lock index 62e30fa36f70..f408c1cae420 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -130,7 +130,7 @@ source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, { name = "aiosignal" }, - { name = "async-timeout", marker = "python_full_version < '3.11'" }, + { name = "async-timeout", version = "4.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "attrs" }, { name = "frozenlist" }, { name = "multidict" }, @@ -317,11 +317,30 @@ wheels = [ name = "async-timeout" version = "4.0.3" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", +] sdist = { url = "https://files.pythonhosted.org/packages/87/d6/21b30a550dafea84b1b8eee21b5e23fa16d010ae006011221f33dcd8d7f8/async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f", size = 8345 } wheels = [ { url = "https://files.pythonhosted.org/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028", size = 5721 }, ] +[[package]] +name = "async-timeout" +version = "5.0.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", +] +sdist = { url = "https://files.pythonhosted.org/packages/a5/ae/136395dfbfe00dfc94da3f3e136d0b13f394cba8f4841120e34226265780/async_timeout-5.0.1.tar.gz", hash = "sha256:d9321a7a3d5a6a5e187e824d2fa0793ce379a202935782d555d6e9d2735677d3", size = 9274 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/ba/e2081de779ca30d473f21f5b30e0e737c438205440784c7dfc81efc2b029/async_timeout-5.0.1-py3-none-any.whl", hash = "sha256:39e3809566ff85354557ec2398b55e096c8364bacac9405a7a1fa429e77fe76c", size = 6233 }, +] + [[package]] name = "asyncer" version = "0.0.7" @@ -399,6 +418,7 @@ dev = [ { name = "azure-identity" }, { name = "chess" }, { name = "colorama" }, + { name = "diskcache" }, { name = "langchain-openai" }, { name = "langgraph" }, { name = "llama-index" }, @@ -416,6 +436,7 @@ dev = [ { name = "pydata-sphinx-theme" }, { name = "pygments" }, { name = "python-dotenv" }, + { name = "redis" }, { name = "requests" }, { name = "sphinx" }, { name = "sphinx-autobuild" }, @@ -455,6 +476,7 @@ dev = [ { name = "azure-identity" }, { name = "chess" }, { name = "colorama" }, + { name = "diskcache" }, { name = "langchain-openai" }, { name = "langgraph" }, { name = "llama-index" }, @@ -472,6 +494,7 @@ dev = [ { name = "pydata-sphinx-theme", specifier = "==0.15.4" }, { name = "pygments" }, { name = "python-dotenv" }, + { name = "redis" }, { name = "requests" }, { name = "sphinx" }, { name = "sphinx-autobuild" }, @@ -504,6 +527,9 @@ azure = [ { name = "azure-core" }, { name = "azure-identity" }, ] +diskcache = [ + { name = "diskcache" }, +] docker = [ { name = "docker" }, ] @@ -531,6 +557,9 @@ openai = [ { name = "openai" }, { name = "tiktoken" }, ] +redis = [ + { name = "redis" }, +] video-surfer = [ { name = "autogen-agentchat" }, { name = "ffmpeg-python" }, @@ -561,6 +590,7 @@ requires-dist = [ { name = "autogen-core", editable = "packages/autogen-core" }, { name = "azure-core", marker = "extra == 'azure'" }, { name = "azure-identity", marker = "extra == 'azure'" }, + { name = "diskcache", marker = "extra == 'diskcache'", specifier = ">=5.6.3" }, { name = "docker", marker = "extra == 'docker'", specifier = "~=7.0" }, { name = "ffmpeg-python", marker = "extra == 'video-surfer'" }, { name = "graphrag", marker = "extra == 'graphrag'", specifier = ">=1.0.1" }, @@ -576,6 +606,7 @@ requires-dist = [ { name = "pillow", marker = "extra == 'web-surfer'", specifier = ">=11.0.0" }, { name = "playwright", marker = "extra == 'magentic-one'", specifier = ">=1.48.0" }, { name = "playwright", marker = "extra == 'web-surfer'", specifier = ">=1.48.0" }, + { name = "redis", marker = "extra == 'redis'", specifier = ">=5.2.1" }, { name = "tiktoken", marker = "extra == 'openai'", specifier = ">=0.8.0" }, ] @@ -1379,6 +1410,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/69/1bcf70f81de1b4a9f21b3a62ec0c83bdff991c88d6cc2267d02408457e88/dirtyjson-1.0.8-py3-none-any.whl", hash = "sha256:125e27248435a58acace26d5c2c4c11a1c0de0a9c5124c5a94ba78e517d74f53", size = 25197 }, ] +[[package]] +name = "diskcache" +version = "5.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/3f/21/1c1ffc1a039ddcc459db43cc108658f32c57d271d7289a2794e401d0fdb6/diskcache-5.6.3.tar.gz", hash = "sha256:2c3a3fa2743d8535d832ec61c2054a1641f41775aa7c556758a109941e33e4fc", size = 67916 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550 }, +] + [[package]] name = "distro" version = "1.9.0" @@ -2345,7 +2385,7 @@ version = "0.3.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, - { name = "async-timeout", marker = "python_full_version < '3.11'" }, + { name = "async-timeout", version = "4.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "langchain-core" }, { name = "langchain-text-splitters" }, { name = "langsmith" }, @@ -4895,6 +4935,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/d2/3b2ab40f455a256cb6672186bea95cd97b459ce4594050132d71e76f0d6f/pyzmq-26.2.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:90412f2db8c02a3864cbfc67db0e3dcdbda336acf1c469526d3e869394fe001c", size = 550762 }, ] +[[package]] +name = "redis" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-timeout", version = "4.0.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "async-timeout", version = "5.0.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and python_full_version < '3.11.3'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/47/da/d283a37303a995cd36f8b92db85135153dc4f7a8e4441aa827721b442cfb/redis-5.2.1.tar.gz", hash = "sha256:16f2e22dff21d5125e8481515e386711a34cbec50f0e44413dd7d9c060a54e0f", size = 4608355 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/5f/fa26b9b2672cbe30e07d9a5bdf39cf16e3b80b42916757c5f92bca88e4ba/redis-5.2.1-py3-none-any.whl", hash = "sha256:ee7e1056b9aea0f04c6c2ed59452947f34c4940ee025f5dd83e6a6418b6989e4", size = 261502 }, +] + [[package]] name = "referencing" version = "0.35.1" @@ -5917,7 +5970,7 @@ name = "triton" version = "3.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "filelock" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/98/29/69aa56dc0b2eb2602b553881e34243475ea2afd9699be042316842788ff5/triton-3.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b0dd10a925263abbe9fa37dcde67a5e9b2383fc269fdf59f5657cac38c5d1d8", size = 209460013 }, From 1f22a7b7a169f2f812185517f53437c43cb39fb9 Mon Sep 17 00:00:00 2001 From: zysoong Date: Fri, 17 Jan 2025 02:48:55 +0100 Subject: [PATCH 2/2] [Documentation] Update tools.ipynb: use system messages in the tool_agent_caller_loop session (#5068) * Update tools.ipynb: concat system messages in the tool_agent_caller_loop session * Fix type mismatch on list concatenation --------- Co-authored-by: Eric Zhu --- .../docs/src/user-guide/core-user-guide/framework/tools.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/tools.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/tools.ipynb index 23fe4dffbb4e..2da5b2c7943c 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/tools.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/tools.ipynb @@ -182,7 +182,7 @@ " @message_handler\n", " async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n", " # Create a session of messages.\n", - " session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n", + " session: List[LLMMessage] = self._system_messages + [UserMessage(content=message.content, source=\"user\")]\n", " # Run the caller loop to handle tool calls.\n", " messages = await tool_agent_caller_loop(\n", " self,\n",