diff --git a/swarm_copy/run.py b/swarm_copy/run.py index 0b0f418..93b275d 100644 --- a/swarm_copy/run.py +++ b/swarm_copy/run.py @@ -145,49 +145,92 @@ async def handle_tool_call( tool_metadata = tool.__annotations__["metadata"](**context_variables) - logging.info("Calling tool: {}".format(name)) - parameters = { - "brain region": input_schema.brain_region, - "mtype": input_schema.mtype, - "etype": input_schema.etype - } - phantasm = Phantasm( - host="localhost", - port=2505, - ) - response = phantasm.get_approval( - name="Resolve Entities", - parameters=parameters - ) - if response.approved: - logging.error(f"Usage of tool {name} was approved.") - tool_instance = tool(input_schema=input_schema, metadata=tool_metadata) - # pass context_variables to agent functions - try: - raw_result = await tool_instance.arun() - except Exception as err: - response = { + if name == "resolve-entities-tool": + parameters = { + "brain region": input_schema.brain_region, + "mtype": input_schema.mtype, + "etype": input_schema.etype + } + phantasm = Phantasm( + host="localhost", + port=2505, + ) + response = phantasm.get_approval( + name="Resolve Entities", + parameters=parameters + ) + logging.error(f"Response: {response}") + if response.approved: + logging.error(f"Usage of tool {name} was approved. New parameters: {response.parameters}") + input_schema.brain_region = response.parameters["brain region"] + input_schema.mtype = response.parameters["mtype"] + input_schema.etype = response.parameters["etype"] + result = await self.call_tool(input_schema, name, tool, tool_call, tool_metadata) + logging.error(f"Result: {result}") + return result + else: + logging.error(f"Usage of tool {name} was NOT approved.") + return { + "role": "tool", + "tool_call_id": tool_call.id, + "tool_name": name, + "content": f"Error: usage of tool '{name}' was denied.", + }, None + elif name == "get-traces-tool": + parameters = { + "brain region id": input_schema.brain_region_id, + "etype id": input_schema.etype_id + } + phantasm = Phantasm( + host="localhost", + port=2505, + ) + response = phantasm.get_approval( + name="Get Traces", + parameters=parameters + ) + logging.error(f"Response: {response}") + if response.approved: + logging.error(f"Usage of tool {name} was approved. New parameters: {response.parameters}") + input_schema.brain_region_id = response.parameters["brain region id"] + input_schema.etype_id = response.parameters["etype id"] + return await self.call_tool(input_schema, name, tool, tool_call, tool_metadata) + else: + logging.error(f"Usage of tool {name} was NOT approved.") + return { "role": "tool", "tool_call_id": tool_call.id, "tool_name": name, - "content": str(err), - } - return response, None + "content": f"Error: usage of tool '{name}' was denied.", + }, None + else: + return await self.call_tool(input_schema, name, tool, tool_call, tool_metadata) - result: Result = self.handle_function_result(raw_result) + async def call_tool(self, input_schema, name, tool, tool_call, tool_metadata): + tool_instance = tool(input_schema=input_schema, metadata=tool_metadata) + # pass context_variables to agent functions + try: + raw_result = await tool_instance.arun() + except Exception as err: response = { "role": "tool", "tool_call_id": tool_call.id, "tool_name": name, - "content": result.value, + "content": str(err), } - if result.agent: - agent = result.agent - else: - agent = None - return response, agent + return response, None + result: Result = self.handle_function_result(raw_result) + response = { + "role": "tool", + "tool_call_id": tool_call.id, + "tool_name": name, + "content": result.value, + } + if result.agent: + agent = result.agent else: - logging.error(f"Usage of tool {name} was NOT approved.") + agent = None + return response, agent async def arun( self, diff --git a/swarm_copy/tools/base_tool.py b/swarm_copy/tools/base_tool.py index 71af055..de94dc5 100644 --- a/swarm_copy/tools/base_tool.py +++ b/swarm_copy/tools/base_tool.py @@ -83,3 +83,7 @@ def pydantic_to_openai_schema(cls) -> ChatCompletionToolParam: @abstractmethod async def arun(self) -> Any: """Run the tool.""" + + +class HILToolOutput(BaseModel): + status: str = "approved" diff --git a/swarm_copy/tools/resolve_entities_tool.py b/swarm_copy/tools/resolve_entities_tool.py index 1264ac1..7247be0 100644 --- a/swarm_copy/tools/resolve_entities_tool.py +++ b/swarm_copy/tools/resolve_entities_tool.py @@ -10,7 +10,7 @@ ETYPE_IDS, BaseMetadata, BaseTool, - EtypesLiteral, + EtypesLiteral, HILToolOutput, ) logger = logging.getLogger(__name__) @@ -41,21 +41,21 @@ class ResolveBRInput(BaseModel): ) -class BRResolveOutput(BaseModel): +class BRResolveOutput(HILToolOutput): """Output schema for the Brain region resolver.""" brain_region_name: str brain_region_id: str -class MTypeResolveOutput(BaseModel): +class MTypeResolveOutput(HILToolOutput): """Output schema for the Mtype resolver.""" mtype_name: str mtype_id: str -class EtypeResolveOutput(BaseModel): +class EtypeResolveOutput(HILToolOutput): """Output schema for the Mtype resolver.""" etype_name: str diff --git a/swarm_copy/tools/traces_tool.py b/swarm_copy/tools/traces_tool.py index 41028b2..f9882c4 100644 --- a/swarm_copy/tools/traces_tool.py +++ b/swarm_copy/tools/traces_tool.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field -from swarm_copy.tools.base_tool import BaseMetadata, BaseTool +from swarm_copy.tools.base_tool import BaseMetadata, BaseTool, HILToolOutput from swarm_copy.utils import get_descendants_id logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ class GetTracesInput(BaseModel): ) -class TracesOutput(BaseModel): +class TracesOutput(HILToolOutput): """Output schema for the traces.""" trace_id: str