Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Human In the Loop for the BlueNaaS tool #20

Merged
merged 7 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- BlueNaaS simulation tool.
- Validation of the project ID.
- BlueNaaS tool test.
- Human in the loop for bluenaas.

## [0.1.1] - 26.09.2024

Expand Down
4 changes: 2 additions & 2 deletions src/neuroagent/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,6 @@ async def get_vlab_and_project(
def get_agent(
_: Annotated[None, Depends(get_vlab_and_project)],
llm: Annotated[ChatOpenAI, Depends(get_language_model)],
bluenaas_tool: Annotated[BlueNaaSTool, Depends(get_bluenaas_tool)],
literature_tool: Annotated[LiteratureSearchTool, Depends(get_literature_tool)],
br_resolver_tool: Annotated[
ResolveBrainRegionTool, Depends(get_brain_region_resolver_tool)
Expand Down Expand Up @@ -454,7 +453,6 @@ def get_agent(
return SupervisorMultiAgent(llm=llm, agents=tools_list) # type: ignore
else:
tools = [
bluenaas_tool,
literature_tool,
br_resolver_tool,
morpho_tool,
Expand All @@ -481,6 +479,7 @@ def get_chat_agent(
morphology_feature_tool: Annotated[
MorphologyFeatureTool, Depends(get_morphology_feature_tool)
],
me_model_tool: Annotated[GetMEModelTool, Depends(get_me_model_tool)],
kg_morpho_feature_tool: Annotated[
KGMorphoFeatureTool, Depends(get_kg_morpho_feature_tool)
],
Expand All @@ -495,6 +494,7 @@ def get_chat_agent(
bluenaas_tool,
literature_tool,
br_resolver_tool,
me_model_tool,
morpho_tool,
morphology_feature_tool,
kg_morpho_feature_tool,
Expand Down
118 changes: 103 additions & 15 deletions src/neuroagent/tools/bluenaas_tool.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""BlueNaaS single cell stimulation, simulation and synapse placement tool."""

import json
import logging
from typing import Any, Literal
from typing import Annotated, Any, Literal

from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.tools import ToolException
from langgraph.prebuilt import InjectedState
from pydantic import BaseModel, Field

from neuroagent.tools.base_tool import BaseToolOutput, BasicTool
Expand All @@ -29,7 +32,6 @@ class InputBlueNaaS(BaseModel):
" fetched using the 'get-me-model-tool'."
)
)

current_injection__inject_to: str = Field(
default="soma[0]", description="Section to inject the current to."
)
Expand Down Expand Up @@ -64,31 +66,47 @@ class InputBlueNaaS(BaseModel):
default=0.05, ge=0.001, le=10, description="Time step in ms"
)
conditions__seed: int = Field(default=100, description="Random seed")
messages: Annotated[list[BaseMessage], InjectedState("messages")]


class BlueNaaSOutput(BaseToolOutput):
class BlueNaaSValidatedOutput(BaseToolOutput):
"""Should return a successful POST request."""

status: Literal["success", "pending", "error"]


class BlueNaaSInvalidatedOutput(BaseModel):
"""Response to the user if the simulation has not been validated yet."""

inputs: dict[str, Any]

def __str__(self) -> str:
"""Format the response passed to the LLM."""
return f"A simulation will be ran with the following inputs <json>{self.inputs}</json>. \n Please confirm that you are satisfied by the simulation parameters, or correct them accordingly."


class BlueNaaSTool(BasicTool):
"""Class defining the BlueNaaS tool."""

name: str = "bluenaas-tool"
description: str = """Runs a single-neuron simulation using the BlueNaaS service.
Requires a "me_model_id" which must be fetched by get-me-model-tool.
Optionally, the user can specify simulation parameters.
The tool will always ask for config validation from the user before running.
If the user mentions an existing configuration, it must always be passed in the tool first to get user's approval.
Specify ALL of the parameters everytime you enter this tool.
"""
metadata: dict[str, Any]
args_schema: type[BaseModel] = InputBlueNaaS
response_format: Literal["content", "content_and_artifact"] = "content_and_artifact"

def _run(self) -> None:
pass

async def _arun(
self,
me_model_id: str,
messages: Annotated[list[BaseMessage], InjectedState("messages")],
current_injection__inject_to: str = "soma[0]",
current_injection__stimulus__stimulus_type: Literal[
"current_clamp", "voltage_clamp", "conductance"
Expand All @@ -104,9 +122,10 @@ async def _arun(
conditions__max_time: int = 100,
conditions__time_step: float = 0.05,
conditions__seed: int = 100,
) -> BaseToolOutput:
) -> tuple[BaseToolOutput | BaseModel, dict[str, bool]]:
"""Run the BlueNaaS tool."""
logger.info("Running BlueNaaS tool")

json_api = self.create_json_api(
current_injection__inject_to=current_injection__inject_to,
current_injection__stimulus__stimulus_type=current_injection__stimulus__stimulus_type,
Expand All @@ -117,22 +136,28 @@ async def _arun(
conditions__vinit=conditions__vinit,
conditions__hypamp=conditions__hypamp,
conditions__max_time=conditions__max_time,
conditions__time_step=conditions__time_step,
conditions__seed=conditions__seed,
)

try:
_ = await self.metadata["httpx_client"].post(
url=self.metadata["url"],
params={"model_id": me_model_id},
headers={"Authorization": f'Bearer {self.metadata["token"]}'},
json=json_api,
timeout=5.0,
)
if self.is_validated(messages, json_api):
try:
await self.metadata["httpx_client"].post(
url=self.metadata["url"],
params={"model_id": me_model_id},
headers={"Authorization": f'Bearer {self.metadata["token"]}'},
json=json_api,
timeout=5.0,
)

return BlueNaaSOutput(status="success")
return BlueNaaSValidatedOutput(status="success"), {
"is_validated": False
}

except Exception as e:
raise ToolException(str(e), self.name)
except Exception as e:
raise ToolException(str(e), self.name)
else:
return BlueNaaSInvalidatedOutput(inputs=json_api), {"is_validated": True}

@staticmethod
def create_json_api(
Expand Down Expand Up @@ -179,3 +204,66 @@ def create_json_api(
"simulationDuration": conditions__max_time,
}
return json_api

@staticmethod
def is_validated(messages: list[BaseMessage], json_api: dict[str, Any]) -> bool:
"""Decide whether the current configuration has been validated by the user.

Parameters
----------
messages
List of Langgraph messages extracted from the graph state.
json_api
Simulation configuration that the tool will run if it has been validated.

Returns
-------
is_validated
Boolean stating wether or not the configuration has been validated by the user.
"""
# If it is the first time the tool is called in the conversation, need validation
try:
# Get the last bluenaas call
last_bluenaas_call = next(
(
message
for message in reversed(messages)
if isinstance(message, ToolMessage)
and message.name == "bluenaas-tool"
)
)
except StopIteration:
return False

# Verify if the tool has been recently called to ask for validation
# There has to be at least 3 messages in the state otherwise cannot be validated
if len(messages) > 3:
last_messages = messages[-4:-1]
recently_validated = (
isinstance(last_messages[-1], HumanMessage) # Approval from the human
and isinstance(
last_messages[-2], AIMessage
) # AI answering the human and asking for validation
and isinstance(
last_messages[-3], ToolMessage
) # First tool call not validated
and last_messages[-3].name == "bluenaas-tool"
)
# If it hasn't been recently validated, we need more validation
if not recently_validated:
return False
# If there is not enough messages in the state to have a potential validation
else:
return False

# If the previous simulation was started, ask for validation on the new one
if not last_bluenaas_call.artifact.get("is_validated"):
return False

# Verify if the config has changed since previous call. Validate again if so
old_config = json.loads(
last_bluenaas_call.content.split("<json>")[-1] # type: ignore
.split("</json>")[0]
.replace("'", '"')
)
return old_config == json_api
8 changes: 4 additions & 4 deletions tests/app/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,6 @@ async def test_get_agent(monkeypatch, httpx_mock, patch_required_env):
)

language_model = get_language_model(settings)
bluenaas_tool = get_bluenaas_tool(
settings=settings, token=token, httpx_client=httpx_client
)
literature_tool = get_literature_tool(
token=token, settings=settings, httpx_client=httpx_client
)
Expand Down Expand Up @@ -370,7 +367,6 @@ async def test_get_agent(monkeypatch, httpx_mock, patch_required_env):
agent = get_agent(
valid_project,
llm=language_model,
bluenaas_tool=bluenaas_tool,
literature_tool=literature_tool,
br_resolver_tool=br_resolver_tool,
morpho_tool=morpho_tool,
Expand Down Expand Up @@ -416,6 +412,9 @@ async def test_get_chat_agent(
literature_tool = get_literature_tool(
token=token, settings=settings, httpx_client=httpx_client
)
me_model_tool = get_me_model_tool(
settings=settings, token=token, httpx_client=httpx_client
)
morpho_tool = get_morpho_tool(
settings=settings, token=token, httpx_client=httpx_client
)
Expand Down Expand Up @@ -447,6 +446,7 @@ async def test_get_chat_agent(
br_resolver_tool=br_resolver_tool,
morpho_tool=morpho_tool,
morphology_feature_tool=morphology_feature_tool,
me_model_tool=me_model_tool,
kg_morpho_feature_tool=kg_morpho_feature_tool,
electrophys_feature_tool=electrophys_feature_tool,
traces_tool=traces_tool,
Expand Down
Loading
Loading