Skip to content

Commit

Permalink
Add tool CRUD tests
Browse files Browse the repository at this point in the history
  • Loading branch information
WonderPG committed Dec 19, 2024
1 parent 56a81ea commit ec83653
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- app unit tests
- Tests of AgentsRoutine.
- Unit tests for database
- Tests for tool CRUD endpoints.

### Fixed
- Migrate LLM Evaluation logic to scripts and add tests
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/app/routers/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,6 @@ async def get_tool_returns(
for msg in tool_messages:
msg_content = json.loads(msg.content)
if msg_content.get("tool_call_id") == tool_call_id:
tool_output.append(msg_content)
tool_output.append(msg_content["content"])

return tool_output
163 changes: 163 additions & 0 deletions swarm_copy_tests/app/routers/test_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Test of the tool router."""

import json

import pytest

from swarm_copy.agent_routine import Agent, AgentsRoutine
from swarm_copy.app.config import Settings
from swarm_copy.app.database.schemas import ToolCallSchema
from swarm_copy.app.dependencies import (
get_agents_routine,
get_context_variables,
get_settings,
get_starting_agent,
)
from swarm_copy.app.main import app
from swarm_copy_tests.mock_client import create_mock_response


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_get_tool_calls(
patch_required_env,
httpx_mock,
app_client,
db_connection,
mock_openai_client,
get_weather_tool,
):
routine = AgentsRoutine(client=mock_openai_client)

mock_openai_client.set_sequential_responses(
[
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[
{"name": "get_weather", "args": {"location": "Geneva"}}
],
),
create_mock_response(
{"role": "assistant", "content": "sample response content"}
),
]
)
agent = Agent(tools=[get_weather_tool])

app.dependency_overrides[get_agents_routine] = lambda: routine
app.dependency_overrides[get_starting_agent] = lambda: agent
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)

with app_client as app_client:
wrong_response = app_client.get("/tools/test/1234")
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

# Create a thread
create_output = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_output["thread_id"]

# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"query": "This is my query"},
params={"thread_id": thread_id},
headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"},
)

tool_calls = app_client.get(f"/tools/{thread_id}/wrong_id")
assert tool_calls.status_code == 404
assert tool_calls.json() == {"detail": {"detail": "Message not found."}}

# Get the messages of the thread
messages = app_client.get(f"/threads/{thread_id}").json()
message_id = messages[-1]["message_id"]
tool_calls = app_client.get(f"/tools/{thread_id}/{message_id}").json()

assert (
tool_calls[0]
== ToolCallSchema(
tool_call_id="mock_tc_id",
name="get_weather",
arguments={"location": "Geneva"},
).model_dump()
)


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_get_tool_output(
patch_required_env,
app_client,
httpx_mock,
db_connection,
mock_openai_client,
agent_handoff_tool,
):
routine = AgentsRoutine(client=mock_openai_client)

mock_openai_client.set_sequential_responses(
[
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[{"name": "agent_handoff_tool", "args": {}}],
),
create_mock_response(
{"role": "assistant", "content": "sample response content"}
),
]
)
agent_1 = Agent(name="Test agent 1", tools=[agent_handoff_tool])
agent_2 = Agent(name="Test agent 2", tools=[])

app.dependency_overrides[get_agents_routine] = lambda: routine
app.dependency_overrides[get_starting_agent] = lambda: agent_1
app.dependency_overrides[get_context_variables] = lambda: {"to_agent": agent_2}
test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)

with app_client as app_client:
wrong_response = app_client.get("/tools/output/test/123")
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

# Create a thread
create_output = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_output["thread_id"]

# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"query": "This is my query"},
params={"thread_id": thread_id},
headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"},
)

tool_output = app_client.get(f"/tools/output/{thread_id}/123")
assert tool_output.status_code == 200
assert tool_output.json() == []

# Get the messages of the thread
messages = app_client.get(f"/threads/{thread_id}").json()
message_id = messages[-1]["message_id"]
tool_calls = app_client.get(f"/tools/{thread_id}/{message_id}").json()

tool_call_id = tool_calls[0]["tool_call_id"]
tool_output = app_client.get(f"/tools/output/{thread_id}/{tool_call_id}")

assert tool_output.json() == [json.dumps({"assistant": agent_2.name})]

0 comments on commit ec83653

Please sign in to comment.