From 3c550ba1c331caf6ec0d99ef192dfe0146650c59 Mon Sep 17 00:00:00 2001 From: Nicolas Frank Date: Tue, 17 Sep 2024 13:47:03 +0200 Subject: [PATCH] Refactor iteration 2 --- .env.example | 11 ++-- CHANGELOG.md | 1 + src/neuroagent/agents/base_agent.py | 52 ------------------ src/neuroagent/app/config.py | 31 ++--------- src/neuroagent/app/dependencies.py | 84 ++++++++++++++--------------- src/neuroagent/app/routers/qa.py | 12 ++--- src/neuroagent/app/schemas.py | 5 +- tests/app/database/test_threads.py | 4 +- tests/app/database/test_tools.py | 4 +- tests/app/test_config.py | 2 +- tests/app/test_dependencies.py | 12 ++--- tests/app/test_main.py | 2 +- tests/app/test_qa.py | 16 +++--- tests/conftest.py | 8 ++- 14 files changed, 79 insertions(+), 165 deletions(-) diff --git a/.env.example b/.env.example index e7b68dd..983437e 100644 --- a/.env.example +++ b/.env.example @@ -5,7 +5,7 @@ NEUROAGENT_GENERATIVE__OPENAI__TOKEN= # Important but not required NEUROAGENT_AGENT__MODEL= -NEUROAGENT_AGENT__CHAT= + NEUROAGENT_KNOWLEDGE_GRAPH__USE_TOKEN= NEUROAGENT_KNOWLEDGE_GRAPH__TOKEN= NEUROAGENT_KNOWLEDGE_GRAPH__DOWNLOAD_HIERARCHY= @@ -27,12 +27,9 @@ NEUROAGENT_TOOLS__TRACE__SEARCH_SIZE= NEUROAGENT_TOOLS__KG_MORPHO__SEARCH_SIZE= -NEUROAGENT_GENERATIVE__LLM_TYPE= # can only be openai for now -NEUROAGENT_GENERATIVE__OPENAI__MODEL= -NEUROAGENT_GENERATIVE__OPENAI__TEMPERATURE= -NEUROAGENT_GENERATIVE__OPENAI__MAX_TOKENS= - -NEUROAGENT_COHERE__TOKEN= +NEUROAGENT_OPENAI__MODEL= +NEUROAGENT_OPENAI__TEMPERATURE= +NEUROAGENT_OPENAI__MAX_TOKENS= NEUROAGENT_LOGGING__LEVEL= NEUROAGENT_LOGGING__EXTERNAL_PACKAGES= diff --git a/CHANGELOG.md b/CHANGELOG.md index fd91bc3..e9a2f85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,3 +12,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Migration to pydantic V2. +- Deleted some legacy code. diff --git a/src/neuroagent/agents/base_agent.py b/src/neuroagent/agents/base_agent.py index df0833d..dbe3faa 100644 --- a/src/neuroagent/agents/base_agent.py +++ b/src/neuroagent/agents/base_agent.py @@ -4,61 +4,9 @@ from typing import Any, AsyncIterator from langchain.chat_models.base import BaseChatModel -from langchain_core.messages import ( - AIMessage, - ChatMessage, - FunctionMessage, - HumanMessage, - SystemMessage, - ToolMessage, -) -from langchain_core.prompts import ( - ChatPromptTemplate, - HumanMessagePromptTemplate, - MessagesPlaceholder, - PromptTemplate, - SystemMessagePromptTemplate, -) from langchain_core.tools import BaseTool from pydantic import BaseModel, ConfigDict -BASE_PROMPT = ChatPromptTemplate( - input_variables=["agent_scratchpad", "input"], - input_types={ - "chat_history": list[ - AIMessage - | HumanMessage - | ChatMessage - | SystemMessage - | FunctionMessage - | ToolMessage - ], - "agent_scratchpad": list[ - AIMessage - | HumanMessage - | ChatMessage - | SystemMessage - | FunctionMessage - | ToolMessage - ], - }, - messages=[ - SystemMessagePromptTemplate( - prompt=PromptTemplate( - input_variables=[], - template="""You are a helpful assistant helping scientists with neuro-scientific questions. - You must always specify in your answers from which brain regions the information is extracted. - Do no blindly repeat the brain region requested by the user, use the output of the tools instead.""", - ) - ), - MessagesPlaceholder(variable_name="chat_history", optional=True), - HumanMessagePromptTemplate( - prompt=PromptTemplate(input_variables=["input"], template="{input}") - ), - MessagesPlaceholder(variable_name="agent_scratchpad"), - ], -) - class AgentStep(BaseModel): """Class for agent decision steps.""" diff --git a/src/neuroagent/app/config.py b/src/neuroagent/app/config.py index f61f5b2..9055def 100644 --- a/src/neuroagent/app/config.py +++ b/src/neuroagent/app/config.py @@ -13,8 +13,7 @@ class SettingsAgent(BaseModel): """Agent setting.""" - model: str = "simple" - chat: str = "simple" + model: Literal["simple", "multi"] = "simple" model_config = ConfigDict(frozen=True) @@ -84,9 +83,9 @@ class SettingsLiterature(BaseModel): """Literature search API settings.""" url: str - retriever_k: int = 700 + retriever_k: int = 500 use_reranker: bool = True - reranker_k: int = 5 + reranker_k: int = 8 model_config = ConfigDict(frozen=True) @@ -173,23 +172,6 @@ class SettingsOpenAI(BaseModel): model_config = ConfigDict(frozen=True) -class SettingsGenerative(BaseModel): - """Generative QA settings.""" - - llm_type: Literal["fake", "openai"] = "openai" - openai: SettingsOpenAI = SettingsOpenAI() - - model_config = ConfigDict(frozen=True) - - -class SettingsCohere(BaseModel): - """Settings cohere reranker.""" - - token: Optional[SecretStr] = None - - model_config = ConfigDict(frozen=True) - - class SettingsLogging(BaseModel): """Metadata settings.""" @@ -219,8 +201,7 @@ class Settings(BaseSettings): knowledge_graph: SettingsKnowledgeGraph agent: SettingsAgent = SettingsAgent() # has no required db: SettingsDB = SettingsDB() # has no required - generative: SettingsGenerative = SettingsGenerative() # has no required - cohere: SettingsCohere = SettingsCohere() # has no required + openai: SettingsOpenAI = SettingsOpenAI() # has no required logging: SettingsLogging = SettingsLogging() # has no required keycloak: SettingsKeycloak = SettingsKeycloak() # has no required misc: SettingsMisc = SettingsMisc() # has no required @@ -240,10 +221,6 @@ def check_consistency(self) -> "Settings": model validator is run during instantiation. """ - # generative - if self.generative.llm_type == "openai": - if self.generative.openai.token is None: - raise ValueError("OpenAI token not provided") if not self.keycloak.password and not self.keycloak.validate_token: if not self.knowledge_graph.use_token: raise ValueError("if no password is provided, please use token auth.") diff --git a/src/neuroagent/app/dependencies.py b/src/neuroagent/app/dependencies.py index 84791a4..be00639 100644 --- a/src/neuroagent/app/dependencies.py +++ b/src/neuroagent/app/dependencies.py @@ -308,12 +308,12 @@ def get_language_model( settings: Annotated[Settings, Depends(get_settings)], ) -> ChatOpenAI: """Get the language model.""" - logger.info(f"OpenAI selected. Loading model {settings.generative.openai.model}.") + logger.info(f"OpenAI selected. Loading model {settings.openai.model}.") return ChatOpenAI( - model_name=settings.generative.openai.model, - temperature=settings.generative.openai.temperature, - openai_api_key=settings.generative.openai.token.get_secret_value(), # type: ignore - max_tokens=settings.generative.openai.max_tokens, + model_name=settings.openai.model, + temperature=settings.openai.temperature, + openai_api_key=settings.openai.token.get_secret_value(), # type: ignore + max_tokens=settings.openai.max_tokens, seed=78, streaming=True, ) @@ -369,43 +369,10 @@ def get_agent( ElectrophysFeatureTool, Depends(get_electrophys_feature_tool) ], traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)], -) -> BaseAgent | BaseMultiAgent: - """Get the generative question answering service.""" - tools = [ - literature_tool, - br_resolver_tool, - morpho_tool, - morphology_feature_tool, - kg_morpho_feature_tool, - electrophys_feature_tool, - traces_tool, - ] - logger.info("Load simple agent") - return SimpleAgent(llm=llm, tools=tools) # type: ignore - - -def get_chat_agent( - llm: Annotated[ChatOpenAI, Depends(get_language_model)], - memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)], - literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)], - br_resolver_tool: Annotated[ - ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool) - ], - morpho_tool: Annotated[GetMorphoTool, Depends(get_morpho_tool)], - morphology_feature_tool: Annotated[ - MorphologyFeatureTool, Depends(get_morphology_feature_tool) - ], - kg_morpho_feature_tool: Annotated[ - KGMorphoFeatureTool, Depends(get_kg_morpho_feature_tool) - ], - electrophys_feature_tool: Annotated[ - ElectrophysFeatureTool, Depends(get_electrophys_feature_tool) - ], - traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)], settings: Annotated[Settings, Depends(get_settings)], -) -> BaseAgent: +) -> BaseAgent | BaseMultiAgent: """Get the generative question answering service.""" - if settings.agent.chat == "multi": + if settings.agent.model == "multi": logger.info("Load multi-agent chat") tools_list = [ ("literature", [literature_tool]), @@ -422,7 +389,6 @@ def get_chat_agent( ] return SupervisorMultiAgent(llm=llm, agents=tools_list) # type: ignore else: - logger.info("Load simple chat") tools = [ literature_tool, br_resolver_tool, @@ -432,7 +398,41 @@ def get_chat_agent( electrophys_feature_tool, traces_tool, ] - return SimpleChatAgent(llm=llm, tools=tools, memory=memory) # type: ignore + logger.info("Load simple agent") + return SimpleAgent(llm=llm, tools=tools) # type: ignore + + +def get_chat_agent( + llm: Annotated[ChatOpenAI, Depends(get_language_model)], + memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)], + literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)], + br_resolver_tool: Annotated[ + ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool) + ], + morpho_tool: Annotated[GetMorphoTool, Depends(get_morpho_tool)], + morphology_feature_tool: Annotated[ + MorphologyFeatureTool, Depends(get_morphology_feature_tool) + ], + kg_morpho_feature_tool: Annotated[ + KGMorphoFeatureTool, Depends(get_kg_morpho_feature_tool) + ], + electrophys_feature_tool: Annotated[ + ElectrophysFeatureTool, Depends(get_electrophys_feature_tool) + ], + traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)], +) -> BaseAgent: + """Get the generative question answering service.""" + logger.info("Load simple chat") + tools = [ + literature_tool, + br_resolver_tool, + morpho_tool, + morphology_feature_tool, + kg_morpho_feature_tool, + electrophys_feature_tool, + traces_tool, + ] + return SimpleChatAgent(llm=llm, tools=tools, memory=memory) # type: ignore async def get_update_kg_hierarchy( diff --git a/src/neuroagent/app/routers/qa.py b/src/neuroagent/app/routers/qa.py index 447bb28..1104d00 100644 --- a/src/neuroagent/app/routers/qa.py +++ b/src/neuroagent/app/routers/qa.py @@ -34,8 +34,8 @@ async def run_agent( ) -> AgentOutput: """Run agent.""" logger.info("Running agent query.") - logger.info(f"User's query: {request.inputs}") - return await agent.arun(request.inputs) + logger.info(f"User's query: {request.query}") + return await agent.arun(request.query) @router.post("/chat/{thread_id}", response_model=AgentOutput) @@ -47,8 +47,8 @@ async def run_chat_agent( ) -> AgentOutput: """Run chat agent.""" logger.info("Running agent query.") - logger.info(f"User's query: {request.inputs}") - return await agent.arun(query=request.inputs, thread_id=thread_id) + logger.info(f"User's query: {request.query}") + return await agent.arun(query=request.query, thread_id=thread_id) @router.post("/chat_streamed/{thread_id}") @@ -60,5 +60,5 @@ async def run_streamed_chat_agent( ) -> StreamingResponse: """Run agent in streaming mode.""" logger.info("Running agent query.") - logger.info(f"User's query: {request.inputs}") - return StreamingResponse(agent.astream(query=request.inputs, thread_id=thread_id)) # type: ignore + logger.info(f"User's query: {request.query}") + return StreamingResponse(agent.astream(query=request.query, thread_id=thread_id)) # type: ignore diff --git a/src/neuroagent/app/schemas.py b/src/neuroagent/app/schemas.py index 344a957..962ae84 100644 --- a/src/neuroagent/app/schemas.py +++ b/src/neuroagent/app/schemas.py @@ -1,12 +1,9 @@ """Schemas.""" -from typing import Any - from pydantic import BaseModel class AgentRequest(BaseModel): """Class for agent request.""" - inputs: str - parameters: dict[str, Any] + query: str diff --git a/tests/app/database/test_threads.py b/tests/app/database/test_threads.py index 0d0f2da..7807826 100644 --- a/tests/app/database/test_threads.py +++ b/tests/app/database/test_threads.py @@ -65,7 +65,7 @@ async def test_get_thread( # Fill the thread app_client.post( f"/qa/chat/{thread_id}", - json={"inputs": "This is my query", "parameters": {}}, + json={"query": "This is my query"}, ) create_output = app_client.post("/threads/").json() @@ -131,7 +131,7 @@ async def test_delete_thread( # Fill the thread app_client.post( f"/qa/chat/{thread_id}", - json={"inputs": "This is my query", "parameters": {}}, + json={"query": "This is my query"}, params={"thread_id": thread_id}, ) # Get the messages of the thread diff --git a/tests/app/database/test_tools.py b/tests/app/database/test_tools.py index 6c5a7fd..a5f55e0 100644 --- a/tests/app/database/test_tools.py +++ b/tests/app/database/test_tools.py @@ -30,7 +30,7 @@ async def test_get_tool_calls( # Fill the thread app_client.post( f"/qa/chat/{thread_id}", - json={"inputs": "This is my query", "parameters": {}}, + json={"query": "This is my query"}, params={"thread_id": thread_id}, ) @@ -121,7 +121,7 @@ async def test_get_tool_output( # Fill the thread app_client.post( f"/qa/chat/{thread_id}", - json={"inputs": "This is my query", "parameters": {}}, + json={"query": "This is my query"}, params={"thread_id": thread_id}, ) diff --git a/tests/app/test_config.py b/tests/app/test_config.py index b3543b9..459d457 100644 --- a/tests/app/test_config.py +++ b/tests/app/test_config.py @@ -10,7 +10,7 @@ def test_required(monkeypatch, patch_required_env): assert settings.tools.literature.url == "https://fake_url" assert settings.knowledge_graph.base_url == "https://fake_url/api/nexus/v1" - assert settings.generative.openai.token.get_secret_value() == "dummy" + assert settings.openai.token.get_secret_value() == "dummy" assert settings.knowledge_graph.use_token assert settings.knowledge_graph.token.get_secret_value() == "token" diff --git a/tests/app/test_dependencies.py b/tests/app/test_dependencies.py index 4611047..1acfa4b 100644 --- a/tests/app/test_dependencies.py +++ b/tests/app/test_dependencies.py @@ -91,8 +91,8 @@ def test_get_literature_tool(monkeypatch, patch_required_env): literature_tool = get_literature_tool(token, settings, httpx_client) assert isinstance(literature_tool, LiteratureSearchTool) assert literature_tool.metadata["url"] == url - assert literature_tool.metadata["retriever_k"] == 700 - assert literature_tool.metadata["reranker_k"] == 5 + assert literature_tool.metadata["retriever_k"] == 500 + assert literature_tool.metadata["reranker_k"] == 8 assert literature_tool.metadata["use_reranker"] is True monkeypatch.setenv("NEUROAGENT_TOOLS__LITERATURE__RETRIEVER_K", "30") @@ -163,9 +163,9 @@ async def test_get_memory(patch_required_env, db_connection): def test_language_model(monkeypatch, patch_required_env): - monkeypatch.setenv("NEUROAGENT_GENERATIVE__OPENAI__MODEL", "dummy") - monkeypatch.setenv("NEUROAGENT_GENERATIVE__OPENAI__TEMPERATURE", "99") - monkeypatch.setenv("NEUROAGENT_GENERATIVE__OPENAI__MAX_TOKENS", "99") + monkeypatch.setenv("NEUROAGENT_OPENAI__MODEL", "dummy") + monkeypatch.setenv("NEUROAGENT_OPENAI__TEMPERATURE", "99") + monkeypatch.setenv("NEUROAGENT_OPENAI__MAX_TOKENS", "99") settings = Settings() @@ -217,6 +217,7 @@ def test_get_agent(monkeypatch, patch_required_env): kg_morpho_feature_tool=kg_morpho_feature_tool, electrophys_feature_tool=electrophys_feature_tool, traces_tool=traces_tool, + settings=settings, ) assert isinstance(agent, SimpleAgent) @@ -267,7 +268,6 @@ async def test_get_chat_agent(monkeypatch, db_connection, patch_required_env): electrophys_feature_tool=electrophys_feature_tool, traces_tool=traces_tool, memory=memory, - settings=settings, ) assert isinstance(agent, SimpleChatAgent) diff --git a/tests/app/test_main.py b/tests/app/test_main.py index a21b6fd..79fe9a8 100644 --- a/tests/app/test_main.py +++ b/tests/app/test_main.py @@ -12,7 +12,7 @@ def test_settings_endpoint(app_client, dont_look_at_env_file): replace_secretstr = settings.model_dump() replace_secretstr["keycloak"]["password"] = "**********" - replace_secretstr["generative"]["openai"]["token"] = "**********" + replace_secretstr["openai"]["token"] = "**********" assert response.json() == replace_secretstr diff --git a/tests/app/test_qa.py b/tests/app/test_qa.py index 3847764..3959280 100644 --- a/tests/app/test_qa.py +++ b/tests/app/test_qa.py @@ -21,13 +21,11 @@ def test_run_agent(app_client): agent_mock.arun.return_value = agent_output app.dependency_overrides[get_agent] = lambda: agent_mock - response = app_client.post( - "/qa/run", json={"inputs": "This is my query", "parameters": {}} - ) + response = app_client.post("/qa/run", json={"query": "This is my query"}) assert response.status_code == 200 assert response.json() == agent_output.model_dump() - # Missing inputs + # Missing query response = app_client.post("/qa/run", json={}) assert response.status_code == 422 @@ -55,15 +53,13 @@ def test_run_chat_agent(app_client, tmp_path, patch_required_env): create_output = app_client.post("/threads/").json() response = app_client.post( f"/qa/chat/{create_output['thread_id']}", - json={"inputs": "This is my query", "parameters": {}}, + json={"query": "This is my query"}, ) assert response.status_code == 200 assert response.json() == agent_output.model_dump() - # Missing thread_id inputs - response = app_client.post( - "/qa/chat", json={"inputs": "This is my query", "parameters": {}} - ) + # Missing thread_id query + response = app_client.post("/qa/chat", json={"query": "This is my query"}) assert response.status_code == 404 @@ -123,7 +119,7 @@ def test_chat_streamed(app_client, tmp_path, patch_required_env): create_output = app_client.post("/threads/").json() response = app_client.post( f"/qa/chat_streamed/{create_output['thread_id']}", - json={"inputs": "This is my query", "parameters": {}}, + json={"query": "This is my query"}, ) assert response.status_code == 200 assert response.content == expected_tokens diff --git a/tests/conftest.py b/tests/conftest.py index 0ee31f7..d4c5f6d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,10 +30,8 @@ def client_fixture(): knowledge_graph={ "base_url": "https://fake_url/api/nexus/v1", }, - generative={ - "openai": { - "token": "fake_token", - } + openai={ + "token": "fake_token", }, keycloak={ "username": "fake_username", @@ -59,7 +57,7 @@ def patch_required_env(monkeypatch): monkeypatch.setenv( "NEUROAGENT_KNOWLEDGE_GRAPH__BASE_URL", "https://fake_url/api/nexus/v1" ) - monkeypatch.setenv("NEUROAGENT_GENERATIVE__OPENAI__TOKEN", "dummy") + monkeypatch.setenv("NEUROAGENT_OPENAI__TOKEN", "dummy") monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__TOKEN", "token") monkeypatch.setenv("NEUROAGENT_KNOWLEDGE_GRAPH__USE_TOKEN", "true")