diff --git a/CHANGELOG.md b/CHANGELOG.md index 565e35fc8..e75efa856 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## v0.3.11 (30 May 2024) + +- Update to non-beta version of Anthropic tool use (remove legacy xml tools implementation). + ## v0.3.10 (29 May 2024) - **BREAKING:** The `pattern` scorer has been modified to match against any (or all) regex match groups. This replaces the previous behaviour when there was more than one group, which would only match the second group. diff --git a/src/inspect_ai/model/_providers/anthropic.py b/src/inspect_ai/model/_providers/anthropic.py index 774ec4a01..adb6937a3 100644 --- a/src/inspect_ai/model/_providers/anthropic.py +++ b/src/inspect_ai/model/_providers/anthropic.py @@ -1,10 +1,5 @@ -import ast -import builtins import os -import re -from copy import deepcopy from typing import Any, Tuple, cast -from xml.sax.saxutils import escape from anthropic import ( APIConnectionError, @@ -21,17 +16,13 @@ MessageParam, TextBlock, TextBlockParam, -) -from anthropic.types.beta.tools import ToolParam as BetaToolParam -from anthropic.types.beta.tools import ( + ToolParam, ToolResultBlockParam, - ToolsBetaMessage, - ToolsBetaMessageParam, ToolUseBlock, ToolUseBlockParam, message_create_params, ) -from anthropic.types.beta.tools.tool_param import ( +from anthropic.types.tool_param import ( InputSchema, ) from typing_extensions import override @@ -39,7 +30,6 @@ from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_MAX_TOKENS from inspect_ai._util.error import exception_message from inspect_ai._util.images import image_as_data_uri -from inspect_ai._util.json import json_type_to_python_type from inspect_ai._util.url import data_uri_mime_type, data_uri_to_base64, is_data_uri from inspect_ai.model._providers.util import model_base_url @@ -48,7 +38,6 @@ ChatMessage, ChatMessageAssistant, ChatMessageSystem, - ChatMessageTool, Content, ContentText, GenerateConfig, @@ -57,7 +46,7 @@ ModelUsage, StopReason, ) -from .._tool import ToolCall, ToolChoice, ToolFunction, ToolInfo, ToolParam +from .._tool import ToolCall, ToolChoice, ToolFunction, ToolInfo from .._util import chat_api_tool ANTHROPIC_API_KEY = "ANTHROPIC_API_KEY" @@ -70,13 +59,10 @@ def __init__( base_url: str | None, config: GenerateConfig = GenerateConfig(), bedrock: bool = False, - tools_beta: bool = True, **model_args: Any, ): super().__init__(model_name=model_name, base_url=base_url, config=config) - self.tools_beta = tools_beta and not bedrock - # create client if bedrock: base_url = model_base_url( @@ -115,49 +101,25 @@ async def generate( ) -> ModelOutput: # generate try: - # use tools beta endpoint if we have tools and haven't opted out (note that - # bedrock is an implicit opt-out as it doesn't yet support the tools api) - if ( - len(tools) > 0 - and self.tools_beta - and not isinstance(self.client, AsyncAnthropicBedrock) - ): - ( - system_message, - beta_tools, - beta_messages, - ) = await resolve_tools_beta_chat_input(input, tools, config) - - message = await self.client.beta.tools.messages.create( - stream=False, - messages=beta_messages, - system=system_message if system_message is not None else NOT_GIVEN, - stop_sequences=( - config.stop_seqs if config.stop_seqs is not None else NOT_GIVEN - ), - tools=beta_tools, - tool_choice=tools_beta_tool_choice(tool_choice), - **self.completion_params(config), - ) - - return tools_beta_model_output_from_message(message, tools) - - # otherwise use standard chat endpoint - else: - system_message, stop_seq, messages = await resolve_chat_input( - input, tools, config - ) + (system_message, tools_param, messages) = await resolve_chat_input( + input, tools, config + ) - message = await self.client.messages.create( - stream=False, - messages=messages, - system=system_message if system_message is not None else NOT_GIVEN, - stop_sequences=stop_seq if stop_seq is not None else NOT_GIVEN, - **self.completion_params(config), - ) + message = await self.client.messages.create( + stream=False, + messages=messages, + system=system_message if system_message is not None else NOT_GIVEN, + stop_sequences=( + config.stop_seqs if config.stop_seqs is not None else NOT_GIVEN + ), + tools=tools_param, + tool_choice=( + message_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN + ), + **self.completion_params(config), + ) - # extract model output from text response (may have tool calls) - return model_output_from_message(message, tools) + return model_output_from_message(message, tools) except BadRequestError as ex: return ModelOutput.from_content( @@ -196,9 +158,7 @@ def is_rate_limit(self, ex: BaseException) -> bool: # always be transient). Equating this to rate limit errors may occasionally # result in retrying too many times, but much more often will avert a failed # eval that just needed to survive a transient error - return ( - isinstance(ex, RateLimitError | InternalServerError | APIConnectionError) - ) + return isinstance(ex, RateLimitError | InternalServerError | APIConnectionError) @override def collapse_user_messages(self) -> bool: @@ -209,18 +169,11 @@ def collapse_assistant_messages(self) -> bool: return True -####################################################################################### -# Resolve input, tools, and config into the right shape of input for the Anthropic -# tool use beta. we also keep the legacy tools implementation around for now (see below) -# for users on Bedrock of who want to opt out for tools beta for any reason -####################################################################################### - - -async def resolve_tools_beta_chat_input( +async def resolve_chat_input( input: list[ChatMessage], tools: list[ToolInfo], config: GenerateConfig, -) -> Tuple[str | None, list[BetaToolParam], list[ToolsBetaMessageParam]]: +) -> Tuple[str | None, list[ToolParam], list[MessageParam]]: # extract system message system_message, messages = split_system_message(input, config) @@ -228,15 +181,15 @@ async def resolve_tools_beta_chat_input( if len(tools) > 0: # encourage claude to show its thinking, see # https://docs.anthropic.com/claude/docs/tool-use#chain-of-thought-tool-use - system_message = f"{system_message}\n\nBefore answering, explain your reasoning step-by-step." + system_message = f"{system_message}\n\nBefore answering, explain your reasoning step-by-step in tags." # messages - beta_messages = [(await tools_beta_message_param(message)) for message in messages] + beta_messages = [(await message_param(message)) for message in messages] # tools chat_functions = [chat_api_tool(tool)["function"] for tool in tools] beta_tools = [ - BetaToolParam( + ToolParam( name=function["name"], description=function["description"], input_schema=cast(InputSchema, function["parameters"]), @@ -247,7 +200,7 @@ async def resolve_tools_beta_chat_input( return system_message, beta_tools, beta_messages -def tools_beta_tool_choice(tool_choice: ToolChoice) -> message_create_params.ToolChoice: +def message_tool_choice(tool_choice: ToolChoice) -> message_create_params.ToolChoice: if isinstance(tool_choice, ToolFunction): return {"type": "tool", "name": tool_choice.name} elif tool_choice == "any": @@ -256,7 +209,7 @@ def tools_beta_tool_choice(tool_choice: ToolChoice) -> message_create_params.Too return {"type": "auto"} -async def tools_beta_message_param(message: ChatMessage) -> ToolsBetaMessageParam: +async def message_param(message: ChatMessage) -> MessageParam: # no system role for anthropic (this is more like an assertion, # as these should have already been filtered out) if message.role == "system": @@ -273,7 +226,7 @@ async def tools_beta_message_param(message: ChatMessage) -> ToolsBetaMessagePara await message_param_content(content) for content in message.content ] - return ToolsBetaMessageParam( + return MessageParam( role="user", content=[ ToolResultBlockParam( @@ -307,18 +260,18 @@ async def tools_beta_message_param(message: ChatMessage) -> ToolsBetaMessagePara ) ) - return ToolsBetaMessageParam( + return MessageParam( role=message.role, content=tools_content, ) # normal text content elif isinstance(message.content, str): - return ToolsBetaMessageParam(role=message.role, content=message.content) + return MessageParam(role=message.role, content=message.content) # mixed text/images else: - return ToolsBetaMessageParam( + return MessageParam( role=message.role, content=[ await message_param_content(content) for content in message.content @@ -326,9 +279,7 @@ async def tools_beta_message_param(message: ChatMessage) -> ToolsBetaMessagePara ) -def tools_beta_model_output_from_message( - message: ToolsBetaMessage, tools: list[ToolInfo] -) -> ModelOutput: +def model_output_from_message(message: Message, tools: list[ToolInfo]) -> ModelOutput: # extract content and tool calls content: list[Content] = [] tool_calls: list[ToolCall] | None = None @@ -359,7 +310,7 @@ def tools_beta_model_output_from_message( message=ChatMessageAssistant( content=content, tool_calls=tool_calls, source="generate" ), - stop_reason=tools_beta_message_stop_reason(message), + stop_reason=message_stop_reason(message), ) # return ModelOutput @@ -374,7 +325,7 @@ def tools_beta_model_output_from_message( ) -def tools_beta_message_stop_reason(message: ToolsBetaMessage) -> StopReason: +def message_stop_reason(message: Message) -> StopReason: match message.stop_reason: case "end_turn" | "stop_sequence": return "stop" @@ -408,119 +359,6 @@ def split_system_message( return system_message, cast(list[ChatMessage], messages) -####################################################################################### -# Resolve input, tools, and config into the right shape of input for Anthropic models. -# -# Anthropic tools are defined not using a tools component of their API, but rather by -# defining all available tools in the system message. If there are tools then there -# is also a requirement to define a custom stop sequence. This function sorts all of -# that out and returns a system message, a stop sequence (if necessary) and the list -# of anthropic-native MessageParam objects (including converting role="tool" messages -# into XML encoded role="user" messages for Claude -####################################################################################### - -FUNCTIONS_STOP_SEQ = "" - - -async def resolve_chat_input( - input: list[ChatMessage], tools: list[ToolInfo], config: GenerateConfig -) -> Tuple[str | None, list[str] | None, list[MessageParam]]: - # extract system message - system_message, messages = split_system_message(input, config) - - # resolve tool use (system message and stop sequences) - stop_seqs = deepcopy(config.stop_seqs) - if len(tools) > 0: - system_message = f"{system_message}\n\n{tools_system_message(tools)}" - stop_seqs = ( - config.stop_seqs if config.stop_seqs else ["\n\nHuman:", "\n\nAssistant"] - ) - stop_seqs.append(FUNCTIONS_STOP_SEQ) - - # create anthropic message params - message_params = [await message_param(m) for m in messages] - - # done! - return system_message, stop_seqs, message_params - - -def tools_system_message(tools: list[ToolInfo]) -> str: - tool_sep = "\n\n" - return f""" -In this environment you have access to a set of tools you can use to answer the user's question. - -You may call them like this: - - -$TOOL_NAME - -<$PARAMETER_NAME>$PARAMETER_VALUE -... - - - - -Here are the tools available: - -{tool_sep.join([tool_description(tool) for tool in tools])} - -""" - - -def tool_description(tool: ToolInfo) -> str: - newline = "\n" - return f""" - -{escape(tool.name)} -{escape(tool.description)} - -{newline.join(tool_param(param) for param in tool.params)} - - -""" - - -def tool_param(param: ToolParam) -> str: - return f""" - -{escape(param.name)} -{escape(param.type)} -{escape(param.description)} - -""" - - -async def message_param(message: ChatMessage) -> MessageParam: - # no system role for anthropic (this is more like an assertion, - # as these should have already been filtered out) - if message.role == "system": - raise ValueError("Anthropic models do not support the system role") - - # "tool" means serving a tool call result back to claude - elif message.role == "tool": - return tool_message_param(message) - - # tool_calls means claude is attempting to call our tools - elif message.role == "assistant" and message.tool_calls: - return MessageParam( - role=message.role, - content=f"{message.content}\n{function_calls(message.tool_calls)}", - ) - - # normal text content - elif isinstance(message.content, str): - return MessageParam(role=message.role, content=message.content) - - # mixed text/images - else: - return MessageParam( - role=message.role, - content=[ - await message_param_content(content) for content in message.content - ], - ) - - async def message_param_content( content: Content, ) -> TextBlockParam | ImageBlockParam: @@ -543,307 +381,3 @@ async def message_param_content( type="image", source=dict(type="base64", media_type=cast(Any, media_type), data=image), ) - - -def tool_message_param(message: ChatMessageTool) -> MessageParam: - results = f""" - -{function_result(message)} - -""" - return MessageParam(role="user", content=results) - - -def function_calls(tool_calls: list[ToolCall]) -> str: - nl = "\n" - return f""" - -{nl.join([function_call(tool_call) for tool_call in tool_calls])} - -""" - - -def function_call(tool_call: ToolCall) -> str: - nl = "\n" - return f""" - -{escape(tool_call.function)} - -{nl.join([function_parameter(name,value) for name, value in tool_call.arguments.items()])} - - -""" - - -def function_parameter(name: str, value: Any) -> str: - return f"<{name}>{value}" - - -def function_result(message: ChatMessageTool) -> str: - if message.tool_error: - return f""" - -{escape(message.tool_error)} - -""" - else: - return f""" - -{escape(str(message.tool_call_id))} - -{escape(message.text)} - - -""" - - -####################################################################################### -# Extract model output (including tool calls) from an Anthropic message -# -# Anthropic encodes tool calls (in XML) directly in role="assistant" messages. The -# code below deals with this by parsing out the tool calls and separating them into -# the Inspect native ToolCall objects. -####################################################################################### - - -def model_output_from_message(message: Message, tools: list[ToolInfo]) -> ModelOutput: - # extract function calls (if any); throws ValueError if xml is invalid - try: - content_with_functions = extract_function_calls(message) - if content_with_functions: - content = content_with_functions.content - tool_calls = [ - tool_call(function_call, tools) - for function_call in content_with_functions.function_calls - ] - else: - content = message_content(message) - tool_calls = None - except ValueError as ex: - return ModelOutput.from_content( - message.model, - f"{message_content(message)}\n\nError: {exception_message(ex)}", - ) - - # resolve choice - choice = ChatCompletionChoice( - message=ChatMessageAssistant( - content=content, tool_calls=tool_calls, source="generate" - ), - stop_reason=message_stop_reason(message), - ) - - # return ModelOutput - return ModelOutput( - model=message.model, - choices=[choice], - usage=ModelUsage( - input_tokens=message.usage.input_tokens, - output_tokens=message.usage.output_tokens, - total_tokens=message.usage.input_tokens + message.usage.output_tokens, - ), - ) - - -def message_stop_reason(message: Message) -> StopReason: - match message.stop_reason: - case "end_turn": - return "stop" - case "max_tokens": - return "length" - case "stop_sequence": - if message.stop_sequence == FUNCTIONS_STOP_SEQ: - return "tool_calls" - else: - return "stop" - case _: - return "unknown" - - -# This function call parsing code is adapted from the anthropic-tools package (which is in "alpha" -# and not on PyPI, This will likely end up in the main anthropic package -- when that happens we'll -# switch to using that. Here is the commit we forked: -# https://github.com/anthropics/anthropic-tools/blob/a7822678db8a0867b1d05da9c836c456d263e3d9/tool_use_package/tool_user.py#L243 - - -class FunctionCall: - def __init__(self, function: str, parameters: list[tuple[str, str]]) -> None: - self.function = function - self.parameters = parameters - - -def message_content(message: Message) -> str: - return "\n".join([content.text for content in message.content]) - - -class ContentWithFunctionCalls: - def __init__( - self, - content: str, - function_calls: list[FunctionCall], - ) -> None: - self.content = content - self.function_calls = function_calls - - -def extract_function_calls(message: Message) -> ContentWithFunctionCalls | None: - content = message_content(message) - - # see if we need to append the stop token - if ( - message.stop_reason == "stop_sequence" - and message.stop_sequence == "" - ): - content = f"{content}" - - """Check if the function call follows a valid format and extract the attempted function calls if so. - Does not check if the tools actually exist or if they are called with the requisite params.""" - # Check if there are any of the relevant XML tags present that would indicate an attempted function call. - function_call_tags = re.findall( - r"|||||||", - content, - re.DOTALL, - ) - if not function_call_tags: - return None - - # Extract content between tags. If there are multiple we will only parse the first and ignore the rest, regardless of their correctness. - match = re.search(r"(.*)", content, re.DOTALL) - if not match: - return None - func_calls = match.group(1) - - # get content appearing before the function calls - prefix_match = re.search(r"^(.*?)", content, re.DOTALL) - if prefix_match: - func_call_prefix_content = prefix_match.group(1) - - # Check for invoke tags - invoke_regex = r".*?" - if not re.search(invoke_regex, func_calls, re.DOTALL): - raise ValueError( - "Missing tags inside of tags." - ) - - # Check each invoke contains tool name and parameters - invoke_strings = re.findall(invoke_regex, func_calls, re.DOTALL) - invokes: list[FunctionCall] = [] - for invoke_string in invoke_strings: - tool_name = re.findall(r".*?", invoke_string, re.DOTALL) - if not tool_name: - raise ValueError( - "Missing tags inside of tags." - ) - - if len(tool_name) > 1: - raise ValueError( - "More than one tool_name specified inside single set of tags." - ) - - parameters = re.findall( - r".*?", invoke_string, re.DOTALL - ) - if not parameters: - raise ValueError( - "Missing tags inside of tags." - ) - - if len(parameters) > 1: - raise ValueError( - "More than one set of tags specified inside single set of tags." - ) - - # Check for balanced tags inside parameters - # TODO: This will fail if the parameter value contains <> pattern or if there is a parameter called parameters. Fix that issue. - tags = re.findall( - r"<.*?>", - parameters[0].replace("", "").replace("", ""), - re.DOTALL, - ) - if len(tags) % 2 != 0: - raise ValueError("Imbalanced tags inside tags.") - - # Loop through the tags and check if each even-indexed tag matches the tag in the position after it (with the / of course). - # If valid store their content for later use. - # TODO: Add a check to make sure there aren't duplicates provided of a given parameter. - parameters_with_values = [] - for i in range(0, len(tags), 2): - opening_tag = tags[i] - closing_tag = tags[i + 1] - closing_tag_without_second_char = closing_tag[:1] + closing_tag[2:] - if closing_tag[1] != "/" or opening_tag != closing_tag_without_second_char: - raise ValueError( - "Non-matching opening and closing tags inside tags." - ) - - match_param = re.search( - rf"{opening_tag}(.*?){closing_tag}", parameters[0], re.DOTALL - ) - if match_param: - parameters_with_values.append((opening_tag[1:-1], match_param.group(1))) - - # Parse out the full function call - invokes.append( - FunctionCall( - tool_name[0].replace("", "").replace("", ""), - parameters_with_values, - ) - ) - - return ContentWithFunctionCalls(func_call_prefix_content, invokes) - - -####################################################################################### -# These functions deal with converting Anthropic to our native ToolCall -####################################################################################### - - -def tool_call(invoke: FunctionCall, tools: list[ToolInfo]) -> ToolCall: - tool_def = next((tool for tool in tools if invoke.function == tool.name), None) - return ToolCall( - id=invoke.function, - function=invoke.function, - arguments=tool_arguments(invoke.parameters, tool_def), - type="function", - ) - - -def tool_arguments( - params: list[tuple[str, str]], tool_info: ToolInfo | None -) -> dict[str, Any]: - arguments: dict[str, Any] = dict() - for param in params: - # get params - name, value = param - - # coerce type if we have a tool_def - if tool_info: - type_str = next( - (param.type for param in tool_info.params if param.name == name), None - ) - if type_str: - value = tool_argument_value(value, type_str) - - arguments[name] = value - - return arguments - - -def tool_argument_value(value: Any, type_str: str) -> Any: - """Convert a string value into its appropriate Python data type based on the provided type string. - - Arg: - value: the value to convert - type_str: the type to convert the value to - Returns: - The value converted into the requested type or the original value - if the conversion failed. - """ - type_str = json_type_to_python_type(type_str) - if type_str in ("list", "dict"): - return ast.literal_eval(value) - type_class = getattr(builtins, type_str) - try: - return type_class(value) - except ValueError: - return value diff --git a/src/inspect_ai/model/_providers/providers.py b/src/inspect_ai/model/_providers/providers.py index 4dc61173d..d72139a3d 100644 --- a/src/inspect_ai/model/_providers/providers.py +++ b/src/inspect_ai/model/_providers/providers.py @@ -25,7 +25,7 @@ def openai() -> type[ModelAPI]: def anthropic() -> type[ModelAPI]: FEATURE = "Anthropic API" PACKAGE = "anthropic" - MIN_VERSION = "0.26.0" + MIN_VERSION = "0.27.0" # verify we have the package try: diff --git a/tools/vscode/CHANGELOG.md b/tools/vscode/CHANGELOG.md index e12f16891..cfc2cc422 100644 --- a/tools/vscode/CHANGELOG.md +++ b/tools/vscode/CHANGELOG.md @@ -1,5 +1,14 @@ # Changelog +## 0.3.19 + +- Fix an issue showing the log viewer when an evaluation completes (specific to Inspect 0.3.10 or later) + +## 0.3.18 + +- Fix issues with task params when type hints are provided +- Improve metric appearance in `inspect view` + ## 0.3.17 - Improve `inspect view` title bar treatment diff --git a/tools/vscode/assets/www/view/view-overrides.css b/tools/vscode/assets/www/view/view-overrides.css index 25b2731b3..a0caedb12 100644 --- a/tools/vscode/assets/www/view/view-overrides.css +++ b/tools/vscode/assets/www/view/view-overrides.css @@ -46,3 +46,11 @@ body[class^="vscode-"] #sidebarOffCanvas > div > span { body[class^="vscode-"] code:not(.sourceCode) { color: var(--bs-body-color); } + +/* temporary hack to improve the appearance of metrics in the navbar + to truly fix, remove 'navbar-brand' from metrics div and use `navbar-metrics` + to properly style it */ +body[class^="vscode-"] .navbar > div > .navbar-text:not(.navbar-brand) > div > div > div:last-of-type { + margin-top: -10px; + transform: scale(0.7); +} \ No newline at end of file diff --git a/tools/vscode/package.json b/tools/vscode/package.json index 5a7a3a5ff..f42f729cf 100644 --- a/tools/vscode/package.json +++ b/tools/vscode/package.json @@ -7,7 +7,7 @@ "author": { "name": "UK AI Safety Institute" }, - "version": "0.3.17", + "version": "0.3.19", "license": "MIT", "homepage": "https://ukgovernmentbeis.github.io/inspect_ai/", "repository": { diff --git a/tools/vscode/src/components/task.ts b/tools/vscode/src/components/task.ts index 691524533..43d1645f5 100644 --- a/tools/vscode/src/components/task.ts +++ b/tools/vscode/src/components/task.ts @@ -28,8 +28,8 @@ export interface TaskData { const kTaskPattern = /@task/; const kFunctionNamePattern = /def\s+(.*)\((.*)$/; -const kFunctionEndPattern = /\s*\):\s*/; -const kParamsPattern = /^(.*?)(\):)?$/; +const kFunctionEndPattern = /\s*\)\s*(->\s*\S+)?\s*:\s*/; +const kParamsPattern = /^(.*?)\s*(?:\)\s*:\s*|$|\)\s*(->\s*\S+)?\s*:\s*)/; export function readTaskData(document: TextDocument): TaskData[] { const tasks: TaskData[] = []; @@ -92,7 +92,9 @@ const readParams = (line: string, task: TaskData) => { const params = paramsStr.split(","); params.forEach((param) => { const name = param.split("=")[0].trim(); - if (name) { + if (name && name.includes(':')) { + task.params.push(name.split(':')[0]); + } else if (name) { task.params.push(name); } }); diff --git a/tools/vscode/src/extension.ts b/tools/vscode/src/extension.ts index ad23ad95a..4333bce89 100644 --- a/tools/vscode/src/extension.ts +++ b/tools/vscode/src/extension.ts @@ -62,7 +62,7 @@ export async function activate(context: ExtensionContext) { context.subscriptions.push(inspectManager); // Eval Manager - const [inspectEvalCommands, inspectEvalMgr] = activateEvalManager(stateManager); + const [inspectEvalCommands, inspectEvalMgr] = await activateEvalManager(stateManager, context); // Activate a watcher which inspects the active document and determines // the active task (if any) diff --git a/tools/vscode/src/providers/active-task/active-task-provider.ts b/tools/vscode/src/providers/active-task/active-task-provider.ts index bb8554ef8..1f1672126 100644 --- a/tools/vscode/src/providers/active-task/active-task-provider.ts +++ b/tools/vscode/src/providers/active-task/active-task-provider.ts @@ -2,14 +2,15 @@ import { Event, EventEmitter, ExtensionContext, + NotebookEditor, Position, Selection, TextDocument, + TextEditorSelectionChangeEvent, Uri, commands, window, workspace, - } from "vscode"; import { DebugActiveTaskCommand, @@ -19,6 +20,7 @@ import { InspectEvalManager } from "../inspect/inspect-eval"; import { Command } from "../../core/command"; import { DocumentTaskInfo, readTaskData } from "../../components/task"; import { cellTasks, isNotebook } from "../../components/notebook"; +import { debounce } from "lodash"; // Activates the provider which tracks the currently active task (document and task name) export function activateActiveTaskProvider( @@ -45,39 +47,50 @@ export class ActiveTaskManager { // Listen for the editor changing and update task state // when there is a new selection context.subscriptions.push( - window.onDidChangeTextEditorSelection(async (event) => { - await this.updateActiveTaskWithDocument( - event.textEditor.document, - event.selections[0] - ); - }) + window.onDidChangeTextEditorSelection( + debounce( + async (event: TextEditorSelectionChangeEvent) => { + await this.updateActiveTaskWithDocument( + event.textEditor.document, + event.selections[0] + ); + }, + 300, + { trailing: true } + ) + ) ); context.subscriptions.push( - window.onDidChangeActiveNotebookEditor(async (event) => { - if (window.activeNotebookEditor?.selection.start) { - const cell = event?.notebook.cellAt(window.activeNotebookEditor.selection.start); - await this.updateActiveTaskWithDocument( - cell?.document, - new Selection(new Position(0, 0), new Position(0, 0)) - ); + window.onDidChangeActiveNotebookEditor( + debounce(async (event: NotebookEditor | undefined) => { + if (window.activeNotebookEditor?.selection.start) { + const cell = event?.notebook.cellAt( + window.activeNotebookEditor.selection.start + ); + await this.updateActiveTaskWithDocument( + cell?.document, + new Selection(new Position(0, 0), new Position(0, 0)) + ); + } + }, 300, { trailing: true }) + )); + + context.subscriptions.push( + window.onDidChangeActiveTextEditor(async (event) => { + if (event) { + await this.updateActiveTaskWithDocument(event.document); } }) ); - - context.subscriptions.push(window.onDidChangeActiveTextEditor(async (event) => { - if (event) { - await this.updateActiveTaskWithDocument( - event.document - ); - } - })); } private activeTaskInfo_: DocumentTaskInfo | undefined; - private readonly onActiveTaskChanged_ = new EventEmitter(); + private readonly onActiveTaskChanged_ = + new EventEmitter(); // Event to be notified when task information changes - public readonly onActiveTaskChanged: Event = this.onActiveTaskChanged_.event; + public readonly onActiveTaskChanged: Event = + this.onActiveTaskChanged_.event; // Get the task information for the current selection public getActiveTaskInfo(): DocumentTaskInfo | undefined { @@ -101,13 +114,18 @@ export class ActiveTaskManager { "inspect_ai.activeTask", taskActive ); - } } - async updateActiveTaskWithDocument(document?: TextDocument, selection?: Selection) { + async updateActiveTaskWithDocument( + document?: TextDocument, + selection?: Selection + ) { if (document && selection) { - const activeTaskInfo = document.languageId === "python" ? getTaskInfoFromDocument(document, selection) : undefined; + const activeTaskInfo = + document.languageId === "python" + ? getTaskInfoFromDocument(document, selection) + : undefined; await this.updateTask(activeTaskInfo); } } @@ -115,10 +133,14 @@ export class ActiveTaskManager { async updateActiveTask(documentUri: Uri, task: string) { if (isNotebook(documentUri)) { // Compute the cell and position of the task - const notebookDocument = await workspace.openNotebookDocument(documentUri); + const notebookDocument = await workspace.openNotebookDocument( + documentUri + ); const cells = cellTasks(notebookDocument); const cellTask = cells.find((c) => { - return c.tasks.find((t) => { return t.name === task; }); + return c.tasks.find((t) => { + return t.name === task; + }); }); if (cellTask) { const cell = notebookDocument.cellAt(cellTask?.cellIndex); @@ -156,7 +178,6 @@ function getTaskInfoFromDocument( return undefined; } - const selectionLine = selection?.start.line || 0; // Find the first task that appears before the selection @@ -168,7 +189,7 @@ function getTaskInfoFromDocument( return { document: document.uri, tasks, - activeTask: activeTask || (tasks.length > 0 ? tasks[0] : undefined) + activeTask: activeTask || (tasks.length > 0 ? tasks[0] : undefined), }; } @@ -191,6 +212,6 @@ function getTaskInfo( return { document: document.uri, tasks, - activeTask: activeTask || (tasks.length > 0 ? tasks[0] : undefined) + activeTask: activeTask || (tasks.length > 0 ? tasks[0] : undefined), }; } diff --git a/tools/vscode/src/providers/inspect/inspect-eval.ts b/tools/vscode/src/providers/inspect/inspect-eval.ts index 85d4e5d8a..ebb878e91 100644 --- a/tools/vscode/src/providers/inspect/inspect-eval.ts +++ b/tools/vscode/src/providers/inspect/inspect-eval.ts @@ -1,4 +1,4 @@ -import { DebugConfiguration, debug, window, workspace } from "vscode"; +import { DebugConfiguration, ExtensionContext, debug, window, workspace } from "vscode"; import { inspectEvalCommands } from "./inspect-eval-commands"; import { Command } from "../../core/command"; import { @@ -11,11 +11,30 @@ import { inspectVersion } from "../../inspect"; import { inspectBinPath } from "../../inspect/props"; import { activeWorkspaceFolder } from "../../core/workspace"; import { findOpenPort } from "../../core/port"; +import { log } from "../../core/log"; -export function activateEvalManager( - stateManager: WorkspaceStateManager -): [Command[], InspectEvalManager] { +export async function activateEvalManager( + stateManager: WorkspaceStateManager, + context: ExtensionContext +): Promise<[Command[], InspectEvalManager]> { + // Activate the manager const inspectEvalMgr = new InspectEvalManager(stateManager); + + + // Set up our terminal environment + // Update the workspace id used in our terminal environments + await stateManager.initializeWorkspaceId(); + + const workspaceId = stateManager.getWorkspaceInstance(); + const env = context.environmentVariableCollection; + log.append(`Workspace: ${workspaceId}`); + log.append(`Resetting Terminal Workspace:`); + + log.append(`new: ${workspaceId}`); + + env.delete('INSPECT_WORKSPACE_ID'); + env.append('INSPECT_WORKSPACE_ID', workspaceId); + return [inspectEvalCommands(inspectEvalMgr), inspectEvalMgr]; } @@ -108,19 +127,19 @@ export class InspectEvalManager { await runDebugger(inspectBinPath()?.path || "inspect", args, workspaceDir.path, debugPort); } else { // Run the command - runEvalCmd(args, workspaceDir.path, this.stateManager_); + runEvalCmd(args, workspaceDir.path); } } } -const runEvalCmd = (args: string[], cwd: string, stateManager: WorkspaceStateManager) => { +const runEvalCmd = (args: string[], cwd: string) => { // See if there a non-busy terminal that we can re-use const name = "Inspect Eval"; let terminal = window.terminals.find((t) => { return t.name === name; }); if (!terminal) { - terminal = window.createTerminal({ name, cwd, env: { ["INSPECT_WORKSPACE_ID"]: stateManager.getWorkspaceInstance() } }); + terminal = window.createTerminal({ name, cwd }); } terminal.show(); terminal.sendText(`cd ${cwd}`); diff --git a/tools/vscode/src/providers/workspace/workspace-state-provider.ts b/tools/vscode/src/providers/workspace/workspace-state-provider.ts index 7841f23ae..0325c7cba 100644 --- a/tools/vscode/src/providers/workspace/workspace-state-provider.ts +++ b/tools/vscode/src/providers/workspace/workspace-state-provider.ts @@ -25,12 +25,18 @@ export interface ModelState { export class WorkspaceStateManager { constructor(private readonly context_: ExtensionContext) { - this.instanceId = `${Date.now()}-${randomInt(0, 100000)}`; } - private instanceId: string; - public getWorkspaceInstance() { - return this.instanceId; + public async initializeWorkspaceId() { + const existingKey = this.context_.workspaceState.get('INSPECT_WORKSPACE_ID'); + if (!existingKey) { + const key = `${Date.now()}-${randomInt(0, 100000)}`; + await this.context_.workspaceState.update('INSPECT_WORKSPACE_ID', key); + } + } + + public getWorkspaceInstance(): string { + return this.context_.workspaceState.get('INSPECT_WORKSPACE_ID')!; } public getState(key: string) {