Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add anthropic tool use #39

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions aide/backend/backend_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Backend for Anthropic API."""

import logging
import time

from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
from funcy import notnone, once, select_values
import anthropic

logger = logging.getLogger("aide")

_client: anthropic.Anthropic = None # type: ignore

ANTHROPIC_TIMEOUT_EXCEPTIONS = (
Expand All @@ -15,6 +18,10 @@
anthropic.InternalServerError,
)

ANTHROPIC_MODEL_ALIASES = {
"claude-3.5-sonnet": "claude-3-sonnet-20241022",
}


@once
def _setup_anthropic_client():
Expand All @@ -28,23 +35,32 @@ def query(
func_spec: FunctionSpec | None = None,
**model_kwargs,
) -> tuple[OutputType, float, int, int, dict]:
"""
Query Anthropic's API, optionally with tool use (Anthropic's equivalent to function calling).
"""
_setup_anthropic_client()

filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore
if "max_tokens" not in filtered_kwargs:
filtered_kwargs["max_tokens"] = 4096 # default for Claude models

if func_spec is not None:
raise NotImplementedError(
"Anthropic does not support function calling for now."
)
model_name = filtered_kwargs.get("model", "")
logger.debug(f"Anthropic query called with model='{model_name}'")

if model_name in ANTHROPIC_MODEL_ALIASES:
model_name = ANTHROPIC_MODEL_ALIASES[model_name]

if func_spec is not None and func_spec.name == "submit_review":
filtered_kwargs["tools"] = [func_spec.as_anthropic_tool_dict]
# Force tool use
filtered_kwargs["tool_choice"] = func_spec.anthropic_tool_choice_dict

# Anthropic doesn't allow not having a user messages
# Anthropic doesn't allow not having user messages
# if we only have system msg -> use it as user msg
if system_message is not None and user_message is None:
system_message, user_message = user_message, system_message

# Anthropic passes the system messages as a separate argument
# Anthropic passes system messages as a separate argument
if system_message is not None:
filtered_kwargs["system"] = system_message

Expand All @@ -59,14 +75,33 @@ def query(
)
req_time = time.time() - t0

assert len(message.content) == 1 and message.content[0].type == "text"
# Handle tool calls if present
if (
func_spec is not None
and "tools" in filtered_kwargs
and len(message.content) > 0
and message.content[0].type == "tool_use"
):
block = message.content[0] # This is a "ToolUseBlock"
# block has attributes: type, id, name, input
assert (
block.name == func_spec.name
), f"Function name mismatch: expected {func_spec.name}, got {block.name}"
output = block.input # Anthropic calls the parameters "input"
else:
# For non-tool responses, ensure we have text content
assert len(message.content) == 1, "Expected single content item"
assert (
message.content[0].type == "text"
), f"Expected text response, got {message.content[0].type}"
output = message.content[0].text

output: str = message.content[0].text
in_tokens = message.usage.input_tokens
out_tokens = message.usage.output_tokens

info = {
"stop_reason": message.stop_reason,
"model": message.model,
}

return output, req_time, in_tokens, out_tokens, info
18 changes: 18 additions & 0 deletions aide/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __post_init__(self):

@property
def as_openai_tool_dict(self):
"""Convert to OpenAI's function format."""
return {
"type": "function",
"function": {
Expand All @@ -81,3 +82,20 @@ def openai_tool_choice_dict(self):
"type": "function",
"function": {"name": self.name},
}

@property
def as_anthropic_tool_dict(self):
"""Convert to Anthropic's tool format."""
return {
"name": self.name,
"description": self.description,
"input_schema": self.json_schema, # Anthropic uses input_schema instead of parameters
}

@property
def anthropic_tool_choice_dict(self):
"""Convert to Anthropic's tool choice format."""
return {
"type": "tool", # Anthropic uses "tool" instead of "function"
"name": self.name,
}
Loading