Skip to content

Commit

Permalink
SQUASHED & CHERRY PICKED from feature/anthropic-native-bash-tool
Browse files Browse the repository at this point in the history
  • Loading branch information
jjallaire authored and Eric Patey committed Jan 8, 2025
1 parent 6a90664 commit 45451b6
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 19 deletions.
37 changes: 37 additions & 0 deletions examples/hello_computer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from inspect_ai import Task, task
from inspect_ai.dataset import Sample
from inspect_ai.solver import generate, use_tools
from inspect_ai.tool import tool


@tool
def computer():
async def execute(
action: str,
text: str | None = None,
coordinate: list[int] | None = None,
) -> str:
"""Take an action using a computer.
Args:
action: Action to take.
text: Text related to the action
coordinate: Coordinate related to the action.
Returns:
The sound that was passed to check.
"""
return action

return execute


@task
def hello_computer():
return Task(
dataset=[Sample(input="Call the computer tool with the action 'screenshot'")],
solver=[
use_tools([computer()]),
generate(),
],
)
13 changes: 13 additions & 0 deletions src/inspect_ai/_cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,14 @@ def eval_options(func: Callable[..., Any]) -> Callable[..., click.Context]:
help="Whether to enable parallel function calling during tool use (defaults to True) OpenAI and Groq only.",
envvar="INSPECT_EVAL_PARALLEL_TOOL_CALLS",
)
@click.option(
"--internal-tools/--no-internal-tools",
type=bool,
is_flag=True,
default=True,
help="Whether to automatically map tools to model internal implementations (e.g. 'computer' for anthropic).",
envvar="INSPECT_EVAL_INTERNAL_TOOLS",
)
@click.option(
"--max-tool-output",
type=int,
Expand Down Expand Up @@ -438,6 +446,7 @@ def eval_command(
logprobs: bool | None,
top_logprobs: int | None,
parallel_tool_calls: bool | None,
internal_tools: bool | None,
max_tool_output: int | None,
cache_prompt: str | None,
reasoning_effort: str | None,
Expand Down Expand Up @@ -597,6 +606,7 @@ def eval_set_command(
logprobs: bool | None,
top_logprobs: int | None,
parallel_tool_calls: bool | None,
internal_tools: bool | None,
max_tool_output: int | None,
cache_prompt: str | None,
reasoning_effort: str | None,
Expand Down Expand Up @@ -835,6 +845,9 @@ def config_from_locals(locals: dict[str, Any]) -> GenerateConfigArgs:
if key == "parallel_tool_calls":
if value is not False:
value = None
if key == "internal_tools":
if value is not False:
value = None
config[key] = value # type: ignore
return config

Expand Down
15 changes: 15 additions & 0 deletions src/inspect_ai/_view/www/log-schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,7 @@
"logprobs": null,
"top_logprobs": null,
"parallel_tool_calls": null,
"internal_tools": null,
"max_tool_output": null,
"cache_prompt": null,
"reasoning_effort": null
Expand Down Expand Up @@ -2118,6 +2119,18 @@
"default": null,
"title": "Parallel Tool Calls"
},
"internal_tools": {
"anyOf": [
{
"type": "boolean"
},
{
"type": "null"
}
],
"default": null,
"title": "Internal Tools"
},
"max_tool_output": {
"anyOf": [
{
Expand Down Expand Up @@ -2186,6 +2199,7 @@
"logprobs",
"top_logprobs",
"parallel_tool_calls",
"internal_tools",
"max_tool_output",
"cache_prompt",
"reasoning_effort"
Expand Down Expand Up @@ -4123,6 +4137,7 @@
"best_of": null,
"cache_prompt": null,
"frequency_penalty": null,
"internal_tools": null,
"logit_bias": null,
"logprobs": null,
"max_connections": null,
Expand Down
3 changes: 3 additions & 0 deletions src/inspect_ai/_view/www/src/types/log.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export type NumChoices = number | null;
export type Logprobs = boolean | null;
export type TopLogprobs = number | null;
export type ParallelToolCalls = boolean | null;
export type InternalTools = boolean | null;
export type MaxToolOutput = number | null;
export type CachePrompt = "auto" | boolean | null;
export type ReasoningEffort = ("low" | "medium" | "high") | null;
Expand Down Expand Up @@ -531,6 +532,7 @@ export interface GenerateConfig {
logprobs: Logprobs;
top_logprobs: TopLogprobs;
parallel_tool_calls: ParallelToolCalls;
internal_tools: InternalTools;
max_tool_output: MaxToolOutput;
cache_prompt: CachePrompt;
reasoning_effort: ReasoningEffort;
Expand Down Expand Up @@ -873,6 +875,7 @@ export interface GenerateConfig1 {
logprobs: Logprobs;
top_logprobs: TopLogprobs;
parallel_tool_calls: ParallelToolCalls;
internal_tools: InternalTools;
max_tool_output: MaxToolOutput;
cache_prompt: CachePrompt;
reasoning_effort: ReasoningEffort;
Expand Down
6 changes: 6 additions & 0 deletions src/inspect_ai/model/_generate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class GenerateConfigArgs(TypedDict, total=False):
parallel_tool_calls: bool | None
"""Whether to enable parallel function calling during tool use (defaults to True). OpenAI and Groq only."""

internal_tools: bool | None
"""Whether to automatically map tools to model internal implementations (e.g. 'computer' for anthropic)."""

max_tool_output: int | None
"""Maximum tool output (in bytes). Defaults to 16 * 1024."""

Expand Down Expand Up @@ -136,6 +139,9 @@ class GenerateConfig(BaseModel):
parallel_tool_calls: bool | None = Field(default=None)
"""Whether to enable parallel function calling during tool use (defaults to True). OpenAI and Groq only."""

internal_tools: bool | None = Field(default=None)
"""Whether to automatically map tools to model internal implementations (e.g. 'computer' for anthropic)."""

max_tool_output: int | None = Field(default=None)
"""Maximum tool output (in bytes). Defaults to 16 * 1024."""

Expand Down
96 changes: 77 additions & 19 deletions src/inspect_ai/model/_providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from copy import copy
from logging import getLogger
from typing import Any, Literal, Tuple, cast
from typing import Any, Literal, NotRequired, Tuple, TypedDict, cast

from anthropic import (
APIConnectionError,
Expand Down Expand Up @@ -142,7 +142,7 @@ def model_call() -> ModelCall:
system_param,
tools_param,
messages,
cache_prompt,
computer_use,
) = await resolve_chat_input(self.model_name, input, tools, config)

# prepare request params (assembed this way so we can log the raw model call)
Expand All @@ -158,13 +158,11 @@ def model_call() -> ModelCall:
# additional options
request = request | self.completion_params(config)

# caching header
if cache_prompt:
request["extra_headers"] = {
"anthropic-beta": "prompt-caching-2024-07-31"
}
# computer use beta
if computer_use:
request["extra_headers"] = {"anthropic-beta": "computer-use-2024-10-22"}

# call model
# make request
message = await self.client.messages.create(**request, stream=False)

# set response for ModelCall
Expand Down Expand Up @@ -256,6 +254,9 @@ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | None:
elif "content filtering" in error:
content = "Sorry, but I am unable to help with that request."
stop_reason = "content_filter"
else:
content = error
stop_reason = "unknown"

if content and stop_reason:
return ModelOutput.from_content(
Expand All @@ -268,12 +269,26 @@ def handle_bad_request(self, ex: BadRequestError) -> ModelOutput | None:
return None


# native anthropic tool definitions for computer use beta
# https://docs.anthropic.com/en/docs/build-with-claude/computer-use
class ComputerUseToolParam(TypedDict):
type: str
name: str
display_width_px: NotRequired[int]
display_height_px: NotRequired[int]
display_number: NotRequired[int]


# tools can be either a stock tool param or a special computer use tool param
ToolParamDef = ToolParam | ComputerUseToolParam


async def resolve_chat_input(
model: str,
input: list[ChatMessage],
tools: list[ToolInfo],
config: GenerateConfig,
) -> Tuple[list[TextBlockParam] | None, list[ToolParam], list[MessageParam], bool]:
) -> Tuple[list[TextBlockParam] | None, list[ToolParamDef], list[MessageParam], bool]:
# extract system message
system_messages, messages = split_system_messages(input, config)

Expand All @@ -286,14 +301,7 @@ async def resolve_chat_input(
)

# tools
tools_params = [
ToolParam(
name=tool.name,
description=tool.description,
input_schema=tool.parameters.model_dump(exclude_none=True),
)
for tool in tools
]
tools_params, computer_use = tool_params_for_tools(tools, config)

# system messages
if len(system_messages) > 0:
Expand Down Expand Up @@ -343,10 +351,60 @@ async def resolve_chat_input(
add_cache_control(cast(dict[str, Any], content[-1]))

# return chat input
return system_param, tools_params, message_params, cache_prompt
return system_param, tools_params, message_params, computer_use


def tool_params_for_tools(
tools: list[ToolInfo], config: GenerateConfig
) -> tuple[list[ToolParamDef], bool]:
# tool params and computer_use bit to return
tool_params: list[ToolParamDef] = []
computer_use = False

# for each tool, check if it has a native computer use implementation and use that
# when available (noting that we need to set the computer use request header)
for tool in tools:
computer_use_tool = (
computer_use_tool_param(tool)
if config.internal_tools is not False
else None
)
if computer_use_tool:
tool_params.append(computer_use_tool)
computer_use = True
else:
tool_params.append(
ToolParam(
name=tool.name,
description=tool.description,
input_schema=tool.parameters.model_dump(exclude_none=True),
)
)

return tool_params, computer_use


def computer_use_tool_param(tool: ToolInfo) -> ComputerUseToolParam | None:
# check for compatible 'computer' tool
if tool.name == "computer" and (
sorted(tool.parameters.properties.keys())
== sorted(["action", "coordinate", "text"])
):
return ComputerUseToolParam(
type="computer_20241022",
name="computer",
display_width_px=1024,
display_height_px=768,
display_number=1,
)
# not a computer_use tool
else:
return None


def add_cache_control(param: TextBlockParam | ToolParam | dict[str, Any]) -> None:
def add_cache_control(
param: TextBlockParam | ToolParam | ComputerUseToolParam | dict[str, Any],
) -> None:
cast(dict[str, Any], param)["cache_control"] = {"type": "ephemeral"}


Expand Down
3 changes: 3 additions & 0 deletions tools/vscode/src/@types/log.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ export type NumChoices = number | null;
export type Logprobs = boolean | null;
export type TopLogprobs = number | null;
export type ParallelToolCalls = boolean | null;
export type InternalTools = boolean | null;
export type MaxToolOutput = number | null;
export type CachePrompt = "auto" | boolean | null;
export type ReasoningEffort = ("low" | "medium" | "high") | null;
Expand Down Expand Up @@ -531,6 +532,7 @@ export interface GenerateConfig {
logprobs: Logprobs;
top_logprobs: TopLogprobs;
parallel_tool_calls: ParallelToolCalls;
internal_tools: InternalTools;
max_tool_output: MaxToolOutput;
cache_prompt: CachePrompt;
reasoning_effort: ReasoningEffort;
Expand Down Expand Up @@ -873,6 +875,7 @@ export interface GenerateConfig1 {
logprobs: Logprobs;
top_logprobs: TopLogprobs;
parallel_tool_calls: ParallelToolCalls;
internal_tools: InternalTools;
max_tool_output: MaxToolOutput;
cache_prompt: CachePrompt;
reasoning_effort: ReasoningEffort;
Expand Down

0 comments on commit 45451b6

Please sign in to comment.