diff --git a/narrative_llm_tools/handlers/huggingface.py b/narrative_llm_tools/handlers/huggingface.py index addd158..9b1cf5c 100644 --- a/narrative_llm_tools/handlers/huggingface.py +++ b/narrative_llm_tools/handlers/huggingface.py @@ -24,8 +24,9 @@ class HandlerResponse(BaseModel): """Response from the handler.""" - tool_calls: list[dict[str, Any]] - warnings: list[str] | None + text_response: str | None = None + tool_calls: list[dict[str, Any]] | None = None + warnings: list[str] | None = None class ModelConfig(BaseModel): @@ -285,18 +286,25 @@ def __call__(self, data: dict[str, Any]) -> dict[str, Any]: }: self._process_conversation_turn(conversation_state) - return_msg = json.loads(conversation_state.get_last_message().content) + if conversation_state.tool_choice != "none": + return_msg = json.loads(conversation_state.get_last_message().content) - if not isinstance(return_msg, list): - raise ModelOutputError("Model output is not a list of tool calls.") - - for tool_call in return_msg: - if not isinstance(tool_call, dict): + if not isinstance(return_msg, list): raise ModelOutputError("Model output is not a list of tool calls.") - return HandlerResponse(tool_calls=return_msg, warnings=None).model_dump( - exclude_none=True - ) + for tool_call in return_msg: + if not isinstance(tool_call, dict): + raise ModelOutputError("Model output is not a list of tool calls.") + + return HandlerResponse( + tool_calls=return_msg, warnings=None, text_response=None + ).model_dump(exclude_none=True) + else: + return HandlerResponse( + tool_calls=None, + text_response=conversation_state.get_last_message().content, + warnings=None, + ).model_dump(exclude_none=True) except ( ValidationError, @@ -316,17 +324,42 @@ def __call__(self, data: dict[str, Any]) -> dict[str, Any]: def _process_conversation_turn(self, state: ConversationState) -> None: """Process a single turn of the conversation.""" conversation_text = self._format_conversation(state) - format_enforcer = get_format_enforcer(self.pipeline.tokenizer, state.update_current_tools()) + format_enforcer = self._get_format_enforcer(state) + model_output = self._generate_prediction( conversation_text, format_enforcer, state.pipeline_params ) - tool_calls = self._format_model_output(model_output) - serialized = [tool.model_dump() for tool in tool_calls] - state.add_message(ConversationMessage(role="tool_calls", content=json.dumps(serialized))) + formatted_output = self._format_model_output(model_output, state.tool_choice) + + if state.tool_choice != "none": + if not isinstance(formatted_output, list): + logger.warning("Expected list of tool calls but got different type") + return + + serialized = [tool.model_dump() for tool in formatted_output] + state.add_message( + ConversationMessage(role="tool_calls", content=json.dumps(serialized)) + ) - if state.only_called_rest_api_tools(tool_calls): - self._execute_tool_calls(tool_calls, state) + if state.only_called_rest_api_tools(formatted_output): + self._execute_tool_calls(formatted_output, state) + else: + if not isinstance(formatted_output, str): + logger.warning("Expected string response but got different type") + return + + state.add_message(ConversationMessage(role="assistant", content=formatted_output)) + + def _get_format_enforcer(self, state: ConversationState) -> FormatEnforcer | None: + """Get the format enforcer based on current tools state.""" + if not state.tools_catalog: + return None + + current_tools = state.update_current_tools() + return ( + get_format_enforcer(self.pipeline.tokenizer, current_tools) if current_tools else None + ) def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState) -> None: """Execute tool calls and update conversation state.""" @@ -400,7 +433,11 @@ def _execute_tool_calls(self, tool_calls: list[Tool], state: ConversationState) if return_to_user and state.status != ConversationStatus.COMPLETED: state.transition_to(ConversationStatus.COMPLETED) - def _format_model_output(self, model_output: list[dict[str, Any]]) -> list[Tool]: + def _format_model_output( + self, + model_output: list[dict[str, Any]], + tool_choice: Literal["required", "none", "auto"], + ) -> list[Tool] | str: """Format the model output into a list of dictionaries.""" if not model_output: return [] @@ -413,6 +450,9 @@ def _format_model_output(self, model_output: list[dict[str, Any]]) -> list[Tool] if generated_text is None: raise ModelOutputError("No generated_text found in the model output.") + if tool_choice == "none": + return generated_text + try: logger.debug(f"Generated text: {generated_text}") parsed_output: list[Tool] = [ diff --git a/narrative_llm_tools/state/conversation_state.py b/narrative_llm_tools/state/conversation_state.py index 2f08583..798d940 100644 --- a/narrative_llm_tools/state/conversation_state.py +++ b/narrative_llm_tools/state/conversation_state.py @@ -82,7 +82,7 @@ class ConversationState(BaseModel): raw_messages: list[ConversationMessage] max_tool_rounds: int = 5 tool_choice: Literal["required", "auto", "none"] = "required" - tools_catalog: JsonSchemaTools = JsonSchemaTools.only_user_response_tool() + tools_catalog: JsonSchemaTools | None = JsonSchemaTools.only_user_response_tool() pipeline_params: dict[str, Any] status: ConversationStatus = ConversationStatus.RUNNING @@ -103,14 +103,14 @@ def from_api_request(cls, request_data: dict[str, Any]) -> "ConversationState": if k not in cls.RESERVED_KEYS and not k.startswith("_") } + tool_choice = request_data.get("tool_choice", "required") tools_data = request_data.get("tools", {}) tools_instance = ( JsonSchemaTools.model_validate(tools_data) - if tools_data - else JsonSchemaTools.only_user_response_tool() + if tools_data and tools_data != {} and tool_choice != "none" + else JsonSchemaTools.only_user_response_tool() if tool_choice != "none" else None ) - tool_choice = request_data.get("tool_choice", "required") status = ( ConversationStatus.WRAP_THINGS_UP if tool_choice == "none" @@ -188,7 +188,9 @@ def _has_non_rest_tool(self) -> bool: """ Internal helper to check if there's at least one non-REST API tool available. """ - return len(self.rest_api_names) != len(self.tools_catalog.items.anyOf) + return not self.tools_catalog or len(self.rest_api_names) != len( + self.tools_catalog.items.anyOf + ) def _has_rest_api_tools(self, content: str) -> bool: """Checks if the given content calls any REST API tools.""" @@ -235,13 +237,14 @@ def can_respond(self) -> bool: def get_rest_api_catalog(self) -> dict[str, RestApiClient]: """Returns all REST API tools from the current catalog.""" - return self.tools_catalog.get_rest_apis() + return self.tools_catalog.get_rest_apis() if self.tools_catalog else {} def remove_tool(self, tool_name: str) -> None: """ Removes the specified tool from the catalog if it exists. """ - self.tools_catalog = self.tools_catalog.remove_tool_by_name(tool_name) + if self.tools_catalog: + self.tools_catalog = self.tools_catalog.remove_tool_by_name(tool_name) @property def tool_calls_count(self) -> int: @@ -256,7 +259,11 @@ def _tool_catalog_message(self) -> ConversationMessage: """ return ConversationMessage( role="tool_catalog", - content=json.dumps(self.tools_catalog.model_dump(), separators=(",", ":")), + content=( + json.dumps(self.tools_catalog.model_dump(), separators=(",", ":")) + if self.tools_catalog + else "" + ), ) def add_message(self, message: ConversationMessage) -> None: @@ -269,12 +276,24 @@ def add_message(self, message: ConversationMessage) -> None: Raises: ValueError: If the message role is invalid or adding it violates state constraints. """ + logger.info(f"Adding message: {message}") + if message.role == "tool_calls": self._handle_tool_call(message) elif message.role == "tool_response": self._handle_tool_response(message) + elif message.role == "assistant": + self._handle_assistant_response(message) logger.info(f"Conversation state after adding message: {self}") + + def _handle_assistant_response(self, message: ConversationMessage) -> None: + """ + Handles adding an assistant response message and updating state accordingly. + """ + logger.info(f"Handling assistant response: {message}") + self.raw_messages.append(message) + self.transition_to(ConversationStatus.COMPLETED) def _handle_tool_call(self, message: ConversationMessage) -> None: """ @@ -348,16 +367,17 @@ def _remove_rest_api_tools(self) -> None: """ Removes all REST API tools from the catalog. """ - self.tools_catalog = self.tools_catalog.remove_rest_api_tools() + if self.tools_catalog: + self.tools_catalog = self.tools_catalog.remove_rest_api_tools() - def update_current_tools(self) -> JsonSchemaTools: + def update_current_tools(self) -> JsonSchemaTools | None: """ Returns the appropriate tool catalog for the current conversation state: - If status is WRAP_THINGS_UP, only return user-response tool. - If status is RUNNING but there's no way to respond, return a catalog that includes a user-response tool. Otherwise, return the current tools. """ - if len(self.tools_catalog.items.anyOf) == 0: + if self.tools_catalog and len(self.tools_catalog.items.anyOf) == 0: self.tools_catalog = JsonSchemaTools.only_user_response_tool() elif self.status == ConversationStatus.WRAP_THINGS_UP: logger.info( @@ -369,11 +389,13 @@ def update_current_tools(self) -> JsonSchemaTools: "After removing rest API tools, " "we have {len(self.tools_catalog.items.anyOf)} tools.", ) - if len(self.tools_catalog.items.anyOf) == 0: + if self.tools_catalog and len(self.tools_catalog.items.anyOf) == 0: self.tools_catalog = JsonSchemaTools.only_user_response_tool() elif self.status == ConversationStatus.RUNNING: if not self.can_respond(): - self.tools_catalog = self.tools_catalog.with_user_response_tool() + self.tools_catalog = ( + self.tools_catalog.with_user_response_tool() if self.tools_catalog else None + ) elif self.status in [ ConversationStatus.WAITING_TOOL_RESPONSE, ConversationStatus.COMPLETED, diff --git a/tests/handlers/test_huggingface.py b/tests/handlers/test_huggingface.py index 82f7a5c..3cdef99 100644 --- a/tests/handlers/test_huggingface.py +++ b/tests/handlers/test_huggingface.py @@ -222,7 +222,7 @@ def test_endpoint_handler_format_model_output(endpoint_handler: EndpointHandler) Test parsing valid JSON from the pipeline's model output. """ mock_output = [{"generated_text": '[{"name": "tool1", "parameters": {"p": 1}}]'}] - tools = endpoint_handler._format_model_output(mock_output) + tools = endpoint_handler._format_model_output(mock_output, "auto") assert len(tools) == 1 assert tools[0].name == "tool1" assert tools[0].parameters == {"p": 1} diff --git a/uv.lock b/uv.lock index fe280ec..ce6d23e 100644 --- a/uv.lock +++ b/uv.lock @@ -211,6 +211,17 @@ toml = [ { name = "tomli", marker = "python_full_version <= '3.11'" }, ] +[[package]] +name = "cssbeautifier" +version = "1.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "editorconfig" }, + { name = "jsbeautifier" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e5/66/9bfd2d69fb4479d38439076132a620972939f7949015563dce5e61d29a8b/cssbeautifier-1.15.1.tar.gz", hash = "sha256:9f7064362aedd559c55eeecf6b6bed65e05f33488dcbe39044f0403c26e1c006", size = 25673 } + [[package]] name = "distlib" version = "0.3.9" @@ -220,6 +231,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 }, ] +[[package]] +name = "djlint" +version = "1.36.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "colorama" }, + { name = "cssbeautifier" }, + { name = "jsbeautifier" }, + { name = "json5" }, + { name = "pathspec" }, + { name = "pyyaml" }, + { name = "regex" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "tqdm" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/89/ecf5be9f5c59a0c53bcaa29671742c5e269cc7d0e2622e3f65f41df251bf/djlint-1.36.4.tar.gz", hash = "sha256:17254f218b46fe5a714b224c85074c099bcb74e3b2e1f15c2ddc2cf415a408a1", size = 47849 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/53/71/6a3ce2b49a62e635b85dce30ccf3eb3a18fe79275d45535325a55a63d3a3/djlint-1.36.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a2dfb60883ceb92465201bfd392291a7597c6752baede6fbb6f1980cac8d6c5c", size = 354135 }, + { url = "https://files.pythonhosted.org/packages/72/47/308412dc579e277c910774f41b380308d582862b16763425583e69e0fc14/djlint-1.36.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4bc6a1320c0030244b530ac200642f883d3daa451a115920ef3d56d08b644292", size = 328501 }, + { url = "https://files.pythonhosted.org/packages/9b/6f/428dc044d1e34363265b1301dc9b53253007acd858879d54b369d233aa96/djlint-1.36.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3164a048c7bb0baf042387b1e33f9bbbf99d90d1337bb4c3d66eb0f96f5400a1", size = 415849 }, + { url = "https://files.pythonhosted.org/packages/d6/13/0d488e551d73ddf369552fc6f4c7702ea683e4bc1305bcf5c1d198fbdace/djlint-1.36.4-cp310-cp310-win_amd64.whl", hash = "sha256:3196d5277da5934962d67ad6c33a948ba77a7b6eadf064648bef6ee5f216b03c", size = 360969 }, + { url = "https://files.pythonhosted.org/packages/04/68/18ecd1e4d54a523e1d077f01419d669116e5dede97f97f1eb8ddb918a872/djlint-1.36.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3d68da0ed10ee9ca1e32e225cbb8e9b98bf7e6f8b48a8e4836117b6605b88cc7", size = 344261 }, + { url = "https://files.pythonhosted.org/packages/1e/03/005cf5c66e57ca2d26249f8385bc64420b2a95fea81c5eb619c925199029/djlint-1.36.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c0478d5392247f1e6ee29220bbdbf7fb4e1bc0e7e83d291fda6fb926c1787ba7", size = 319580 }, + { url = "https://files.pythonhosted.org/packages/9f/88/aea3c81343a273a87362f30442abc13351dc8ada0b10e51daa285b4dddac/djlint-1.36.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:962f7b83aee166e499eff916d631c6dde7f1447d7610785a60ed2a75a5763483", size = 407070 }, + { url = "https://files.pythonhosted.org/packages/60/77/0f767ac0b72e9a664bb8c92b8940f21bc1b1e806e5bd727584d40a4ca551/djlint-1.36.4-cp311-cp311-win_amd64.whl", hash = "sha256:53cbc450aa425c832f09bc453b8a94a039d147b096740df54a3547fada77ed08", size = 360775 }, + { url = "https://files.pythonhosted.org/packages/53/f5/9ae02b875604755d4d00cebf96b218b0faa3198edc630f56a139581aed87/djlint-1.36.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ff9faffd7d43ac20467493fa71d5355b5b330a00ade1c4d1e859022f4195223b", size = 354886 }, + { url = "https://files.pythonhosted.org/packages/97/51/284443ff2f2a278f61d4ae6ae55eaf820ad9f0fd386d781cdfe91f4de495/djlint-1.36.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:79489e262b5ac23a8dfb7ca37f1eea979674cfc2d2644f7061d95bea12c38f7e", size = 323237 }, + { url = "https://files.pythonhosted.org/packages/6d/5e/791f4c5571f3f168ad26fa3757af8f7a05c623fde1134a9c4de814ee33b7/djlint-1.36.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e58c5fa8c6477144a0be0a87273706a059e6dd0d6efae01146ae8c29cdfca675", size = 411719 }, + { url = "https://files.pythonhosted.org/packages/1f/11/894425add6f84deffcc6e373f2ce250f2f7b01aa58c7f230016ebe7a0085/djlint-1.36.4-cp312-cp312-win_amd64.whl", hash = "sha256:bb6903777bf3124f5efedcddf1f4716aef097a7ec4223fc0fa54b865829a6e08", size = 362076 }, + { url = "https://files.pythonhosted.org/packages/4b/67/f7aeea9be6fb3bd984487af8d0d80225a0b1e5f6f7126e3332d349fb13fe/djlint-1.36.4-py3-none-any.whl", hash = "sha256:e9699b8ac3057a6ed04fb90835b89bee954ed1959c01541ce4f8f729c938afdd", size = 52290 }, +] + +[[package]] +name = "editorconfig" +version = "0.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b4/29/785595a0d8b30ab8d2486559cfba1d46487b8dcbd99f74960b6b4cca92a4/editorconfig-0.17.0.tar.gz", hash = "sha256:8739052279699840065d3a9f5c125d7d5a98daeefe53b0e5274261d77cb49aa2", size = 13369 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/e5/8dba39ea24ca3de0e954e668107692f4dfc13a85300a531fa9a39e83fde4/EditorConfig-0.17.0-py3-none-any.whl", hash = "sha256:fe491719c5f65959ec00b167d07740e7ffec9a3f362038c72b289330b9991dfc", size = 16276 }, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -331,6 +385,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256 }, ] +[[package]] +name = "jsbeautifier" +version = "1.15.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "editorconfig" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/3e/dd37e1a7223247e3ef94714abf572415b89c4e121c4af48e9e4c392e2ca0/jsbeautifier-1.15.1.tar.gz", hash = "sha256:ebd733b560704c602d744eafc839db60a1ee9326e30a2a80c4adb8718adc1b24", size = 75606 } + +[[package]] +name = "json5" +version = "0.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/85/3d/bbe62f3d0c05a689c711cff57b2e3ac3d3e526380adb7c781989f075115c/json5-0.10.0.tar.gz", hash = "sha256:e66941c8f0a02026943c52c2eb34ebeb2a6f819a0be05920a6f5243cd30fd559", size = 48202 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/42/797895b952b682c3dafe23b1834507ee7f02f4d6299b65aaa61425763278/json5-0.10.0-py3-none-any.whl", hash = "sha256:19b23410220a7271e8377f81ba8aacba2fdd56947fbb137ee5977cbe1f5e8dfa", size = 34049 }, +] + [[package]] name = "jsonschema" version = "4.23.0" @@ -494,6 +567,7 @@ dependencies = [ dev = [ { name = "accelerate" }, { name = "black" }, + { name = "djlint" }, { name = "mypy" }, { name = "pre-commit" }, { name = "pytest" }, @@ -525,6 +599,7 @@ requires-dist = [ dev = [ { name = "accelerate", specifier = ">=0.26.0" }, { name = "black", specifier = ">=24.2.0" }, + { name = "djlint", specifier = ">=1.36.4" }, { name = "mypy", specifier = ">=1.9.0" }, { name = "pre-commit", specifier = ">=4.0.1" }, { name = "pytest", specifier = ">=8.3.4" }, @@ -1144,6 +1219,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/46/5d11dc300feaad285c2f1bd784ff3f689f5e0ab6be49aaf568f3a77019eb/safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:21742b391b859e67b26c0b2ac37f52c9c0944a879a25ad2f9f9f3cd61e7fda8f", size = 606660 }, ] +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050 }, +] + [[package]] name = "sympy" version = "1.13.1"