diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index 1672d5b..9a49485 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -28,82 +28,88 @@ def build(self, **kwargs) -> Template: ) -_REQUEST_EXAMPLE_TEXT = """ -ALWAYS reply with a JSON containing an array of available commands. You must not use natural language: +_REQUEST_FORMAT_TEXT = """ +You should ALWAYS reply with a JSON containing an array of commands: { "commands": [ { "command": "", "args": [ - + // ] } ] } +The commands are invoked by system on user's behalf. +""".strip() + +_PROTOCOL_FOOTER = """ +* reply +The command delivers final response to the user. +Arguments: + - MESSAGE is a string containing the final and complete result for the user. -Example: +Your goal is to answer user questions. Use relevant commands when they help to achieve the goal. + +## Example {"commands": [{"command": "reply", "args": ["Hello, world!"]}]} """.strip() _SYSTEM_TEXT = """ Today's date is {{today_date}}. +This message defines the following communication protocol. {%- if system_prefix %} {{system_prefix}} {%- endif %} -Protocol -The following commands are available to reply to user or find out the answer to the user's question: +# Protocol +{{request_format}} + +## Commands {%- if tools %} -> run-plugin -The command runs a specified plugin to solve a one-shot task written in natural language. -Plugins do not see current conversation and require all details to be provided in the query to solve the task. -The command returns the result of the plugin call. +* run-addon +This command executes a specified addon to address a one-time task described in natural language. +Addons do not see current conversation and require all details to be provided in the query to solve the task. Arguments: - - NAME is one of the following plugins: + - NAME is one of the following addons: {%- for name, description in tools.items() %} * {{name}} - {{description | decap}} {%- endfor %} - - QUERY is a string formulating the query to the plugin. + - QUERY is the query string. {%- endif %} -> reply -The command delivers ultimate result to the user. -Arguments: - - MESSAGE is a string containing response for user. - -{{request_response}} +{{protocol_footer}} """.strip() -_PLUGIN_SYSTEM_TEXT = """ +_ADDON_SYSTEM_TEXT = """ Today's date is {{today_date}}. +This message defines the following communication protocol. -Service +# Service API_DESCRIPTION: {{api_description}} -API_SCHEMA: +# API Schema ```typescript {{api_schema}}} ``` -Protocol -The following commands are available to reply to user or find out the answer to the user's question: +# Protocol +{{request_format}} + +## Commands {%- for command_name in command_names %} -> {{command_name}} +* {{command_name}} Arguments: - - + - {%- endfor %} -> reply -The command delivers ultimate result to the user -Arguments: - - MESSAGE is a string containing response for user. - -{{request_response}} +{{protocol_footer}} """.strip() _ENFORCE_JSON_FORMAT_TEXT = """ {{response}} -**Remember to reply with a JSON with commands** + +**Protocol reminder: reply with commands** """.strip() _MAIN_BEST_EFFORT_TEXT = ( @@ -134,7 +140,7 @@ def build(self, **kwargs) -> Template: """ ).strip() -_PLUGIN_BEST_EFFORT_TEXT = ( +_ADDON_BEST_EFFORT_TEXT = ( """ You were allowed to use the following API to answer the query below. @@ -165,13 +171,19 @@ def build(self, **kwargs) -> Template: MAIN_SYSTEM_DIALOG_MESSAGE = PartialTemplate( _SYSTEM_TEXT, - globals={"request_response": _REQUEST_EXAMPLE_TEXT}, + globals={ + "request_format": _REQUEST_FORMAT_TEXT, + "protocol_footer": _PROTOCOL_FOOTER, + }, template_class=DateAwareTemplate, ) -PLUGIN_SYSTEM_DIALOG_MESSAGE = PartialTemplate( - _PLUGIN_SYSTEM_TEXT, - globals={"request_response": _REQUEST_EXAMPLE_TEXT}, +ADDON_SYSTEM_DIALOG_MESSAGE = PartialTemplate( + _ADDON_SYSTEM_TEXT, + globals={ + "request_format": _REQUEST_FORMAT_TEXT, + "protocol_footer": _PROTOCOL_FOOTER, + }, template_class=DateAwareTemplate, ) @@ -179,4 +191,4 @@ def build(self, **kwargs) -> Template: MAIN_BEST_EFFORT_TEMPLATE = PartialTemplate(_MAIN_BEST_EFFORT_TEXT) -PLUGIN_BEST_EFFORT_TEMPLATE = PartialTemplate(_PLUGIN_BEST_EFFORT_TEXT) +ADDON_BEST_EFFORT_TEMPLATE = PartialTemplate(_ADDON_BEST_EFFORT_TEXT) diff --git a/aidial_assistant/chain/command_chain.py b/aidial_assistant/chain/command_chain.py index 244d4d5..c5e54f9 100644 --- a/aidial_assistant/chain/command_chain.py +++ b/aidial_assistant/chain/command_chain.py @@ -150,7 +150,8 @@ async def _run_with_protocol_failure_retries( retry += 1 last_error = e retries.append( - chunk_stream.buffer, json.dumps({"error": str(e)}) + chunk_stream.buffer, + "Failed to parse JSON commands: " + str(e), ) finally: self._log_message(Role.ASSISTANT, chunk_stream.buffer) diff --git a/aidial_assistant/commands/run_plugin.py b/aidial_assistant/commands/run_plugin.py index 72f951f..08fe0e2 100644 --- a/aidial_assistant/commands/run_plugin.py +++ b/aidial_assistant/commands/run_plugin.py @@ -5,8 +5,8 @@ from typing_extensions import override from aidial_assistant.application.prompts import ( - PLUGIN_BEST_EFFORT_TEMPLATE, - PLUGIN_SYSTEM_DIALOG_MESSAGE, + ADDON_BEST_EFFORT_TEMPLATE, + ADDON_SYSTEM_DIALOG_MESSAGE, ) from aidial_assistant.chain.command_chain import ( CommandChain, @@ -50,7 +50,7 @@ def __init__( @staticmethod def token(): - return "run-plugin" + return "run-addon" @override async def execute( @@ -78,7 +78,7 @@ async def _run_plugin( ) -> ResultObject: if name not in self.plugins: raise ValueError( - f"Unknown plugin: {name}. Available plugins: {[*self.plugins.keys()]}" + f"Unknown addon: {name}. Available addons: {[*self.plugins.keys()]}" ) plugin = self.plugins[name] @@ -98,12 +98,12 @@ def create_command(op: APIOperation): command_dict[Reply.token()] = Reply history = History( - assistant_system_message_template=PLUGIN_SYSTEM_DIALOG_MESSAGE.build( + assistant_system_message_template=ADDON_SYSTEM_DIALOG_MESSAGE.build( command_names=ops.keys(), api_description=info.ai_plugin.description_for_model, api_schema=api_schema, ), - best_effort_template=PLUGIN_BEST_EFFORT_TEMPLATE.build( + best_effort_template=ADDON_BEST_EFFORT_TEMPLATE.build( api_schema=api_schema ), scoped_messages=[ScopedMessage(message=Message.user(query))], diff --git a/poetry.lock b/poetry.lock index 073694f..0d7f901 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "aidial-sdk" -version = "0.1.0" +version = "0.2.0" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "aidial_sdk-0.1.0-py3-none-any.whl", hash = "sha256:596c9e7aca688e56b1749fb70b0c97ebd508827b6a39cfe6035a3b860cf9f7af"}, - {file = "aidial_sdk-0.1.0.tar.gz", hash = "sha256:fe8fa9ea9d3ccd3f9e719daac08d8dd946f423cb4f2511d9ec43bcc747ef51ad"}, + {file = "aidial_sdk-0.2.0-py3-none-any.whl", hash = "sha256:ce3c2e2ea5ef133d2594bb64c7a70f54970e1f8339608ecfb47b0f955e1536e7"}, + {file = "aidial_sdk-0.2.0.tar.gz", hash = "sha256:fcb00ccfa6fbed7482d6d78828a95ba7e29f45269708cfc3691db9711b91f3fe"}, ] [package.dependencies] @@ -19,6 +19,9 @@ requests = ">=2.19,<3.0" uvicorn = ">=0.19,<1.0" wrapt = ">=1.14,<2.0" +[package.extras] +telemetry = ["opentelemetry-api (==1.20.0)", "opentelemetry-distro (==0.41b0)", "opentelemetry-exporter-otlp-proto-grpc (==1.20.0)", "opentelemetry-exporter-prometheus (==1.12.0rc1)", "opentelemetry-instrumentation (==0.41b0)", "opentelemetry-instrumentation-aiohttp-client (==0.41b0)", "opentelemetry-instrumentation-fastapi (==0.41b0)", "opentelemetry-instrumentation-logging (==0.41b0)", "opentelemetry-instrumentation-requests (==0.41b0)", "opentelemetry-instrumentation-system-metrics (==0.41b0)", "opentelemetry-instrumentation-urllib (==0.41b0)", "opentelemetry-sdk (==1.20.0)", "starlette-exporter (==0.16.0)"] + [[package]] name = "aiocache" version = "0.12.2" @@ -580,7 +583,7 @@ files = [ {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"}, {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"}, {file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"}, - {file = "greenlet-3.0.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1482fba7fbed96ea7842b5a7fc11d61727e8be75a077e603e8ab49d24e234383"}, + {file = "greenlet-3.0.0-cp311-universal2-macosx_10_9_universal2.whl", hash = "sha256:c3692ecf3fe754c8c0f2c95ff19626584459eab110eaab66413b1e7425cd84e9"}, {file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"}, {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"}, {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"}, @@ -590,6 +593,7 @@ files = [ {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"}, {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"}, {file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"}, + {file = "greenlet-3.0.0-cp312-universal2-macosx_10_9_universal2.whl", hash = "sha256:553d6fb2324e7f4f0899e5ad2c427a4579ed4873f42124beba763f16032959af"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"}, @@ -1413,14 +1417,6 @@ files = [ {file = "SQLAlchemy-2.0.21-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b69f1f754d92eb1cc6b50938359dead36b96a1dcf11a8670bff65fd9b21a4b09"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-win32.whl", hash = "sha256:af520a730d523eab77d754f5cf44cc7dd7ad2d54907adeb3233177eeb22f271b"}, {file = "SQLAlchemy-2.0.21-cp311-cp311-win_amd64.whl", hash = "sha256:141675dae56522126986fa4ca713739d00ed3a6f08f3c2eb92c39c6dfec463ce"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:56628ca27aa17b5890391ded4e385bf0480209726f198799b7e980c6bd473bd7"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:db726be58837fe5ac39859e0fa40baafe54c6d54c02aba1d47d25536170b690f"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7421c1bfdbb7214313919472307be650bd45c4dc2fcb317d64d078993de045b"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:632784f7a6f12cfa0e84bf2a5003b07660addccf5563c132cd23b7cc1d7371a9"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:f6f7276cf26145a888f2182a98f204541b519d9ea358a65d82095d9c9e22f917"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2a1f7ffac934bc0ea717fa1596f938483fb8c402233f9b26679b4f7b38d6ab6e"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-win32.whl", hash = "sha256:bfece2f7cec502ec5f759bbc09ce711445372deeac3628f6fa1c16b7fb45b682"}, - {file = "SQLAlchemy-2.0.21-cp312-cp312-win_amd64.whl", hash = "sha256:526b869a0f4f000d8d8ee3409d0becca30ae73f494cbb48801da0129601f72c6"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7614f1eab4336df7dd6bee05bc974f2b02c38d3d0c78060c5faa4cd1ca2af3b8"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d59cb9e20d79686aa473e0302e4a82882d7118744d30bb1dfb62d3c47141b3ec"}, {file = "SQLAlchemy-2.0.21-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a95aa0672e3065d43c8aa80080cdd5cc40fe92dc873749e6c1cf23914c4b83af"}, @@ -1782,4 +1778,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "78ba813fea2eebfe6b8a3408b00caf7c8497ca2a7910862a3ebdbe1406685956" +content-hash = "a13c588da4a63c1ae51b331f9dacdb495243077706a5b68b9ccb3520ad0176c4" diff --git a/pyproject.toml b/pyproject.toml index ab9ced4..6649e52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ openai = "^0.28.0" pydantic = "1.10.13" pyyaml = "^6.0.1" typing-extensions = "^4.8.0" -aidial-sdk = "^0.1.0" +aidial-sdk = "^0.2.0" aiohttp = "^3.9.0" openapi-schema-pydantic = "^1.2.4" openapi-pydantic = "^0.3.2" diff --git a/tests/unit_tests/application/test_prompts.py b/tests/unit_tests/application/test_prompts.py index 44bca26..5600e6f 100644 --- a/tests/unit_tests/application/test_prompts.py +++ b/tests/unit_tests/application/test_prompts.py @@ -1,6 +1,6 @@ from aidial_assistant.application.prompts import ( + ADDON_BEST_EFFORT_TEMPLATE, MAIN_BEST_EFFORT_TEMPLATE, - PLUGIN_BEST_EFFORT_TEMPLATE, ) @@ -65,9 +65,7 @@ def test_main_best_effort_prompt_with_empty_dialogue(): def test_plugin_best_effort_prompt(): - actual = PLUGIN_BEST_EFFORT_TEMPLATE.build( - api_schema="" - ).render( + actual = ADDON_BEST_EFFORT_TEMPLATE.build(api_schema="").render( error="", message="", dialogue=[{"role": "", "content": ""}], @@ -99,9 +97,7 @@ def test_plugin_best_effort_prompt(): def test_plugin_best_effort_prompt_with_empty_dialogue(): - actual = PLUGIN_BEST_EFFORT_TEMPLATE.build( - api_schema="" - ).render( + actual = ADDON_BEST_EFFORT_TEMPLATE.build(api_schema="").render( error="", message="", dialogue=[], diff --git a/tests/unit_tests/chain/test_command_chain_best_effort.py b/tests/unit_tests/chain/test_command_chain_best_effort.py index 0656066..b90e284 100644 --- a/tests/unit_tests/chain/test_command_chain_best_effort.py +++ b/tests/unit_tests/chain/test_command_chain_best_effort.py @@ -20,7 +20,7 @@ SYSTEM_MESSAGE = "" USER_MESSAGE = "" -ENFORCE_JSON_FORMAT = "**Remember to reply with a JSON with commands**" +ENFORCE_JSON_FORMAT = "\n\n**Protocol reminder: reply with commands**" BEST_EFFORT_ANSWER = "" NO_TOKENS_ERROR = "No tokens left" FAILED_PROTOCOL_ERROR = "The next constructed API request is incorrect." @@ -78,7 +78,7 @@ async def test_model_doesnt_support_protocol(): call( [ Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}\n{ENFORCE_JSON_FORMAT}"), + Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ], usage_publisher, ), @@ -132,7 +132,7 @@ async def test_model_partially_supports_protocol(): call( [ Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}\n{ENFORCE_JSON_FORMAT}"), + Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ], usage_publisher, ), @@ -141,7 +141,7 @@ async def test_model_partially_supports_protocol(): Message.system(f"system_prefix={SYSTEM_MESSAGE}"), Message.user(USER_MESSAGE), Message.assistant(TEST_COMMAND_REQUEST), - Message.user(f"{TEST_COMMAND_RESPONSE}\n{ENFORCE_JSON_FORMAT}"), + Message.user(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), ], usage_publisher, ), @@ -191,7 +191,7 @@ async def test_no_tokens_for_tools(): call( [ Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}\n{ENFORCE_JSON_FORMAT}"), + Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ], usage_publisher, ), @@ -200,7 +200,7 @@ async def test_no_tokens_for_tools(): Message.system(f"system_prefix={SYSTEM_MESSAGE}"), Message.user(USER_MESSAGE), Message.assistant(TEST_COMMAND_REQUEST), - Message.user(f"{TEST_COMMAND_RESPONSE}\n{ENFORCE_JSON_FORMAT}"), + Message.user(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), ], usage_publisher, ),