Skip to content

Commit

Permalink
Merge branch 'main' into add-kwargs-call-single
Browse files Browse the repository at this point in the history
  • Loading branch information
maykcaldas authored Jan 28, 2025
2 parents 3695b42 + b6ca776 commit dca8808
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
2 changes: 1 addition & 1 deletion llmclient/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ async def acompletion(self, messages: list[Message], **kwargs) -> list[LLMResult
# cast is necessary for LiteLLM typing bug: https://github.com/BerriAI/litellm/issues/7641
prompts = cast(
list[litellm.types.llms.openai.AllMessageValues],
[m.model_dump(by_alias=True) for m in messages if m.content],
[m.model_dump(by_alias=True) for m in messages],
)
completions = await track_costs(self.router.acompletion)(
self.name, prompts, **kwargs
Expand Down
61 changes: 60 additions & 1 deletion tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import litellm
import numpy as np
import pytest
from aviary.core import Message, Tool, ToolRequestMessage
from aviary.core import Message, Tool, ToolRequestMessage, ToolResponseMessage
from pydantic import BaseModel, Field, TypeAdapter, computed_field

from llmclient.exceptions import JSONSchemaValidationError
Expand Down Expand Up @@ -532,6 +532,65 @@ async def test_multiple_completion(self, model_name: str, request) -> None:
assert len(results) == self.NUM_COMPLETIONS


class TestTooling:
@pytest.mark.asyncio
# @pytest.mark.vcr
async def test_tool_selection(self) -> None:
model = LiteLLMModel(name=CommonLLMNames.OPENAI_TEST.value, config={"n": 1})

def double(x: int) -> int:
"""Double the input.
Args:
x: The input to double
Returns:
The double of the input.
"""
return 2 * x

tools = [Tool.from_function(double)]
messages = [
Message(
role="system",
content="You are a helpful assistant who can use tools to do math. Use a tool if needed. If you don't need a tool, just respond with the answer.",
),
Message(role="user", content="What is double of 8?"),
]

results = await model.call(
messages, tools=tools, tool_choice=LiteLLMModel.MODEL_CHOOSES_TOOL
)
assert isinstance(results, list)
assert isinstance(results[0].messages, list)

tool_message = results[0].messages[0]

assert isinstance(
tool_message, ToolRequestMessage
), "It should have selected a tool"
assert not tool_message.content
assert (
tool_message.tool_calls[0].function.arguments["x"] == 8
), "LLM failed in select the correct tool or arguments"

# Simulate the observation
observation = ToolResponseMessage(
role="tool",
name="double",
content="Observation: 16",
tool_call_id=tool_message.tool_calls[0].id,
)
messages.extend([tool_message, observation])

results = await model.call(
messages, tools=tools, tool_choice=LiteLLMModel.MODEL_CHOOSES_TOOL
)
assert isinstance(results, list)
assert isinstance(results[0].messages, list)
assert results[0].messages[0].content
assert "16" in results[0].messages[0].content


def test_json_schema_validation() -> None:
# Invalid JSON
mock_completion1 = Mock()
Expand Down

0 comments on commit dca8808

Please sign in to comment.