Skip to content

Commit

Permalink
Refactor iteration 2
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicolas Frank committed Sep 17, 2024
1 parent 9524d96 commit 3c550ba
Show file tree
Hide file tree
Showing 14 changed files with 79 additions and 165 deletions.
11 changes: 4 additions & 7 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ NEUROAGENT_GENERATIVE__OPENAI__TOKEN=

# Important but not required
NEUROAGENT_AGENT__MODEL=
NEUROAGENT_AGENT__CHAT=

NEUROAGENT_KNOWLEDGE_GRAPH__USE_TOKEN=
NEUROAGENT_KNOWLEDGE_GRAPH__TOKEN=
NEUROAGENT_KNOWLEDGE_GRAPH__DOWNLOAD_HIERARCHY=
Expand All @@ -27,12 +27,9 @@ NEUROAGENT_TOOLS__TRACE__SEARCH_SIZE=

NEUROAGENT_TOOLS__KG_MORPHO__SEARCH_SIZE=

NEUROAGENT_GENERATIVE__LLM_TYPE= # can only be openai for now
NEUROAGENT_GENERATIVE__OPENAI__MODEL=
NEUROAGENT_GENERATIVE__OPENAI__TEMPERATURE=
NEUROAGENT_GENERATIVE__OPENAI__MAX_TOKENS=

NEUROAGENT_COHERE__TOKEN=
NEUROAGENT_OPENAI__MODEL=
NEUROAGENT_OPENAI__TEMPERATURE=
NEUROAGENT_OPENAI__MAX_TOKENS=

NEUROAGENT_LOGGING__LEVEL=
NEUROAGENT_LOGGING__EXTERNAL_PACKAGES=
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Migration to pydantic V2.
- Deleted some legacy code.
52 changes: 0 additions & 52 deletions src/neuroagent/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,9 @@
from typing import Any, AsyncIterator

