From ef9597f114a98d44595c6cbafdd2acbb631826c8 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Tue, 5 Mar 2024 14:56:23 +0000 Subject: [PATCH] Return discarded messages as a list of message indices. --- .../application/assistant_application.py | 41 ++++-- aidial_assistant/chain/history.py | 136 ++++++++++-------- aidial_assistant/commands/run_plugin.py | 4 +- aidial_assistant/model/model_client.py | 41 ++++-- aidial_assistant/utils/state.py | 19 ++- poetry.lock | 22 +-- pyproject.toml | 2 +- .../chain/test_command_chain_best_effort.py | 4 +- tests/unit_tests/chain/test_history.py | 99 ++++--------- tests/unit_tests/model/test_model_client.py | 7 +- .../utils/test_exception_handler.py | 8 +- tests/unit_tests/utils/test_state.py | 8 ++ 12 files changed, 210 insertions(+), 181 deletions(-) diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index ed15f0c..bd53866 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -26,7 +26,7 @@ CommandConstructor, CommandDict, ) -from aidial_assistant.chain.history import History +from aidial_assistant.chain.history import History, ScopedMessage from aidial_assistant.commands.reply import Reply from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin from aidial_assistant.commands.run_tool import RunTool @@ -109,6 +109,25 @@ def _construct_tool(name: str, description: str) -> ChatCompletionToolParam: ) +def _create_history( + messages: list[ScopedMessage], plugins: list[PluginInfo] +) -> History: + plugin_descriptions = { + plugin.info.ai_plugin.name_for_model: plugin.info.open_api.info.description + or plugin.info.ai_plugin.description_for_human + for plugin in plugins + } + return History( + assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build( + addons=plugin_descriptions + ), + best_effort_template=MAIN_BEST_EFFORT_TEMPLATE.build( + addons=plugin_descriptions + ), + scoped_messages=messages, + ) + + class AssistantApplication(ChatCompletion): def __init__( self, config_dir: Path, tools_supporting_deployments: set[str] @@ -204,6 +223,7 @@ def create_command(addon: PluginInfo): or addon.info.ai_plugin.description_for_human for addon in addons } + scoped_messages = parse_history(request.messages) history = History( assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build( addons=addon_descriptions @@ -211,14 +231,17 @@ def create_command(addon: PluginInfo): best_effort_template=MAIN_BEST_EFFORT_TEMPLATE.build( addons=addon_descriptions ), - scoped_messages=parse_history(request.messages), + scoped_messages=scoped_messages, ) - discarded_messages: int | None = None + discarded_user_messages: set[int] | None = None if request.max_prompt_tokens is not None: - original_size = history.user_message_count - history = await history.truncate(request.max_prompt_tokens, model) - truncated_size = history.user_message_count - discarded_messages = original_size - truncated_size + history, discarded_messages = await history.truncate( + model, request.max_prompt_tokens + ) + discarded_user_messages = set( + scoped_messages[index].user_index + for index in discarded_messages + ) # TODO: else compare the history size to the max prompt tokens of the underlying model choice = response.create_single_choice() @@ -243,8 +266,8 @@ def create_command(addon: PluginInfo): model.total_prompt_tokens, model.total_completion_tokens ) - if discarded_messages is not None: - response.set_discarded_messages(discarded_messages) + if discarded_user_messages is not None: + response.set_discarded_messages(list(discarded_user_messages)) @staticmethod async def _run_native_tools_chat( diff --git a/aidial_assistant/chain/history.py b/aidial_assistant/chain/history.py index 6e8db05..f4e594c 100644 --- a/aidial_assistant/chain/history.py +++ b/aidial_assistant/chain/history.py @@ -1,7 +1,11 @@ from enum import Enum +from typing import Tuple, cast from jinja2 import Template -from openai.types.chat import ChatCompletionMessageParam +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionSystemMessageParam, +) from pydantic import BaseModel from aidial_assistant.chain.command_result import ( @@ -26,6 +30,7 @@ class MessageScope(str, Enum): class ScopedMessage(BaseModel): scope: MessageScope = MessageScope.USER message: ChatCompletionMessageParam + user_index: int class History: @@ -40,35 +45,32 @@ def __init__( ) self.best_effort_template = best_effort_template self.scoped_messages = scoped_messages - self._user_message_count = sum( - 1 - for message in scoped_messages - if message.scope == MessageScope.USER - ) def to_protocol_messages(self) -> list[ChatCompletionMessageParam]: messages: list[ChatCompletionMessageParam] = [] - for index, scoped_message in enumerate(self.scoped_messages): + scoped_message_iterator = iter(self.scoped_messages) + if self._is_first_system_message(): + message = cast( + ChatCompletionSystemMessageParam, + next(scoped_message_iterator).message, + ) + messages.append( + system_message( + self.assistant_system_message_template.render( + system_prefix=message["content"] + ) + ) + ) + else: + messages.append( + system_message(self.assistant_system_message_template.render()) + ) + + for scoped_message in scoped_message_iterator: message = scoped_message.message scope = scoped_message.scope - if index == 0: - if message["role"] == "system": - messages.append( - system_message( - self.assistant_system_message_template.render( - system_prefix=message["content"] - ) - ) - ) - else: - messages.append( - system_message( - self.assistant_system_message_template.render() - ) - ) - messages.append(message) - elif scope == MessageScope.USER and message["role"] == "assistant": + if scope == MessageScope.USER and message["role"] == "assistant": # Clients see replies in plain text, but the model should understand how to reply appropriately. content = commands_to_text( [ @@ -107,51 +109,59 @@ def to_best_effort_messages( return messages async def truncate( - self, max_prompt_tokens: int, model_client: ModelClient - ) -> "History": - discarded_messages = await model_client.get_discarded_messages( - self.to_protocol_messages(), - max_prompt_tokens, + self, model_client: ModelClient, max_prompt_tokens: int + ) -> Tuple["History", list[int]]: + discarded_messages = await self._get_discarded_messages( + model_client, max_prompt_tokens ) - if discarded_messages > 0: - return History( + if not discarded_messages: + return self, [] + + discarded_messages_set = set(discarded_messages) + return ( + History( assistant_system_message_template=self.assistant_system_message_template, best_effort_template=self.best_effort_template, - scoped_messages=self._skip_messages(discarded_messages), - ) - - return self - - @property - def user_message_count(self) -> int: - return self._user_message_count - - def _skip_messages(self, discarded_messages: int) -> list[ScopedMessage]: - messages: list[ScopedMessage] = [] - current_message = self.scoped_messages[0] - message_iterator = iter(self.scoped_messages) - for _ in range(discarded_messages): - current_message = next(message_iterator) - while current_message.message["role"] == "system": - # System messages should be kept in the history - messages.append(current_message) - current_message = next(message_iterator) + scoped_messages=[ + scoped_message + for index, scoped_message in enumerate(self.scoped_messages) + if index not in discarded_messages_set + ], + ), + discarded_messages, + ) - if current_message.scope == MessageScope.INTERNAL: - while current_message.scope == MessageScope.INTERNAL: - current_message = next(message_iterator) + async def _get_discarded_messages( + self, model_client: ModelClient, max_prompt_tokens: int + ) -> list[int]: + discarded_protocol_messages = await model_client.get_discarded_messages( + self.to_protocol_messages(), + max_prompt_tokens, + ) - # Internal messages (i.e. addon requests/responses) are always followed by an assistant reply - assert ( - current_message.message["role"] == "assistant" - ), "Internal messages must be followed by an assistant reply." + if discarded_protocol_messages: + discarded_protocol_messages.sort() + discarded_messages = ( + discarded_protocol_messages + if self._is_first_system_message() + else [index - 1 for index in discarded_protocol_messages] + ) + user_indices = set( + self.scoped_messages[index].user_index + for index in discarded_messages + ) - remaining_messages = list(message_iterator) - assert ( - len(remaining_messages) > 0 - ), "No user messages left after history truncation." + return [ + index + for index, scoped_message in enumerate(self.scoped_messages) + if scoped_message.user_index in user_indices + ] - messages += remaining_messages + return discarded_protocol_messages - return messages + def _is_first_system_message(self) -> bool: + return ( + len(self.scoped_messages) > 0 + and self.scoped_messages[0].message["role"] == "system" + ) diff --git a/aidial_assistant/commands/run_plugin.py b/aidial_assistant/commands/run_plugin.py index 96f2913..02f157a 100644 --- a/aidial_assistant/commands/run_plugin.py +++ b/aidial_assistant/commands/run_plugin.py @@ -87,7 +87,9 @@ def create_command(op: APIOperation): best_effort_template=ADDON_BEST_EFFORT_TEMPLATE.build( api_schema=api_schema ), - scoped_messages=[ScopedMessage(message=user_message(query))], + scoped_messages=[ + ScopedMessage(message=user_message(query), user_index=0) + ], ) chat = CommandChain( diff --git a/aidial_assistant/model/model_client.py b/aidial_assistant/model/model_client.py index 83dfa8e..198d3f5 100644 --- a/aidial_assistant/model/model_client.py +++ b/aidial_assistant/model/model_client.py @@ -1,4 +1,5 @@ from abc import ABC +from itertools import islice from typing import Any, AsyncIterator, List from aidial_sdk.utils.merge_chunks import merge @@ -16,7 +17,7 @@ class ReasonLengthException(Exception): class ExtraResultsCallback: - def on_discarded_messages(self, discarded_messages: int): + def on_discarded_messages(self, discarded_messages: list[int]): pass def on_prompt_tokens(self, prompt_tokens: int): @@ -36,6 +37,21 @@ async def _flush_stream(stream: AsyncIterator[str]): pass +def _discarded_messages_count_to_indices( + messages: list[ChatCompletionMessageParam], discarded_messages: int +) -> list[int]: + return list( + islice( + ( + i + for i, message in enumerate(messages) + if message["role"] != "system" + ), + discarded_messages, + ) + ) + + class ModelClient(ABC): def __init__(self, client: AsyncOpenAI, model_args: dict[str, Any]): self.client = client @@ -70,12 +86,16 @@ async def agenerate( extra_results_callback.on_prompt_tokens(prompt_tokens) if extra_results_callback: - discarded_messages: int | None = chunk_dict.get( + discarded_messages: int | list[int] | None = chunk_dict.get( "statistics", {} ).get("discarded_messages") if discarded_messages is not None: extra_results_callback.on_discarded_messages( - discarded_messages + _discarded_messages_count_to_indices( + messages, discarded_messages + ) + if isinstance(discarded_messages, int) + else discarded_messages ) choice = chunk.choices[0] @@ -128,15 +148,16 @@ def on_prompt_tokens(self, prompt_tokens: int): return callback.token_count # TODO: Use a dedicated endpoint for discarded_messages. + # https://github.com/epam/ai-dial-assistant/issues/39 async def get_discarded_messages( self, messages: list[ChatCompletionMessageParam], max_prompt_tokens: int - ) -> int: + ) -> list[int]: class DiscardedMessagesCallback(ExtraResultsCallback): def __init__(self): - self.message_count: int | None = None + self.discarded_messages: list[int] | None = None - def on_discarded_messages(self, discarded_messages: int): - self.message_count = discarded_messages + def on_discarded_messages(self, discarded_messages: list[int]): + self.discarded_messages = discarded_messages callback = DiscardedMessagesCallback() await _flush_stream( @@ -147,10 +168,10 @@ def on_discarded_messages(self, discarded_messages: int): max_tokens=1, ) ) - if callback.message_count is None: - raise Exception("No message count received.") + if callback.discarded_messages is None: + raise Exception("Discarded messages were not provided.") - return callback.message_count + return callback.discarded_messages @property def total_prompt_tokens(self) -> int: diff --git a/aidial_assistant/utils/state.py b/aidial_assistant/utils/state.py index f8c9dd5..91a4581 100644 --- a/aidial_assistant/utils/state.py +++ b/aidial_assistant/utils/state.py @@ -70,7 +70,7 @@ def _convert_old_commands(string: str) -> str: def parse_history(history: list[Message]) -> list[ScopedMessage]: messages: list[ScopedMessage] = [] - for message in history: + for index, message in enumerate(history): if message.role == Role.ASSISTANT: invocations = _get_invocations(message.custom_content) for invocation in invocations: @@ -80,25 +80,36 @@ def parse_history(history: list[Message]) -> list[ScopedMessage]: message=assistant_message( _convert_old_commands(invocation["request"]) ), + user_index=index, ) ) messages.append( ScopedMessage( scope=MessageScope.INTERNAL, message=user_message(invocation["response"]), + user_index=index, ) ) messages.append( - ScopedMessage(message=assistant_message(message.content or "")) + ScopedMessage( + message=assistant_message(message.content or ""), + user_index=index, + ) ) elif message.role == Role.USER: messages.append( - ScopedMessage(message=user_message(message.content or "")) + ScopedMessage( + message=user_message(message.content or ""), + user_index=index, + ) ) elif message.role == Role.SYSTEM: messages.append( - ScopedMessage(message=system_message(message.content or "")) + ScopedMessage( + message=system_message(message.content or ""), + user_index=index, + ) ) else: raise RequestParameterValidationError( diff --git a/poetry.lock b/poetry.lock index dfcbe5d..fdcb845 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,14 +1,14 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aidial-sdk" -version = "0.6.2" +version = "0.7.0" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "aidial_sdk-0.6.2-py3-none-any.whl", hash = "sha256:fa1cc43f1f8f70047e81adc5fae9914ddf6c94e4d7f55b83ba7ecca3cea5d122"}, - {file = "aidial_sdk-0.6.2.tar.gz", hash = "sha256:46dafb6360cad6cea08531d3cea7600d87cda06cd8c86a560330b61d0a492cab"}, + {file = "aidial_sdk-0.7.0-py3-none-any.whl", hash = "sha256:e22a948011f6ed55d7f7eef4c0f589f26d6e2412a6b55072be5fd37e8adc5752"}, + {file = "aidial_sdk-0.7.0.tar.gz", hash = "sha256:a239af55a29742c18446df8a8a29ced8fedd2deebc8cf351d565fcbf8299c295"}, ] [package.dependencies] @@ -666,7 +666,7 @@ files = [ {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"}, {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"}, {file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"}, - {file = "greenlet-3.0.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1482fba7fbed96ea7842b5a7fc11d61727e8be75a077e603e8ab49d24e234383"}, + {file = "greenlet-3.0.0-cp311-universal2-macosx_10_9_universal2.whl", hash = "sha256:c3692ecf3fe754c8c0f2c95ff19626584459eab110eaab66413b1e7425cd84e9"}, {file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"}, {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"}, {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"}, @@ -676,6 +676,7 @@ files = [ {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"}, {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"}, {file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"}, + {file = "greenlet-3.0.0-cp312-universal2-macosx_10_9_universal2.whl", hash = "sha256:553d6fb2324e7f4f0899e5ad2c427a4579ed4873f42124beba763f16032959af"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"}, @@ -1959,7 +1960,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -2065,14 +2065,6 @@ files = [ {file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b69f1f754d92eb1cc6b50938359dead36b96a1dcf11a8670bff65fd9b21a4b09"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-win32.whl", hash = "sha256:af520a730d523eab77d754f5cf44cc7dd7ad2d54907adeb3233177eeb22f271b"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-win_amd64.whl", hash = "sha256:141675dae56522126986fa4ca713739d00ed3a6f08f3c2eb92c39c6dfec463ce"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:56628ca27aa17b5890391ded4e385bf0480209726f198799b7e980c6bd473bd7"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:db726be58837fe5ac39859e0fa40baafe54c6d54c02aba1d47d25536170b690f"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7421c1bfdbb7214313919472307be650bd45c4dc2fcb317d64d078993de045b"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:632784f7a6f12cfa0e84bf2a5003b07660addccf5563c132cd23b7cc1d7371a9"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f6f7276cf26145a888f2182a98f204541b519d9ea358a65d82095d9c9e22f917"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2a1f7ffac934bc0ea717fa1596f938483fb8c402233f9b26679b4f7b38d6ab6e"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-win32.whl", hash = "sha256:bfece2f7cec502ec5f759bbc09ce711445372deeac3628f6fa1c16b7fb45b682"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-win_amd64.whl", hash = "sha256:526b869a0f4f000d8d8ee3409d0becca30ae73f494cbb48801da0129601f72c6"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7614f1eab4336df7dd6bee05bc974f2b02c38d3d0c78060c5faa4cd1ca2af3b8"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d59cb9e20d79686aa473e0302e4a82882d7118744d30bb1dfb62d3c47141b3ec"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a95aa0672e3065d43c8aa80080cdd5cc40fe92dc873749e6c1cf23914c4b83af"}, @@ -2449,4 +2441,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "7820f937addbd7f17fcbfe89f84835c082dc3b0f185de1436df3315084e5811e" +content-hash = "7aca6482f0432e0e972ca0224a900ede7e49d582fcd720c08b0e0a8b65bac214" diff --git a/pyproject.toml b/pyproject.toml index 15bf4f0..496f36c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ openai = "^1.3.9" pydantic = "1.10.13" pyyaml = "^6.0.1" typing-extensions = "^4.8.0" -aidial-sdk = { version = "^0.6.2", extras = ["telemetry"] } +aidial-sdk = { version = "^0.7.0", extras = ["telemetry"] } aiohttp = "^3.9.2" openapi-schema-pydantic = "^1.2.4" openapi-pydantic = "^0.3.2" diff --git a/tests/unit_tests/chain/test_command_chain_best_effort.py b/tests/unit_tests/chain/test_command_chain_best_effort.py index f17fd4e..f0bced2 100644 --- a/tests/unit_tests/chain/test_command_chain_best_effort.py +++ b/tests/unit_tests/chain/test_command_chain_best_effort.py @@ -50,8 +50,8 @@ "user_message={{message}}, error={{error}}, dialogue={{dialogue}}" ), scoped_messages=[ - ScopedMessage(message=system_message(SYSTEM_MESSAGE)), - ScopedMessage(message=user_message(USER_MESSAGE)), + ScopedMessage(message=system_message(SYSTEM_MESSAGE), user_index=0), + ScopedMessage(message=user_message(USER_MESSAGE), user_index=1), ], ) diff --git a/tests/unit_tests/chain/test_history.py b/tests/unit_tests/chain/test_history.py index c3e6317..7d4186e 100644 --- a/tests/unit_tests/chain/test_history.py +++ b/tests/unit_tests/chain/test_history.py @@ -12,11 +12,11 @@ ) TRUNCATION_TEST_DATA = [ - (0, [0, 1, 2, 3, 4, 5, 6]), - (1, [0, 2, 3, 4, 5, 6]), - (2, [0, 2, 6]), - (3, [0, 2, 6]), - (4, [0, 2, 6]), + ([], [0, 1, 2, 3, 4, 5, 6]), + ([1], [0, 2, 3, 4, 5, 6]), + ([1, 3], [0, 2, 6]), + ([1, 3, 4], [0, 2, 6]), + ([1, 3, 4, 5], [0, 2, 6]), ] MAX_PROMPT_TOKENS = 123 @@ -24,92 +24,51 @@ @pytest.mark.asyncio @pytest.mark.parametrize( - "discarded_messages,expected_indices", TRUNCATION_TEST_DATA + "discarded_model_messages,expected_indices", TRUNCATION_TEST_DATA ) async def test_history_truncation( - discarded_messages: int, expected_indices: list[int] + discarded_model_messages, expected_indices: list[int] ): - history = History( + full_history = History( assistant_system_message_template=Template(""), best_effort_template=Template(""), scoped_messages=[ - ScopedMessage(message=system_message("a")), - ScopedMessage(message=user_message("b")), - ScopedMessage(message=system_message("c")), + ScopedMessage(message=system_message("a"), user_index=0), + ScopedMessage(message=user_message("b"), user_index=1), + ScopedMessage(message=system_message("c"), user_index=2), ScopedMessage( message=assistant_message("d"), scope=MessageScope.INTERNAL, + user_index=3, ), ScopedMessage( message=user_message(content="e"), scope=MessageScope.INTERNAL, + user_index=3, ), - ScopedMessage(message=assistant_message("f")), - ScopedMessage(message=user_message("g")), + ScopedMessage(message=assistant_message("f"), user_index=3), + ScopedMessage(message=user_message("g"), user_index=4), ], ) model_client = Mock(spec=ModelClient) - model_client.get_discarded_messages.return_value = discarded_messages - - actual = await history.truncate(MAX_PROMPT_TOKENS, model_client) + model_client.get_discarded_messages.return_value = discarded_model_messages - assert ( - actual.assistant_system_message_template - == history.assistant_system_message_template - ) - assert actual.best_effort_template == history.best_effort_template - assert actual.scoped_messages == [ - history.scoped_messages[i] for i in expected_indices - ] - - -@pytest.mark.asyncio -async def test_truncation_overflow(): - history = History( - assistant_system_message_template=Template(""), - best_effort_template=Template(""), - scoped_messages=[ - ScopedMessage(message=system_message("a")), - ScopedMessage(message=user_message("b")), - ], + truncated_history, _ = await full_history.truncate( + model_client, MAX_PROMPT_TOKENS ) - model_client = Mock(spec=ModelClient) - model_client.get_discarded_messages.return_value = 1 - - with pytest.raises(Exception) as exc_info: - await history.truncate(MAX_PROMPT_TOKENS, model_client) - assert ( - str(exc_info.value) == "No user messages left after history truncation." + full_history.assistant_system_message_template + == full_history.assistant_system_message_template ) - - -@pytest.mark.asyncio -async def test_truncation_with_incorrect_message_sequence(): - history = History( - assistant_system_message_template=Template(""), - best_effort_template=Template(""), - scoped_messages=[ - ScopedMessage( - message=user_message("a"), - scope=MessageScope.INTERNAL, - ), - ScopedMessage(message=user_message("b")), - ], - ) - - model_client = Mock(spec=ModelClient) - model_client.get_discarded_messages.return_value = 1 - - with pytest.raises(Exception) as exc_info: - await history.truncate(MAX_PROMPT_TOKENS, model_client) - assert ( - str(exc_info.value) - == "Internal messages must be followed by an assistant reply." + truncated_history.best_effort_template + == full_history.best_effort_template ) + assert truncated_history.scoped_messages == [ + full_history.scoped_messages[i] for i in expected_indices + ] def test_protocol_messages_with_system_message(): @@ -122,9 +81,11 @@ def test_protocol_messages_with_system_message(): ), best_effort_template=Template(""), scoped_messages=[ - ScopedMessage(message=system_message(system_content)), - ScopedMessage(message=user_message(user_content)), - ScopedMessage(message=assistant_message(assistant_content)), + ScopedMessage(message=system_message(system_content), user_index=0), + ScopedMessage(message=user_message(user_content), user_index=1), + ScopedMessage( + message=assistant_message(assistant_content), user_index=2 + ), ], ) diff --git a/tests/unit_tests/model/test_model_client.py b/tests/unit_tests/model/test_model_client.py index a5ed1cf..f23b526 100644 --- a/tests/unit_tests/model/test_model_client.py +++ b/tests/unit_tests/model/test_model_client.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import Mock, call import pytest @@ -34,7 +35,7 @@ class Choice(BaseModel): class Chunk(BaseModel): choices: list[Choice] - statistics: dict[str, int] | None = None + statistics: dict[str, Any] | None = None usage: Usage | None = None @@ -46,7 +47,7 @@ async def test_discarded_messages(): [ Chunk( choices=[Choice(delta=Delta(content=""))], - statistics={"discarded_messages": 2}, + statistics={"discarded_messages": [0, 1]}, ) ] ) @@ -56,7 +57,7 @@ async def test_discarded_messages(): await join_string(model_client.agenerate([], extra_results_callback)) assert extra_results_callback.on_discarded_messages.call_args_list == [ - call(2) + call([0, 1]) ] diff --git a/tests/unit_tests/utils/test_exception_handler.py b/tests/unit_tests/utils/test_exception_handler.py index a9bfc31..84a3402 100644 --- a/tests/unit_tests/utils/test_exception_handler.py +++ b/tests/unit_tests/utils/test_exception_handler.py @@ -24,7 +24,7 @@ async def function(): assert ( repr(exc_info.value) == f"HTTPException(message='{ERROR_MESSAGE}', status_code=422," - f" type='invalid_request_error', param='{PARAM}', code=None)" + f" type='invalid_request_error', param='{PARAM}', code=None, display_message=None)" ) @@ -40,7 +40,7 @@ async def function(): assert ( repr(exc_info.value) == f"HTTPException(message='{ERROR_MESSAGE}', status_code=500," - f" type='internal_server_error', param=None, code=None)" + f" type='internal_server_error', param=None, code=None, display_message=None)" ) @@ -72,7 +72,7 @@ async def function(): assert ( repr(exc_info.value) == f"HTTPException(message='{ERROR_MESSAGE}', status_code={http_status}," - f" type='{error_type}', param='{PARAM}', code='{error_code}')" + f" type='{error_type}', param='{PARAM}', code='{error_code}', display_message=None)" ) @@ -88,5 +88,5 @@ async def function(): assert ( repr(exc_info.value) == f"HTTPException(message='{ERROR_MESSAGE}', status_code=500," - f" type='internal_server_error', param=None, code=None)" + f" type='internal_server_error', param=None, code=None, display_message=None)" ) diff --git a/tests/unit_tests/utils/test_state.py b/tests/unit_tests/utils/test_state.py index 58569e8..d0e5a11 100644 --- a/tests/unit_tests/utils/test_state.py +++ b/tests/unit_tests/utils/test_state.py @@ -46,33 +46,41 @@ def test_parse_history(): ScopedMessage( scope=MessageScope.USER, message=user_message(FIRST_USER_MESSAGE), + user_index=0, ), ScopedMessage( scope=MessageScope.INTERNAL, message=assistant_message(FIRST_REQUEST_FIXED), + user_index=1, ), ScopedMessage( scope=MessageScope.INTERNAL, message=user_message(FIRST_RESPONSE), + user_index=1, ), ScopedMessage( scope=MessageScope.INTERNAL, message=assistant_message(SECOND_REQUEST), + user_index=1, ), ScopedMessage( scope=MessageScope.INTERNAL, message=user_message(content=SECOND_RESPONSE), + user_index=1, ), ScopedMessage( scope=MessageScope.USER, message=assistant_message(FIRST_ASSISTANT_MESSAGE), + user_index=1, ), ScopedMessage( scope=MessageScope.USER, message=user_message(SECOND_USER_MESSAGE), + user_index=2, ), ScopedMessage( scope=MessageScope.USER, message=assistant_message(SECOND_ASSISTANT_MESSAGE), + user_index=3, ), ]