Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests of AgentsRoutine #58

Merged
merged 4 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Tool implementations without langchain or langgraph dependencies
- CRUDs.
- BlueNaas CRUD tools
- Tests of AgentsRoutine.
- Unit tests for database

### Fixed
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ convention = "numpy"

[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D"]
"swarm_copy_tests/*" = ["D"]

[tool.mypy]
mypy_path = "src"
Expand Down
4 changes: 2 additions & 2 deletions swarm_copy/run.py → swarm_copy/agent_routine.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def astream(
) -> AsyncIterator[str | Response]:
"""Stream the agent response."""
active_agent = agent
context_variables = copy.deepcopy(context_variables)

history = copy.deepcopy(messages)
init_len = len(messages)
is_streaming = False
Expand Down Expand Up @@ -251,7 +251,7 @@ async def astream(
stream=True,
)
async for chunk in completion: # type: ignore
delta = json.loads(chunk.choices[0].delta.json())
delta = json.loads(chunk.choices[0].delta.model_dump_json())

# Check for tool calls
if delta["tool_calls"]:
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from starlette.status import HTTP_401_UNAUTHORIZED

from swarm_copy.agent_routine import AgentsRoutine
from swarm_copy.app.app_utils import validate_project
from swarm_copy.app.config import Settings
from swarm_copy.app.database.sql_schemas import Threads
from swarm_copy.cell_types import CellTypesMeta
from swarm_copy.new_types import Agent
from swarm_copy.run import AgentsRoutine
from swarm_copy.tools import (
ElectrophysFeatureTool,
GetMorphoTool,
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi.responses import StreamingResponse
from sqlalchemy.ext.asyncio import AsyncSession

from swarm_copy.agent_routine import AgentsRoutine
from swarm_copy.app.database.db_utils import get_history, get_thread, save_history
from swarm_copy.app.database.sql_schemas import Threads
from swarm_copy.app.dependencies import (
Expand All @@ -17,7 +18,6 @@
get_user_id,
)
from swarm_copy.new_types import Agent, AgentRequest, AgentResponse
from swarm_copy.run import AgentsRoutine
from swarm_copy.stream import stream_agent_response

router = APIRouter(prefix="/qa", tags=["Run the agent"])
Expand Down
2 changes: 1 addition & 1 deletion swarm_copy/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from openai import AsyncOpenAI
from sqlalchemy.ext.asyncio import AsyncSession

from swarm_copy.agent_routine import AgentsRoutine
from swarm_copy.app.database.db_utils import save_history
from swarm_copy.new_types import Agent, Response
from swarm_copy.run import AgentsRoutine


async def stream_agent_response(
Expand Down
1 change: 1 addition & 0 deletions swarm_copy_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Sarm copy tests."""
89 changes: 67 additions & 22 deletions swarm_copy_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,20 @@

import json
from pathlib import Path
from typing import ClassVar

import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from pydantic import BaseModel, ConfigDict
from sqlalchemy import MetaData
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine

from swarm_copy.app.config import Settings
from swarm_copy.app.dependencies import get_kg_token, get_settings
from swarm_copy.app.dependencies import Agent, get_kg_token, get_settings
from swarm_copy.app.main import app


@pytest.fixture(name="settings")
def settings():
return Settings(
tools={
"literature": {
"url": "fake_literature_url",
},
},
knowledge_graph={
"base_url": "https://fake_url/api/nexus/v1",
},
openai={
"token": "fake_token",
},
keycloak={
"username": "fake_username",
"password": "fake_password",
},
)
from swarm_copy.tools.base_tool import BaseTool
from swarm_copy_tests.mock_client import MockOpenAIClient, create_mock_response


@pytest.fixture(name="app_client")
Expand Down Expand Up @@ -62,6 +45,68 @@ def client_fixture():
yield app_client
app.dependency_overrides.clear()

@pytest.fixture
def mock_openai_client():
"""Fake openai client."""
m = MockOpenAIClient()
m.set_response(
create_mock_response(
{"role": "assistant", "content": "sample response content"}
)
)
return m


@pytest.fixture(name="get_weather_tool")
def fake_tool():
"""Fake get weather tool."""

class FakeToolInput(BaseModel):
location: str

class FakeToolMetadata(
BaseModel
): # Should be a BaseMetadata but we don't want httpx client here
model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True)
planet: str | None = None

class FakeTool(BaseTool):
name: ClassVar[str] = "get_weather"
description: ClassVar[str] = "Great description"
metadata: FakeToolMetadata
input_schema: FakeToolInput

async def arun(self):
if self.metadata.planet:
return f"It's sunny today in {self.input_schema.location} from planet {self.metadata.planet}."
return "It's sunny today."

return FakeTool


@pytest.fixture
def agent_handoff_tool():
"""Fake agent handoff tool."""

class HandoffToolInput(BaseModel):
pass

class HandoffToolMetadata(
BaseModel
): # Should be a BaseMetadata but we don't want httpx client here
to_agent: Agent
model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True)

class HandoffTool(BaseTool):
name: ClassVar[str] = "agent_handoff_tool"
description: ClassVar[str] = "Handoff to another agent."
metadata: HandoffToolMetadata
input_schema: HandoffToolInput

async def arun(self):
return self.metadata.to_agent

return HandoffTool

@pytest.fixture(autouse=True, scope="session")
def dont_look_at_env_file():
Expand Down
68 changes: 68 additions & 0 deletions swarm_copy_tests/mock_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import json
from unittest.mock import AsyncMock

from openai.types.chat import ChatCompletionMessage
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
Function,
)


def create_mock_response(message, function_calls=[], model="gpt-4o-mini"):
role = message.get("role", "assistant")
content = message.get("content", "")
tool_calls = (
[
ChatCompletionMessageToolCall(
id="mock_tc_id",
type="function",
function=Function(
name=call.get("name", ""),
arguments=json.dumps(call.get("args", {})),
),
)
for call in function_calls
]
if function_calls
else None
)

return ChatCompletion(
id="mock_cc_id",
created=1234567890,
model=model,
object="chat.completion",
choices=[
Choice(
message=ChatCompletionMessage(
role=role, content=content, tool_calls=tool_calls
),
finish_reason="stop",
index=0,
)
],
)


class MockOpenAIClient:
def __init__(self):
self.chat = AsyncMock()
self.chat.completions = AsyncMock()

def set_response(self, response: ChatCompletion):
"""
Set the mock to return a specific response.
:param response: A ChatCompletion response to return.
"""
self.chat.completions.create.return_value = response

def set_sequential_responses(self, responses: list[ChatCompletion]):
"""
Set the mock to return different responses sequentially.
:param responses: A list of ChatCompletion responses to return in order.
"""
self.chat.completions.create.side_effect = responses

def assert_create_called_with(self, **kwargs):
self.chat.completions.create.assert_called_with(**kwargs)
Loading
Loading