Skip to content

Commit

Permalink
Add test for get_messages
Browse files Browse the repository at this point in the history
  • Loading branch information
BoBer78 committed Dec 19, 2024
1 parent 1d45ce8 commit 9edab09
Showing 1 changed file with 101 additions and 10 deletions.
111 changes: 101 additions & 10 deletions swarm_copy_tests/app/routers/test_threads.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import logging
from unittest.mock import AsyncMock, Mock, patch

import pytest

from swarm_copy.agent_routine import Agent, AgentsRoutine
from swarm_copy.app.config import Settings
from swarm_copy.app.dependencies import get_settings
from swarm_copy.app.dependencies import (
get_agents_routine,
get_settings,
get_starting_agent,
)
from swarm_copy.app.main import app
from swarm_copy_tests.mock_client import create_mock_response


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
def test_create_thread(patch_required_env, httpx_mock, app_client, db_connection):
test_settings = Settings(
db={"prefix": db_connection},
Expand Down Expand Up @@ -52,6 +56,87 @@ def test_get_threads(patch_required_env, httpx_mock, app_client, db_connection):
assert threads[1] == create_output_2


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
@pytest.mark.asyncio
async def test_get_messages(
patch_required_env,
httpx_mock,
app_client,
db_connection,
mock_openai_client,
get_weather_tool,
):
# Put data in the db
routine = AgentsRoutine(client=mock_openai_client)

mock_openai_client.set_sequential_responses(
[
create_mock_response(
message={"role": "assistant", "content": ""},
function_calls=[
{"name": "get_weather", "args": {"location": "Geneva"}}
],
),
create_mock_response(
{"role": "assistant", "content": "sample response content"}
),
]
)
agent = Agent(tools=[get_weather_tool])

app.dependency_overrides[get_agents_routine] = lambda: routine
app.dependency_overrides[get_starting_agent] = lambda: agent

test_settings = Settings(
db={"prefix": db_connection},
)
app.dependency_overrides[get_settings] = lambda: test_settings
httpx_mock.add_response(
url=f"{test_settings.virtual_lab.get_project_url}/test_vlab/projects/test_project"
)

with app_client as app_client:
# wrong thread ID
wrong_response = app_client.get("/threads/test")
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

# Create a thread
create_output = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
thread_id = create_output["thread_id"]

# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"query": "This is my query"},
headers={"x-virtual-lab-id": "test_vlab", "x-project-id": "test_project"},
)

create_output = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
empty_thread_id = create_output["thread_id"]
empty_messages = app_client.get(f"/threads/{empty_thread_id}").json()
assert empty_messages == []

# Get the messages of the thread
messages = app_client.get(f"/threads/{thread_id}").json()

assert messages[0]["order"] == 0
assert messages[0]["entity"] == "user"
assert messages[0]["msg_content"] == "This is my query"
assert messages[0]["message_id"]
assert messages[0]["creation_date"]

assert messages[1]["order"] == 3
assert messages[1]["entity"] == "ai_message"
assert messages[1]["msg_content"] == "sample response content"
assert messages[1]["message_id"]
assert messages[1]["creation_date"]


@pytest.mark.httpx_mock(can_send_already_matched_responses=True)
def test_update_thread_title(patch_required_env, httpx_mock, app_client, db_connection):
test_settings = Settings(
Expand All @@ -66,6 +151,13 @@ def test_update_thread_title(patch_required_env, httpx_mock, app_client, db_conn
threads = app_client.get("/threads/").json()
assert not threads

# Check when wrong thread id
wrong_response = app_client.patch(
"/threads/wrong_id", json={"title": "great_title"}
)
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

create_thread_response = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
Expand Down Expand Up @@ -93,6 +185,11 @@ def test_delete_thread(patch_required_env, httpx_mock, app_client, db_connection
threads = app_client.get("/threads/").json()
assert not threads

# Check when wrong thread id
wrong_response = app_client.delete("/threads/wrong_id")
assert wrong_response.status_code == 404
assert wrong_response.json() == {"detail": {"detail": "Thread not found."}}

create_thread_response = app_client.post(
"/threads/?virtual_lab_id=test_vlab&project_id=test_project"
).json()
Expand All @@ -107,9 +204,3 @@ def test_delete_thread(patch_required_env, httpx_mock, app_client, db_connection

threads = app_client.get("/threads/").json()
assert not threads


@pytest.fixture(autouse=True)
def stop_patches():
yield
patch.stopall()

0 comments on commit 9edab09

Please sign in to comment.