Skip to content

Commit

Permalink
Tools migration (#32)
Browse files Browse the repository at this point in the history
* Migrated bluenaas

* Migration of the rest of the tools

* Some pydantic fixes

* Fixed error when calling agents

* Removed references from neuroagent in swarm_copy

* Remove changelog job

* CR fixes

* Review fixes

* lint

* lint

* MR comments

* Added missing calculated feature

* small cleanup

* MR comments

* Add bluenaas tool + fixes

* Fix mypy

* removed BlueNaaS

* fix KG output

---------

Co-authored-by: Boris Bergsma <[email protected]>
Co-authored-by: Nicolas Frank <[email protected]>
  • Loading branch information
3 people authored Nov 12, 2024
1 parent f048cd4 commit 2fd3786
Show file tree
Hide file tree
Showing 23 changed files with 3,146 additions and 85 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- LLM evaluation logic
- Tool implementations without langchain or langgraph dependencies

## [0.3.3] - 30.10.2024

Expand Down
4 changes: 2 additions & 2 deletions src/neuroagent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tools folder."""

from neuroagent.tools.bluenaas_tool import BlueNaaSTool
from neuroagent.tools.electrophys_tool import ElectrophysFeatureTool, FeaturesOutput
from neuroagent.tools.electrophys_tool import ElectrophysFeatureTool, FeatureOutput
from neuroagent.tools.get_me_model_tool import GetMEModelTool
from neuroagent.tools.get_morpho_tool import GetMorphoTool, KnowledgeGraphOutput
from neuroagent.tools.kg_morpho_features_tool import (
Expand All @@ -26,7 +26,7 @@
"BlueNaaSTool",
"BRResolveOutput",
"ElectrophysFeatureTool",
"FeaturesOutput",
"FeatureOutput",
"GetMorphoTool",
"GetTracesTool",
"KGMorphoFeatureOutput",
Expand Down
6 changes: 3 additions & 3 deletions src/neuroagent/tools/electrophys_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class InputElectrophys(BaseModel):
)


class FeaturesOutput(BaseToolOutput):
class FeatureOutput(BaseToolOutput):
"""Output schema for the neurom tool."""

brain_region: str
Expand Down Expand Up @@ -193,7 +193,7 @@ async def _arun(
calculated_feature: CALCULATED_FEATURES | None = None,
stimuli_types: STIMULI_TYPES | None = None,
amplitude: AmplitudeInput | None = None,
) -> FeaturesOutput | dict[str, str]:
) -> FeatureOutput | dict[str, str]:
"""Give features about trace.
Parameters
Expand Down Expand Up @@ -327,7 +327,7 @@ async def _arun(
output_features[protocol_name]["stimulus_current"] = (
f"{protocol_def['step']['amp']} nA"
)
return FeaturesOutput(
return FeatureOutput(
brain_region=metadata.brain_region, feature_dict=output_features
)
except Exception as e:
Expand Down
156 changes: 153 additions & 3 deletions swarm_copy/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,45 @@
from typing import Annotated, Any, AsyncIterator

from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPBearer
from httpx import AsyncClient, HTTPStatusError
from keycloak import KeycloakOpenID
from openai import AsyncOpenAI
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
from starlette.status import HTTP_401_UNAUTHORIZED

from swarm_copy.app.config import Settings
from swarm_copy.app.database.sql_schemas import Threads
from swarm_copy.cell_types import CellTypesMeta
from swarm_copy.new_types import Agent
from swarm_copy.run import AgentsRoutine
from swarm_copy.tools import PrintAccountDetailsTool
from swarm_copy.tools import (
ElectrophysFeatureTool,
GetMEModelTool,
GetMorphoTool,
GetTracesTool,
KGMorphoFeatureTool,
LiteratureSearchTool,
MorphologyFeatureTool,
ResolveEntitiesTool,
)
from swarm_copy.utils import RegionMeta, get_file_from_KG

logger = logging.getLogger(__name__)


class HTTPBearerDirect(HTTPBearer):
"""HTTPBearer class that returns directly the token in the call."""

async def __call__(self, request: Request) -> str | None: # type: ignore
"""Intercept the bearer token in the headers."""
auth_credentials = await super().__call__(request)
return auth_credentials.credentials if auth_credentials else None


auth = HTTPBearerDirect(auto_error=False)


@cache
def get_settings() -> Settings:
"""Get the global settings."""
Expand Down Expand Up @@ -93,12 +120,76 @@ def get_starting_agent(
instructions="""You are a helpful assistant helping scientists with neuro-scientific questions.
You must always specify in your answers from which brain regions the information is extracted.
Do no blindly repeat the brain region requested by the user, use the output of the tools instead.""",
tools=[PrintAccountDetailsTool],
tools=[
LiteratureSearchTool,
ElectrophysFeatureTool,
GetMEModelTool,
GetMorphoTool,
KGMorphoFeatureTool,
MorphologyFeatureTool,
ResolveEntitiesTool,
GetTracesTool,
],
model=settings.openai.model,
)
return agent


async def get_httpx_client(request: Request) -> AsyncIterator[AsyncClient]:
"""Manage the httpx client for the request."""
client = AsyncClient(
timeout=None,
verify=False,
headers={"x-request-id": request.headers["x-request-id"]},
)
try:
yield client
finally:
await client.aclose()


def get_kg_token(
settings: Annotated[Settings, Depends(get_settings)],
token: Annotated[str | None, Depends(auth)],
) -> str:
"""Get a Knowledge graph token using Keycloak."""
if token:
return token
else:
instance = KeycloakOpenID(
server_url=settings.keycloak.server_url,
realm_name=settings.keycloak.realm,
client_id=settings.keycloak.client_id,
)
return instance.token(
username=settings.keycloak.username,
password=settings.keycloak.password.get_secret_value(), # type: ignore
)["access_token"]


async def get_user_id(
token: Annotated[str, Depends(auth)],
settings: Annotated[Settings, Depends(get_settings)],
httpx_client: Annotated[AsyncClient, Depends(get_httpx_client)],
) -> str:
"""Validate JWT token and returns user ID."""
if settings.keycloak.validate_token and settings.keycloak.user_info_endpoint:
try:
response = await httpx_client.get(
settings.keycloak.user_info_endpoint,
headers={"Authorization": f"Bearer {token}"},
)
response.raise_for_status()
user_info = response.json()
return user_info["sub"]
except HTTPStatusError:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token."
)
else:
return "dev"


# TEMP function, will get replaced by the CRUDs.
async def get_thread_id(
session: Annotated[AsyncSession, Depends(get_session)],
Expand All @@ -123,13 +214,72 @@ async def get_thread_id(
def get_context_variables(
settings: Annotated[Settings, Depends(get_settings)],
starting_agent: Annotated[Agent, Depends(get_starting_agent)],
token: Annotated[str, Depends(get_kg_token)],
httpx_client: Annotated[AsyncClient, Depends(get_httpx_client)],
) -> dict[str, Any]:
"""Get the global context variables to feed the tool's metadata."""
return {"user_id": 1234, "starting_agent": starting_agent}
return {
"starting_agent": starting_agent,
"token": token,
"retriever_k": settings.tools.literature.retriever_k,
"reranker_k": settings.tools.literature.reranker_k,
"use_reranker": settings.tools.literature.use_reranker,
"literature_search_url": settings.tools.literature.url,
"knowledge_graph_url": settings.knowledge_graph.url,
"me_model_search_size": settings.tools.me_model.search_size,
"brainregion_path": settings.knowledge_graph.br_saving_path,
"celltypes_path": settings.knowledge_graph.ct_saving_path,
"morpho_search_size": settings.tools.morpho.search_size,
"kg_morpho_feature_search_size": settings.tools.kg_morpho_features.search_size,
"trace_search_size": settings.tools.trace.search_size,
"kg_sparql_url": settings.knowledge_graph.sparql_url,
"kg_class_view_url": settings.knowledge_graph.class_view_url,
"httpx_client": httpx_client,
}


def get_agents_routine(
openai: Annotated[AsyncOpenAI | None, Depends(get_openai_client)],
) -> AgentsRoutine:
"""Get the AgentRoutine client."""
return AgentsRoutine(openai)


async def get_update_kg_hierarchy(
token: Annotated[str, Depends(get_kg_token)],
httpx_client: Annotated[AsyncClient, Depends(get_httpx_client)],
settings: Annotated[Settings, Depends(get_settings)],
file_name: str = "brainregion.json",
) -> None:
"""Query file from KG and update the local hierarchy file."""
file_url = f"<{settings.knowledge_graph.hierarchy_url}/brainregion>"
KG_hierarchy = await get_file_from_KG(
file_url=file_url,
file_name=file_name,
view_url=settings.knowledge_graph.sparql_url,
token=token,
httpx_client=httpx_client,
)
RegionMeta_temp = RegionMeta.from_KG_dict(KG_hierarchy)
RegionMeta_temp.save_config(settings.knowledge_graph.br_saving_path)
logger.info("Knowledge Graph Brain Regions Hierarchy file updated.")


async def get_cell_types_kg_hierarchy(
token: Annotated[str, Depends(get_kg_token)],
httpx_client: Annotated[AsyncClient, Depends(get_httpx_client)],
settings: Annotated[Settings, Depends(get_settings)],
file_name: str = "celltypes.json",
) -> None:
"""Query file from KG and update the local hierarchy file."""
file_url = f"<{settings.knowledge_graph.hierarchy_url}/celltypes>"
hierarchy = await get_file_from_KG(
file_url=file_url,
file_name=file_name,
view_url=settings.knowledge_graph.sparql_url,
token=token,
httpx_client=httpx_client,
)
celltypesmeta = CellTypesMeta.from_dict(hierarchy)
celltypesmeta.save_config(settings.knowledge_graph.ct_saving_path)
logger.info("Knowledge Graph Cell Types Hierarchy file updated.")
19 changes: 18 additions & 1 deletion swarm_copy/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
from asgi_correlation_id import CorrelationIdMiddleware
from fastapi import Depends, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from httpx import AsyncClient

from swarm_copy.app.app_utils import setup_engine
from swarm_copy.app.config import Settings
from swarm_copy.app.database.sql_schemas import Base
from swarm_copy.app.dependencies import (
get_cell_types_kg_hierarchy,
get_connection_string,
get_kg_token,
get_settings,
get_update_kg_hierarchy,
)
from swarm_copy.app.routers import qa

Expand Down Expand Up @@ -72,9 +76,22 @@ async def lifespan(fastapi_app: FastAPI) -> AsyncContextManager[None]: # type:
await conn.run_sync(Base.metadata.create_all)

logging.getLogger().setLevel(app_settings.logging.external_packages.upper())
logging.getLogger("neuroagent").setLevel(app_settings.logging.level.upper())
logging.getLogger("swarm_copy").setLevel(app_settings.logging.level.upper())
logging.getLogger("bluepyefe").setLevel("CRITICAL")

if app_settings.knowledge_graph.download_hierarchy:
# update KG hierarchy file if requested
await get_update_kg_hierarchy(
token=get_kg_token(app_settings, token=None),
httpx_client=AsyncClient(),
settings=app_settings,
)
await get_cell_types_kg_hierarchy(
token=get_kg_token(app_settings, token=None),
httpx_client=AsyncClient(),
settings=app_settings,
)

yield
if engine:
await engine.dispose()
Expand Down
Loading

0 comments on commit 2fd3786

Please sign in to comment.