from langchain.chat_models.base import BaseChatModel
from langchain_core.messages import (
AIMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
SystemMessagePromptTemplate,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel, ConfigDict

BASE_PROMPT = ChatPromptTemplate(
input_variables=["agent_scratchpad", "input"],
input_types={
"chat_history": list[
AIMessage
| HumanMessage
| ChatMessage
| SystemMessage
| FunctionMessage
| ToolMessage
],
"agent_scratchpad": list[
AIMessage
| HumanMessage
| ChatMessage
| SystemMessage
| FunctionMessage
| ToolMessage
],
},
messages=[
SystemMessagePromptTemplate(
prompt=PromptTemplate(
input_variables=[],
template="""You are a helpful assistant helping scientists with neuro-scientific questions.
You must always specify in your answers from which brain regions the information is extracted.
Do no blindly repeat the brain region requested by the user, use the output of the tools instead.""",
)
),
MessagesPlaceholder(variable_name="chat_history", optional=True),
HumanMessagePromptTemplate(
prompt=PromptTemplate(input_variables=["input"], template="{input}")
),
MessagesPlaceholder(variable_name="agent_scratchpad"),
],
)


class AgentStep(BaseModel):
"""Class for agent decision steps."""
Expand Down
31 changes: 4 additions & 27 deletions src/neuroagent/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
class SettingsAgent(BaseModel):
"""Agent setting."""

model: str = "simple"
chat: str = "simple"
model: Literal["simple", "multi"] = "simple"

model_config = ConfigDict(frozen=True)

Expand Down Expand Up @@ -84,9 +83,9 @@ class SettingsLiterature(BaseModel):
"""Literature search API settings."""

url: str
retriever_k: int = 700
retriever_k: int = 500
use_reranker: bool = True
reranker_k: int = 5
reranker_k: int = 8

model_config = ConfigDict(frozen=True)

Expand Down Expand Up @@ -173,23 +172,6 @@ class SettingsOpenAI(BaseModel):
model_config = ConfigDict(frozen=True)


class SettingsGenerative(BaseModel):
"""Generative QA settings."""

llm_type: Literal["fake", "openai"] = "openai"
openai: SettingsOpenAI = SettingsOpenAI()

model_config = ConfigDict(frozen=True)


class SettingsCohere(BaseModel):
"""Settings cohere reranker."""

token: Optional[SecretStr] = None

model_config = ConfigDict(frozen=True)


class SettingsLogging(BaseModel):
"""Metadata settings."""

Expand Down Expand Up @@ -219,8 +201,7 @@ class Settings(BaseSettings):
knowledge_graph: SettingsKnowledgeGraph
agent: SettingsAgent = SettingsAgent() # has no required
db: SettingsDB = SettingsDB() # has no required
generative: SettingsGenerative = SettingsGenerative() # has no required
cohere: SettingsCohere = SettingsCohere() # has no required
openai: SettingsOpenAI = SettingsOpenAI() # has no required
logging: SettingsLogging = SettingsLogging() # has no required
keycloak: SettingsKeycloak = SettingsKeycloak() # has no required
misc: SettingsMisc = SettingsMisc() # has no required
Expand All @@ -240,10 +221,6 @@ def check_consistency(self) -> "Settings":
model validator is run during instantiation.
"""
# generative
if self.generative.llm_type == "openai":
if self.generative.openai.token is None:
raise ValueError("OpenAI token not provided")
if not self.keycloak.password and not self.keycloak.validate_token:
if not self.knowledge_graph.use_token:
raise ValueError("if no password is provided, please use token auth.")
Expand Down
84 changes: 42 additions & 42 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,12 @@ def get_language_model(
settings: Annotated[Settings, Depends(get_settings)],
) -> ChatOpenAI:
"""Get the language model."""
logger.info(f"OpenAI selected. Loading model {settings.generative.openai.model}.")
logger.info(f"OpenAI selected. Loading model {settings.openai.model}.")
return ChatOpenAI(
model_name=settings.generative.openai.model,
temperature=settings.generative.openai.temperature,
openai_api_key=settings.generative.openai.token.get_secret_value(), # type: ignore
max_tokens=settings.generative.openai.max_tokens,
model_name=settings.openai.model,
temperature=settings.openai.temperature,
openai_api_key=settings.openai.token.get_secret_value(), # type: ignore
max_tokens=settings.openai.max_tokens,
seed=78,
streaming=True,
)
Expand Down Expand Up @@ -369,43 +369,10 @@ def get_agent(
ElectrophysFeatureTool, Depends(get_electrophys_feature_tool)
],
traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)],
) -> BaseAgent | BaseMultiAgent:
"""Get the generative question answering service."""
tools = [
literature_tool,
br_resolver_tool,
morpho_tool,
morphology_feature_tool,
kg_morpho_feature_tool,
electrophys_feature_tool,
traces_tool,
]
logger.info("Load simple agent")
return SimpleAgent(llm=llm, tools=tools) # type: ignore


def get_chat_agent(
llm: Annotated[ChatOpenAI, Depends(get_language_model)],
memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)],
literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)],
br_resolver_tool: Annotated[
ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool)
],
morpho_tool: Annotated[GetMorphoTool, Depends(get_morpho_tool)],
morphology_feature_tool: Annotated[
MorphologyFeatureTool, Depends(get_morphology_feature_tool)
],
kg_morpho_feature_tool: Annotated[
KGMorphoFeatureTool, Depends(get_kg_morpho_feature_tool)
],
electrophys_feature_tool: Annotated[
ElectrophysFeatureTool, Depends(get_electrophys_feature_tool)
],
traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)],
settings: Annotated[Settings, Depends(get_settings)],
) -> BaseAgent:
) -> BaseAgent | BaseMultiAgent:
"""Get the generative question answering service."""
if settings.agent.chat == "multi":
if settings.agent.model == "multi":
logger.info("Load multi-agent chat")
tools_list = [
("literature", [literature_tool]),
Expand All @@ -422,7 +389,6 @@ def get_chat_agent(
]
return SupervisorMultiAgent(llm=llm, agents=tools_list) # type: ignore
else:
logger.info("Load simple chat")
tools = [
literature_tool,
br_resolver_tool,
Expand All @@ -432,7 +398,41 @@ def get_chat_agent(
electrophys_feature_tool,
traces_tool,
]
return SimpleChatAgent(llm=llm, tools=tools, memory=memory) # type: ignore
logger.info("Load simple agent")
return SimpleAgent(llm=llm, tools=tools) # type: ignore


def get_chat_agent(
llm: Annotated[ChatOpenAI, Depends(get_language_model)],
memory: Annotated[BaseCheckpointSaver, Depends(get_agent_memory)],
literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)],
br_resolver_tool: Annotated[
ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool)
],
morpho_tool: Annotated[GetMorphoTool, Depends(get_morpho_tool)],
morphology_feature_tool: Annotated[
MorphologyFeatureTool, Depends(get_morphology_feature_tool)
],
kg_morpho_feature_tool: Annotated[
KGMorphoFeatureTool, Depends(get_kg_morpho_feature_tool)
],
electrophys_feature_tool: Annotated[
ElectrophysFeatureTool, Depends(get_electrophys_feature_tool)
],
traces_tool: Annotated[GetTracesTool, Depends(get_traces_tool)],
) -> BaseAgent:
"""Get the generative question answering service."""
logger.info("Load simple chat")
tools = [
literature_tool,
br_resolver_tool,
morpho_tool,
morphology_feature_tool,
kg_morpho_feature_tool,
electrophys_feature_tool,
traces_tool,
]
return SimpleChatAgent(llm=llm, tools=tools, memory=memory) # type: ignore


async def get_update_kg_hierarchy(
Expand Down
12 changes: 6 additions & 6 deletions src/neuroagent/app/routers/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ async def run_agent(
) -> AgentOutput:
"""Run agent."""
logger.info("Running agent query.")
logger.info(f"User's query: {request.inputs}")
return await agent.arun(request.inputs)
logger.info(f"User's query: {request.query}")
return await agent.arun(request.query)


@router.post("/chat/{thread_id}", response_model=AgentOutput)
Expand All @@ -47,8 +47,8 @@ async def run_chat_agent(
) -> AgentOutput:
"""Run chat agent."""
logger.info("Running agent query.")
logger.info(f"User's query: {request.inputs}")
return await agent.arun(query=request.inputs, thread_id=thread_id)
logger.info(f"User's query: {request.query}")
return await agent.arun(query=request.query, thread_id=thread_id)


@router.post("/chat_streamed/{thread_id}")
Expand All @@ -60,5 +60,5 @@ async def run_streamed_chat_agent(
) -> StreamingResponse:
"""Run agent in streaming mode."""
logger.info("Running agent query.")
logger.info(f"User's query: {request.inputs}")
return StreamingResponse(agent.astream(query=request.inputs, thread_id=thread_id)) # type: ignore
logger.info(f"User's query: {request.query}")
return StreamingResponse(agent.astream(query=request.query, thread_id=thread_id)) # type: ignore
5 changes: 1 addition & 4 deletions src/neuroagent/app/schemas.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
"""Schemas."""

from typing import Any

from pydantic import BaseModel


class AgentRequest(BaseModel):
"""Class for agent request."""

inputs: str
parameters: dict[str, Any]
query: str
4 changes: 2 additions & 2 deletions tests/app/database/test_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def test_get_thread(
# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"inputs": "This is my query", "parameters": {}},
json={"query": "This is my query"},
)

create_output = app_client.post("/threads/").json()
Expand Down Expand Up @@ -131,7 +131,7 @@ async def test_delete_thread(
# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"inputs": "This is my query", "parameters": {}},
json={"query": "This is my query"},
params={"thread_id": thread_id},
)
# Get the messages of the thread
Expand Down
4 changes: 2 additions & 2 deletions tests/app/database/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_get_tool_calls(
# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"inputs": "This is my query", "parameters": {}},
json={"query": "This is my query"},
params={"thread_id": thread_id},
)

Expand Down Expand Up @@ -121,7 +121,7 @@ async def test_get_tool_output(
# Fill the thread
app_client.post(
f"/qa/chat/{thread_id}",
json={"inputs": "This is my query", "parameters": {}},
json={"query": "This is my query"},
params={"thread_id": thread_id},
)

Expand Down
2 changes: 1 addition & 1 deletion tests/app/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_required(monkeypatch, patch_required_env):

assert settings.tools.literature.url == "https://fake_url"
assert settings.knowledge_graph.base_url == "https://fake_url/api/nexus/v1"
assert settings.generative.openai.token.get_secret_value() == "dummy"
assert settings.openai.token.get_secret_value() == "dummy"
assert settings.knowledge_graph.use_token
assert settings.knowledge_graph.token.get_secret_value() == "token"

Expand Down
Loading

0 comments on commit 3c550ba

Please sign in to comment.