From 29f8e8a394028879774b0d8c0cfd0c38c556e059 Mon Sep 17 00:00:00 2001 From: cszsol Date: Tue, 17 Dec 2024 13:33:37 +0100 Subject: [PATCH] Added db_utils tests (#53) * Added db_utils tests * Run new tests * Fixed merge issues * Review * Added new db util tests --------- Co-authored-by: kanesoban --- .github/workflows/ci.yaml | 2 +- CHANGELOG.md | 1 + swarm_copy/tools/bluenaas_memodel_getall.py | 4 +- swarm_copy/tools/bluenaas_memodel_getone.py | 4 +- swarm_copy/tools/traces_tool.py | 3 +- swarm_copy_tests/app/database/__init__.py | 1 + .../app/database/test_db_utils.py | 449 ++++++++++++++++++ swarm_copy_tests/conftest.py | 124 +++++ 8 files changed, 582 insertions(+), 6 deletions(-) create mode 100644 swarm_copy_tests/app/database/__init__.py create mode 100644 swarm_copy_tests/app/database/test_db_utils.py create mode 100644 swarm_copy_tests/conftest.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index d8becc5..0347c9b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -90,4 +90,4 @@ jobs: mypy src/ swarm_copy/ # Include src/ directory in Python path to prioritize local files in pytest export PYTHONPATH=$(pwd)/src:$PYTHONPATH - pytest --color=yes + pytest --color=yes tests/ swarm_copy_tests/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 52955bc..bceac7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Tool implementations without langchain or langgraph dependencies - CRUDs. - BlueNaas CRUD tools +- Unit tests for database ### Fixed - Migrate LLM Evaluation logic to scripts and add tests diff --git a/swarm_copy/tools/bluenaas_memodel_getall.py b/swarm_copy/tools/bluenaas_memodel_getall.py index 8bda00e..a7cee19 100644 --- a/swarm_copy/tools/bluenaas_memodel_getall.py +++ b/swarm_copy/tools/bluenaas_memodel_getall.py @@ -29,7 +29,7 @@ class InputMEModelGetAll(BaseModel): page_size: int = Field( default=20, description="Number of results returned by the API." ) - model_type: Literal["single-neuron-simulation", "synaptome-simulation"] = Field( + memodel_type: Literal["single-neuron-simulation", "synaptome-simulation"] = Field( default="single-neuron-simulation", description="Type of simulation to retrieve.", ) @@ -55,7 +55,7 @@ async def arun(self) -> PaginatedResponseUnionMEModelResponseSynaptomeModelRespo response = await self.metadata.httpx_client.get( url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/me-models", params={ - "simulation_type": self.input_schema.model_type, + "simulation_type": self.input_schema.memodel_type, "offset": self.input_schema.offset, "page_size": self.input_schema.page_size, }, diff --git a/swarm_copy/tools/bluenaas_memodel_getone.py b/swarm_copy/tools/bluenaas_memodel_getone.py index 4f4a3b3..f84acfa 100644 --- a/swarm_copy/tools/bluenaas_memodel_getone.py +++ b/swarm_copy/tools/bluenaas_memodel_getone.py @@ -24,7 +24,7 @@ class MEModelGetOneMetadata(BaseMetadata): class InputMEModelGetOne(BaseModel): """Inputs for the BlueNaaS single-neuron simulation.""" - model_id: str = Field( + memodel_id: str = Field( description="ID of the model to retrieve. Should be an https link." ) @@ -45,7 +45,7 @@ async def arun(self) -> MEModelResponse: ) response = await self.metadata.httpx_client.get( - url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.model_id)}", + url=f"{self.metadata.bluenaas_url}/neuron-model/{self.metadata.vlab_id}/{self.metadata.project_id}/{quote_plus(self.input_schema.memodel_id)}", headers={"Authorization": f"Bearer {self.metadata.token}"}, ) diff --git a/swarm_copy/tools/traces_tool.py b/swarm_copy/tools/traces_tool.py index 41028b2..0434013 100644 --- a/swarm_copy/tools/traces_tool.py +++ b/swarm_copy/tools/traces_tool.py @@ -1,6 +1,7 @@ """Traces tool.""" import logging +from pathlib import Path from typing import Any, ClassVar from pydantic import BaseModel, Field @@ -46,7 +47,7 @@ class GetTracesMetadata(BaseMetadata): knowledge_graph_url: str token: str trace_search_size: int - brainregion_path: str + brainregion_path: str | Path class GetTracesTool(BaseTool): diff --git a/swarm_copy_tests/app/database/__init__.py b/swarm_copy_tests/app/database/__init__.py new file mode 100644 index 0000000..8ce3e8d --- /dev/null +++ b/swarm_copy_tests/app/database/__init__.py @@ -0,0 +1 @@ +"""Unit tests for database.""" diff --git a/swarm_copy_tests/app/database/test_db_utils.py b/swarm_copy_tests/app/database/test_db_utils.py new file mode 100644 index 0000000..840e697 --- /dev/null +++ b/swarm_copy_tests/app/database/test_db_utils.py @@ -0,0 +1,449 @@ +import json + +import pytest +from fastapi import HTTPException +from sqlalchemy import select + +from swarm_copy.app.app_utils import setup_engine +from swarm_copy.app.config import Settings +from swarm_copy.app.database.db_utils import get_thread, save_history, get_history +from swarm_copy.app.database.sql_schemas import Entity, Messages, Base, Threads +from swarm_copy.app.dependencies import get_session + + +@pytest.mark.asyncio +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +async def test_get_thread(patch_required_env, db_connection): + test_settings = Settings( + db={"prefix": db_connection}, + ) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + valid_thread_id = "test_thread_id" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=valid_thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title", + ) + session.add(new_thread) + await session.commit() + + try: + thread = await get_thread( + user_id=user_id, + thread_id=valid_thread_id, + session=session, + ) + assert thread.user_id == user_id + assert thread.thread_id == valid_thread_id + assert thread.title == "test_title" + finally: + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +async def test_get_thread_invalid_thread_id(patch_required_env, db_connection): + test_settings = Settings( + db={"prefix": db_connection}, + ) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + valid_thread_id = "test_thread_id" + invalid_thread_id = "wrong_thread_id" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=valid_thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title", + ) + session.add(new_thread) + await session.commit() + + try: + with pytest.raises(HTTPException) as exc_info: + await get_thread( + user_id=user_id, + thread_id=invalid_thread_id, + session=session, + ) + assert exc_info.value.status_code == 404 + assert exc_info.value.detail["detail"] == "Thread not found." + finally: + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +@pytest.mark.httpx_mock(can_send_already_matched_responses=True) +async def test_get_thread_invalid_user_id(patch_required_env, db_connection): + test_settings = Settings( + db={"prefix": db_connection}, + ) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + valid_thread_id = "test_thread_id" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=valid_thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title", + ) + session.add(new_thread) + await session.commit() + + try: + with pytest.raises(HTTPException) as exc_info: + await get_thread( + user_id="wrong_user", + thread_id=valid_thread_id, + session=session, + ) + assert exc_info.value.status_code == 404 + assert exc_info.value.detail["detail"] == "Thread not found." + + finally: + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +async def test_save_history(patch_required_env, db_connection): + test_settings = Settings(db={"prefix": db_connection}) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + thread_id = "test_thread_id" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title", + ) + session.add(new_thread) + await session.commit() + + try: + history = [ + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "AI message"}, + ] + await save_history(history, user_id, thread_id, offset=0, session=session) + + result = await session.execute(select(Messages).where(Messages.thread_id == thread_id)) + messages = result.scalars().all() + + assert len(messages) == len(history) + assert messages[0].entity == Entity.USER + assert messages[0].content == json.dumps(history[0]) + assert messages[1].entity == Entity.AI_MESSAGE + assert messages[1].content == json.dumps(history[1]) + + updated_thread = await get_thread(user_id=user_id, thread_id=thread_id, session=session) + assert updated_thread.update_date is not None + + finally: + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +async def test_save_history_with_tool_messages(patch_required_env, db_connection): + test_settings = Settings(db={"prefix": db_connection}) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + thread_id = "test_thread_id" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title", + ) + session.add(new_thread) + await session.commit() + + try: + history = [ + {"role": "tool", "content": "Tool invoked"}, + {"role": "assistant", "content": ""}, + ] + await save_history(history, user_id, thread_id, offset=0, session=session) + + result = await session.execute(select(Messages).where(Messages.thread_id == thread_id)) + messages = result.scalars().all() + + assert len(messages) == len(history) + assert messages[0].entity == Entity.TOOL + assert messages[0].content == json.dumps(history[0]) + assert messages[1].entity == Entity.AI_TOOL + assert messages[1].content == json.dumps(history[1]) + + finally: + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +async def test_save_history_invalid_message_entity(patch_required_env, db_connection): + test_settings = Settings(db={"prefix": db_connection}) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + thread_id = "test_thread_id" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title", + ) + session.add(new_thread) + await session.commit() + + try: + history = [{"role": "unknown", "content": "Invalid entity message"}] + + with pytest.raises(HTTPException) as exc_info: + await save_history(history, user_id, thread_id, offset=0, session=session) + + assert exc_info.value.status_code == 500 + assert exc_info.value.detail == "Unknown message entity." + + finally: + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +async def test_save_history_with_offset(patch_required_env, db_connection): + test_settings = Settings(db={"prefix": db_connection}) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + thread_id = "test_thread_id" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title", + ) + session.add(new_thread) + await session.commit() + + try: + history = [ + {"role": "user", "content": "First user message"}, + {"role": "assistant", "content": "First AI message"}, + ] + await save_history(history, user_id, thread_id, offset=5, session=session) + + result = await session.execute(select(Messages).where(Messages.thread_id == thread_id)) + messages = result.scalars().all() + + assert len(messages) == len(history) + assert messages[0].order == 5 + assert messages[0].content == json.dumps(history[0]) + assert messages[1].order == 6 + assert messages[1].content == json.dumps(history[1]) + + finally: + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +async def test_get_history_empty_thread(patch_required_env, db_connection): + test_settings = Settings(db={"prefix": db_connection}) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + thread_id = "empty_thread" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title_empty", + ) + session.add(new_thread) + await session.commit() + + try: + thread = await get_thread(user_id=user_id, thread_id=thread_id, session=session) + history = await get_history(thread) + + assert history == [] + + finally: + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +async def test_get_history_with_messages(patch_required_env, db_connection): + test_settings = Settings(db={"prefix": db_connection}) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + thread_id = "valid_thread" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title_valid", + ) + session.add(new_thread) + await session.commit() + + messages_to_add = [ + Messages(order=1, thread_id=thread_id, entity=Entity.USER, + content=json.dumps({"role": "user", "content": "User message"})), + Messages(order=2, thread_id=thread_id, entity=Entity.AI_MESSAGE, + content=json.dumps({"role": "assistant", "content": "AI message"})), + ] + session.add_all(messages_to_add) + await session.commit() + + try: + thread = await get_thread(user_id=user_id, thread_id=thread_id, session=session) + history = await get_history(thread) + + assert len(history) == 2 + assert history[0] == {"role": "user", "content": "User message"} + assert history[1] == {"role": "assistant", "content": "AI message"} + + finally: + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +async def test_get_history_ignore_empty_messages(patch_required_env, db_connection): + test_settings = Settings(db={"prefix": db_connection}) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + thread_id = "thread_with_empty_messages" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title_ignore_empty", + ) + session.add(new_thread) + await session.commit() + + messages_to_add = [ + Messages(order=1, thread_id=thread_id, entity=Entity.USER, + content=json.dumps({"role": "user", "content": "User message"})), + Messages(order=2, thread_id=thread_id, entity=Entity.TOOL, content=""), # Empty content should be ignored + Messages(order=3, thread_id=thread_id, entity=Entity.AI_MESSAGE, + content=json.dumps({"role": "assistant", "content": "AI message"})), + ] + session.add_all(messages_to_add) + await session.commit() + + try: + thread = await get_thread(user_id=user_id, thread_id=thread_id, session=session) + history = await get_history(thread) + + assert len(history) == 2 + assert history[0] == {"role": "user", "content": "User message"} + assert history[1] == {"role": "assistant", "content": "AI message"} + + finally: + await session.close() + await engine.dispose() + + +@pytest.mark.asyncio +async def test_get_history_with_malformed_json(patch_required_env, db_connection): + test_settings = Settings(db={"prefix": db_connection}) + engine = setup_engine(test_settings, db_connection) + session = await anext(get_session(engine)) + user_id = "test_user" + thread_id = "malformed_thread" + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + new_thread = Threads( + user_id=user_id, + thread_id=thread_id, + vlab_id="test_vlab_DB", + project_id="project_id_DB", + title="test_title_malformed", + ) + session.add(new_thread) + await session.commit() + + messages_to_add = [ + Messages(order=1, thread_id=thread_id, entity=Entity.USER, + content=json.dumps({"role": "user", "content": "Valid message"})), + Messages(order=2, thread_id=thread_id, entity=Entity.AI_MESSAGE, content="MALFORMED_JSON"), # Malformed JSON + ] + session.add_all(messages_to_add) + await session.commit() + + try: + thread = await get_thread(user_id=user_id, thread_id=thread_id, session=session) + with pytest.raises(json.JSONDecodeError): + await get_history(thread) + + finally: + await session.close() + await engine.dispose() diff --git a/swarm_copy_tests/conftest.py b/swarm_copy_tests/conftest.py new file mode 100644 index 0000000..54ee423 --- /dev/null +++ b/swarm_copy_tests/conftest.py @@ -0,0 +1,124 @@ +"""Test configuration.""" + +import json +from pathlib import Path + +import pytest +import pytest_asyncio +from fastapi.testclient import TestClient +from sqlalchemy import MetaData +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + +from swarm_copy.app.config import Settings +from swarm_copy.app.dependencies import get_kg_token, get_settings +from swarm_copy.app.main import app + + +@pytest.fixture(name="settings") +def settings(): + return Settings( + tools={ + "literature": { + "url": "fake_literature_url", + }, + }, + knowledge_graph={ + "base_url": "https://fake_url/api/nexus/v1", + }, + openai={ + "token": "fake_token", + }, + keycloak={ + "username": "fake_username", + "password": "fake_password", + }, + ) + + +@pytest.fixture(name="app_client") +def client_fixture(): + """Get client and clear app dependency_overrides.""" + app_client = TestClient(app) + test_settings = Settings( + tools={ + "literature": { + "url": "fake_literature_url", + }, + }, + knowledge_graph={ + "base_url": "https://fake_url/api/nexus/v1", + }, + openai={ + "token": "fake_token", + }, + keycloak={ + "username": "fake_username", + "password": "fake_password", + }, + ) + app.dependency_overrides[get_settings] = lambda: test_settings + # mock keycloak authentication + app.dependency_overrides[get_kg_token] = lambda: "fake_token" + yield app_client + app.dependency_overrides.clear() + + +@pytest.fixture(autouse=True, scope="session") +def dont_look_at_env_file(): + """Never look inside of the .env when running unit tests.""" + Settings.model_config["env_file"] = None + + +@pytest.fixture() +def patch_required_env(monkeypatch): + monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__URL", "https://fake_url") + monkeypatch.setenv( + "NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "https://fake_url/api/nexus/v1" + ) + monkeypatch.setenv("NEUROAGENT_OPENAI__TOKEN", "dummy") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__VALIDATE_TOKEN", "False") + monkeypatch.setenv("NEUROAGENT_KEYCLOAK__PASSWORD", "password") + + +@pytest_asyncio.fixture(params=["sqlite", "postgresql"], name="db_connection") +async def setup_sql_db(request, tmp_path): + db_type = request.param + + # To start the postgresql database: + # docker run -it --rm -p 5432:5432 -e POSTGRES_USER=test -e POSTGRES_PASSWORD=password postgres:latest + path = ( + f"sqlite+aiosqlite:///{tmp_path / 'test_db.db'}" + if db_type == "sqlite" + else "postgresql+asyncpg://test:password@localhost:5432" + ) + if db_type == "postgresql": + try: + async with create_async_engine(path).connect() as conn: + pass + except Exception: + pytest.skip("Postgres database not connected") + yield path + if db_type == "postgresql": + metadata = MetaData() + engine = create_async_engine(path) + session = AsyncSession(bind=engine) + async with engine.begin() as conn: + await conn.run_sync(metadata.reflect) + await conn.run_sync(metadata.drop_all) + + await session.commit() + await engine.dispose() + await session.aclose() + + +@pytest.fixture +def get_resolve_query_output(): + with open("tests/data/resolve_query.json") as f: + outputs = json.loads(f.read()) + return outputs + + +@pytest.fixture +def brain_region_json_path(): + br_path = Path(__file__).parent / "data" / "brainregion_hierarchy.json" + return br_path