diff --git a/src/inspect_ai/_eval/context.py b/src/inspect_ai/_eval/context.py index 31e48af84..0aa2c80dc 100644 --- a/src/inspect_ai/_eval/context.py +++ b/src/inspect_ai/_eval/context.py @@ -1,10 +1,12 @@ from inspect_ai.model import Model from inspect_ai.model._model import init_active_model, init_model_usage +from inspect_ai.util._concurrency import init_concurrency from inspect_ai.util._logger import init_logger_records from inspect_ai.util._subprocess import init_max_subprocesses def init_eval_context(max_subprocesses: int | None = None) -> None: + init_concurrency() init_max_subprocesses(max_subprocesses) diff --git a/src/inspect_ai/scorer/_pattern.py b/src/inspect_ai/scorer/_pattern.py index c4cca9963..7e71b236c 100644 --- a/src/inspect_ai/scorer/_pattern.py +++ b/src/inspect_ai/scorer/_pattern.py @@ -9,16 +9,23 @@ from ._target import Target +def match_target(match: str, target: Target, ignore_case: bool) -> bool: + if ignore_case: + match = match.lower() + target = Target([t.lower() for t in target]) + + return match in target + + def match_first( matches: tuple[str | Any, ...], target: Target, ignore_case: bool ) -> str | None: for match in matches: - if isinstance(match, str): - if ignore_case: - match = match.lower() + if not isinstance(match, str): + continue - if match in target: - return match + if match_target(match, target, ignore_case): + return match return None @@ -27,12 +34,11 @@ def match_all_groups( matches: tuple[str | Any, ...], target: Target, ignore_case: bool ) -> str | None: for match in matches: - if isinstance(match, str): - if ignore_case: - match = match.lower() + if not isinstance(match, str): + continue - if match not in target: - return None + if not match_target(match, target, ignore_case): + return None return target.text @@ -64,21 +70,29 @@ async def score(state: TaskState, target: Target) -> Score: ) if match: - if ignore_case: - target = Target([t.lower() for t in target]) - + groups = match.groups() if match_all: found_match = match_all_groups( - matches=match.groups(), target=target, ignore_case=ignore_case + matches=groups, target=target, ignore_case=ignore_case ) + answer = found_match else: found_match = match_first( - matches=match.groups(), target=target, ignore_case=ignore_case + matches=groups, target=target, ignore_case=ignore_case ) + if found_match is None and len(groups) == 1: + # A common use of a pattern is to extract a single answer + # from some templated text. If we fail to match in that + # scenario, it's worth returning the failed match because + # this is useful information for the user. + answer = groups[0] + else: + answer = found_match + return Score( value=CORRECT if found_match else INCORRECT, - answer=found_match, + answer=answer, explanation=state.output.completion, ) else: diff --git a/src/inspect_ai/solver/_tool/environment/docker/config.py b/src/inspect_ai/solver/_tool/environment/docker/config.py index 7fac9c16d..4e42dbd71 100644 --- a/src/inspect_ai/solver/_tool/environment/docker/config.py +++ b/src/inspect_ai/solver/_tool/environment/docker/config.py @@ -75,6 +75,7 @@ def safe_cleanup_auto_compose(file: str | None) -> None: image: "python:3.12-bookworm" command: "tail -f /dev/null" network_mode: none + stop_grace_period: 1s """ COMPOSE_DOCKERFILE_YAML = f"""{COMPOSE_COMMENT} @@ -84,6 +85,7 @@ def safe_cleanup_auto_compose(file: str | None) -> None: context: "." command: "tail -f /dev/null" network_mode: none + stop_grace_period: 1s """ diff --git a/src/inspect_ai/util/_concurrency.py b/src/inspect_ai/util/_concurrency.py index f130a3cb8..ddfbd6c2e 100644 --- a/src/inspect_ai/util/_concurrency.py +++ b/src/inspect_ai/util/_concurrency.py @@ -1,4 +1,5 @@ import asyncio +from contextvars import ContextVar from dataclasses import dataclass @@ -40,12 +41,12 @@ def concurrency( key = key if key else name # do we have an existing semaphore? if not create one and store it - semaphore = _concurrency_semaphores.get(key, None) + semaphore = _concurrency_semaphores.get().get(key, None) if semaphore is None: semaphore = ConcurencySempahore( name, concurrency, asyncio.Semaphore(concurrency) ) - _concurrency_semaphores[key] = semaphore + _concurrency_semaphores.get()[key] = semaphore # return the semaphore return semaphore.semaphore @@ -53,11 +54,15 @@ def concurrency( def concurrency_status() -> dict[str, tuple[int, int]]: status: dict[str, tuple[int, int]] = {} - for c in _concurrency_semaphores.values(): + for c in _concurrency_semaphores.get().values(): status[c.name] = (c.concurrency - c.semaphore._value, c.concurrency) return status +def init_concurrency() -> None: + _concurrency_semaphores.set({}) + + @dataclass class ConcurencySempahore: name: str @@ -65,4 +70,6 @@ class ConcurencySempahore: semaphore: asyncio.Semaphore -_concurrency_semaphores: dict[str, ConcurencySempahore] = {} +_concurrency_semaphores: ContextVar[dict[str, ConcurencySempahore]] = ContextVar( + "concurrency_semaphores", default={} +) diff --git a/tests/scorer/test_pattern.py b/tests/scorer/test_pattern.py index a6961c35a..c7e84ed3b 100644 --- a/tests/scorer/test_pattern.py +++ b/tests/scorer/test_pattern.py @@ -105,3 +105,31 @@ async def test_only_returns_exact_target_matches(): result = await scorer(state, Target(["bar"])) assert result.text == INCORRECT + + +@pytest.mark.asyncio +async def test_one_match_group_returns_incorrect_match(): + scorer = pattern( + "ANSWER: (A|B)", + ignore_case=False, + match_all=False, + ) + state = simple_task_state(model_output="ANSWER: A") + result = await scorer(state, Target(["B"])) + + assert result.answer == "A" + assert result.text == INCORRECT + + +@pytest.mark.asyncio +async def test_multiple_match_group_returns_none(): + scorer = pattern( + "ANSWER: (A|B) ALTERNATE_ANSWER: (A|B)", + ignore_case=False, + match_all=False, + ) + state = simple_task_state(model_output="ANSWER: A ALTERNATE_ANSWER: A") + result = await scorer(state, Target(["B"])) + + assert result.answer is None + assert result.text == INCORRECT diff --git a/tools/vscode/CHANGELOG.md b/tools/vscode/CHANGELOG.md index 0cf11ef5b..81581dbc5 100644 --- a/tools/vscode/CHANGELOG.md +++ b/tools/vscode/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## 0.3.23 + +- Ensure the log view only opens in the correct window when debugging a task +- Changes to improve performance and usability of large log files + ## 0.3.22 - Improve reliability of opening and viewing log files upon completion of evaluations diff --git a/tools/vscode/package.json b/tools/vscode/package.json index 3caae3d89..6748f2592 100644 --- a/tools/vscode/package.json +++ b/tools/vscode/package.json @@ -7,7 +7,7 @@ "author": { "name": "UK AI Safety Institute" }, - "version": "0.3.22", + "version": "0.3.23", "license": "MIT", "homepage": "https://ukgovernmentbeis.github.io/inspect_ai/", "repository": { diff --git a/tools/vscode/src/providers/inspect/inspect-eval.ts b/tools/vscode/src/providers/inspect/inspect-eval.ts index ebb878e91..cb3f643b4 100644 --- a/tools/vscode/src/providers/inspect/inspect-eval.ts +++ b/tools/vscode/src/providers/inspect/inspect-eval.ts @@ -1,4 +1,10 @@ -import { DebugConfiguration, ExtensionContext, debug, window, workspace } from "vscode"; +import { + DebugConfiguration, + ExtensionContext, + debug, + window, + workspace, +} from "vscode"; import { inspectEvalCommands } from "./inspect-eval-commands"; import { Command } from "../../core/command"; import { @@ -20,7 +26,6 @@ export async function activateEvalManager( // Activate the manager const inspectEvalMgr = new InspectEvalManager(stateManager); - // Set up our terminal environment // Update the workspace id used in our terminal environments await stateManager.initializeWorkspaceId(); @@ -32,15 +37,14 @@ export async function activateEvalManager( log.append(`new: ${workspaceId}`); - env.delete('INSPECT_WORKSPACE_ID'); - env.append('INSPECT_WORKSPACE_ID', workspaceId); + env.delete("INSPECT_WORKSPACE_ID"); + env.append("INSPECT_WORKSPACE_ID", workspaceId); return [inspectEvalCommands(inspectEvalMgr), inspectEvalMgr]; } export class InspectEvalManager { - constructor(private readonly stateManager_: WorkspaceStateManager) { - } + constructor(private readonly stateManager_: WorkspaceStateManager) { } public async startEval(file: AbsolutePath, task?: string, debug = false) { // if we don't have inspect bail and let the user know @@ -113,7 +117,6 @@ export class InspectEvalManager { // If we're debugging, launch using the debugger if (debug) { - // Handle debugging let debugPort = 5678; if (debug === true) { @@ -124,7 +127,19 @@ export class InspectEvalManager { args.push(debugPort.toString()); } - await runDebugger(inspectBinPath()?.path || "inspect", args, workspaceDir.path, debugPort); + // Pass the workspace ID to the debug environment so we'll + // properly target the workspace window when showing the logview + const env = { + INSPECT_WORKSPACE_ID: this.stateManager_.getWorkspaceInstance(), + }; + + await runDebugger( + inspectBinPath()?.path || "inspect", + args, + workspaceDir.path, + debugPort, + env + ); } else { // Run the command runEvalCmd(args, workspaceDir.path); @@ -146,7 +161,13 @@ const runEvalCmd = (args: string[], cwd: string) => { terminal.sendText(["inspect", ...args].join(" ")); }; -const runDebugger = async (program: string, args: string[], cwd: string, port: number) => { +const runDebugger = async ( + program: string, + args: string[], + cwd: string, + port: number, + env?: Record +) => { const name = "Inspect Eval"; const debugConfiguration: DebugConfiguration = { name, @@ -157,7 +178,8 @@ const runDebugger = async (program: string, args: string[], cwd: string, port: n console: "internalConsole", cwd, port, - "justMyCode": false + env, + justMyCode: false, }; await debug.startDebugging(activeWorkspaceFolder(), debugConfiguration); };