diff --git a/llmclient/llms.py b/llmclient/llms.py index 8dd5725..aabae44 100644 --- a/llmclient/llms.py +++ b/llmclient/llms.py @@ -557,7 +557,7 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult # cast is necessary for LiteLLM typing bug: https://github.com/BerriAI/litellm/issues/7641 prompts = cast( list[litellm.types.llms.openai.AllMessageValues], - [m.model_dump(by_alias=True) for m in messages if m.content], + [m.model_dump(by_alias=True) for m in messages], ) completions = await track_costs(self.router.acompletion)( self.name, prompts, **kwargs diff --git a/tests/test_llms.py b/tests/test_llms.py index c058798..9ae82fa 100644 --- a/tests/test_llms.py +++ b/tests/test_llms.py @@ -7,7 +7,7 @@ import litellm import numpy as np import pytest -from aviary.core import Message, Tool, ToolRequestMessage +from aviary.core import Message, Tool, ToolRequestMessage, ToolResponseMessage from pydantic import BaseModel, Field, TypeAdapter, computed_field from llmclient.exceptions import JSONSchemaValidationError @@ -532,6 +532,65 @@ async def test_multiple_completion(self, model_name: str, request) -> None: assert len(results) == self.NUM_COMPLETIONS +class TestTooling: + @pytest.mark.asyncio + # @pytest.mark.vcr + async def test_tool_selection(self) -> None: + model = LiteLLMModel(name=CommonLLMNames.OPENAI_TEST.value, config={"n": 1}) + + def double(x: int) -> int: + """Double the input. + + Args: + x: The input to double + Returns: + The double of the input. + """ + return 2 * x + + tools = [Tool.from_function(double)] + messages = [ + Message( + role="system", + content="You are a helpful assistant who can use tools to do math. Use a tool if needed. If you don't need a tool, just respond with the answer.", + ), + Message(role="user", content="What is double of 8?"), + ] + + results = await model.call( + messages, tools=tools, tool_choice=LiteLLMModel.MODEL_CHOOSES_TOOL + ) + assert isinstance(results, list) + assert isinstance(results[0].messages, list) + + tool_message = results[0].messages[0] + + assert isinstance( + tool_message, ToolRequestMessage + ), "It should have selected a tool" + assert not tool_message.content + assert ( + tool_message.tool_calls[0].function.arguments["x"] == 8 + ), "LLM failed in select the correct tool or arguments" + + # Simulate the observation + observation = ToolResponseMessage( + role="tool", + name="double", + content="Observation: 16", + tool_call_id=tool_message.tool_calls[0].id, + ) + messages.extend([tool_message, observation]) + + results = await model.call( + messages, tools=tools, tool_choice=LiteLLMModel.MODEL_CHOOSES_TOOL + ) + assert isinstance(results, list) + assert isinstance(results[0].messages, list) + assert results[0].messages[0].content + assert "16" in results[0].messages[0].content + + def test_json_schema_validation() -> None: # Invalid JSON mock_completion1 = Mock()