Skip to content

Commit

Permalink
Added more tools for approval
Browse files Browse the repository at this point in the history
  • Loading branch information
cszsol committed Nov 14, 2024
1 parent 3c0133a commit 9810596
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 39 deletions.
109 changes: 76 additions & 33 deletions swarm_copy/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,49 +145,92 @@ async def handle_tool_call(

tool_metadata = tool.__annotations__["metadata"](**context_variables)

logging.info("Calling tool: {}".format(name))
parameters = {
"brain region": input_schema.brain_region,
"mtype": input_schema.mtype,
"etype": input_schema.etype
}
phantasm = Phantasm(
host="localhost",
port=2505,
)
response = phantasm.get_approval(
name="Resolve Entities",
parameters=parameters
)
if response.approved:
logging.error(f"Usage of tool {name} was approved.")
tool_instance = tool(input_schema=input_schema, metadata=tool_metadata)
# pass context_variables to agent functions
try:
raw_result = await tool_instance.arun()
except Exception as err:
response = {
if name == "resolve-entities-tool":
parameters = {
"brain region": input_schema.brain_region,
"mtype": input_schema.mtype,
"etype": input_schema.etype
}
phantasm = Phantasm(
host="localhost",
port=2505,
)
response = phantasm.get_approval(
name="Resolve Entities",
parameters=parameters
)
logging.error(f"Response: {response}")
if response.approved:
logging.error(f"Usage of tool {name} was approved. New parameters: {response.parameters}")
input_schema.brain_region = response.parameters["brain region"]
input_schema.mtype = response.parameters["mtype"]
input_schema.etype = response.parameters["etype"]
result = await self.call_tool(input_schema, name, tool, tool_call, tool_metadata)
logging.error(f"Result: {result}")
return result
else:
logging.error(f"Usage of tool {name} was NOT approved.")
return {
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": f"Error: usage of tool '{name}' was denied.",
}, None
elif name == "get-traces-tool":
parameters = {
"brain region id": input_schema.brain_region_id,
"etype id": input_schema.etype_id
}
phantasm = Phantasm(
host="localhost",
port=2505,
)
response = phantasm.get_approval(
name="Get Traces",
parameters=parameters
)
logging.error(f"Response: {response}")
if response.approved:
logging.error(f"Usage of tool {name} was approved. New parameters: {response.parameters}")
input_schema.brain_region_id = response.parameters["brain region id"]
input_schema.etype_id = response.parameters["etype id"]
return await self.call_tool(input_schema, name, tool, tool_call, tool_metadata)
else:
logging.error(f"Usage of tool {name} was NOT approved.")
return {
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": str(err),
}
return response, None
"content": f"Error: usage of tool '{name}' was denied.",
}, None
else:
return await self.call_tool(input_schema, name, tool, tool_call, tool_metadata)

result: Result = self.handle_function_result(raw_result)
async def call_tool(self, input_schema, name, tool, tool_call, tool_metadata):
tool_instance = tool(input_schema=input_schema, metadata=tool_metadata)
# pass context_variables to agent functions
try:
raw_result = await tool_instance.arun()
except Exception as err:
response = {
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": result.value,
"content": str(err),
}
if result.agent:
agent = result.agent
else:
agent = None
return response, agent
return response, None
result: Result = self.handle_function_result(raw_result)
response = {
"role": "tool",
"tool_call_id": tool_call.id,
"tool_name": name,
"content": result.value,
}
if result.agent:
agent = result.agent
else:
logging.error(f"Usage of tool {name} was NOT approved.")
agent = None
return response, agent

async def arun(
self,
Expand Down
4 changes: 4 additions & 0 deletions swarm_copy/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,7 @@ def pydantic_to_openai_schema(cls) -> ChatCompletionToolParam:
@abstractmethod
async def arun(self) -> Any:
"""Run the tool."""


class HILToolOutput(BaseModel):
status: str = "approved"
8 changes: 4 additions & 4 deletions swarm_copy/tools/resolve_entities_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ETYPE_IDS,
BaseMetadata,
BaseTool,
EtypesLiteral,
EtypesLiteral, HILToolOutput,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -41,21 +41,21 @@ class ResolveBRInput(BaseModel):
)


class BRResolveOutput(BaseModel):
class BRResolveOutput(HILToolOutput):
"""Output schema for the Brain region resolver."""

brain_region_name: str
brain_region_id: str


class MTypeResolveOutput(BaseModel):
class MTypeResolveOutput(HILToolOutput):
"""Output schema for the Mtype resolver."""

mtype_name: str
mtype_id: str


class EtypeResolveOutput(BaseModel):
class EtypeResolveOutput(HILToolOutput):
"""Output schema for the Mtype resolver."""

etype_name: str
Expand Down
4 changes: 2 additions & 2 deletions swarm_copy/tools/traces_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pydantic import BaseModel, Field

from swarm_copy.tools.base_tool import BaseMetadata, BaseTool
from swarm_copy.tools.base_tool import BaseMetadata, BaseTool, HILToolOutput
from swarm_copy.utils import get_descendants_id

logger = logging.getLogger(__name__)
Expand All @@ -25,7 +25,7 @@ class GetTracesInput(BaseModel):
)


class TracesOutput(BaseModel):
class TracesOutput(HILToolOutput):
"""Output schema for the traces."""

trace_id: str
Expand Down

0 comments on commit 9810596

Please sign in to comment.