From 0f342f7af6c73b8127cf613e849f55b45f273696 Mon Sep 17 00:00:00 2001 From: jjallaire Date: Thu, 21 Nov 2024 11:44:17 -0800 Subject: [PATCH] resolve tmux issues in fullscreen display (#876) * initial work on tmux fixes for textual/realtime * explicitly enumerate known_fields for approval policy work w/ more versions of pydantic * improved theming across terminal environments * ruff check * fix openai image test --------- Co-authored-by: aisi-inspect <166920645+aisi-inspect@users.noreply.github.com> --- docs/tutorial.qmd | 2 +- pyproject.toml | 1 + requirements.txt | 1 + src/inspect_ai/_cli/cache.py | 2 +- src/inspect_ai/_cli/common.py | 21 +- src/inspect_ai/_cli/score.py | 2 +- src/inspect_ai/_display/__init__.py | 29 +- src/inspect_ai/_display/core/active.py | 52 ++ src/inspect_ai/_display/core/config.py | 43 + .../_display/{_display.py => core/display.py} | 86 +- src/inspect_ai/_display/core/footer.py | 27 + src/inspect_ai/_display/core/group.py | 79 ++ src/inspect_ai/_display/core/panel.py | 92 +++ src/inspect_ai/_display/core/progress.py | 111 +++ src/inspect_ai/_display/core/results.py | 180 ++++ src/inspect_ai/_display/core/rich.py | 99 +++ src/inspect_ai/_display/logger.py | 148 ---- src/inspect_ai/_display/rich.py | 782 ------------------ src/inspect_ai/_display/rich/__init__.py | 0 src/inspect_ai/_display/rich/display.py | 341 ++++++++ src/inspect_ai/_display/textual/app.py | 359 ++++++++ src/inspect_ai/_display/textual/app.tcss | 26 + src/inspect_ai/_display/textual/display.py | 71 ++ src/inspect_ai/_display/textual/theme.py | 30 + .../_display/textual/widgets/clock.py | 55 ++ .../_display/textual/widgets/console.py | 52 ++ .../_display/textual/widgets/footer.py | 38 + .../_display/textual/widgets/samples.py | 259 ++++++ .../_display/textual/widgets/tasks.py | 195 +++++ .../_display/textual/widgets/titlebar.py | 90 ++ .../_display/textual/widgets/transcript.py | 344 ++++++++ src/inspect_ai/_eval/context.py | 8 +- src/inspect_ai/_eval/eval.py | 67 +- src/inspect_ai/_eval/run.py | 51 +- src/inspect_ai/_eval/task/run.py | 86 +- src/inspect_ai/_util/ansi.py | 5 - src/inspect_ai/_util/constants.py | 1 + src/inspect_ai/_util/display.py | 34 + src/inspect_ai/_util/logger.py | 152 +++- src/inspect_ai/_util/rich.py | 24 + src/inspect_ai/_util/terminal.py | 138 ++++ src/inspect_ai/_util/transcript.py | 86 ++ src/inspect_ai/_view/view.py | 2 +- src/inspect_ai/_view/www/log-schema.json | 226 ++++- src/inspect_ai/_view/www/src/types/log.d.ts | 38 +- src/inspect_ai/approval/_human.py | 8 +- src/inspect_ai/approval/_policy.py | 2 +- src/inspect_ai/log/_samples.py | 70 ++ src/inspect_ai/log/_transcript.py | 43 +- src/inspect_ai/model/_call_tools.py | 32 +- src/inspect_ai/model/_model.py | 78 +- src/inspect_ai/model/_render.py | 24 + src/inspect_ai/model/_trace.py | 42 +- src/inspect_ai/solver/_task_state.py | 4 +- src/inspect_ai/util/_concurrency.py | 13 +- src/inspect_ai/util/_console.py | 2 +- .../util/_sandbox/docker/compose.py | 8 +- .../util/_sandbox/docker/internal.py | 22 +- src/inspect_ai/util/_subtask.py | 8 +- src/inspect_ai/util/_trace.py | 36 +- tests/util/test_images.py | 2 +- tools/vscode/src/@types/log.d.ts | 38 +- 62 files changed, 3738 insertions(+), 1229 deletions(-) create mode 100644 src/inspect_ai/_display/core/active.py create mode 100644 src/inspect_ai/_display/core/config.py rename src/inspect_ai/_display/{_display.py => core/display.py} (63%) create mode 100644 src/inspect_ai/_display/core/footer.py create mode 100644 src/inspect_ai/_display/core/group.py create mode 100644 src/inspect_ai/_display/core/panel.py create mode 100644 src/inspect_ai/_display/core/progress.py create mode 100644 src/inspect_ai/_display/core/results.py create mode 100644 src/inspect_ai/_display/core/rich.py delete mode 100644 src/inspect_ai/_display/logger.py delete mode 100644 src/inspect_ai/_display/rich.py create mode 100644 src/inspect_ai/_display/rich/__init__.py create mode 100644 src/inspect_ai/_display/rich/display.py create mode 100644 src/inspect_ai/_display/textual/app.py create mode 100644 src/inspect_ai/_display/textual/app.tcss create mode 100644 src/inspect_ai/_display/textual/display.py create mode 100644 src/inspect_ai/_display/textual/theme.py create mode 100644 src/inspect_ai/_display/textual/widgets/clock.py create mode 100644 src/inspect_ai/_display/textual/widgets/console.py create mode 100644 src/inspect_ai/_display/textual/widgets/footer.py create mode 100644 src/inspect_ai/_display/textual/widgets/samples.py create mode 100644 src/inspect_ai/_display/textual/widgets/tasks.py create mode 100644 src/inspect_ai/_display/textual/widgets/titlebar.py create mode 100644 src/inspect_ai/_display/textual/widgets/transcript.py delete mode 100644 src/inspect_ai/_util/ansi.py create mode 100644 src/inspect_ai/_util/display.py create mode 100644 src/inspect_ai/_util/rich.py create mode 100644 src/inspect_ai/_util/terminal.py create mode 100644 src/inspect_ai/_util/transcript.py create mode 100644 src/inspect_ai/log/_samples.py create mode 100644 src/inspect_ai/model/_render.py diff --git a/docs/tutorial.qmd b/docs/tutorial.qmd index e93860c47..2922b0125 100644 --- a/docs/tutorial.qmd +++ b/docs/tutorial.qmd @@ -54,7 +54,7 @@ computer security and provide a short response in a few words. ### Eval {.unlisted} -Discerning whether the correct security guidance was provided by the model might provide difficult using only text matching algorithms. Here we use a model to read the response and assess the quality of the answer. +Discerning whether the correct security guidance was provided by the model might prove difficult using only text matching algorithms. Here we use a model to read the response and assess the quality of the answer. ```{python} @task diff --git a/pyproject.toml b/pyproject.toml index 019e8ea33..0971a3b5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,7 @@ dev = [ "pytest-dotenv", "pytest-xdist", "ruff==0.7.4", # match version specified in .pre-commit-config.yaml + "textual-dev>=0.86.2", "types-PyYAML", "types-aiofiles", "types-beautifulsoup4", diff --git a/requirements.txt b/requirements.txt index 94a3a7e89..1e82560eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,5 +24,6 @@ s3fs>=2023 semver>=3.0.0 shortuuid tenacity +textual>=0.86.2 typing_extensions>=4.9.0 zipp>=3.19.1 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/src/inspect_ai/_cli/cache.py b/src/inspect_ai/_cli/cache.py index 8209bfa0c..b6c2d50ca 100644 --- a/src/inspect_ai/_cli/cache.py +++ b/src/inspect_ai/_cli/cache.py @@ -2,7 +2,7 @@ from rich import print from rich.table import Table -from inspect_ai._display.logger import init_logger +from inspect_ai._util.logger import init_logger from inspect_ai.model import ( ModelName, cache_clear, diff --git a/src/inspect_ai/_cli/common.py b/src/inspect_ai/_cli/common.py index 5f17a3438..8494da529 100644 --- a/src/inspect_ai/_cli/common.py +++ b/src/inspect_ai/_cli/common.py @@ -1,21 +1,23 @@ import functools -import os -from typing import Any, Callable, cast +from typing import Any, Callable, Literal, cast import click from typing_extensions import TypedDict from inspect_ai._util.constants import ( ALL_LOG_LEVELS, + DEFAULT_DISPLAY, DEFAULT_LOG_LEVEL, DEFAULT_LOG_LEVEL_TRANSCRIPT, ) +from inspect_ai._util.display import init_display_type class CommonOptions(TypedDict): log_level: str log_level_transcript: str log_dir: str + display: Literal["full", "rich", "plain", "none"] no_ansi: bool | None debug: bool debug_port: int @@ -60,10 +62,18 @@ def common_options(func: Callable[..., Any]) -> Callable[..., click.Context]: envvar="INSPECT_LOG_DIR", help="Directory for log files.", ) + @click.option( + "--display", + type=click.Choice(["full", "rich", "plain", "none"], case_sensitive=False), + default=DEFAULT_DISPLAY, + envvar="INSPECT_DISPLAY", + help="Set the display type (defaults to 'full')", + ) @click.option( "--no-ansi", type=bool, is_flag=True, + hidden=True, help="Do not print ANSI control characters.", envvar="INSPECT_NO_ANSI", ) @@ -91,9 +101,12 @@ def wrapper(*args: Any, **kwargs: Any) -> click.Context: def process_common_options(options: CommonOptions) -> None: - # disable ansi if requested + # propagate display if options["no_ansi"]: - os.environ["INSPECT_NO_ANSI"] = "1" + display = "plain" + else: + display = options["display"].lower().strip() + init_display_type(display) # attach debugger if requested if options["debug"]: diff --git a/src/inspect_ai/_cli/score.py b/src/inspect_ai/_cli/score.py index 71e46fb46..6d0651038 100644 --- a/src/inspect_ai/_cli/score.py +++ b/src/inspect_ai/_cli/score.py @@ -57,7 +57,7 @@ async def score( log_level_transcript: str | None, ) -> None: # init eval context - init_eval_context(None, log_level, log_level_transcript) + init_eval_context(log_level, log_level_transcript) # read the eval log recorder = create_recorder_for_location(log_file, log_dir) diff --git a/src/inspect_ai/_display/__init__.py b/src/inspect_ai/_display/__init__.py index fc0421956..4f288d101 100644 --- a/src/inspect_ai/_display/__init__.py +++ b/src/inspect_ai/_display/__init__.py @@ -1,6 +1,25 @@ -from ._display import Display -from .rich import rich_display +from .core.active import display +from .core.display import ( + Display, + Progress, + TaskCancelled, + TaskError, + TaskProfile, + TaskResult, + TaskScreen, + TaskSuccess, + TaskWithResult, +) - -def display() -> Display: - return rich_display() +__all__ = [ + "display", + "Display", + "Progress", + "TaskCancelled", + "TaskError", + "TaskProfile", + "TaskResult", + "TaskScreen", + "TaskWithResult", + "TaskSuccess", +] diff --git a/src/inspect_ai/_display/core/active.py b/src/inspect_ai/_display/core/active.py new file mode 100644 index 000000000..5f452aa21 --- /dev/null +++ b/src/inspect_ai/_display/core/active.py @@ -0,0 +1,52 @@ +import sys +from contextvars import ContextVar + +import rich + +from inspect_ai._util.display import display_type +from inspect_ai.util._trace import trace_enabled + +from ..rich.display import RichDisplay +from ..textual.display import TextualDisplay +from .display import Display, TaskScreen + + +def display() -> Display: + global _active_display + if _active_display is None: + if ( + display_type() == "full" + and sys.stdout.isatty() + and not trace_enabled() + and not rich.get_console().is_jupyter + ): + _active_display = TextualDisplay() + else: + _active_display = RichDisplay() + + return _active_display + + +_active_display: Display | None = None + + +def task_screen() -> TaskScreen: + screen = _active_task_screen.get(None) + if screen is None: + raise RuntimeError( + "console input function called outside of running evaluation." + ) + return screen + + +def init_task_screen(screen: TaskScreen) -> None: + _active_task_screen.set(screen) + + +def clear_task_screen() -> None: + _active_task_screen.set(None) + + +_active_task_screen: ContextVar[TaskScreen | None] = ContextVar( + "task_screen", default=None +) diff --git a/src/inspect_ai/_display/core/config.py b/src/inspect_ai/_display/core/config.py new file mode 100644 index 000000000..327eefdd5 --- /dev/null +++ b/src/inspect_ai/_display/core/config.py @@ -0,0 +1,43 @@ +from inspect_ai._util.registry import is_registry_dict + +from .display import TaskProfile + + +def task_config( + profile: TaskProfile, generate_config: bool = True, style: str = "" +) -> str: + # merge config + # wind params back for display + task_args = dict(profile.task_args) + for key in task_args.keys(): + value = task_args[key] + if is_registry_dict(value): + task_args[key] = value["name"] + config = task_args | dict(profile.eval_config.model_dump(exclude_none=True)) + if generate_config: + config = config | dict(profile.generate_config.model_dump(exclude_none=True)) + if profile.tags: + config["tags"] = ",".join(profile.tags) + config_print: list[str] = [] + for name, value in config.items(): + if name == "approval": + config_print.append( + f"{name}: {','.join([approver['name'] for approver in value['approvers']])}" + ) + elif name not in ["limit", "model"]: + config_print.append(f"{name}: {value}") + values = ", ".join(config_print) + if values: + if style: + return f"[{style}]{values}[/{style}]" + else: + return values + else: + return "" + + +def task_dict(d: dict[str, str], bold_value: bool = False) -> str: + slot1, slot2 = ("", "[/bold]") if bold_value else ("[/bold]", "") + return " ".join( + [f"[bold]{key}:{slot1} {value}{slot2}" for key, value in d.items()] + ) diff --git a/src/inspect_ai/_display/_display.py b/src/inspect_ai/_display/core/display.py similarity index 63% rename from src/inspect_ai/_display/_display.py rename to src/inspect_ai/_display/core/display.py index 0eb6fde26..4758b4495 100644 --- a/src/inspect_ai/_display/_display.py +++ b/src/inspect_ai/_display/core/display.py @@ -1,24 +1,38 @@ -import abc import contextlib -from contextvars import ContextVar from dataclasses import dataclass from types import TracebackType -from typing import Any, Iterator, Type, Union - +from typing import ( + Any, + AsyncIterator, + Coroutine, + Iterator, + Protocol, + Type, + TypeVar, + Union, + runtime_checkable, +) + +import rich from rich.console import Console from inspect_ai.log import EvalConfig, EvalResults, EvalStats from inspect_ai.model import GenerateConfig, ModelName -class Progress(abc.ABC): - @abc.abstractmethod +@runtime_checkable +class Progress(Protocol): def update(self, n: int = 1) -> None: ... - @abc.abstractmethod def complete(self) -> None: ... +@dataclass +class TaskSpec: + name: str + model: ModelName + + @dataclass class TaskProfile: name: str @@ -59,58 +73,54 @@ class TaskSuccess: TaskResult = Union[TaskError, TaskCancelled, TaskSuccess] +@dataclass +class TaskWithResult: + profile: TaskProfile + result: TaskResult | None + + +TR = TypeVar("TR") + + class TaskScreen(contextlib.AbstractContextManager["TaskScreen"]): - @abc.abstractmethod + def __exit__(self, *excinfo: Any) -> None: + pass + @contextlib.contextmanager def input_screen( self, header: str | None = None, transient: bool | None = None, width: int | None = None, - ) -> Iterator[Console]: ... + ) -> Iterator[Console]: + yield rich.get_console() -class TaskDisplay(abc.ABC): - @abc.abstractmethod +@runtime_checkable +class TaskDisplay(Protocol): @contextlib.contextmanager def progress(self) -> Iterator[Progress]: ... - @abc.abstractmethod def complete(self, result: TaskResult) -> None: ... -class Display(abc.ABC): - @abc.abstractmethod +@runtime_checkable +class Display(Protocol): def print(self, message: str) -> None: ... - @abc.abstractmethod @contextlib.contextmanager def progress(self, total: int) -> Iterator[Progress]: ... - @abc.abstractmethod - @contextlib.contextmanager - def task_screen(self, total_tasks: int, parallel: bool) -> Iterator[TaskScreen]: ... + def run_task_app(self, main: Coroutine[Any, Any, TR]) -> TR: ... - @abc.abstractmethod @contextlib.contextmanager - def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]: ... - - -def task_screen() -> TaskScreen: - screen = _task_screen.get(None) - if screen is None: - raise RuntimeError( - "console input function called outside of running evaluation." - ) - return screen - + def suspend_task_app(self) -> Iterator[None]: ... -def init_task_screen(screen: TaskScreen) -> None: - _task_screen.set(screen) + @contextlib.asynccontextmanager + async def task_screen( + self, tasks: list[TaskSpec], parallel: bool + ) -> AsyncIterator[TaskScreen]: + yield TaskScreen() - -def clear_task_screen() -> None: - _task_screen.set(None) - - -_task_screen: ContextVar[TaskScreen | None] = ContextVar("task_screen", default=None) + @contextlib.contextmanager + def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]: ... diff --git a/src/inspect_ai/_display/core/footer.py b/src/inspect_ai/_display/core/footer.py new file mode 100644 index 000000000..fbb252156 --- /dev/null +++ b/src/inspect_ai/_display/core/footer.py @@ -0,0 +1,27 @@ +from rich.console import RenderableType +from rich.text import Text + +from inspect_ai._util.logger import http_rate_limit_count +from inspect_ai._util.throttle import throttle +from inspect_ai.util._concurrency import concurrency_status + +from .config import task_dict + + +@throttle(1) +def task_footer(style: str = "") -> tuple[RenderableType, RenderableType]: + return ( + Text.from_markup(task_resources(), style=style), + Text.from_markup(task_http_rate_limits(), style=style), + ) + + +def task_resources() -> str: + resources: dict[str, str] = {} + for model, resource in concurrency_status().items(): + resources[model] = f"{resource[0]}/{resource[1]}" + return task_dict(resources) + + +def task_http_rate_limits() -> str: + return f"HTTP rate limits: {http_rate_limit_count():,}" diff --git a/src/inspect_ai/_display/core/group.py b/src/inspect_ai/_display/core/group.py new file mode 100644 index 000000000..0da22e200 --- /dev/null +++ b/src/inspect_ai/_display/core/group.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass +from typing import Sequence + +from inspect_ai.log._transcript import Event, StepEvent, SubtaskEvent, ToolEvent + + +@dataclass +class EventGroup: + """Event and (optionally) its embedded event groups. + + - Some events (e.g. SampleInitEvent, LogEvent) have no embedded event groups. + - Some events (e.g. ToolEvent, SubtaskEvent) contain lists of event groups. + - StepEvent has an implicit of event groups based on its begin/end instances. + """ + + event: Event + level: int + groups: list["EventGroup"] | None = None + + +def group_events(events: Sequence[Event], level: int = 1) -> list[EventGroup]: + """Transform ordinary list of events into list of event groups.""" + # groups are either plain events (some of which can have sub-events) + # and higher level steps (e.g. solvers/scorers) that contain events + event_groups: list[EventGroup] = [] + + # track stack of active steps + active_steps: list[tuple[StepEvent, list[EventGroup]]] = [] + + # iterate though events + for event in events: + # manage step events + if isinstance(event, StepEvent): + if event.action == "begin": + active_steps.append((event, [])) + elif event.action == "end": + begin_step, step_groups = active_steps.pop() + target_group = ( + active_steps[-1][1] if len(active_steps) else event_groups + ) + target_group.append( + EventGroup( + event=begin_step, + level=level + len(active_steps), + groups=step_groups, + ) + ) + + # other events + else: + # target level depends on whether we are appending to a set + target_level = level + len(active_steps) + + # tool and subtask events have their own nested event lists + if isinstance(event, ToolEvent | SubtaskEvent): + group = EventGroup( + event=event, + groups=group_events(event.events, level=target_level + 1), + level=target_level, + ) + else: + group = EventGroup(event=event, level=target_level) + + # add to active step if we have one + if len(active_steps) > 0: + active_steps[-1][1].append(group) + # otherwise just add to root list + else: + event_groups.append(group) + + # if there are active steps alive then collect them (an error + # may have prevented them from having end steps) + while len(active_steps) > 0: + begin_step, step_groups = active_steps.pop() + event_groups.append( + EventGroup(event=begin_step, level=level, groups=step_groups) + ) + + return event_groups diff --git a/src/inspect_ai/_display/core/panel.py b/src/inspect_ai/_display/core/panel.py new file mode 100644 index 000000000..3cada1da5 --- /dev/null +++ b/src/inspect_ai/_display/core/panel.py @@ -0,0 +1,92 @@ +import rich +from rich.console import RenderableType +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH +from inspect_ai._util.path import cwd_relative_path + +from .display import TaskProfile +from .rich import is_vscode_notebook, rich_theme + + +def task_panel( + profile: TaskProfile, + show_model: bool, + body: RenderableType, + footer: RenderableType | tuple[RenderableType, RenderableType] | None, + log_location: str | None, +) -> Panel: + # rendering context + theme = rich_theme() + console = rich.get_console() + width = CONSOLE_DISPLAY_WIDTH if is_vscode_notebook(console) else None + jupyter = console.is_jupyter + + # setup table + table = Table.grid(expand=True) + table.add_column() + table.add_column(justify="right") + + # main progress and task info + targets = Text.from_markup(task_targets(profile), style=theme.meta) + table.add_row(body, targets) + + # footer if specified + if footer: + table.add_row() + if isinstance(footer, tuple): + table.add_row(footer[0], footer[1]) + else: + table.add_row(footer) + + # enclose in outer table for log link footer + root = table + if log_location: + # if we are in jupyter then use a real hyperlink + if jupyter: + log_location = f"[link={log_location}]{log_location}[/link]" + + # Print a cwd relative path + try: + log_location_relative = cwd_relative_path(log_location, walk_up=True) + except ValueError: + log_location_relative = log_location + + root = Table.grid(expand=True) + root.add_column() + root.add_row(table) + root.add_row() + root.add_row( + f"[bold][{theme.light}]Log:[/{theme.light}][/bold] " + + f"[{theme.link}]{log_location_relative}[/{theme.link}]" + ) + + # create panel w/ title + panel = Panel( + root, + title=f"[bold][{theme.meta}]{task_title(profile, show_model)}[/{theme.meta}][/bold]", + title_align="left", + width=width, + expand=True, + ) + return panel + + +def tasks_title(completed: int, total: int) -> str: + return f"{completed}/{total} tasks complete" + + +def task_title(profile: TaskProfile, show_model: bool) -> str: + eval_epochs = profile.eval_config.epochs or 1 + epochs = f" x {profile.eval_config.epochs}" if eval_epochs > 1 else "" + samples = f"{profile.samples//eval_epochs:,}{epochs} sample{'s' if profile.samples > 1 else ''}" + title = f"{profile.name} ({samples})" + if show_model: + title = f"{title}: {profile.model}" + return title + + +def task_targets(profile: TaskProfile) -> str: + return f"dataset: {profile.dataset}" diff --git a/src/inspect_ai/_display/core/progress.py b/src/inspect_ai/_display/core/progress.py new file mode 100644 index 000000000..3f35b9e89 --- /dev/null +++ b/src/inspect_ai/_display/core/progress.py @@ -0,0 +1,111 @@ +from typing import Callable + +import rich +from rich.progress import ( + BarColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, +) +from rich.progress import Progress as RProgress +from rich.text import Text +from typing_extensions import override + +from inspect_ai._util.registry import registry_unqualified_name +from inspect_ai.model._model import ModelName + +from .display import Progress, TaskCancelled, TaskError, TaskProfile, TaskResult +from .rich import is_vscode_notebook, rich_theme + +# Note that use of rich progress seems to result in an extra +# empty cell after execution, see: https://github.com/Textualize/rich/issues/3274 + +PROGRESS_TOTAL = 102 + + +class RichProgress(Progress): + def __init__( + self, + total: int, + progress: RProgress, + description: str = "", + model: str = "", + status: Callable[[], str] | None = None, + on_update: Callable[[], None] | None = None, + ) -> None: + self.total = total + self.progress = progress + self.status = status if status else lambda: "" + self.on_update = on_update + self.task_id = progress.add_task( + description, total=PROGRESS_TOTAL, model=model, status=self.status() + ) + + @override + def update(self, n: int = 1) -> None: + advance = (float(n) / float(self.total)) * 100 + self.progress.update( + task_id=self.task_id, advance=advance, refresh=True, status=self.status() + ) + if self.on_update: + self.on_update() + + @override + def complete(self) -> None: + self.progress.update( + task_id=self.task_id, completed=PROGRESS_TOTAL, status=self.status() + ) + + +def rich_progress() -> RProgress: + console = rich.get_console() + return RProgress( + TextColumn("{task.fields[status]}"), + TextColumn("{task.description}"), + TextColumn("{task.fields[model]}"), + BarColumn(bar_width=40 if is_vscode_notebook(console) else None), + TaskProgressColumn(), + TimeElapsedColumn(), + transient=True, + console=console, + expand=not is_vscode_notebook(console), + ) + + +MAX_MODEL_NAME_WIDTH = 25 +MAX_DESCRIPTION_WIDTH = 25 + + +def progress_model_name( + model_name: ModelName, max_width: int = MAX_MODEL_NAME_WIDTH, pad: bool = False +) -> Text: + model = Text(str(model_name)) + model.truncate(max_width, overflow="ellipsis", pad=pad) + return model + + +def progress_description( + profile: TaskProfile, max_width: int = MAX_DESCRIPTION_WIDTH, pad: bool = False +) -> Text: + description = Text(registry_unqualified_name(profile.name)) + description.truncate(max_width, overflow="ellipsis", pad=pad) + return description + + +def progress_status_icon(result: TaskResult | None) -> str: + theme = rich_theme() + if result: + if isinstance(result, TaskError): + return f"[{theme.error}]✗[{theme.error}]" + elif isinstance(result, TaskCancelled): + return f"[{theme.error}]✗[{theme.error}]" + else: + return f"[{theme.success}]✔[{theme.success}]" + else: + return f"[{theme.meta}]⠿[{theme.meta}]" + + +def progress_time(time: float) -> str: + minutes, seconds = divmod(time, 60) + hours, minutes = divmod(minutes, 60) + return f"{hours:2.0f}:{minutes:02.0f}:{seconds:02.0f}" diff --git a/src/inspect_ai/_display/core/results.py b/src/inspect_ai/_display/core/results.py new file mode 100644 index 000000000..0d78ea2a7 --- /dev/null +++ b/src/inspect_ai/_display/core/results.py @@ -0,0 +1,180 @@ +from datetime import datetime +from typing import Sequence, Set + +from rich.console import Group, RenderableType +from rich.table import Table +from rich.text import Text + +from inspect_ai.log import EvalStats +from inspect_ai.log._log import rich_traceback + +from .config import task_config, task_dict +from .display import ( + TaskCancelled, + TaskError, + TaskProfile, + TaskSuccess, + TaskWithResult, +) +from .panel import task_panel +from .rich import rich_theme + + +def tasks_results(tasks: Sequence[TaskWithResult]) -> RenderableType: + def render_task(task: TaskWithResult) -> RenderableType: + if isinstance(task.result, TaskCancelled): + return task_result_cancelled(task.profile, task.result) + elif isinstance(task.result, TaskError): + return task_result_error(task.profile, task.result) + elif isinstance(task.result, TaskSuccess): + return task_result_summary(task.profile, task.result) + else: + return "" + + return Group(*[render_task(task) for task in tasks]) + + +def task_result_cancelled( + profile: TaskProfile, cancelled: TaskCancelled +) -> RenderableType: + return task_panel( + profile=profile, + show_model=True, + body=task_stats(profile, cancelled.stats), + footer=task_interrupted(profile, cancelled.samples_completed), + log_location=profile.log_location, + ) + + +def task_results(profile: TaskProfile, success: TaskSuccess) -> RenderableType: + theme = rich_theme() + + # do we have more than one scorer name? + results = success.results + scorer_names: Set[str] = {score.name for score in results.scores} + reducer_names: Set[str] = { + score.reducer for score in results.scores if score.reducer is not None + } + show_reducer = len(reducer_names) > 1 or "avg" not in reducer_names + output: dict[str, str] = {} + for score in results.scores: + for name, metric in score.metrics.items(): + value = ( + "1.0" + if metric.value == 1 + else ( + str(metric.value) + if isinstance(metric.value, int) + else f"{metric.value:.3g}" + ) + ) + name = ( + rf"{name}\[{score.reducer}]" + if show_reducer and score.reducer is not None + else name + ) + key = f"{score.name}/{name}" if (len(scorer_names) > 1) else name + output[key] = value + + if output: + message = f"[{theme.metric}]{task_dict(output, True)}[/{theme.metric}]" + else: + message = "" + + # note if some of our samples had errors + if success.samples_completed < profile.samples: + sample_errors = profile.samples - success.samples_completed + sample_error_pct = int(float(sample_errors) / float(profile.samples) * 100) + if message: + message = f"{message}\n\n" + message = f"{message}[{theme.warning}]WARNING: {sample_errors} of {profile.samples} samples ({sample_error_pct}%) had errors and were not scored.[/{theme.warning}]" + + return message + + +def task_result_summary(profile: TaskProfile, success: TaskSuccess) -> RenderableType: + return task_panel( + profile=profile, + show_model=True, + body=task_stats(profile, success.stats), + footer=task_results(profile, success), + log_location=profile.log_location, + ) + + +def task_result_error(profile: TaskProfile, error: TaskError) -> RenderableType: + return task_panel( + profile=profile, + show_model=True, + body=rich_traceback(error.exc_type, error.exc_value, error.traceback), + footer=task_interrupted(profile, error.samples_completed), + log_location=profile.log_location, + ) + + +def task_stats(profile: TaskProfile, stats: EvalStats) -> RenderableType: + theme = rich_theme() + panel = Table.grid(expand=True) + panel.add_column() + config = task_config(profile) + if config: + panel.add_row(config) + panel.add_row() + elif len(stats.model_usage) < 2: + panel.add_row() + + table = Table.grid(expand=True) + table.add_column(style="bold") + table.add_column() + + # eval time + started = datetime.fromisoformat(stats.started_at) + completed = datetime.fromisoformat(stats.completed_at) + elapsed = completed - started + table.add_row(Text("total time:", style="bold"), f" {elapsed}", style=theme.light) + + # token usage + for model, usage in stats.model_usage.items(): + if ( + usage.input_tokens_cache_read is not None + or usage.input_tokens_cache_write is not None + ): + input_tokens_cache_read = usage.input_tokens_cache_read or 0 + input_tokens_cache_write = usage.input_tokens_cache_write or 0 + input_tokens = f"[bold]I: [/bold]{usage.input_tokens:,}, [bold]CW: [/bold]{input_tokens_cache_write:,}, [bold]CR: [/bold]{input_tokens_cache_read:,}" + else: + input_tokens = f"[bold]I: [/bold]{usage.input_tokens:,}" + + table.add_row( + Text(model, style="bold"), + f" {usage.total_tokens:,} tokens [{input_tokens}, [bold]O: [/bold]{usage.output_tokens:,}]", + style=theme.light, + ) + + panel.add_row(table) + return panel + + +def task_can_retry(profile: TaskProfile) -> bool: + return profile.file is not None or "/" in profile.name + + +def task_interrupted(profile: TaskProfile, samples_completed: int) -> RenderableType: + log_location = profile.log_location + theme = rich_theme() + message = f"[bold][{theme.error}]Task interrupted (" + if samples_completed > 0: + message = f"{message}{samples_completed:,} of {profile.samples:,} total samples logged before interruption)." + if task_can_retry(profile): + message = ( + f"{message} Resume task with:[/{theme.error}][/bold]\n\n" + + f"[bold][{theme.light}]inspect eval-retry {log_location}[/{theme.light}][/bold]" + ) + else: + message = f"{message}[/{theme.error}][/bold]" + else: + message = ( + f"{message}no samples completed before interruption)[/{theme.error}][/bold]" + ) + + return message diff --git a/src/inspect_ai/_display/core/rich.py b/src/inspect_ai/_display/core/rich.py new file mode 100644 index 000000000..46cdeb77b --- /dev/null +++ b/src/inspect_ai/_display/core/rich.py @@ -0,0 +1,99 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Iterator + +import rich +from rich.console import Console, ConsoleOptions, RenderResult +from rich.markdown import CodeBlock, Markdown +from rich.segment import Segment +from rich.syntax import Syntax +from typing_extensions import override + +from inspect_ai._util.display import display_type +from inspect_ai._util.platform import is_running_in_jupyterlab, is_running_in_vscode +from inspect_ai._util.transcript import transcript_code_theme + + +def is_vscode_notebook(console: Console) -> bool: + return console.is_jupyter and is_running_in_vscode() + + +def rich_no_color() -> bool: + return ( + display_type() == "plain" + or not is_running_in_vscode() + or is_running_in_jupyterlab() + ) + + +def rich_initialise() -> None: + # reflect ansi prefs + if display_type() == "plain": + rich.reconfigure(no_color=True, force_terminal=False, force_interactive=False) + elif rich_no_color(): + rich.reconfigure(no_color=True) + + # reflect display == none + if display_type() == "none": + rich.reconfigure(quiet=True) + + # consistent markdown code bock background + class CustomCodeBlock(CodeBlock): + @override + def __rich_console__( + self, console: Console, options: ConsoleOptions + ) -> RenderResult: + code = str(self.text).rstrip() + syntax = Syntax( + code, + self.lexer_name, + theme=transcript_code_theme(), + word_wrap=True, + background_color="#282c34", + padding=0, + ) + yield syntax + + Markdown.elements["fence"] = CustomCodeBlock + Markdown.elements["code_block"] = CustomCodeBlock + + +@dataclass +class RichTheme: + meta: str = "blue" + light: str = "bright_black" + metric: str = "green" + link: str = "blue" + success: str = "green" + error: str = "red" + warning: str = "orange3" + + +def rich_theme() -> RichTheme: + global _theme + if _theme is None: + _theme = RichTheme() + return _theme + + +_theme: RichTheme | None = None + + +@contextmanager +def record_console_input() -> Iterator[None]: + # monkey patch .input method to record inputs + input_original = Console.input + + def input_with_record(self: Console, *args: Any, **kwargs: Any) -> str: + result = input_original(self, *args, **kwargs) + if self.record: + with self._record_buffer_lock: + self._record_buffer.append(Segment(result)) + return result + + Console.input = input_with_record # type: ignore + + try: + yield + finally: + Console.input = input_original # type: ignore diff --git a/src/inspect_ai/_display/logger.py b/src/inspect_ai/_display/logger.py deleted file mode 100644 index 7f0ccb571..000000000 --- a/src/inspect_ai/_display/logger.py +++ /dev/null @@ -1,148 +0,0 @@ -import os -from logging import ( - INFO, - WARNING, - FileHandler, - Formatter, - LogRecord, - addLevelName, - getLevelName, - getLogger, -) - -from rich.console import ConsoleRenderable -from rich.logging import RichHandler -from rich.text import Text -from typing_extensions import override - -from inspect_ai._util.constants import ( - ALL_LOG_LEVELS, - DEFAULT_LOG_LEVEL, - DEFAULT_LOG_LEVEL_TRANSCRIPT, - HTTP, - HTTP_LOG_LEVEL, - PKG_NAME, - SANDBOX, - SANDBOX_LOG_LEVEL, -) -from inspect_ai._util.error import PrerequisiteError -from inspect_ai._util.logger import notify_logger_record - -from .rich import rich_console - - -# log handler that filters messages to stderr and the log file -class LogHandler(RichHandler): - def __init__(self, levelno: int, transcript_levelno: int) -> None: - super().__init__(levelno, console=rich_console()) - self.transcript_levelno = transcript_levelno - self.display_level = WARNING - # log into an external file if requested via env var - file_logger = os.environ.get("INSPECT_PY_LOGGER_FILE", None) - self.file_logger = FileHandler(file_logger) if file_logger else None - if self.file_logger: - self.file_logger.setFormatter( - Formatter("%(asctime)s - %(levelname)s - %(message)s") - ) - - # see if the user has a special log level for the file - file_logger_level = os.environ.get("INSPECT_PY_LOGGER_LEVEL", "") - if file_logger_level: - self.file_logger_level = int(getLevelName(file_logger_level.upper())) - else: - self.file_logger_level = 0 - - @override - def emit(self, record: LogRecord) -> None: - # demote httpx and return notifications to log_level http - if ( - record.name == "httpx" - or "http" in record.name - or "Retrying request" in record.getMessage() - ): - record.levelno = HTTP - record.levelname = HTTP_LOG_LEVEL - - # skip httpx event loop is closed errors - if "Event loop is closed" in record.getMessage(): - return - - # write to stderr if we are at or above the threshold - if record.levelno >= self.display_level: - super().emit(record) - - # write to file if the log file level matches. if the - # user hasn't explicitly specified a level then we - # take the minimum of 'info' and the display level - if self.file_logger and record.levelno >= ( - self.file_logger_level or min(self.display_level, INFO) - ): - self.file_logger.emit(record) - - # eval log always gets info level and higher records - # eval log only gets debug or http if we opt-in - write = record.levelno >= self.transcript_levelno - notify_logger_record(record, write) - - @override - def render_message(self, record: LogRecord, message: str) -> ConsoleRenderable: - return Text.from_ansi(message) - - -# initialize logging -- this function can be called multiple times -# in the lifetime of the process (the levelno will update globally) -def init_logger( - log_level: str | None = None, log_level_transcript: str | None = None -) -> None: - # backwards compatibility for 'tools' - if log_level == "tools": - log_level = "sandbox" - - # register http and tools levels - addLevelName(HTTP, HTTP_LOG_LEVEL) - addLevelName(SANDBOX, SANDBOX_LOG_LEVEL) - - def validate_level(option: str, level: str) -> None: - if level not in ALL_LOG_LEVELS: - log_levels = ", ".join([level.lower() for level in ALL_LOG_LEVELS]) - raise PrerequisiteError( - f"Invalid {option} '{level.lower()}'. Log level must be one of {log_levels}" - ) - - # resolve default log level - log_level = ( - log_level if log_level else os.getenv("INSPECT_LOG_LEVEL", DEFAULT_LOG_LEVEL) - ).upper() - validate_level("log level", log_level) - - # reolve log file level - log_level_transcript = ( - log_level_transcript - if log_level_transcript - else os.getenv("INSPECT_LOG_LEVEL_TRANSCRIPT", DEFAULT_LOG_LEVEL_TRANSCRIPT) - ).upper() - validate_level("log file level", log_level_transcript) - - # convert to integer - levelno = getLevelName(log_level) - transcript_levelno = getLevelName(log_level_transcript) - - # init logging handler on demand - global _logHandler - if not _logHandler: - _logHandler = LogHandler(min(HTTP, levelno), transcript_levelno) - getLogger().addHandler(_logHandler) - - # establish default capture level - capture_level = min(HTTP, levelno) - - # see all the messages (we won't actually display/write all of them) - getLogger().setLevel(capture_level) - getLogger(PKG_NAME).setLevel(capture_level) - getLogger("httpx").setLevel(capture_level) - - # set the levelno on the global handler - _logHandler.display_level = levelno - - -_logHandler: LogHandler | None = None diff --git a/src/inspect_ai/_display/rich.py b/src/inspect_ai/_display/rich.py deleted file mode 100644 index cc666f088..000000000 --- a/src/inspect_ai/_display/rich.py +++ /dev/null @@ -1,782 +0,0 @@ -import asyncio -import contextlib -import datetime -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Any, Callable, Iterator, Set - -import rich -from rich.console import Console, ConsoleOptions, Group, RenderableType, RenderResult -from rich.live import Live -from rich.markdown import CodeBlock, Markdown -from rich.panel import Panel -from rich.progress import ( - BarColumn, - TaskProgressColumn, - TextColumn, - TimeElapsedColumn, -) -from rich.progress import Progress as RProgress -from rich.segment import Segment -from rich.syntax import Syntax -from rich.table import Table -from rich.text import Text -from typing_extensions import override - -from inspect_ai._util.ansi import no_ansi -from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH -from inspect_ai._util.logger import http_rate_limit_count -from inspect_ai._util.path import cwd_relative_path -from inspect_ai._util.platform import is_running_in_jupyterlab, is_running_in_vscode -from inspect_ai._util.registry import is_registry_dict -from inspect_ai._util.throttle import throttle -from inspect_ai.log import EvalStats -from inspect_ai.log._log import rich_traceback -from inspect_ai.log._transcript import InputEvent, transcript -from inspect_ai.util._concurrency import concurrency_status -from inspect_ai.util._trace import trace_enabled - -from ._display import ( - Display, - Progress, - TaskCancelled, - TaskDisplay, - TaskError, - TaskProfile, - TaskResult, - TaskScreen, - TaskSuccess, -) - - -@dataclass -class Theme: - meta: str = "blue" - light: str = "bright_black" - metric: str = "green" - link: str = "blue" - success: str = "green" - error: str = "red" - warning: str = "orange3" - - -@dataclass -class TaskStatus: - profile: TaskProfile - result: TaskResult | None - progress: RProgress - - -class RichDisplay(Display): - def __init__(self) -> None: - self.total_tasks: int = 0 - self.tasks: list[TaskStatus] = [] - self.progress_ui: RProgress | None = None - self.parallel = False - self.live: Live | None = None - self.timer_handle: asyncio.TimerHandle | None = None - rich_initialise() - - @override - def print(self, message: str) -> None: - rich_console().print(message, markup=False, highlight=False) - - @override - @contextlib.contextmanager - def progress(self, total: int) -> Iterator[Progress]: - with rich_progress() as progress: - yield RichProgress(total, progress) - - @override - @contextlib.contextmanager - def task_screen(self, total_tasks: int, parallel: bool) -> Iterator[TaskScreen]: - self.total_tasks = total_tasks - self.tasks = [] - self.progress_ui = rich_progress() - self.parallel = parallel - try: - with ( - Live( - None, - console=rich_console(), - transient=True, - auto_refresh=False, - ) as live, - ): - with RichTaskScreen(live) as task_screen: - # save reference to live - self.live = live - - # enque a display update - self.timer_handle = asyncio.get_event_loop().call_later( - 1, self._update_display - ) - - # yield - yield task_screen - - # render task results (re-enable live if necessary) - if not live.is_started: - live.start() - live.transient = False - live.update(tasks_results(self.tasks), refresh=True) - finally: - # clear tasks and progress - self.total_tasks = 0 - self.tasks = [] - self.progress_ui = None - self.parallel = False - self.live = None - if self.timer_handle: - self.timer_handle.cancel() - - @override - @contextlib.contextmanager - def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]: - # if there is no ansi display than all of the below will - # be a no-op, so we print a simple text message for the task - if no_ansi(): - rich_console().print(task_no_ansi(profile)) - - # for typechekcer - if self.tasks is None: - self.tasks = [] - if self.progress_ui is None: - self.progress_ui = rich_progress() - - status = TaskStatus(profile, None, self.progress_ui) - self.tasks.append(status) - self._update_display() - yield RichTaskDisplay( - status, show_name=self.parallel, on_update=self._update_display - ) - - @throttle(1) - def _update_display(self) -> None: - if ( - self.tasks is not None - and self.tasks - and self.progress_ui is not None - and self.live is not None - and self.live.is_started - ): - if self.parallel: - r = tasks_live_status(self.total_tasks, self.tasks, self.progress_ui) - else: - r = task_live_status(self.tasks, self.progress_ui) - self.live.update(r, refresh=True) - - self.timer_handle = asyncio.get_event_loop().call_later(1, self._update_display) - - -class RichTaskScreen(TaskScreen): - def __init__(self, live: Live) -> None: - theme = rich_theme() - self.live = live - status_text = "Working" if trace_enabled() else "Task running" - self.status = self.live.console.status( - f"[{theme.meta} bold]{status_text}...[/{theme.meta} bold]", spinner="clock" - ) - - def __exit__(self, *excinfo: Any) -> None: - self.status.stop() - - @override - @contextlib.contextmanager - def input_screen( - self, - header: str | None = None, - transient: bool | None = None, - width: int | None = None, - ) -> Iterator[Console]: - # determine transient based on trace mode - if transient is None: - transient = not trace_enabled() - - # clear live task status and transient status - self.live.update("", refresh=True) - self.status.stop() - - # show cursor for input - self.live.console.show_cursor(True) - - # set width - old_width: int | None = None - if width: - old_width = self.live.console.width - self.live.console.width = min(old_width, width) - - # record console activity for event - self.live.console.record = True - - try: - # print header if requested - if header: - style = f"{rich_theme().meta} bold" - self.live.console.rule(f"[{style}]{header}[/{style}]", style="black") - self.live.console.print("") - - # yield the console - with record_console_input(): - yield self.live.console - - finally: - # capture recording then yield input event - input = self.live.console.export_text(clear=False, styles=False) - input_ansi = self.live.console.export_text(clear=True, styles=True) - self.live.console.record = False - transcript()._event(InputEvent(input=input, input_ansi=input_ansi)) - - # print one blank line - self.live.console.print("") - - # reset width - if old_width: - self.live.console.width = old_width - - # disable cursor while not collecting input - self.live.console.show_cursor(False) - - # if transient then disable live updates entirely - if transient is False and self.live.is_started: - self.live.stop() - - # otherwise make sure they are enabled - elif transient is True and not self.live.is_started: - self.live.start() - - # if not transient then display mini-status - if not transient: - self.status.start() - - -class RichTaskDisplay(TaskDisplay): - def __init__( - self, - status: TaskStatus, - show_name: bool, - on_update: Callable[[], None] | None = None, - ) -> None: - theme = rich_theme() - self.status = status - model = Text(str(self.status.profile.model)) - model.truncate(25, overflow="ellipsis") - description = Text(f"{self.status.profile.name} " if show_name else "") - if show_name: - description.truncate(20, overflow="ellipsis") - - def task_status() -> str: - if self.status.result: - if isinstance(self.status.result, TaskError): - return f"[{theme.error}]✗[{theme.error}]" - elif isinstance(self.status.result, TaskCancelled): - return f"[{theme.error}]✗[{theme.error}]" - else: - return f"[{theme.success}]✔[{theme.success}]" - else: - return f"[{theme.meta}]⠿[{theme.meta}]" - - self.p = RichProgress( - total=self.status.profile.steps, - progress=self.status.progress, - description=f"{description.markup}", - model=f"{model.markup} ", - status=task_status, - on_update=on_update, - ) - - @override - @contextlib.contextmanager - def progress(self) -> Iterator[Progress]: - yield self.p - - @override - def complete(self, result: TaskResult) -> None: - self.status.result = result - self.p.complete() - - -# Note that use of rich progress seems to result in an extra -# empty cell after execution, see: https://github.com/Textualize/rich/issues/3274 - -PROGRESS_TOTAL = 102 - - -class RichProgress(Progress): - def __init__( - self, - total: int, - progress: RProgress, - description: str = "", - model: str = "", - status: Callable[[], str] | None = None, - on_update: Callable[[], None] | None = None, - ) -> None: - self.total = total - self.progress = progress - self.status = status if status else lambda: "" - self.on_update = on_update - self.task_id = progress.add_task( - description, total=PROGRESS_TOTAL, model=model, status=self.status() - ) - - @override - def update(self, n: int = 1) -> None: - advance = (float(n) / float(self.total)) * 100 - self.progress.update( - task_id=self.task_id, advance=advance, refresh=True, status=self.status() - ) - if self.on_update: - self.on_update() - - @override - def complete(self) -> None: - self.progress.update( - task_id=self.task_id, completed=PROGRESS_TOTAL, status=self.status() - ) - - -def tasks_results(tasks: list[TaskStatus]) -> RenderableType: - def render_task(task: TaskStatus) -> RenderableType: - if isinstance(task.result, TaskCancelled): - return task_result_cancelled(task.profile, task.result) - elif isinstance(task.result, TaskError): - return task_result_error(task.profile, task.result) - elif isinstance(task.result, TaskSuccess): - return task_result_summary(task.profile, task.result) - else: - return "" - - return Group(*[render_task(task) for task in tasks]) - - -def task_live_status(tasks: list[TaskStatus], progress: RProgress) -> RenderableType: - body: list[RenderableType] = ["", progress] - config = task_config(tasks[0].profile) - if config: - body = [config] + body - - return task_panel( - profile=tasks[0].profile, - show_model=len(tasks) == 1, - body=Group(*body), - footer=live_task_footer(), - log_location=None, - ) - - -def tasks_live_status( - total_tasks: int, tasks: list[TaskStatus], progress: RProgress -) -> RenderableType: - # rendering context - theme = rich_theme() - console = rich_console() - width = CONSOLE_DISPLAY_WIDTH if is_vscode_notebook(console) else None - - # compute completed tasks - completed = sum(1 for task in tasks if task.result is not None) - - # get config - config = task_config(tasks[0].profile, generate_config=False) - if config: - config += "\n" - - # build footer table - footer_table = Table.grid(expand=True) - footer_table.add_column() - footer_table.add_column(justify="right") - footer = live_task_footer() - footer_table.add_row() - footer_table.add_row(footer[0], footer[1]) - - # create panel w/ title - panel = Panel( - Group(config, progress, footer_table, fit=False), - title=f"[bold][{theme.meta}]eval: {completed}/{total_tasks} tasks complete[/{theme.meta}][/bold]", - title_align="left", - width=width, - expand=True, - ) - return panel - - -def task_result_cancelled( - profile: TaskProfile, cancelled: TaskCancelled -) -> RenderableType: - return task_panel( - profile=profile, - show_model=True, - body=task_stats(profile, cancelled.stats), - footer=task_interrupted(profile, cancelled.samples_completed), - log_location=profile.log_location, - ) - - -def task_result_summary(profile: TaskProfile, success: TaskSuccess) -> RenderableType: - return task_panel( - profile=profile, - show_model=True, - body=task_stats(profile, success.stats), - footer=task_results(profile, success), - log_location=profile.log_location, - ) - - -def task_result_error(profile: TaskProfile, error: TaskError) -> RenderableType: - return task_panel( - profile=profile, - show_model=True, - body=rich_traceback(error.exc_type, error.exc_value, error.traceback), - footer=task_interrupted(profile, error.samples_completed), - log_location=profile.log_location, - ) - - -def task_panel( - profile: TaskProfile, - show_model: bool, - body: RenderableType, - footer: RenderableType | tuple[RenderableType, RenderableType] | None, - log_location: str | None, -) -> Panel: - # rendering context - theme = rich_theme() - console = rich_console() - width = CONSOLE_DISPLAY_WIDTH if is_vscode_notebook(console) else None - jupyter = console.is_jupyter - - # setup table - table = Table.grid(expand=True) - table.add_column() - table.add_column(justify="right") - - # main progress and task info - table.add_row( - body, - Text(task_targets(profile), style=theme.meta), - ) - - # footer if specified - if footer: - table.add_row() - if isinstance(footer, tuple): - table.add_row(footer[0], footer[1]) - else: - table.add_row(footer) - - # enclose in outer table for log link footer - root = table - if log_location: - # if we are in jupyter then use a real hyperlink - if jupyter: - log_location = f"[link={log_location}]{log_location}[/link]" - - # Print a cwd relative path - try: - log_location_relative = cwd_relative_path(log_location, walk_up=True) - except ValueError: - log_location_relative = log_location - - root = Table.grid(expand=True) - root.add_column() - root.add_row(table) - root.add_row() - root.add_row( - f"[bold][{theme.light}]Log:[/{theme.light}][/bold] " - + f"[{theme.link}]{log_location_relative}[/{theme.link}]" - ) - - # create panel w/ title - panel = Panel( - root, - title=f"[bold][{theme.meta}]{task_title(profile, show_model)}[/{theme.meta}][/bold]", - title_align="left", - width=width, - expand=True, - ) - return panel - - -def task_title(profile: TaskProfile, show_model: bool) -> str: - eval_epochs = profile.eval_config.epochs or 1 - epochs = f" x {profile.eval_config.epochs}" if eval_epochs > 1 else "" - samples = f"{profile.samples//eval_epochs:,}{epochs} sample{'s' if profile.samples > 1 else ''}" - title = f"{profile.name} ({samples})" - if show_model: - title = f"{title}: {profile.model}" - return title - - -def task_targets(profile: TaskProfile) -> str: - targets = [f"dataset: {profile.dataset}", f"scorer: {profile.scorer}"] - return " " + "\n ".join(targets) - - -def task_no_ansi(profile: TaskProfile) -> str: - message = f"Running task {task_title(profile, True)}" - config = task_config(profile) - if config: - message = f"{message} (config: {config})" - return f"{message}...\n" - - -def task_config(profile: TaskProfile, generate_config: bool = True) -> str: - # merge config - theme = rich_theme() - # wind params back for display - task_args = dict(profile.task_args) - for key in task_args.keys(): - value = task_args[key] - if is_registry_dict(value): - task_args[key] = value["name"] - config = task_args | dict(profile.eval_config.model_dump(exclude_none=True)) - if generate_config: - config = config | dict(profile.generate_config.model_dump(exclude_none=True)) - if profile.tags: - config["tags"] = ",".join(profile.tags) - config_print: list[str] = [] - for name, value in config.items(): - if name == "approval": - config_print.append( - f"{name}: {','.join([approver['name'] for approver in value['approvers']])}" - ) - elif name not in ["limit", "model"]: - config_print.append(f"{name}: {value}") - values = ", ".join(config_print) - if values: - return f"[{theme.light}]{values}[/{theme.light}]" - else: - return "" - - -def task_resources() -> str: - resources: dict[str, str] = {} - for model, resource in concurrency_status().items(): - resources[model] = f"{resource[0]}/{resource[1]}" - return task_dict(resources) - - -@throttle(1) -def live_task_footer() -> tuple[RenderableType, RenderableType]: - theme = rich_theme() - return ( - f"[{theme.light}]{task_resources()}[/{theme.light}]", - Text(task_http_rate_limits(), style=theme.light), - ) - - -def task_interrupted(profile: TaskProfile, samples_completed: int) -> RenderableType: - log_location = profile.log_location - theme = rich_theme() - message = f"[bold][{theme.error}]Task interrupted (" - if samples_completed > 0: - message = f"{message}{samples_completed:,} of {profile.samples:,} total samples logged before interruption)." - if task_can_retry(profile): - message = ( - f"{message} Resume task with:[/{theme.error}][/bold]\n\n" - + f"[bold][{theme.light}]inspect eval-retry {log_location}[/{theme.light}][/bold]" - ) - else: - message = f"{message}[/{theme.error}][/bold]" - else: - message = ( - f"{message}no samples completed before interruption)[/{theme.error}][/bold]" - ) - - return message - - -def task_can_retry(profile: TaskProfile) -> bool: - return profile.file is not None or "/" in profile.name - - -def task_results(profile: TaskProfile, success: TaskSuccess) -> RenderableType: - theme = rich_theme() - - # do we have more than one scorer name? - results = success.results - scorer_names: Set[str] = {score.name for score in results.scores} - reducer_names: Set[str] = { - score.reducer for score in results.scores if score.reducer is not None - } - show_reducer = len(reducer_names) > 1 or "avg" not in reducer_names - output: dict[str, str] = {} - for score in results.scores: - for name, metric in score.metrics.items(): - value = ( - "1.0" - if metric.value == 1 - else ( - str(metric.value) - if isinstance(metric.value, int) - else f"{metric.value:.3g}" - ) - ) - name = ( - rf"{name}\[{score.reducer}]" - if show_reducer and score.reducer is not None - else name - ) - key = f"{score.name}/{name}" if (len(scorer_names) > 1) else name - output[key] = value - - if output: - message = f"[{theme.metric}]{task_dict(output, True)}[/{theme.metric}]" - else: - message = "" - - # note if some of our samples had errors - if success.samples_completed < profile.samples: - sample_errors = profile.samples - success.samples_completed - sample_error_pct = int(float(sample_errors) / float(profile.samples) * 100) - if message: - message = f"{message}\n\n" - message = f"{message}[{theme.warning}]WARNING: {sample_errors} of {profile.samples} samples ({sample_error_pct}%) had errors and were not scored.[/{theme.warning}]" - - return message - - -def task_stats(profile: TaskProfile, stats: EvalStats) -> RenderableType: - theme = rich_theme() - panel = Table.grid(expand=True) - panel.add_column() - config = task_config(profile) - if config: - panel.add_row(config) - panel.add_row() - elif len(stats.model_usage) < 2: - panel.add_row() - - table = Table.grid(expand=True) - table.add_column(style="bold") - table.add_column() - - # eval time - started = datetime.datetime.fromisoformat(stats.started_at) - completed = datetime.datetime.fromisoformat(stats.completed_at) - elapsed = completed - started - table.add_row(Text("total time:", style="bold"), f" {elapsed}", style=theme.light) - - # token usage - for model, usage in stats.model_usage.items(): - if ( - usage.input_tokens_cache_read is not None - or usage.input_tokens_cache_write is not None - ): - input_tokens_cache_read = usage.input_tokens_cache_read or 0 - input_tokens_cache_write = usage.input_tokens_cache_write or 0 - input_tokens = f"[bold]I: [/bold]{usage.input_tokens:,}, [bold]CW: [/bold]{input_tokens_cache_write:,}, [bold]CR: [/bold]{input_tokens_cache_read:,}" - else: - input_tokens = f"[bold]I: [/bold]{usage.input_tokens:,}" - - table.add_row( - Text(model, style="bold"), - f" {usage.total_tokens:,} tokens [{input_tokens}, [bold]O: [/bold]{usage.output_tokens:,}]", - style=theme.light, - ) - - panel.add_row(table) - return panel - - -def task_http_rate_limits() -> str: - return f"HTTP rate limits: {http_rate_limit_count():,}" - - -def task_dict(d: dict[str, str], bold_value: bool = False) -> str: - slot1, slot2 = ("", "[/bold]") if bold_value else ("[/bold]", "") - return " ".join( - [f"[bold]{key}:{slot1} {value}{slot2}" for key, value in d.items()] - ) - - -def is_vscode_notebook(console: Console) -> bool: - return console.is_jupyter and is_running_in_vscode() - - -def rich_no_color() -> bool: - return no_ansi() or not is_running_in_vscode() or is_running_in_jupyterlab() - - -def rich_initialise() -> None: - # reflect ansi prefs - if no_ansi(): - rich.reconfigure(no_color=True, force_terminal=False, force_interactive=False) - elif rich_no_color(): - rich.reconfigure(no_color=True) - - # disable markdown code bock backgrounds (don't work well across light/dark themes) - class CustomCodeBlock(CodeBlock): - @override - def __rich_console__( - self, console: Console, options: ConsoleOptions - ) -> RenderResult: - code = str(self.text).rstrip() - syntax = Syntax( - code, - self.lexer_name, - theme=self.theme, - word_wrap=True, - background_color="default", - ) - yield syntax - - Markdown.elements["fence"] = CustomCodeBlock - Markdown.elements["code_block"] = CustomCodeBlock - - -def rich_theme() -> Theme: - global _theme - if _theme is None: - _theme = Theme() - return _theme - - -def rich_console() -> Console: - return rich.get_console() - - -def rich_display() -> RichDisplay: - global _display - if _display is None: - _display = RichDisplay() - return _display - - -def rich_progress() -> RProgress: - console = rich_console() - return RProgress( - TextColumn("{task.fields[status]}"), - TextColumn("{task.description}"), - TextColumn("{task.fields[model]}"), - BarColumn(bar_width=40 if is_vscode_notebook(console) else None), - TaskProgressColumn(), - TimeElapsedColumn(), - transient=True, - console=console, - expand=not is_vscode_notebook(console), - ) - - -_theme: Theme | None = None -_display: RichDisplay | None = None - - -@contextmanager -def record_console_input() -> Iterator[None]: - # monkey patch .input method to record inputs - input_original = Console.input - - def input_with_record(self: Console, *args: Any, **kwargs: Any) -> str: - result = input_original(self, *args, **kwargs) - if self.record: - with self._record_buffer_lock: - self._record_buffer.append(Segment(result)) - return result - - Console.input = input_with_record # type: ignore - - try: - yield - finally: - Console.input = input_original # type: ignore diff --git a/src/inspect_ai/_display/rich/__init__.py b/src/inspect_ai/_display/rich/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/inspect_ai/_display/rich/display.py b/src/inspect_ai/_display/rich/display.py new file mode 100644 index 000000000..afbf94e00 --- /dev/null +++ b/src/inspect_ai/_display/rich/display.py @@ -0,0 +1,341 @@ +import asyncio +import contextlib +from dataclasses import dataclass +from typing import Any, AsyncIterator, Callable, Coroutine, Iterator + +import rich +from rich.console import Console, Group, RenderableType +from rich.live import Live +from rich.panel import Panel +from rich.progress import Progress as RProgress +from rich.table import Table +from typing_extensions import override + +from inspect_ai._util.constants import CONSOLE_DISPLAY_WIDTH +from inspect_ai._util.display import display_type +from inspect_ai._util.throttle import throttle +from inspect_ai.log._transcript import InputEvent, transcript +from inspect_ai.util._trace import trace_enabled + +from ..core.config import task_config +from ..core.display import ( + TR, + Display, + Progress, + TaskDisplay, + TaskProfile, + TaskResult, + TaskScreen, + TaskSpec, + TaskWithResult, +) +from ..core.footer import task_footer +from ..core.panel import task_panel, task_title, tasks_title +from ..core.progress import ( + RichProgress, + progress_description, + progress_model_name, + progress_status_icon, + rich_progress, +) +from ..core.results import tasks_results +from ..core.rich import ( + is_vscode_notebook, + record_console_input, + rich_initialise, + rich_theme, +) + + +@dataclass +class TaskStatus(TaskWithResult): + progress: RProgress + + +class RichDisplay(Display): + def __init__(self) -> None: + self.total_tasks: int = 0 + self.tasks: list[TaskStatus] = [] + self.progress_ui: RProgress | None = None + self.parallel = False + self.live: Live | None = None + self.timer_handle: asyncio.TimerHandle | None = None + rich_initialise() + + @override + def print(self, message: str) -> None: + rich.get_console().print(message, markup=False, highlight=False) + + @override + @contextlib.contextmanager + def progress(self, total: int) -> Iterator[Progress]: + with rich_progress() as progress: + yield RichProgress(total, progress) + + @override + def run_task_app(self, main: Coroutine[Any, Any, TR]) -> TR: + return asyncio.run(main) + + @override + @contextlib.contextmanager + def suspend_task_app(self) -> Iterator[None]: + yield + + @override + @contextlib.asynccontextmanager + async def task_screen( + self, tasks: list[TaskSpec], parallel: bool + ) -> AsyncIterator[TaskScreen]: + self.total_tasks = len(tasks) + self.tasks = [] + self.progress_ui = rich_progress() + self.parallel = parallel + try: + with ( + Live( + None, + console=rich.get_console(), + transient=True, + auto_refresh=False, + ) as live, + ): + # save reference to live + with RichTaskScreen(live) as task_screen: + self.live = live + + # enque a display update + self.timer_handle = asyncio.get_event_loop().call_later( + 1, self._update_display + ) + + # yield + yield task_screen + + # render task results (re-enable live if necessary) + if not live.is_started: + live.start() + live.transient = False + live.update(tasks_results(self.tasks), refresh=True) + finally: + # clear tasks and progress + self.total_tasks = 0 + self.tasks = [] + self.progress_ui = None + self.parallel = False + self.live = None + if self.timer_handle: + self.timer_handle.cancel() + + @override + @contextlib.contextmanager + def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]: + # if there is no ansi display than all of the below will + # be a no-op, so we print a simple text message for the task + if display_type() == "plain": + rich.get_console().print(task_no_ansi(profile)) + + # for typechekcer + if self.tasks is None: + self.tasks = [] + if self.progress_ui is None: + self.progress_ui = rich_progress() + + status = TaskStatus(profile, None, self.progress_ui) + self.tasks.append(status) + self._update_display() + yield RichTaskDisplay( + status, show_name=self.parallel, on_update=self._update_display + ) + + @throttle(1) + def _update_display(self) -> None: + if ( + self.tasks is not None + and self.tasks + and self.progress_ui is not None + and self.live is not None + and self.live.is_started + ): + if self.parallel: + r = tasks_live_status(self.total_tasks, self.tasks, self.progress_ui) + else: + r = task_live_status(self.tasks, self.progress_ui) + self.live.update(r, refresh=True) + + self.timer_handle = asyncio.get_event_loop().call_later(1, self._update_display) + + +class RichTaskScreen(TaskScreen): + def __init__(self, live: Live) -> None: + self.theme = rich_theme() + self.live = live + status_text = "Working" if trace_enabled() else "Task running" + self.status = self.live.console.status( + f"[{self.theme.meta} bold]{status_text}...[/{self.theme.meta} bold]", + spinner="clock", + ) + + def __exit__(self, *excinfo: Any) -> None: + self.status.stop() + + @override + @contextlib.contextmanager + def input_screen( + self, + header: str | None = None, + transient: bool | None = None, + width: int | None = None, + ) -> Iterator[Console]: + # determine transient based on trace mode + if transient is None: + transient = not trace_enabled() + + # clear live task status and transient status + self.live.update("", refresh=True) + self.status.stop() + + # show cursor for input + self.live.console.show_cursor(True) + + # set width + old_width: int | None = None + if width: + old_width = self.live.console.width + self.live.console.width = min(old_width, width) + + # record console activity for event + self.live.console.record = True + + try: + # print header if requested + if header: + style = f"{rich_theme().meta} bold" + self.live.console.rule(f"[{style}]{header}[/{style}]", style="black") + self.live.console.print("") + + # yield the console + with record_console_input(): + yield self.live.console + + finally: + # capture recording then yield input event + input = self.live.console.export_text(clear=False, styles=False) + input_ansi = self.live.console.export_text(clear=True, styles=True) + self.live.console.record = False + transcript()._event(InputEvent(input=input, input_ansi=input_ansi)) + + # print one blank line + self.live.console.print("") + + # reset width + if old_width: + self.live.console.width = old_width + + # disable cursor while not collecting input + self.live.console.show_cursor(False) + + # if transient then disable live updates entirely + if transient is False and self.live.is_started: + self.live.stop() + + # otherwise make sure they are enabled + elif transient is True and not self.live.is_started: + self.live.start() + + # if not transient then display mini-status + if not transient: + self.status.start() + + +class RichTaskDisplay(TaskDisplay): + def __init__( + self, + status: TaskStatus, + show_name: bool, + on_update: Callable[[], None] | None = None, + ) -> None: + self.status = status + model = progress_model_name(self.status.profile.model) + description = progress_description(self.status.profile) + + def task_status() -> str: + return progress_status_icon(self.status.result) + + self.p = RichProgress( + total=self.status.profile.steps, + progress=self.status.progress, + description=f"{description.markup}", + model=f"{model.markup} ", + status=task_status, + on_update=on_update, + ) + + @override + @contextlib.contextmanager + def progress(self) -> Iterator[Progress]: + yield self.p + + @override + def complete(self, result: TaskResult) -> None: + self.status.result = result + self.p.complete() + + +def task_live_status(tasks: list[TaskStatus], progress: RProgress) -> RenderableType: + theme = rich_theme() + body: list[RenderableType] = ["", progress] + config = task_config(tasks[0].profile, style=theme.light) + if config: + body = [config] + body + + return task_panel( + profile=tasks[0].profile, + show_model=len(tasks) == 1, + body=Group(*body), + footer=task_footer(theme.light), + log_location=None, + ) + + +def tasks_live_status( + total_tasks: int, tasks: list[TaskStatus], progress: RProgress +) -> RenderableType: + # rendering context + theme = rich_theme() + console = rich.get_console() + width = CONSOLE_DISPLAY_WIDTH if is_vscode_notebook(console) else None + + # compute completed tasks + completed = sum(1 for task in tasks if task.result is not None) + + # get config + config = task_config(tasks[0].profile, generate_config=False, style=theme.light) + if config: + config += "\n" + + # build footer table + footer_table = Table.grid(expand=True) + footer_table.add_column() + footer_table.add_column(justify="right") + footer = task_footer(theme.light) + footer_table.add_row() + footer_table.add_row(footer[0], footer[1]) + + # create panel w/ title + panel = Panel( + Group(config, progress, footer_table, fit=False), + title=f"[bold][{theme.meta}]{tasks_title(completed, total_tasks)}[/{theme.meta}][/bold]", + title_align="left", + width=width, + expand=True, + ) + return panel + + +def task_no_ansi(profile: TaskProfile) -> str: + theme = rich_theme() + message = f"Running task {task_title(profile, True)}" + config = task_config(profile, style=theme.light) + if config: + message = f"{message} (config: {config})" + return f"{message}...\n" diff --git a/src/inspect_ai/_display/textual/app.py b/src/inspect_ai/_display/textual/app.py new file mode 100644 index 000000000..8f89a6de3 --- /dev/null +++ b/src/inspect_ai/_display/textual/app.py @@ -0,0 +1,359 @@ +import asyncio +import contextlib +from asyncio import CancelledError +from typing import Any, AsyncIterator, Coroutine, Generic, Iterator + +import rich +from rich.console import Console +from rich.text import Text +from textual.app import App, ComposeResult +from textual.events import Print +from textual.widgets import TabbedContent, TabPane +from textual.worker import Worker, WorkerState +from typing_extensions import override + +from inspect_ai.log._samples import active_samples +from inspect_ai.log._transcript import InputEvent, transcript + +from ..core.config import task_config +from ..core.display import ( + TR, + TaskDisplay, + TaskProfile, + TaskScreen, + TaskSpec, + TaskWithResult, +) +from ..core.footer import task_footer +from ..core.panel import task_targets, task_title, tasks_title +from ..core.rich import record_console_input, rich_initialise, rich_theme +from .theme import inspect_dark, inspect_light +from .widgets.console import ConsoleView +from .widgets.footer import AppFooter +from .widgets.samples import SamplesView +from .widgets.tasks import TasksView +from .widgets.titlebar import AppTitlebar + + +class TaskScreenResult(Generic[TR]): + def __init__( + self, + value: TR | BaseException, + tasks: list[TaskWithResult], + output: list[str], + ) -> None: + self.value = value + self.tasks = tasks + self.output = output + + +class TaskScreenApp(App[TR]): + CSS_PATH = "app.tcss" + + def __init__(self) -> None: + # call super + super().__init__() + + # worker and output + self._worker: Worker[TR] | None = None + self._error: BaseException | None = None + self._output: list[str] = [] + + # task screen + self._total_tasks = 0 + self._parallel = False + self._tasks: list[TaskWithResult] = [] + + # all tasks processed by app + self._app_tasks: list[TaskWithResult] = [] + + # enable rich hooks + rich_initialise() + + def run_app(self, main: Coroutine[Any, Any, TR]) -> TaskScreenResult[TR]: + # create the worker + self._worker = self.run_worker(main, start=False, exit_on_error=False) + + # run the app + self.run() + + # determine result value + if self.return_value is not None: + value: TR | BaseException = self.return_value + elif self._error is not None: + value = self._error + else: + value = CancelledError() + + # return result w/ output + return TaskScreenResult(value=value, tasks=self._app_tasks, output=self._output) + + async def on_load(self) -> None: + # events used to synchronise loading + self._on_load_app = asyncio.Event() + self._on_app_loaded = asyncio.Event() + + # run the workers + self.workers.start_all() + + # wait until we are given the signal to load + await self._on_load_app.wait() + + @contextlib.contextmanager + def suspend_app(self) -> Iterator[None]: + # suspend only if the app is already loaded + # (otherwise its not yet displayed ) + if self._on_app_loaded.is_set(): + with self.app.suspend(): + try: + yield + finally: + self.app.refresh(repaint=True) + else: + yield + + # exit the app when the worker terminates + def on_worker_state_changed(self, event: Worker.StateChanged) -> None: + if event.worker.state == WorkerState.ERROR: + self._error = event.worker.error + self.exit(None, 1) + elif event.worker.state == WorkerState.CANCELLED: + self._error = CancelledError() + self.exit(None, 1) + elif event.worker.state == WorkerState.SUCCESS: + self.exit(event.worker.result) + + # notification that a new top level set of tasks are being run + @contextlib.asynccontextmanager + async def task_screen( + self, tasks: list[TaskSpec], parallel: bool + ) -> AsyncIterator[TaskScreen]: + # indicate its time to load then wait on the load + self._on_load_app.set() + await self._on_app_loaded.wait() + + # reset state + self._tasks = [] + self._total_tasks = len(tasks) + self._parallel = parallel + + # clear existing task progress + tasks_view = self.query_one(TasksView) + tasks_view.init_tasks(tasks) + + # dynamic tab caption for task(s) + tabs = self.query_one(TabbedContent) + tasks_tab = tabs.get_tab("tasks") + tasks_tab.label = Text.from_markup("Tasks" if self._total_tasks > 1 else "Task") + + # update display + self.update_display() + + # force repaint + self.refresh(repaint=True) + + try: + yield TextualTaskScreen(self) + finally: + self._tasks = [] + self._total_tasks = 0 + self._parallel = False + + # notification that a task is running and requires display + @contextlib.contextmanager + def task_display(self, profile: TaskProfile) -> Iterator[TaskDisplay]: + # create and track task + task = TaskWithResult(profile, None) + self._app_tasks.append(task) + self._tasks.append(task) + + # update display + self.update_display() + + # add task + try: + yield self.query_one(TasksView).add_task(task) + finally: + pass + + # compose use + def compose(self) -> ComposeResult: + yield AppTitlebar() + yield AppFooter() + + with TabbedContent(id="tabs", initial="tasks"): + with TabPane("Tasks", id="tasks"): + yield TasksView() + with TabPane("Samples", id="samples"): + yield SamplesView() + with TabPane("Console", id="console"): + yield ConsoleView() + + def on_mount(self) -> None: + # register and set theme + self.register_theme(inspect_dark) + self.register_theme(inspect_light) + self.theme = "inspect-dark" + + # capture stdout/stderr (works w/ on_print) + self.begin_capture_print(self) + + # handle tab activations + self.handle_tab_activations() + + # handle console unread + self.handle_console_unread() + + # update display every second + self.set_interval(1, self.update_display) + + # indicate that the app is loaded + self._on_app_loaded.set() + + # update dynamic parts of display + def update_display(self) -> None: + self.update_title() + self.update_tasks() + self.update_samples() + self.update_footer() + + # update the header title + def update_title(self) -> None: + # determine title + if len(self._tasks) > 0: + if self._parallel: + completed = sum(1 for task in self._tasks if task.result is not None) + title = f"{tasks_title(completed, self._total_tasks)}" + else: + title = f"{task_title(self._tasks[0].profile, show_model=len(self._tasks) == 1)}" + else: + title = "" + + # set if required + header = self.query_one(AppTitlebar) + if header.title != title: + header.title = title + + def update_tasks(self) -> None: + tasks = self.query_one(TasksView) + if len(self._tasks) > 0: + tasks.config = task_config( + self._tasks[0].profile, generate_config=not self._parallel + ) + if not self._parallel: + tasks.targets = task_targets(self._tasks[0].profile) + else: + tasks.targets = " \n " + else: + tasks.config = "" + tasks.targets = "" + + def update_samples(self) -> None: + samples_view = self.query_one(SamplesView) + samples_view.set_samples(active_samples()) + + def update_footer(self) -> None: + left, right = task_footer() + footer = self.query_one(AppFooter) + footer.left = left + footer.right = right + + # track and display console unread state + def handle_console_unread(self) -> None: + # unread management + tabs = self.query_one(TabbedContent) + console_tab = tabs.get_tab("console") + console_view = self.query_one(ConsoleView) + + def set_unread(unread: int | None) -> None: + if unread is not None: + console_tab.label = Text.from_markup(f"Console ({unread})") + else: + console_tab.label = Text.from_markup("Console") + + self.watch(console_view, "unread", set_unread) + + # handle tab activations + def handle_tab_activations(self) -> None: + tabs = self.query_one(TabbedContent) + console_view = self.query_one(ConsoleView) + samples_view = self.query_one(SamplesView) + + async def set_active_tab(active: str) -> None: + await console_view.notify_active(active == "console") + await samples_view.notify_active(active == "samples") + + self.watch(tabs, "active", set_active_tab) + + # capture output and route to console view and our buffer + def on_print(self, event: Print) -> None: + # remove trailing newline + text = event.text + if text.endswith("\n"): + text = text[:-1] + + # track output (for printing at the end) + self._output.append(text) + + # write to console view + self.query_one(ConsoleView).write_ansi(text) + + # map ctrl+c to cancelling the worker + @override + async def action_quit(self) -> None: + if self._worker and self._worker.is_running: + self._worker.cancel() + + +class TextualTaskScreen(TaskScreen, Generic[TR]): + def __init__(self, app: TaskScreenApp[TR]) -> None: + self.app = app + + def __exit__(self, *excinfo: Any) -> None: + pass + + @override + @contextlib.contextmanager + def input_screen( + self, + header: str | None = None, + transient: bool | None = None, + width: int | None = None, + ) -> Iterator[Console]: + with self.app.suspend_app(): + # get rich console + console = rich.get_console() + + # set width + old_width: int | None = None + if width: + old_width = console.width + console.width = min(old_width, width) + + # record console activity for event + console.record = True + + try: + # print header if requested + if header: + style = f"{rich_theme().meta} bold" + console.rule(f"[{style}]{header}[/{style}]", style="black") + console.print("") + + # yield the console + with record_console_input(): + yield console + + finally: + # capture recording then yield input event + input = console.export_text(clear=False, styles=False) + input_ansi = console.export_text(clear=True, styles=True) + console.record = False + transcript()._event(InputEvent(input=input, input_ansi=input_ansi)) + + # print one blank line + console.print("") + + # reset width + if old_width: + console.width = old_width diff --git a/src/inspect_ai/_display/textual/app.tcss b/src/inspect_ai/_display/textual/app.tcss new file mode 100644 index 000000000..b589b4f1a --- /dev/null +++ b/src/inspect_ai/_display/textual/app.tcss @@ -0,0 +1,26 @@ + +ContentTab { + height: 1; + padding: 0 1 0 2; +} + +ContentTabs { + height: 2; +} + +ContentTabs #tabs-list { + min-height: 1; +} + +TabPane { + padding: 0 0 0 0; +} + +Tabs { + &:focus { + & .-active { + background: transparent; + } + } +} + diff --git a/src/inspect_ai/_display/textual/display.py b/src/inspect_ai/_display/textual/display.py new file mode 100644 index 000000000..719840b0d --- /dev/null +++ b/src/inspect_ai/_display/textual/display.py @@ -0,0 +1,71 @@ +import contextlib +from typing import Any, AsyncIterator, Coroutine, Iterator + +import rich +from typing_extensions import override + +from ..core.display import ( + TR, + Display, + Progress, + TaskDisplay, + TaskProfile, + TaskScreen, + TaskSpec, +) +from ..core.progress import RichProgress, rich_progress +from ..core.results import tasks_results +from .app import TaskScreenApp + + +class TextualDisplay(Display): + @override + def print(self, message: str) -> None: + rich.get_console().print(message, markup=False, highlight=False) + + @override + @contextlib.contextmanager + def progress(self, total: int) -> Iterator[Progress]: + with rich_progress() as progress: + yield RichProgress(total, progress) + + @override + def run_task_app(self, main: Coroutine[Any, Any, TR]) -> TR: + # create and run the app + self.app = TaskScreenApp[TR]() + result = self.app.run_app(main) + + # print output + if result.output: + print("\n".join(result.output)) + + # print tasks + rich.print(tasks_results(result.tasks)) + + # raise error as required + if isinstance(result.value, BaseException): + raise result.value + + # success! return value + else: + return result.value + + @override + @contextlib.contextmanager + def suspend_task_app(self) -> Iterator[None]: + with self.app.suspend_app(): + yield + + @override + @contextlib.asynccontextmanager + async def task_screen( + self, tasks: list[TaskSpec], parallel: bool + ) -> AsyncIterator[TaskScreen]: + async with self.app.task_screen(tasks, parallel) as task_screen: + yield task_screen + + @override + @contextlib.contextmanager + def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]: + with self.app.task_display(profile) as task_display: + yield task_display diff --git a/src/inspect_ai/_display/textual/theme.py b/src/inspect_ai/_display/textual/theme.py new file mode 100644 index 000000000..3ab9e1fed --- /dev/null +++ b/src/inspect_ai/_display/textual/theme.py @@ -0,0 +1,30 @@ +from textual.theme import Theme + +inspect_dark = Theme( + name="inspect-dark", + primary="#3376CD", + secondary="#004578", + accent="#ffa62b", + warning="#ffa62b", + error="#ba3c5b", + success="#408558", + foreground="#e0e0e0", +) + + +inspect_light = Theme( + name="inspect-light", + primary="#4283CA", + secondary="#0178D4", + accent="#ffa62b", + warning="#ffa62b", + error="#ba3c5b", + success="#54B98F", + surface="#D8D8D8", + panel="#DFDFDF", + background="#F8F8F8", + dark=False, + variables={ + "footer-key-foreground": "#0178D4", + }, +) diff --git a/src/inspect_ai/_display/textual/widgets/clock.py b/src/inspect_ai/_display/textual/widgets/clock.py new file mode 100644 index 000000000..f22c314a7 --- /dev/null +++ b/src/inspect_ai/_display/textual/widgets/clock.py @@ -0,0 +1,55 @@ +from datetime import datetime + +from textual.reactive import reactive +from textual.timer import Timer +from textual.widgets import Static + +from inspect_ai._display.core.progress import progress_time + + +class Clock(Static): + DEFAULT_CSS = """ + Clock { + color: $primary-lighten-3; + } + """ + + time: reactive[float] = reactive(datetime.now().timestamp) + timer: Timer | None = None + + def __init__(self, interval: int = 1) -> None: + super().__init__() + self.start_time: float | None = None + self.time = datetime.now().timestamp() + self.interval = interval + + def start(self, start_time: float) -> None: + if start_time != self.start_time: + self.stop() + self.start_time = start_time + self.update_time() + self.timer = self.set_interval(self.interval, self.update_time) + + def stop(self) -> None: + self.start_time = None + if self.timer: + self.timer.stop() + self.timer = None + + def on_unmount(self) -> None: + self.stop() + + def watch_start_time(self, start_time: float | None) -> None: + if start_time is not None: + if self.timer is None: + self.timer = self.set_interval(self.interval, self.update_time) + self.update(progress_time(start_time)) + else: + self.stop() + + def update_time(self) -> None: + if self.start_time is not None: + self.time = datetime.now().timestamp() - self.start_time + + def watch_time(self, time: float) -> None: + self.update(progress_time(time)) diff --git a/src/inspect_ai/_display/textual/widgets/console.py b/src/inspect_ai/_display/textual/widgets/console.py new file mode 100644 index 000000000..dda0b3118 --- /dev/null +++ b/src/inspect_ai/_display/textual/widgets/console.py @@ -0,0 +1,52 @@ +from rich.text import Text +from textual.reactive import reactive +from textual.widgets import RichLog + + +class ConsoleView(RichLog): + DEFAULT_CSS = """ + ConsoleView { + scrollbar-size-horizontal: 1; + scrollbar-size-vertical: 1; + scrollbar-gutter: stable; + background: transparent; + } + """ + + # enable tab container to print our unread count + unread: reactive[int | None] = reactive(None) + + def __init__(self) -> None: + super().__init__() + self.active = False + self.show_horizontal_scrollbar = False + + async def notify_active(self, active: bool) -> None: + self.active = active + if self.active: + self.unread = None + + def write_ansi(self, text: str) -> None: + # process line by line + for line in text.splitlines(): + self.write_ansi_line(line) + + # tick unread if we aren't active + if not self.active and len(text.strip()) > 0: + self.unread = (self.unread or 0) + 1 + + def write_ansi_line(self, line: str) -> None: + # tweak rich console lines with path at end to not go under the scrollbar + # (remove two inner spaces and add a space at the end) + if "[2m" in line: + chars = list(line) + removed = 0 + for i in range(len(chars) - 1, -1, -1): + if chars[i].isspace(): + chars.pop(i) + removed += 1 + if removed > 1: + break + line = "".join(chars) + " " + + self.write(Text.from_ansi(line)) diff --git a/src/inspect_ai/_display/textual/widgets/footer.py b/src/inspect_ai/_display/textual/widgets/footer.py new file mode 100644 index 000000000..39938695e --- /dev/null +++ b/src/inspect_ai/_display/textual/widgets/footer.py @@ -0,0 +1,38 @@ +from typing import cast + +from rich.console import RenderableType +from textual.app import ComposeResult +from textual.reactive import reactive +from textual.widget import Widget +from textual.widgets import Static + + +class AppFooter(Widget): + DEFAULT_CSS = """ + AppFooter { + layout: grid; + grid-size: 2 1; + grid-columns: 1fr auto; + grid-gutter: 2; + background: $foreground 5%; + color: $text-muted; + dock: bottom; + height: auto; + padding: 0 1 + } + """ + + left: reactive[RenderableType] = reactive("") + right: reactive[RenderableType] = reactive("") + + def compose(self) -> ComposeResult: + yield Static(id="footer-left") + yield Static(id="footer-right") + + def watch_left(self, new_left: RenderableType) -> None: + footer_left = cast(Static, self.query_one("#footer-left")) + footer_left.update(new_left) + + def watch_right(self, new_right: RenderableType) -> None: + footer_right = cast(Static, self.query_one("#footer-right")) + footer_right.update(new_right) diff --git a/src/inspect_ai/_display/textual/widgets/samples.py b/src/inspect_ai/_display/textual/widgets/samples.py new file mode 100644 index 000000000..44b58cc5b --- /dev/null +++ b/src/inspect_ai/_display/textual/widgets/samples.py @@ -0,0 +1,259 @@ +from typing import cast + +from rich.table import Table +from rich.text import Text +from textual.app import ComposeResult +from textual.containers import Horizontal, HorizontalGroup, VerticalGroup +from textual.widget import Widget +from textual.widgets import Button, LoadingIndicator, OptionList, Static +from textual.widgets.option_list import Option, Separator + +from inspect_ai._util.registry import registry_unqualified_name +from inspect_ai.log._samples import ActiveSample + +from ...core.progress import progress_time +from .clock import Clock +from .transcript import TranscriptView + + +class SamplesView(Widget): + DEFAULT_CSS = """ + SamplesView { + width: 1fr; + height: 1fr; + padding: 0 1 0 1; + layout: grid; + grid-size: 2 2; + grid-rows: 1fr auto; + grid-columns: 30 1fr; + grid-gutter: 1; + } + SamplesView OptionList { + height: 100%; + scrollbar-size-vertical: 1; + margin-bottom: 1; + row-span: 2; + background: transparent; + } + SamplesView OptionList:focus > .option-list--option-highlighted { + background: $primary 40%; + } + + SamplesView OptionList > .option-list--option-highlighted { + background: $primary 40%; + } + + SamplesView TranscriptView { + scrollbar-size-vertical: 1; + scrollbar-gutter: stable; + } + """ + + def __init__(self) -> None: + super().__init__() + self.samples: list[ActiveSample] = [] + + def compose(self) -> ComposeResult: + yield OptionList() + yield TranscriptView() + yield SampleToolbar() + + def on_mount(self) -> None: + self.watch(self.query_one(OptionList), "highlighted", self.set_highlighted) + + async def notify_active(self, active: bool) -> None: + await self.query_one(TranscriptView).notify_active(active) + + def set_samples(self, samples: list[ActiveSample]) -> None: + # check for a highlighted sample (make sure we don't remove it) + option_list = self.query_one(OptionList) + highlighted_id = ( + option_list.get_option_at_index(option_list.highlighted).id + if option_list.highlighted is not None + else None + ) + highlighted_sample = ( + sample_for_id(self.samples, highlighted_id) + if highlighted_id is not None + else None + ) + + # assign the new samples + self.samples = samples.copy() + + # add the highlighted sample if its no longer in the list + if highlighted_sample and (highlighted_sample not in self.samples): + self.samples.append(highlighted_sample) + + # sort the samples by execution time + self.samples.sort(key=lambda sample: sample.execution_time, reverse=True) + + # rebuild the list + option_list.clear_options() + options: list[Option | Separator] = [] + for sample in self.samples: + table = Table.grid(expand=True) + table.add_column() + table.add_column(justify="right") + table.add_column() + task_name = Text.from_markup(f"{registry_unqualified_name(sample.task)}") + task_name.truncate(18, overflow="ellipsis", pad=True) + task_time = Text.from_markup(f"{progress_time(sample.execution_time)}") + table.add_row(task_name, task_time, " ") + sample_id = Text.from_markup(f"id: {sample.sample.id}") + sample_id.truncate(18, overflow="ellipsis", pad=True) + sample_epoch = Text.from_markup(f"epoch: {sample.epoch:.0f}") + table.add_row( + sample_id, + sample_epoch, + " ", + ) + options.append(Option(table, id=sample.id)) + options.append(Separator()) + + option_list.add_options(options) + + # select sample (re-select the highlighted sample if there is one) + if len(self.samples) > 0: + if highlighted_id is not None: + index = sample_index_for_id(self.samples, highlighted_id) + else: + index = 0 + option_list.highlighted = index + option_list.scroll_to_highlight() + + async def set_highlighted(self, highlighted: int | None) -> None: + # alias widgets + option_list = self.query_one(OptionList) + transcript_view = self.query_one(TranscriptView) + sample_toolbar = self.query_one(SampleToolbar) + + # look for a highlighted sample to sync + if highlighted is not None: + highlighted_id = option_list.get_option_at_index(highlighted).id + if highlighted_id is not None: + sample = sample_for_id(self.samples, highlighted_id) + if sample: + await sample_toolbar.sync_sample(sample) + await transcript_view.sync_sample(sample) + return + + +class SampleToolbar(Horizontal): + DEFAULT_CSS = """ + SampleToolbar Button { + margin-bottom: 1; + margin-right: 2; + min-width: 20; + } + SampleToolbar #cancel-score-output { + color: $primary-darken-3; + } + SampleToolbar #cancel-raise-error { + color: $warning-darken-3; + } + """ + + def __init__(self) -> None: + super().__init__() + self.sample: ActiveSample | None = None + + def compose(self) -> ComposeResult: + with VerticalGroup(id="pending-status"): + yield Static("Executing...", id="pending-caption") + yield HorizontalGroup(EventLoadingIndicator(), Clock()) + yield Button( + Text("Cancel (Score)"), + id="cancel-score-output", + tooltip="Cancel the sample and score whatever output has been generated so far.", + ) + yield Button( + Text("Cancel (Error)"), + id="cancel-raise-error", + tooltip="Cancel the sample and raise an error (task will exit unless fail_on_error is set)", + ) + + def on_mount(self) -> None: + self.query_one("#pending-status").visible = False + self.query_one("#cancel-score-output").display = False + self.query_one("#cancel-raise-error").display = False + + def on_button_pressed(self, event: Button.Pressed) -> None: + if self.sample: + if event.button.id == "cancel-score-output": + self.sample.interrupt("score") + elif event.button.id == "cancel-raise-error": + self.sample.interrupt("error") + + async def sync_sample(self, sample: ActiveSample | None) -> None: + from inspect_ai.log._transcript import ModelEvent + + # track the sample + self.sample = sample + + pending_status = self.query_one("#pending-status") + clock = self.query_one(Clock) + cancel_score_output = cast(Button, self.query_one("#cancel-score-output")) + cancel_with_error = cast(Button, self.query_one("#cancel-raise-error")) + if sample and not sample.completed: + # update visibility and button status + self.display = True + cancel_score_output.display = True + cancel_with_error.display = not sample.fails_on_error + + # if we have a pending event then start the clock and show pending status + last_event = ( + sample.transcript.events[-1] + if len(sample.transcript.events) > 0 + else None + ) + if last_event and last_event.pending: + pending_status.visible = True + pending_caption = cast(Static, self.query_one("#pending-caption")) + pending_caption_text = ( + "Generating..." + if isinstance(last_event, ModelEvent) + else "Executing..." + ) + pending_caption.update( + Text.from_markup(f"[italic]{pending_caption_text}[/italic]") + ) + clock.start(last_event.timestamp.timestamp()) + else: + pending_status.visible = False + clock.stop() + + else: + self.display = False + pending_status.visible = False + clock.stop() + + +class EventLoadingIndicator(LoadingIndicator): + DEFAULT_CSS = """ + EventLoadingIndicator { + width: auto; + height: 1; + color: $primary; + text-style: not reverse; + margin-right: 1; + } + """ + + def __init__(self) -> None: + super().__init__() + + +def sample_for_id(samples: list[ActiveSample], id: str) -> ActiveSample | None: + index = sample_index_for_id(samples, id) + if index != -1: + return samples[index] + else: + return None + + +def sample_index_for_id(samples: list[ActiveSample], id: str) -> int: + for i, sample in enumerate(samples): + if sample.id == id: + return i + return -1 diff --git a/src/inspect_ai/_display/textual/widgets/tasks.py b/src/inspect_ai/_display/textual/widgets/tasks.py new file mode 100644 index 000000000..115f4d543 --- /dev/null +++ b/src/inspect_ai/_display/textual/widgets/tasks.py @@ -0,0 +1,195 @@ +import contextlib +from datetime import datetime +from typing import Iterator, cast + +from rich.console import RenderableType +from rich.text import Text +from textual.app import ComposeResult +from textual.containers import Container, ScrollableContainer +from textual.reactive import reactive +from textual.widget import Widget +from textual.widgets import ProgressBar, Static +from typing_extensions import override + +from inspect_ai._display.textual.widgets.clock import Clock + +from ...core.display import ( + Progress, + TaskCancelled, + TaskDisplay, + TaskError, + TaskResult, + TaskSpec, + TaskWithResult, +) +from ...core.progress import ( + MAX_DESCRIPTION_WIDTH, + MAX_MODEL_NAME_WIDTH, + progress_description, + progress_model_name, +) + + +class TasksView(Container): + DEFAULT_CSS = """ + TasksView { + padding: 0 1; + layout: grid; + grid-size: 2 2; + grid-columns: 1fr auto; + grid-rows: auto 1fr; + } + #tasks-progress { + column-span: 2; + scrollbar-size-vertical: 1; + margin-top: 1; + margin-bottom: 1; + } + #tasks-config { + color: $text-muted; + } + #tasks-targets { + text-align: right; + color: $text-muted; + } + """ + + config: reactive[RenderableType] = reactive("") + targets: reactive[RenderableType] = reactive("") + + def __init__(self) -> None: + super().__init__() + self.description_width = MAX_DESCRIPTION_WIDTH + self.model_name_width = MAX_MODEL_NAME_WIDTH + + def init_tasks(self, tasks: list[TaskSpec]) -> None: + # clear existing tasks + self.tasks.remove_children() + + # compute the column widths by looking all of the tasks + self.description_width = min( + max([len(task.name) for task in tasks]), MAX_DESCRIPTION_WIDTH + ) + self.model_name_width = min( + max([len(str(task.model)) for task in tasks]), MAX_MODEL_NAME_WIDTH + ) + + def add_task(self, task: TaskWithResult) -> TaskDisplay: + task_display = TaskProgressView( + task, self.description_width, self.model_name_width + ) + self.tasks.mount(task_display) + self.tasks.scroll_to_widget(task_display) + return task_display + + def compose(self) -> ComposeResult: + yield Static(id="tasks-config") + yield Static(id="tasks-targets") + yield ScrollableContainer(id="tasks-progress") + + def watch_config(self, new_config: RenderableType) -> None: + tasks_config = cast(Static, self.query_one("#tasks-config")) + tasks_config.update(new_config) + + def watch_targets(self, new_targets: RenderableType) -> None: + tasks_targets = cast(Static, self.query_one("#tasks-targets")) + tasks_targets.update(new_targets) + + @property + def tasks(self) -> ScrollableContainer: + return cast(ScrollableContainer, self.query_one("#tasks-progress")) + + +class TaskProgressView(Widget): + DEFAULT_CSS = """ + TaskProgressView { + height: auto; + width: 1fr; + layout: grid; + grid-size: 5 1; + grid-columns: auto auto auto 1fr auto; + grid-gutter: 1; + } + TaskProgressView Bar { + width: 1fr; + &> .bar--bar { + color: $warning 90%; + } + &> .bar--complete { + color: $success; + } + } + """ + + def __init__( + self, task: TaskWithResult, description_width: int, model_name_width: int + ) -> None: + super().__init__() + self.t = task + self.description_width = description_width + self.model_name_width = model_name_width + self.progress_bar = ProgressBar(total=task.profile.steps, show_eta=False) + self.task_progress = TaskProgress(self.progress_bar) + + def compose(self) -> ComposeResult: + yield TaskStatusIcon() + yield Static( + progress_description(self.t.profile, self.description_width, pad=True) + ) + yield Static( + progress_model_name(self.t.profile.model, self.model_name_width, pad=True) + ) + yield self.progress_bar + yield Clock() + + def on_mount(self) -> None: + self.query_one(Clock).start(datetime.now().timestamp()) + + @contextlib.contextmanager + def progress(self) -> Iterator[Progress]: + yield self.task_progress + + def complete(self, result: TaskResult) -> None: + self.t.result = result + self.query_one(TaskStatusIcon).result = result + self.query_one(Clock).stop() + self.task_progress.complete() + + +class TaskStatusIcon(Static): + result: reactive[TaskResult | None] = reactive(None) + + def __init__(self) -> None: + super().__init__() + self.watch_result(None) + + def watch_result(self, new_result: TaskResult | None) -> None: + self.update(self._status_icon(new_result)) + + def _status_icon(self, result: TaskResult | None) -> RenderableType: + error = self.app.current_theme.error or "" + succcess = self.app.current_theme.success or "" + running = self.app.current_theme.secondary or "" + if result: + if isinstance(result, TaskError): + return Text("✗", style=error) + elif isinstance(result, TaskCancelled): + return Text("✗", style=error) + else: + return Text("✔", style=succcess) + else: + return Text("⠿", style=running) + + +class TaskProgress(Progress): + def __init__(self, progress_bar: ProgressBar) -> None: + self.progress_bar = progress_bar + + @override + def update(self, n: int = 1) -> None: + self.progress_bar.update(advance=n) + + @override + def complete(self) -> None: + if self.progress_bar.total is not None: + self.progress_bar.update(progress=self.progress_bar.total) diff --git a/src/inspect_ai/_display/textual/widgets/titlebar.py b/src/inspect_ai/_display/textual/widgets/titlebar.py new file mode 100644 index 000000000..927162d98 --- /dev/null +++ b/src/inspect_ai/_display/textual/widgets/titlebar.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import Iterator + +from rich.console import RenderableType +from rich.text import Text +from textual.reactive import Reactive +from textual.widget import Widget + + +class AppTitlebar(Widget): + DEFAULT_CSS = """ + AppTitlebar { + dock: top; + width: 100%; + background: $panel; + color: $primary; + height: 1; + text-style: bold; + } + """ + + DEFAULT_CLASSES = "" + + def __init__( + self, + *, + name: str | None = None, + id: str | None = None, + classes: str | None = None, + ): + """Initialise the header widget. + + Args: + name: The name of the header widget. + id: The ID of the header widget in the DOM. + classes: The CSS classes of the header widget. + """ + super().__init__(name=name, id=id, classes=classes) + + def compose(self) -> Iterator[Widget]: + yield AppTitlebarTitle() + + @property + def title(self) -> str: + return self._header_title().text + + @title.setter + def title(self, title: str) -> None: + self._header_title().text = title + + @property + def sub_title(self) -> str: + return self._header_title().sub_text + + @sub_title.setter + def sub_title(self, sub_title: str) -> None: + self._header_title().sub_text = sub_title + + def _header_title(self) -> AppTitlebarTitle: + return self.query_one(AppTitlebarTitle) + + +class AppTitlebarTitle(Widget): + """Display the title / subtitle in the header.""" + + DEFAULT_CSS = """ + AppTitlebarTitle { + content-align: center middle; + width: 100%; + } + """ + + text: Reactive[str] = Reactive("") + """The main title text.""" + + sub_text = Reactive("") + """The sub-title text.""" + + def render(self) -> RenderableType: + """Render the title and sub-title. + + Returns: + The value to render. + """ + text = Text(self.text, no_wrap=True, overflow="ellipsis") + if self.sub_text: + text.append(" — ") + text.append(self.sub_text, "dim") + return text diff --git a/src/inspect_ai/_display/textual/widgets/transcript.py b/src/inspect_ai/_display/textual/widgets/transcript.py new file mode 100644 index 000000000..6e5fe6e25 --- /dev/null +++ b/src/inspect_ai/_display/textual/widgets/transcript.py @@ -0,0 +1,344 @@ +from typing import Any, Callable, NamedTuple, Sequence, Type + +from pydantic_core import to_json +from rich.console import Group, RenderableType +from rich.markdown import Markdown +from rich.table import Table +from rich.text import Text +from textual.containers import ScrollableContainer +from textual.widget import Widget +from textual.widgets import Static + +from inspect_ai._util.content import ContentText +from inspect_ai._util.format import format_function_call +from inspect_ai._util.rich import lines_display +from inspect_ai._util.transcript import ( + set_transcript_markdown_options, + transcript_markdown, + transcript_separator, +) +from inspect_ai.log._samples import ActiveSample +from inspect_ai.log._transcript import ( + ApprovalEvent, + ErrorEvent, + Event, + InfoEvent, + InputEvent, + LoggerEvent, + ModelEvent, + SampleInitEvent, + SampleLimitEvent, + ScoreEvent, + StepEvent, + SubtaskEvent, + ToolEvent, +) +from inspect_ai.model._chat_message import ChatMessage, ChatMessageUser +from inspect_ai.model._render import messages_preceding_assistant +from inspect_ai.tool._tool import ToolResult + + +class TranscriptView(ScrollableContainer): + def __init__(self) -> None: + super().__init__() + self._sample_id: str | None = None + self._sample_events: int | None = None + + self._active = False + self._pending_sample: ActiveSample | None = None + + async def notify_active(self, active: bool) -> None: + self._active = active + if self._active and self._pending_sample: + await self.sync_sample(self._pending_sample) + self._pending_sample = None + + async def sync_sample(self, sample: ActiveSample | None) -> None: + # if sample is none then reset + if sample is None: + self._sample = None + self._sample_events = None + await self.remove_children() + + # process sample if we are active + elif self._active: + # if we have either a new sample or a new event count then proceed + if ( + sample.id != self._sample_id + or len(sample.transcript.events) != self._sample_events + ): + # update (scrolling to end if we are already close to it) + new_sample = sample.id != self._sample_id + scroll_to_end = ( + new_sample or abs(self.scroll_y - self.max_scroll_y) <= 20 + ) + async with self.batch(): + await self.remove_children() + await self.mount_all( + self._widgets_for_events(sample.transcript.events) + ) + if scroll_to_end: + self.scroll_end(animate=not new_sample) + + # set members + self._sample_id = sample.id + self._sample_events = len(sample.transcript.events) + + # if we aren't active then save as a pending sample + else: + self._pending_sample = sample + + def _widgets_for_events(self, events: Sequence[Event]) -> list[Widget]: + widgets: list[Widget] = [] + for event in events: + display = render_event(event) + if display: + for d in display: + if d.content: + widgets.append( + Static( + transcript_separator( + d.title, self.app.current_theme.primary + ) + ) + ) + if isinstance(d.content, Markdown): + set_transcript_markdown_options(d.content) + widgets.append(Static(d.content)) + widgets.append(Static(Text(" "))) + return widgets + + +class EventDisplay(NamedTuple): + """Display for an event group.""" + + title: str + """Text for title bar""" + + content: RenderableType | None = None + """Optional custom content to display.""" + + +def render_event(event: Event) -> list[EventDisplay] | None: + # see if we have a renderer + for event_type, renderer in _renderers: + if isinstance(event, event_type): + display = renderer(event) + if display is not None: + return display if isinstance(display, list) else [display] + + # no renderer + return None + + +def render_sample_init_event(event: SampleInitEvent) -> EventDisplay: + # alias sample + sample = event.sample + + # input + messages: list[ChatMessage] = ( + [ChatMessageUser(content=sample.input)] + if isinstance(sample.input, str) + else sample.input + ) + content: list[RenderableType] = [] + for message in messages: + content.extend(render_message(message)) + + # target + if sample.target: + content.append(Text()) + content.append(Text("Target", style="bold")) + content.append(Text()) + content.append(str(sample.target).strip()) + + return EventDisplay("sample init", Group(*content)) + + +def render_sample_limit_event(event: SampleLimitEvent) -> EventDisplay: + return EventDisplay(f"limit: {event.type}", Text(event.message)) + + +def render_model_event(event: ModelEvent) -> EventDisplay: + # content + content: list[RenderableType] = [] + + def append_message(message: ChatMessage, text: str | None = None) -> None: + content.extend(render_message(message, text)) + + # render preceding messages + preceding = messages_preceding_assistant(event.input) + for message in preceding: + append_message(message) + content.append(Text()) + + # display assistant message (note that we don't render tool calls + # because they will be handled as part of render_tool) + if event.output.message and event.output.message.text: + append_message(event.output.message) + + return EventDisplay(f"model: {event.model}", Group(*content)) + + +def render_tool_event(event: ToolEvent) -> list[EventDisplay]: + # render sub-events + display: list[EventDisplay] = [] + if event.events: + for e in event.events: + display.extend(render_event(e) or []) + + # render the call + content: list[RenderableType] = [] + if event.view: + if event.view.title: + content.append(Text.from_markup(f"[bold]{event.view.title}[/bold]\n")) + if event.view.format == "markdown": + content.append(transcript_markdown(event.view.content)) + else: + content.append(event.view.content) + else: + content.append(render_function_call(event.function, event.arguments)) + content.append(Text()) + + # render the output + if isinstance(event.result, list): + result: ToolResult = "\n".join( + [ + content.text + for content in event.result + if isinstance(content, ContentText) + ] + ) + else: + result = event.result + + if result: + result = str(result).strip() + content.extend(lines_display(result, 50)) + + return display + [EventDisplay("tool call", Group(*content))] + + +def render_step_event(event: StepEvent) -> EventDisplay: + if event.type == "solver": + return render_solver_event(event) + if event.type == "scorer": + return render_scorer_event(event) + else: + return EventDisplay(step_title(event)) + + +def render_solver_event(event: StepEvent) -> EventDisplay: + return EventDisplay(step_title(event)) + + +def render_scorer_event(event: StepEvent) -> EventDisplay: + return EventDisplay(step_title(event)) + + +def render_score_event(event: ScoreEvent) -> EventDisplay: + table = Table(box=None, show_header=False) + table.add_column("", min_width=10, justify="left") + table.add_column("", justify="left") + table.add_row("Target", str(event.target).strip()) + if event.score.answer: + table.add_row("Answer", transcript_markdown(event.score.answer)) + table.add_row("Score", str(event.score.value).strip()) + if event.score.explanation: + table.add_row("Explanation", transcript_markdown(event.score.explanation)) + + return EventDisplay("score", table) + + +def render_subtask_event(event: SubtaskEvent) -> list[EventDisplay]: + # render sub-events + display: list[EventDisplay] = [] + if event.events: + for e in event.events: + display.extend(render_event(e) or []) + + content: list[RenderableType] = [render_function_call(event.name, event.input)] + content.append(Text()) + content.append(render_as_json(event.result)) + + return display + [EventDisplay(f"subtask: {event.name}", Group(*content))] + + +def render_input_event(event: InputEvent) -> EventDisplay: + return EventDisplay("input", Text.from_ansi(event.input_ansi.strip())) + + +def render_approval_event(event: ApprovalEvent) -> EventDisplay: + content: list[RenderableType] = [ + f"[bold]{event.approver}[/bold]: {event.decision} ({event.explanation})" + ] + + return EventDisplay("approval", Group(*content)) + + +def render_info_event(event: InfoEvent) -> EventDisplay: + if isinstance(event.data, str): + content: RenderableType = transcript_markdown(event.data) + else: + content = render_as_json(event.data) + return EventDisplay("info", content) + + +def render_logger_event(event: LoggerEvent) -> EventDisplay: + content = event.message.level.upper() + if event.message.name: + content = f"{content} (${event.message.name})" + content = f"{content}: {event.message.message}" + return EventDisplay("logger", content) + + +def render_error_event(event: ErrorEvent) -> EventDisplay: + return EventDisplay("error", event.error.traceback.strip()) + + +def render_function_call(function: str, arguments: dict[str, Any]) -> RenderableType: + call = format_function_call(function, arguments) + return transcript_markdown("```python\n" + call + "\n```\n") + + +def render_as_json(json: Any) -> RenderableType: + return transcript_markdown( + "```json\n" + + to_json(json, indent=2, fallback=lambda _: None).decode() + + "\n```\n" + ) + + +def render_message( + message: ChatMessage, text: str | None = None +) -> list[RenderableType]: + content: list[RenderableType] = [ + Text(message.role.capitalize(), style="bold"), + Text(), + ] + text = text or message.text + if text: + content.extend([transcript_markdown(text.strip())]) + return content + + +def step_title(event: StepEvent) -> str: + return f"{event.type or 'step'}: {event.name}" + + +EventRenderer = Callable[[Any], EventDisplay | list[EventDisplay] | None] + +_renderers: list[tuple[Type[Event], EventRenderer]] = [ + (SampleInitEvent, render_sample_init_event), + (SampleLimitEvent, render_sample_limit_event), + (StepEvent, render_step_event), + (ModelEvent, render_model_event), + (ToolEvent, render_tool_event), + (SubtaskEvent, render_subtask_event), + (ScoreEvent, render_score_event), + (InputEvent, render_input_event), + (ApprovalEvent, render_approval_event), + (InfoEvent, render_info_event), + (LoggerEvent, render_logger_event), + (ErrorEvent, render_error_event), +] diff --git a/src/inspect_ai/_eval/context.py b/src/inspect_ai/_eval/context.py index d0917e238..02b816911 100644 --- a/src/inspect_ai/_eval/context.py +++ b/src/inspect_ai/_eval/context.py @@ -1,16 +1,14 @@ -from inspect_ai._display.logger import init_logger from inspect_ai._util.dotenv import init_dotenv from inspect_ai._util.hooks import init_hooks -from inspect_ai._util.logger import init_http_rate_limit_count +from inspect_ai._util.logger import init_http_rate_limit_count, init_logger +from inspect_ai.log._samples import init_active_samples from inspect_ai.model import GenerateConfig, Model from inspect_ai.model._model import init_active_model, init_model_usage from inspect_ai.util._concurrency import init_concurrency from inspect_ai.util._subprocess import init_max_subprocesses -from inspect_ai.util._trace import init_trace def init_eval_context( - trace: bool | None, log_level: str | None, log_level_transcript: str | None, max_subprocesses: int | None = None, @@ -21,7 +19,7 @@ def init_eval_context( init_max_subprocesses(max_subprocesses) init_http_rate_limit_count() init_hooks() - init_trace(trace) + init_active_samples() def init_task_context(model: Model, config: GenerateConfig = GenerateConfig()) -> None: diff --git a/src/inspect_ai/_eval/eval.py b/src/inspect_ai/_eval/eval.py index d1d38b099..414dd2677 100644 --- a/src/inspect_ai/_eval/eval.py +++ b/src/inspect_ai/_eval/eval.py @@ -1,4 +1,3 @@ -import asyncio import logging import os from pathlib import Path @@ -8,6 +7,7 @@ from typing_extensions import Unpack from inspect_ai._cli.util import parse_cli_args +from inspect_ai._display.core.active import display from inspect_ai._util.config import resolve_args from inspect_ai._util.constants import DEFAULT_LOG_FORMAT from inspect_ai._util.error import PrerequisiteError @@ -33,6 +33,7 @@ from inspect_ai.solver._chain import chain from inspect_ai.solver._solver import Solver, SolverSpec from inspect_ai.util import SandboxEnvironmentType +from inspect_ai.util._trace import init_trace from .context import init_eval_context from .loader import ResolvedTask, resolve_tasks @@ -140,8 +141,11 @@ def eval( # standard platform init for top level entry points platform_init() - return asyncio.run( - eval_async( + # resolve eval trace + max_tasks, max_samples = init_eval_trace(trace, max_tasks, max_samples, model) + + return display().run_task_app( + main=eval_async( tasks=tasks, model=model, model_base_url=model_base_url, @@ -295,7 +299,6 @@ async def eval_async( model_args=model_args, task_args=task_args, sandbox=sandbox, - trace=trace, approval=approval, max_subprocesses=max_subprocesses, log_level=log_level, @@ -490,10 +493,14 @@ def eval_retry( Returns: List of EvalLog (one for each task) """ + # standard platform init for top level entry points platform_init() - return asyncio.run( - eval_retry_async( + # resolve eval trace + max_tasks, max_samples = init_eval_trace(trace, max_tasks, max_samples) + + return display().run_task_app( + main=eval_retry_async( tasks=tasks, log_level=log_level, log_level_transcript=log_level_transcript, @@ -503,7 +510,6 @@ def eval_retry( max_tasks=max_tasks, max_subprocesses=max_subprocesses, sandbox_cleanup=sandbox_cleanup, - trace=trace, fail_on_error=fail_on_error, debug_errors=debug_errors, log_samples=log_samples, @@ -513,7 +519,7 @@ def eval_retry( max_retries=max_retries, timeout=timeout, max_connections=max_connections, - ) + ), ) @@ -527,7 +533,6 @@ async def eval_retry_async( max_tasks: int | None = None, max_subprocesses: int | None = None, sandbox_cleanup: bool | None = None, - trace: bool | None = None, fail_on_error: bool | float | None = None, debug_errors: bool | None = None, log_samples: bool | None = None, @@ -558,7 +563,6 @@ async def eval_retry_async( run in parallel (default is os.cpu_count()) sandbox_cleanup (bool | None): Cleanup sandbox environments after task completes (defaults to True) - trace (bool | None): Trace message interactions with evaluated model to terminal. fail_on_error (bool | float | None): `True` to fail on first sample error (default); `False` to never fail on sample errors; Value between 0 and 1 to fail if a proportion of total samples fails. Value greater than 1 to fail @@ -640,7 +644,6 @@ async def eval_retry_async( if eval_log.eval.config.epochs else None ) - trace = eval_log.eval.config.trace or trace approval = eval_log.eval.config.approval message_limit = eval_log.eval.config.message_limit token_limit = eval_log.eval.config.token_limit @@ -687,7 +690,6 @@ async def eval_retry_async( sandbox_cleanup=sandbox_cleanup, solver=solver, tags=tags, - trace=trace, approval=approval, log_level=log_level, log_level_transcript=log_level_transcript, @@ -724,7 +726,6 @@ def eval_init( model_args: dict[str, Any] | str = dict(), task_args: dict[str, Any] | str = dict(), sandbox: SandboxEnvironmentType | None = None, - trace: bool | None = None, approval: str | list[ApprovalPolicy] | ApprovalPolicyConfig | None = None, max_subprocesses: int | None = None, log_level: str | None = None, @@ -732,7 +733,7 @@ def eval_init( **kwargs: Unpack[GenerateConfigArgs], ) -> tuple[list[Model], list[ApprovalPolicy] | None, list[ResolvedTask]]: # init eval context - init_eval_context(trace, log_level, log_level_transcript, max_subprocesses) + init_eval_context(log_level, log_level_transcript, max_subprocesses) # resolve model and task args model_args = resolve_args(model_args) @@ -751,10 +752,13 @@ def eval_init( # resolve tasks (set active model to resolve uses of the # 'default' model in tools, solvers, and scorers) - resolved_tasks: list[ResolvedTask] = [] - for m in models: - init_active_model(m, generate_config) - resolved_tasks.extend(resolve_tasks(tasks, task_args, m, sandbox)) + from inspect_ai._display.core.active import display + + with display().suspend_task_app(): + resolved_tasks: list[ResolvedTask] = [] + for m in models: + init_active_model(m, generate_config) + resolved_tasks.extend(resolve_tasks(tasks, task_args, m, sandbox)) # resolve approval if isinstance(approval, str | ApprovalPolicyConfig): @@ -764,6 +768,33 @@ def eval_init( return models, approval, resolved_tasks +def init_eval_trace( + trace: bool | None, + max_tasks: int | None, + max_samples: int | None, + model: Any = None, +) -> tuple[int | None, int | None]: + # init trace setting + init_trace(trace) + + # adapt task/samples as required + if trace: + # single task at a time + if max_tasks is not None: + max_tasks = 1 + + # single sample at a time + max_samples = 1 + + # multiple models not allowed in trace mode + if isinstance(model, list) and len(model) > 1: + raise PrerequisiteError( + "Trace mode cannot be used when evaluating multiple models." + ) + + return max_tasks, max_samples + + # A list of eval logs is returned from eval(). We've already displayed # all of the output we need to to though, so we make the return # value 'invisible' diff --git a/src/inspect_ai/_eval/run.py b/src/inspect_ai/_eval/run.py index 99e5af830..886e87c98 100644 --- a/src/inspect_ai/_eval/run.py +++ b/src/inspect_ai/_eval/run.py @@ -7,12 +7,18 @@ from typing_extensions import Unpack from inspect_ai._display import display -from inspect_ai._display._display import clear_task_screen, init_task_screen +from inspect_ai._display.core.active import ( + clear_task_screen, + init_task_screen, +) +from inspect_ai._display.core.display import TaskSpec from inspect_ai._util.error import exception_message from inspect_ai._util.path import chdir +from inspect_ai._util.registry import registry_unqualified_name from inspect_ai.log import EvalConfig, EvalLog from inspect_ai.log._recorders import Recorder from inspect_ai.model import GenerateConfigArgs +from inspect_ai.model._model import ModelName from inspect_ai.scorer._reducer import ScoreReducer, reducer_log_names from inspect_ai.scorer._reducer.registry import validate_reducer from inspect_ai.solver._solver import Solver, SolverSpec @@ -207,9 +213,10 @@ async def eval_run( async def run_single(tasks: list[TaskRunOptions]) -> list[EvalLog]: # https://discuss.python.org/t/asyncio-cancel-a-cancellation-utility-as-a-coroutine-this-time-with-feeling/26304/3 - with display().task_screen(total_tasks=len(tasks), parallel=False) as screen: + async with display().task_screen(task_specs(tasks), parallel=False) as screen: init_task_screen(screen) asyncio_tasks = [asyncio.create_task(task_run(task)) for task in tasks] + try: return await asyncio.gather(*asyncio_tasks) except asyncio.CancelledError: @@ -221,9 +228,9 @@ async def run_single(tasks: list[TaskRunOptions]) -> list[EvalLog]: task.cancel() await task results.append(task.result()) + return results finally: clear_task_screen() - return results # multiple mode -- run multiple logical tasks (requires some smart @@ -285,8 +292,8 @@ async def worker() -> None: break # with task display - with display().task_screen(total_tasks=len(tasks), parallel=True) as screen: - # set screen + async with display().task_screen(task_specs(tasks), parallel=True) as screen: + # init screen init_task_screen(screen) # start worker tasks @@ -325,18 +332,21 @@ async def startup_sandbox_environments( # initialiase sandboxenvs (track cleanups) cleanups: list[tuple[TaskCleanup, str | None, str]] = [] - for sandboxenv in sandboxenvs: - # find type - sandboxenv_type = registry_find_sandboxenv(sandboxenv.sandbox.type) - - # run startup - task_init = cast(TaskInit, getattr(sandboxenv_type, "task_init")) - with chdir(sandboxenv.run_dir): - await task_init("startup", sandboxenv.sandbox.config) - - # append cleanup method - task_cleanup = cast(TaskCleanup, getattr(sandboxenv_type, "task_cleanup")) - cleanups.append((task_cleanup, sandboxenv.sandbox.config, sandboxenv.run_dir)) + with display().suspend_task_app(): + for sandboxenv in sandboxenvs: + # find type + sandboxenv_type = registry_find_sandboxenv(sandboxenv.sandbox.type) + + # run startup + task_init = cast(TaskInit, getattr(sandboxenv_type, "task_init")) + with chdir(sandboxenv.run_dir): + await task_init("startup", sandboxenv.sandbox.config) + + # append cleanup method + task_cleanup = cast(TaskCleanup, getattr(sandboxenv_type, "task_cleanup")) + cleanups.append( + (task_cleanup, sandboxenv.sandbox.config, sandboxenv.run_dir) + ) # return shutdown method async def shutdown() -> None: @@ -351,3 +361,10 @@ async def shutdown() -> None: ) return shutdown + + +def task_specs(tasks: list[TaskRunOptions]) -> list[TaskSpec]: + return [ + TaskSpec(registry_unqualified_name(task.task.name), ModelName(task.model)) + for task in tasks + ] diff --git a/src/inspect_ai/_eval/task/run.py b/src/inspect_ai/_eval/task/run.py index 4fcb66df0..6c68f1a61 100644 --- a/src/inspect_ai/_eval/task/run.py +++ b/src/inspect_ai/_eval/task/run.py @@ -9,12 +9,12 @@ from typing_extensions import Unpack -from inspect_ai._display import display -from inspect_ai._display._display import ( +from inspect_ai._display import ( TaskCancelled, TaskError, TaskProfile, TaskSuccess, + display, ) from inspect_ai._util.constants import ( DEFAULT_EPOCHS, @@ -42,6 +42,7 @@ from inspect_ai.log._condense import condense_sample from inspect_ai.log._file import eval_log_json from inspect_ai.log._log import EvalSampleLimit, EvalSampleReductions, eval_error +from inspect_ai.log._samples import ActiveSample, active_sample from inspect_ai.log._transcript import ( ErrorEvent, SampleInitEvent, @@ -258,6 +259,10 @@ async def generate( log_images=log_images, sample_source=sample_source, sample_error=sample_error_handler, + fails_on_error=( + config.fail_on_error is None + or config.fail_on_error is True + ), time_limit=config.time_limit, semaphore=sample_semaphore, ) @@ -362,6 +367,7 @@ async def task_run_sample( log_images: bool, sample_source: EvalSampleSource | None, sample_error: Callable[[BaseException], EvalError], + fails_on_error: bool, time_limit: int | None, semaphore: asyncio.Semaphore | None, ) -> dict[str, SampleScore] | None: @@ -399,7 +405,7 @@ async def task_run_sample( # initialise subtask and scoring context init_sample_model_usage() set_sample_state(state) - init_subtask(SAMPLE_SUBTASK, state.store) + sample_transcript = init_subtask(SAMPLE_SUBTASK, state.store) if scorers: init_scoring_context(scorers, Target(sample.target)) @@ -415,8 +421,27 @@ async def task_run_sample( timeout(time_limit) if time_limit is not None else contextlib.nullcontext() ) + # helper to handle exceptions (will throw if we've exceeded the limit) + def handle_error(ex: BaseException) -> EvalError: + err = sample_error(ex) + transcript()._event(ErrorEvent(error=err)) + return err + # solver loop - async with semaphore_cm, sandboxenv_cm: + async with ( + semaphore_cm, + sandboxenv_cm, + active_sample( + ActiveSample( + task_name, + str(state.model), + sample, + state.epoch, + fails_on_error, + sample_transcript, + ) + ) as active, + ): error: EvalError | None = None try: async with timeout_cm: @@ -439,24 +464,38 @@ async def task_run_sample( transcript()._event( SampleLimitEvent( type="time", - limit=time_limit, message=f"Sample completed: exceeded time limit ({time_limit:,} seconds)", + limit=time_limit, ) ) # capture most recent state for scoring state = sample_state() or state - except asyncio.CancelledError: - # allow cancelled error to propagate - raise + except asyncio.CancelledError as ex: + if active.interrupt_action: + # record eve t + transcript()._event( + SampleLimitEvent( + type="operator", + message="Sample completed: interrupted by operator", + ) + ) - except BaseException as ex: - # handle error (this will throw if we've exceeded the limit) - error = sample_error(ex) + # handle the action + match active.interrupt_action: + case "score": + # continue to scoring (capture the most recent state) + state = sample_state() or state + case "error": + # default error handling + error = handle_error(ex) - # fire error event - transcript()._event(ErrorEvent(error=error)) + else: + raise + + except BaseException as ex: + error = handle_error(ex) # set timeout for scoring. if the original timeout was never hit # then just create a new timeout_cm targeting the original @@ -501,7 +540,14 @@ async def task_run_sample( results[scorer_name] = sample_score except asyncio.CancelledError: - # allow cancelled error to propagate + if active.interrupt_action: + transcript()._event( + SampleLimitEvent( + type="operator", + message="Unable to score sample due to operator interruption", + ) + ) + raise except BaseException as ex: @@ -510,16 +556,13 @@ async def task_run_sample( transcript()._event( SampleLimitEvent( type="time", - limit=time_limit, message=f"Unable to score sample due to exceeded time limit ({time_limit:,} seconds)", + limit=time_limit, ) ) # handle error (this will throw if we've exceeded the limit) - error = sample_error(ex) - - # fire error event - transcript()._event(ErrorEvent(error=error)) + error = handle_error(ex) progress() @@ -567,11 +610,10 @@ def log_sample( ) # construct sample for logging - sample_events = transcript().events # if a limit was hit, note that in the Eval Sample limit = None - for e in sample_events: + for e in transcript().events: if e.event == "sample_limit": limit = EvalSampleLimit( type=e.type, limit=e.limit if e.limit is not None else -1 @@ -592,7 +634,7 @@ def log_sample( output=state.output, scores=cast(dict[str, Score], scores), store=dict(state.store.items()), - events=sample_events, + events=list(transcript().events), model_usage=sample_model_usage(), error=error, limit=limit, diff --git a/src/inspect_ai/_util/ansi.py b/src/inspect_ai/_util/ansi.py deleted file mode 100644 index 1d42c0a47..000000000 --- a/src/inspect_ai/_util/ansi.py +++ /dev/null @@ -1,5 +0,0 @@ -import os - - -def no_ansi() -> bool: - return os.environ.get("INSPECT_NO_ANSI", None) is not None diff --git a/src/inspect_ai/_util/constants.py b/src/inspect_ai/_util/constants.py index 331f3949a..5a3ce99fd 100644 --- a/src/inspect_ai/_util/constants.py +++ b/src/inspect_ai/_util/constants.py @@ -30,6 +30,7 @@ ALL_LOG_FORMATS = ["eval", "json"] DEFAULT_LOG_FORMAT: Literal["eval", "json"] = "eval" EVAL_LOG_FORMAT = "eval" +DEFAULT_DISPLAY = "rich" LOG_SCHEMA_VERSION = 2 SCORED_SUFFIX = "-scored" SAMPLE_SUBTASK = "sample" diff --git a/src/inspect_ai/_util/display.py b/src/inspect_ai/_util/display.py new file mode 100644 index 000000000..100732b13 --- /dev/null +++ b/src/inspect_ai/_util/display.py @@ -0,0 +1,34 @@ +import os +from logging import getLogger +from typing import Literal + +from inspect_ai._util.constants import DEFAULT_DISPLAY + +logger = getLogger(__name__) + +DisplayType = Literal["full", "rich", "plain", "none"] + + +_display_type: DisplayType | None = None + + +def init_display_type(display: str | None = None) -> DisplayType: + global _display_type + display = ( + display or os.environ.get("INSPECT_DISPLAY", DEFAULT_DISPLAY).lower().strip() + ) + match display: + case "full" | "rich" | "plain" | "none": + _display_type = display + case _: + logger.warning(f"Unknown display type '{display}'") + _display_type = "full" + return _display_type + + +def display_type() -> DisplayType: + global _display_type + if _display_type: + return _display_type + else: + return init_display_type() diff --git a/src/inspect_ai/_util/logger.py b/src/inspect_ai/_util/logger.py index 479328bf9..a404f73ef 100644 --- a/src/inspect_ai/_util/logger.py +++ b/src/inspect_ai/_util/logger.py @@ -1,11 +1,157 @@ +import os from contextvars import ContextVar -from logging import INFO, Logger, LogRecord +from logging import ( + INFO, + WARNING, + FileHandler, + Formatter, + Logger, + LogRecord, + addLevelName, + getLevelName, + getLogger, +) -from inspect_ai.log._message import LoggingMessage -from inspect_ai.log._transcript import LoggerEvent, transcript +import rich +from rich.console import ConsoleRenderable +from rich.logging import RichHandler +from rich.text import Text +from typing_extensions import override + +from inspect_ai._util.constants import ( + ALL_LOG_LEVELS, + DEFAULT_LOG_LEVEL, + DEFAULT_LOG_LEVEL_TRANSCRIPT, + HTTP, + HTTP_LOG_LEVEL, + PKG_NAME, + SANDBOX, + SANDBOX_LOG_LEVEL, +) +from inspect_ai._util.error import PrerequisiteError + + +# log handler that filters messages to stderr and the log file +class LogHandler(RichHandler): + def __init__(self, levelno: int, transcript_levelno: int) -> None: + super().__init__(levelno, console=rich.get_console()) + self.transcript_levelno = transcript_levelno + self.display_level = WARNING + # log into an external file if requested via env var + file_logger = os.environ.get("INSPECT_PY_LOGGER_FILE", None) + self.file_logger = FileHandler(file_logger) if file_logger else None + if self.file_logger: + self.file_logger.setFormatter( + Formatter("%(asctime)s - %(levelname)s - %(message)s") + ) + + # see if the user has a special log level for the file + file_logger_level = os.environ.get("INSPECT_PY_LOGGER_LEVEL", "") + if file_logger_level: + self.file_logger_level = int(getLevelName(file_logger_level.upper())) + else: + self.file_logger_level = 0 + + @override + def emit(self, record: LogRecord) -> None: + # demote httpx and return notifications to log_level http + if ( + record.name == "httpx" + or "http" in record.name + or "Retrying request" in record.getMessage() + ): + record.levelno = HTTP + record.levelname = HTTP_LOG_LEVEL + + # skip httpx event loop is closed errors + if "Event loop is closed" in record.getMessage(): + return + + # write to stderr if we are at or above the threshold + if record.levelno >= self.display_level: + super().emit(record) + + # write to file if the log file level matches. if the + # user hasn't explicitly specified a level then we + # take the minimum of 'info' and the display level + if self.file_logger and record.levelno >= ( + self.file_logger_level or min(self.display_level, INFO) + ): + self.file_logger.emit(record) + + # eval log always gets info level and higher records + # eval log only gets debug or http if we opt-in + write = record.levelno >= self.transcript_levelno + notify_logger_record(record, write) + + @override + def render_message(self, record: LogRecord, message: str) -> ConsoleRenderable: + return Text.from_ansi(message) + + +# initialize logging -- this function can be called multiple times +# in the lifetime of the process (the levelno will update globally) +def init_logger( + log_level: str | None = None, log_level_transcript: str | None = None +) -> None: + # backwards compatibility for 'tools' + if log_level == "tools": + log_level = "sandbox" + + # register http and tools levels + addLevelName(HTTP, HTTP_LOG_LEVEL) + addLevelName(SANDBOX, SANDBOX_LOG_LEVEL) + + def validate_level(option: str, level: str) -> None: + if level not in ALL_LOG_LEVELS: + log_levels = ", ".join([level.lower() for level in ALL_LOG_LEVELS]) + raise PrerequisiteError( + f"Invalid {option} '{level.lower()}'. Log level must be one of {log_levels}" + ) + + # resolve default log level + log_level = ( + log_level if log_level else os.getenv("INSPECT_LOG_LEVEL", DEFAULT_LOG_LEVEL) + ).upper() + validate_level("log level", log_level) + + # reolve log file level + log_level_transcript = ( + log_level_transcript + if log_level_transcript + else os.getenv("INSPECT_LOG_LEVEL_TRANSCRIPT", DEFAULT_LOG_LEVEL_TRANSCRIPT) + ).upper() + validate_level("log file level", log_level_transcript) + + # convert to integer + levelno = getLevelName(log_level) + transcript_levelno = getLevelName(log_level_transcript) + + # init logging handler on demand + global _logHandler + if not _logHandler: + _logHandler = LogHandler(min(HTTP, levelno), transcript_levelno) + getLogger().addHandler(_logHandler) + + # establish default capture level + capture_level = min(HTTP, levelno) + + # see all the messages (we won't actually display/write all of them) + getLogger().setLevel(capture_level) + getLogger(PKG_NAME).setLevel(capture_level) + getLogger("httpx").setLevel(capture_level) + + # set the levelno on the global handler + _logHandler.display_level = levelno + + +_logHandler: LogHandler | None = None def notify_logger_record(record: LogRecord, write: bool) -> None: + from inspect_ai.log._message import LoggingMessage + from inspect_ai.log._transcript import LoggerEvent, transcript + if write: transcript()._event(LoggerEvent(message=LoggingMessage.from_log_record(record))) if record.levelno <= INFO and "429" in record.getMessage(): diff --git a/src/inspect_ai/_util/rich.py b/src/inspect_ai/_util/rich.py new file mode 100644 index 000000000..68565183d --- /dev/null +++ b/src/inspect_ai/_util/rich.py @@ -0,0 +1,24 @@ +from rich.console import RenderableType +from rich.style import Style +from rich.text import Text + + +def lines_display( + text: str, max_lines: int = 100, style: str | Style = "" +) -> list[RenderableType]: + lines = text.splitlines() + if len(lines) > max_lines: + content: list[RenderableType] = [ + Text("\n".join(lines[0:max_lines]), style=style) + ] + content.append(Text()) + content.append( + Text.from_markup( + f"[italic]Output truncated ({len(lines) - max_lines} additional lines)...[/italic]", + style=style, + ) + ) + else: + content = [Text(text, style=style)] + + return content diff --git a/src/inspect_ai/_util/terminal.py b/src/inspect_ai/_util/terminal.py new file mode 100644 index 000000000..dc46430fb --- /dev/null +++ b/src/inspect_ai/_util/terminal.py @@ -0,0 +1,138 @@ +import functools +import re +import select +import sys +from dataclasses import dataclass +from logging import getLogger +from typing import Any + +logger = getLogger(__name__) + + +@dataclass +class RGB: + r: int + g: int + b: int + + +@dataclass +class TerminalBackground: + color: RGB + brightness: float + dark: bool + + +@functools.cache +def detect_terminal_background( + default_color: RGB = RGB(0, 0, 0), +) -> TerminalBackground: + """Query the terminal background color using OSC escape sequence. + + Based on https://dystroy.org/blog/terminal-light/#detect-whether-the-terminal-is-dark-or-light + and https://github.com/Canop/terminal-light/blob/main/src/xterm.rs + + The `default_color` parameter ensures that you always get back a color + even if when on windows or if an error occurs while querying the terminal + (dark terminal is detected in this case). + + Args: + default_color (Rgb): Default color in the case that we + are unable to successfully query for colors. + + Returns: + TerminalBackground: Terminal background color, brightness, and type. + """ + + def background_from_color(color: RGB) -> TerminalBackground: + # compute brightness + brightness = (color.r * 299 + color.g * 587 + color.b * 114) / 1000 + + # return background + return TerminalBackground( + color=color, brightness=brightness, dark=brightness <= 128 + ) + + # this does not work on windows so in that case we return the default + if sys.platform == "win32": + return background_from_color(default_color) + + try: + # Send OSC 11 query for background color + response = _query("\x1b]11;?\x07", 500) + + # Parse the response + # Expected format: ]11;rgb:RRRR/GGGG/BBBB + match = re.search( + r"]11;rgb:([0-9a-fA-F]{2,4})/([0-9a-fA-F]{2,4})/([0-9a-fA-F]{2,4})", + response, + ) + if not match: + raise RuntimeError(f"Unexpected OSC response format: {response}") + + # Extract RGB values (using first 2 digits of each component) + r = int(match.group(1)[:2], 16) + g = int(match.group(2)[:2], 16) + b = int(match.group(3)[:2], 16) + color = RGB(r, g, b) + + # return background + return background_from_color(color) + + except Exception as e: + logger.debug("Error attempting to query terminal background color: " + str(e)) + return background_from_color(default_color) + + +if sys.platform != "win32": + import termios + import tty + + def _query(query_str: str, timeout_ms: int) -> str: + """Send a query to the terminal and wait for response""" + old_settings = None + + try: + switch_to_raw = not _is_raw_mode_enabled() + if switch_to_raw: + old_settings = _enable_raw_mode() + + # Send the query + sys.stdout.write(query_str) + sys.stdout.flush() + + # Wait for response + readable, _, _ = select.select([sys.stdin], [], [], timeout_ms / 1000.0) + if not readable: + raise RuntimeError("Timeout waiting for terminal query response") + + # Read response + response: str = "" + while True: + char = sys.stdin.read(1) + response += char + if char == "\\" or (len(response) > 1 and response[-2:] == "\x1b\\"): + break + + return response + + finally: + if old_settings is not None: + _disable_raw_mode(old_settings) + + def _is_raw_mode_enabled() -> bool: + """Check if the terminal is in raw mode""" + mode = termios.tcgetattr(sys.stdin.fileno()) + return not bool(mode[3] & termios.ICANON) + + def _enable_raw_mode() -> Any: + """Enable raw mode for the terminal""" + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + tty.setraw(fd) + return old_settings + + def _disable_raw_mode(old_settings: Any) -> None: + """Disable raw mode for the terminal""" + fd = sys.stdin.fileno() + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) diff --git a/src/inspect_ai/_util/transcript.py b/src/inspect_ai/_util/transcript.py new file mode 100644 index 000000000..34ea457be --- /dev/null +++ b/src/inspect_ai/_util/transcript.py @@ -0,0 +1,86 @@ +from rich.align import AlignMethod +from rich.box import ROUNDED, Box +from rich.console import Group, RenderableType +from rich.markdown import Markdown +from rich.panel import Panel +from rich.rule import Rule +from rich.text import Text + + +def transcript_code_theme() -> str: + return "github-dark" + + +def transcript_markdown(content: str) -> Markdown: + code_theme = transcript_code_theme() + return Markdown( + content, + code_theme=code_theme, + inline_code_lexer="python", + inline_code_theme=code_theme, + ) + + +def set_transcript_markdown_options(markdown: Markdown) -> None: + code_theme = transcript_code_theme() + markdown.code_theme = code_theme + markdown.inline_code_lexer = "python" + markdown.inline_code_theme = code_theme + + +def transcript_panel( + title: str, + subtitle: str | None = None, + content: RenderableType | list[RenderableType] = [], + level: int = 1, +) -> Panel: + # resolve content to list + content = content if isinstance(content, list) else [content] + + # no padding if there is no content + padding = (0, 1) if content else (0, 0) + + # handle title/level + if level == 1: + title = f"[bold][blue]{title}[/blue][/bold]" + title_align: AlignMethod = "left" + # box if content, else line + box = ROUNDED if content else LINE + else: + title = f"[bold]{title}[/bold]" + title_align = "center" + if level == 2: + box = LINE + else: + box = DOTTED + + # inject subtitle + if subtitle: + content.insert(0, Text()) + content.insert(0, Text.from_markup(f"[bold]{subtitle}[/bold]")) + + # use xcode theme for markdown code + for c in content: + if isinstance(c, Markdown): + set_transcript_markdown_options(c) + + return Panel( + Group(*content), + title=title, + title_align=title_align, + box=box, + padding=padding, + highlight=True, + expand=True, + ) + + +def transcript_separator(title: str, color: str) -> RenderableType: + return Rule(title=title, style=f"{color} bold", align="center", end="\n\n") + + +LINE = Box(" ── \n" " \n" " \n" " \n" " \n" " \n" " \n" " \n") + +DOTTED = Box(" ·· \n" " \n" " \n" " \n" " \n" " \n" " \n" " \n") + +NOBORDER = Box(" \n" " \n" " \n" " \n" " \n" " \n" " \n" " \n") diff --git a/src/inspect_ai/_view/view.py b/src/inspect_ai/_view/view.py index c3a94be96..22b04aec1 100644 --- a/src/inspect_ai/_view/view.py +++ b/src/inspect_ai/_view/view.py @@ -7,13 +7,13 @@ import psutil from inspect_ai._display import display -from inspect_ai._display.logger import init_logger from inspect_ai._util.constants import ( DEFAULT_SERVER_HOST, DEFAULT_VIEW_PORT, ) from inspect_ai._util.dotenv import init_dotenv from inspect_ai._util.error import exception_message +from inspect_ai._util.logger import init_logger from inspect_ai._view.server import view_server from .notify import view_data_dir diff --git a/src/inspect_ai/_view/www/log-schema.json b/src/inspect_ai/_view/www/log-schema.json index 09d1f42cd..24e486ded 100644 --- a/src/inspect_ai/_view/www/log-schema.json +++ b/src/inspect_ai/_view/www/log-schema.json @@ -8,6 +8,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "approval", "default": "approval", @@ -76,6 +88,7 @@ }, "required": [ "timestamp", + "pending", "event", "message", "call", @@ -533,6 +546,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "error", "default": "error", @@ -548,6 +573,7 @@ }, "required": [ "timestamp", + "pending", "event", "error" ], @@ -1989,6 +2015,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "info", "default": "info", @@ -2004,6 +2042,7 @@ }, "required": [ "timestamp", + "pending", "event", "data" ], @@ -2019,6 +2058,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "input", "default": "input", @@ -2039,6 +2090,7 @@ }, "required": [ "timestamp", + "pending", "event", "input", "input_ansi" @@ -2107,6 +2159,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "logger", "default": "logger", @@ -2122,6 +2186,7 @@ }, "required": [ "timestamp", + "pending", "event", "message" ], @@ -2297,6 +2362,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "model", "default": "model", @@ -2389,6 +2466,7 @@ }, "required": [ "timestamp", + "pending", "event", "model", "input", @@ -2663,6 +2741,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "sample_init", "default": "sample_init", @@ -2681,6 +2771,7 @@ }, "required": [ "timestamp", + "pending", "event", "sample", "state" @@ -2697,6 +2788,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "sample_limit", "default": "sample_limit", @@ -2716,6 +2819,10 @@ "title": "Type", "type": "string" }, + "message": { + "title": "Message", + "type": "string" + }, "limit": { "anyOf": [ { @@ -2727,26 +2834,15 @@ ], "default": null, "title": "Limit" - }, - "message": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "default": null, - "title": "Message" } }, "required": [ "timestamp", + "pending", "event", "type", - "limit", - "message" + "message", + "limit" ], "title": "SampleLimitEvent", "type": "object", @@ -3015,6 +3111,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "score", "default": "score", @@ -3048,6 +3156,7 @@ }, "required": [ "timestamp", + "pending", "event", "score", "target" @@ -3064,6 +3173,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "state", "default": "state", @@ -3083,6 +3204,7 @@ }, "required": [ "timestamp", + "pending", "event", "changes" ], @@ -3098,6 +3220,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "step", "default": "step", @@ -3134,6 +3268,7 @@ }, "required": [ "timestamp", + "pending", "event", "action", "type", @@ -3151,6 +3286,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "store", "default": "store", @@ -3170,6 +3317,7 @@ }, "required": [ "timestamp", + "pending", "event", "changes" ], @@ -3185,6 +3333,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "subtask", "default": "subtask", @@ -3271,6 +3431,7 @@ }, "required": [ "timestamp", + "pending", "event", "name", "type", @@ -3437,6 +3598,18 @@ "title": "Timestamp", "type": "string" }, + "pending": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Pending" + }, "event": { "const": "tool", "default": "tool", @@ -3470,6 +3643,17 @@ "title": "Arguments", "type": "object" }, + "view": { + "anyOf": [ + { + "$ref": "#/$defs/ToolCallContent" + }, + { + "type": "null" + } + ], + "default": null + }, "result": { "anyOf": [ { @@ -3522,17 +3706,6 @@ "default": null, "title": "Truncated" }, - "view": { - "anyOf": [ - { - "$ref": "#/$defs/ToolCallContent" - }, - { - "type": "null" - } - ], - "default": null - }, "error": { "anyOf": [ { @@ -3597,14 +3770,15 @@ }, "required": [ "timestamp", + "pending", "event", "type", "id", "function", "arguments", + "view", "result", "truncated", - "view", "error", "events" ], diff --git a/src/inspect_ai/_view/www/src/types/log.d.ts b/src/inspect_ai/_view/www/src/types/log.d.ts index a91bb9815..12cce32fd 100644 --- a/src/inspect_ai/_view/www/src/types/log.d.ts +++ b/src/inspect_ai/_view/www/src/types/log.d.ts @@ -187,6 +187,7 @@ export type Answer = string | null; export type Explanation = string | null; export type Metadata5 = {} | null; export type Timestamp = string; +export type Pending = boolean | null; export type Event = "sample_init"; export type Input1 = | string @@ -206,20 +207,24 @@ export type Files1 = { export type Setup1 = string | null; export type JsonValue = unknown; export type Timestamp1 = string; +export type Pending1 = boolean | null; export type Event1 = "sample_limit"; export type Type5 = "message" | "time" | "token" | "operator"; +export type Message2 = string; export type Limit1 = number | null; -export type Message2 = string | null; export type Timestamp2 = string; +export type Pending2 = boolean | null; export type Event2 = "state"; export type Op = "remove" | "add" | "replace" | "move" | "test" | "copy"; export type Path = string; export type From = string | null; export type Changes = JsonChange[]; export type Timestamp3 = string; +export type Pending3 = boolean | null; export type Event3 = "store"; export type Changes1 = JsonChange[]; export type Timestamp4 = string; +export type Pending4 = boolean | null; export type Event4 = "model"; export type Model2 = string; export type Input2 = ( @@ -253,16 +258,18 @@ export type ToolChoice = ("auto" | "any" | "none") | ToolFunction; export type Name6 = string; export type Cache = ("read" | "write") | null; export type Timestamp5 = string; +export type Pending5 = boolean | null; export type Event5 = "tool"; export type Type8 = "function"; export type Id3 = string; export type Function2 = string; -export type Result = string | number | boolean | (ContentText | ContentImage)[]; -export type Truncated = [unknown, unknown] | null; export type Title = string | null; export type Format = "text" | "markdown"; export type Content5 = string; +export type Result = string | number | boolean | (ContentText | ContentImage)[]; +export type Truncated = [unknown, unknown] | null; export type Timestamp6 = string; +export type Pending6 = boolean | null; export type Event6 = "approval"; export type Message3 = string; export type Approver = string; @@ -274,15 +281,19 @@ export type Decision = | "terminate"; export type Explanation1 = string | null; export type Timestamp7 = string; +export type Pending7 = boolean | null; export type Event7 = "input"; export type Input3 = string; export type InputAnsi = string; export type Timestamp8 = string; +export type Pending8 = boolean | null; export type Event8 = "score"; export type Target2 = string | string[] | null; export type Timestamp9 = string; +export type Pending9 = boolean | null; export type Event9 = "error"; export type Timestamp10 = string; +export type Pending10 = boolean | null; export type Event10 = "logger"; export type Name7 = string | null; export type Level = @@ -299,13 +310,16 @@ export type Filename = string; export type Module = string; export type Lineno = number; export type Timestamp11 = string; +export type Pending11 = boolean | null; export type Event11 = "info"; export type Timestamp12 = string; +export type Pending12 = boolean | null; export type Event12 = "step"; export type Action = "begin" | "end"; export type Type9 = string | null; export type Name8 = string; export type Timestamp13 = string; +export type Pending13 = boolean | null; export type Event13 = "subtask"; export type Name9 = string; export type Type10 = string | null; @@ -674,6 +688,7 @@ export interface Store {} */ export interface SampleInitEvent { timestamp: Timestamp; + pending: Pending; event: Event; sample: Sample; state: JsonValue; @@ -693,16 +708,18 @@ export interface Sample { */ export interface SampleLimitEvent { timestamp: Timestamp1; + pending: Pending1; event: Event1; type: Type5; - limit: Limit1; message: Message2; + limit: Limit1; } /** * Change to the current `TaskState` */ export interface StateEvent { timestamp: Timestamp2; + pending: Pending2; event: Event2; changes: Changes; } @@ -725,6 +742,7 @@ export interface JsonChange { */ export interface StoreEvent { timestamp: Timestamp3; + pending: Pending3; event: Event3; changes: Changes1; } @@ -733,6 +751,7 @@ export interface StoreEvent { */ export interface ModelEvent { timestamp: Timestamp4; + pending: Pending4; event: Event4; model: Model2; input: Input2; @@ -849,14 +868,15 @@ export interface Response { */ export interface ToolEvent { timestamp: Timestamp5; + pending: Pending5; event: Event5; type: Type8; id: Id3; function: Function2; arguments: Arguments1; + view: ToolCallContent | null; result: Result; truncated: Truncated; - view: ToolCallContent | null; error: ToolCallError | null; events: Events1; } @@ -876,6 +896,7 @@ export interface ToolCallContent { */ export interface ApprovalEvent { timestamp: Timestamp6; + pending: Pending6; event: Event6; message: Message3; call: ToolCall; @@ -900,6 +921,7 @@ export interface ToolCallView { */ export interface InputEvent { timestamp: Timestamp7; + pending: Pending7; event: Event7; input: Input3; input_ansi: InputAnsi; @@ -909,6 +931,7 @@ export interface InputEvent { */ export interface ScoreEvent { timestamp: Timestamp8; + pending: Pending8; event: Event8; score: Score; target: Target2; @@ -918,6 +941,7 @@ export interface ScoreEvent { */ export interface ErrorEvent { timestamp: Timestamp9; + pending: Pending9; event: Event9; error: EvalError; } @@ -926,6 +950,7 @@ export interface ErrorEvent { */ export interface LoggerEvent { timestamp: Timestamp10; + pending: Pending10; event: Event10; message: LoggingMessage; } @@ -943,6 +968,7 @@ export interface LoggingMessage { */ export interface InfoEvent { timestamp: Timestamp11; + pending: Pending11; event: Event11; data: JsonValue; } @@ -951,6 +977,7 @@ export interface InfoEvent { */ export interface StepEvent { timestamp: Timestamp12; + pending: Pending12; event: Event12; action: Action; type: Type9; @@ -961,6 +988,7 @@ export interface StepEvent { */ export interface SubtaskEvent { timestamp: Timestamp13; + pending: Pending13; event: Event13; name: Name9; type: Type10; diff --git a/src/inspect_ai/approval/_human.py b/src/inspect_ai/approval/_human.py index b87ad4cb7..43d7dcb1b 100644 --- a/src/inspect_ai/approval/_human.py +++ b/src/inspect_ai/approval/_human.py @@ -1,14 +1,14 @@ from rich.console import RenderableType from rich.highlighter import ReprHighlighter -from rich.markdown import Markdown from rich.prompt import Prompt from rich.rule import Rule from rich.text import Text +from inspect_ai._util.transcript import transcript_markdown, transcript_panel from inspect_ai.solver._task_state import TaskState from inspect_ai.tool._tool_call import ToolCall, ToolCallContent, ToolCallView from inspect_ai.util._console import input_screen -from inspect_ai.util._trace import TracePanel, trace_enabled +from inspect_ai.util._trace import trace_enabled from ._approval import Approval, ApprovalDecision from ._approver import Approver @@ -45,7 +45,7 @@ def add_view_content(view_content: ToolCallContent) -> None: Text.from_markup(f"[bold]{view_content.title}[/bold]\n") ) if view_content.format == "markdown": - renderables.append(Markdown(view_content.content)) + renderables.append(transcript_markdown(view_content.content)) else: text_content = text_highlighter(Text(view_content.content)) renderables.append(text_content) @@ -69,7 +69,7 @@ def add_view_content(view_content: ToolCallContent) -> None: add_view_content(view.call) renderables.append(Text()) - console.print(TracePanel(title="Approve Tool", content=renderables)) + console.print(transcript_panel(title="Approve Tool", content=renderables)) # provide choices prompts: dict[str, str] = {} diff --git a/src/inspect_ai/approval/_policy.py b/src/inspect_ai/approval/_policy.py index 4d05d8904..8314934ba 100644 --- a/src/inspect_ai/approval/_policy.py +++ b/src/inspect_ai/approval/_policy.py @@ -102,7 +102,7 @@ def collect_unknown_fields(cls, data: Any) -> Any: if not isinstance(data, dict): return data - known_fields = set(cls.__pydantic_fields__.keys()) + known_fields = set(["name", "tools", "params"]) unknown_fields = {k: v for k, v in data.items() if k not in known_fields} if unknown_fields: diff --git a/src/inspect_ai/log/_samples.py b/src/inspect_ai/log/_samples.py new file mode 100644 index 000000000..9042a763d --- /dev/null +++ b/src/inspect_ai/log/_samples.py @@ -0,0 +1,70 @@ +import asyncio +import contextlib +from datetime import datetime +from typing import AsyncGenerator, Literal + +from shortuuid import uuid + +from inspect_ai.dataset._dataset import Sample + +from ._transcript import Transcript + + +class ActiveSample: + def __init__( + self, + task: str, + model: str, + sample: Sample, + epoch: int, + fails_on_error: bool, + transcript: Transcript, + ) -> None: + self.id = uuid() + self.started = datetime.now().timestamp() + self.completed: float | None = None + self.task = task + self.model = model + self.sample = sample + self.epoch = epoch + self.fails_on_error = fails_on_error + self.transcript = transcript + self._sample_task = asyncio.current_task() + self._interrupt_action: Literal["score", "error"] | None = None + + @property + def execution_time(self) -> float: + completed = ( + self.completed if self.completed is not None else datetime.now().timestamp() + ) + return completed - self.started + + def interrupt(self, action: Literal["score", "error"]) -> None: + self._interrupt_action = action + assert self._sample_task + self._sample_task.cancel() + + @property + def interrupt_action(self) -> Literal["score", "error"] | None: + return self._interrupt_action + + +def init_active_samples() -> None: + _active_samples.clear() + + +@contextlib.asynccontextmanager +async def active_sample(sample: ActiveSample) -> AsyncGenerator[ActiveSample, None]: + _active_samples.append(sample) + try: + yield sample + finally: + sample.completed = datetime.now().timestamp() + _active_samples.remove(sample) + + +def active_samples() -> list[ActiveSample]: + return _active_samples + + +_active_samples: list[ActiveSample] = [] diff --git a/src/inspect_ai/log/_transcript.py b/src/inspect_ai/log/_transcript.py index 73db78efc..bebd59cda 100644 --- a/src/inspect_ai/log/_transcript.py +++ b/src/inspect_ai/log/_transcript.py @@ -6,6 +6,7 @@ Any, Iterator, Literal, + Sequence, TypeAlias, Union, ) @@ -41,6 +42,9 @@ class BaseEvent(BaseModel): timestamp: datetime = Field(default_factory=datetime.now) """Time at which event occurred.""" + pending: bool | None = Field(default=None) + """Is this event pending?""" + @field_serializer("timestamp") def serialize_timestamp(self, dt: datetime) -> str: return dt.astimezone().isoformat() @@ -68,12 +72,12 @@ class SampleLimitEvent(BaseEvent): type: Literal["message", "time", "token", "operator"] """Type of limit that halted processing""" - limit: int | None = Field(default=None) - """The limit value""" - - message: str | None = Field(default=None) + message: str """A message associated with this limit""" + limit: int | None = Field(default=None) + """The limit value (if any)""" + class StoreEvent(BaseEvent): """Change to data within the current `Store`.""" @@ -144,21 +148,34 @@ class ToolEvent(BaseEvent): arguments: dict[str, JsonValue] """Arguments to function.""" - result: ToolResult + view: ToolCallContent | None = Field(default=None) + """Custom view of tool call input.""" + + result: ToolResult = Field(default_factory=str) """Function return value.""" truncated: tuple[int, int] | None = Field(default=None) """Bytes truncated (from,to) if truncation occurred""" - view: ToolCallContent | None = Field(default=None) - """Custom view of tool call output.""" - error: ToolCallError | None = Field(default=None) """Error that occurred during tool call.""" - events: list["Event"] + events: list["Event"] = Field(default_factory=list) """Transcript of events for tool.""" + def set_result( + self, + result: ToolResult, + truncated: tuple[int, int] | None, + error: ToolCallError | None, + events: list["Event"], + ) -> None: + self.result = result + self.truncated = truncated + self.error = error + self.events = events + self.pending = None + class ApprovalEvent(BaseEvent): """Tool approval.""" @@ -306,7 +323,7 @@ class Transcript: def __init__(self, name: str = "") -> None: self.name = name - self.events: list[Event] = [] + self._events: list[Event] = [] def info(self, data: JsonValue) -> None: """Add an `InfoEvent` to the transcript. @@ -334,8 +351,12 @@ def step(self, name: str, type: str | None = None) -> Iterator[None]: # end step event self._event(StepEvent(action="end", name=name, type=type)) + @property + def events(self) -> Sequence[Event]: + return self._events + def _event(self, event: Event) -> None: - self.events.append(event) + self._events.append(event) def transcript() -> Transcript: diff --git a/src/inspect_ai/model/_call_tools.py b/src/inspect_ai/model/_call_tools.py index c27cc8ca5..21bead80b 100644 --- a/src/inspect_ai/model/_call_tools.py +++ b/src/inspect_ai/model/_call_tools.py @@ -140,7 +140,7 @@ async def call_tool_task(call: ToolCall) -> tuple[ChatMessageTool, ToolEvent]: truncated=truncated, view=tool_call_view(call, tdefs), error=tool_error, - events=transcript().events, + events=list(transcript().events), ) # return message and event @@ -152,18 +152,36 @@ async def call_tool_task(call: ToolCall) -> tuple[ChatMessageTool, ToolEvent]: ), event # call tools - results: list[tuple[ChatMessageTool, ToolEvent]] = [] + tool_messages: list[ChatMessageTool] = [] for call in message.tool_calls: + # create pending tool event and add it to the transcript + event = ToolEvent( + id=call.id, + function=call.function, + arguments=call.arguments, + view=tool_call_view(call, tdefs), + pending=True, + ) + transcript()._event(event) + + # execute the tool call task = asyncio.create_task(call_tool_task(call)) - results.append(await task) + tool_message, result_event = await task + tool_messages.append(tool_message) - # trace and fire tool events for each result - for tool_message, event in [result for result in results]: + # trace if we are tracing trace_tool_mesage(tool_message) - transcript()._event(event) + + # update the event with the results + event.set_result( + result=result_event.result, + truncated=result_event.truncated, + error=result_event.error, + events=result_event.events, + ) # return tool messages - return [result[0] for result in results] + return tool_messages else: return [] diff --git a/src/inspect_ai/model/_model.py b/src/inspect_ai/model/_model.py index 165a23563..af0cc7227 100644 --- a/src/inspect_ai/model/_model.py +++ b/src/inspect_ai/model/_model.py @@ -328,13 +328,13 @@ async def generate() -> ModelOutput: ) existing = cache_fetch(cache_entry) if isinstance(existing, ModelOutput): - await self._record_model_interaction( + self._record_model_interaction( input=input, tools=tools, tool_choice=tool_choice, config=config, - output=existing, cache="read", + output=existing, call=None, ) return existing @@ -342,6 +342,16 @@ async def generate() -> ModelOutput: # verify that model apis are allowed self.verify_model_apis() + # record the interaction before the call to generate + # (we'll update it with the results once we have them) + complete = self._record_model_interaction( + input=input, + tools=tools, + tool_choice=tool_choice, + config=config, + cache="write" if cache else None, + ) + result = await self.api.generate( input=input, tools=tools, @@ -354,16 +364,8 @@ async def generate() -> ModelOutput: output = result call = None - # write to transcript - await self._record_model_interaction( - input=input, - tools=tools, - tool_choice=tool_choice, - config=config, - output=output, - cache="write" if cache else None, - call=call, - ) + # complete the transcript event + complete(output, call) # record usage if output.usage: @@ -417,32 +419,50 @@ def _connection_concurrency(self, config: GenerateConfig) -> asyncio.Semaphore: key=f"Model{self.api.connection_key()}", ) - async def _record_model_interaction( + def _record_model_interaction( self, input: list[ChatMessage], tools: list[ToolInfo], tool_choice: ToolChoice, config: GenerateConfig, - output: ModelOutput, cache: Literal["read", "write"] | None, - call: ModelCall | None, - ) -> None: + output: ModelOutput | None = None, + call: ModelCall | None = None, + ) -> Callable[[ModelOutput, ModelCall | None], None]: from inspect_ai.log._transcript import ModelEvent, transcript - trace_assistant_message(input, output.choices[0].message) - - transcript()._event( - ModelEvent( - model=str(self), - input=input, - tools=tools, - tool_choice=tool_choice, - config=config, - output=output, - cache=cache, - call=call, - ) + # create event and add it to the transcript + model = str(self) + event = ModelEvent( + model=model, + input=input, + tools=tools, + tool_choice=tool_choice, + config=config, + output=output if output else ModelOutput.from_content(model, ""), + cache=cache, + call=call, + pending=output is None, ) + transcript()._event(event) + + # callable that can be used to update the interaction w/ output + def complete( + updated_output: ModelOutput, updated_call: ModelCall | None + ) -> None: + # trace + trace_assistant_message(input, updated_output.choices[0].message) + + # update event + event.output = updated_output + event.call = updated_call + event.pending = None + + # if we have output then complete it now + if output: + complete(output, call) + + return complete class ModelName: diff --git a/src/inspect_ai/model/_render.py b/src/inspect_ai/model/_render.py new file mode 100644 index 000000000..ec2be21a0 --- /dev/null +++ b/src/inspect_ai/model/_render.py @@ -0,0 +1,24 @@ +from rich.console import RenderableType + +from inspect_ai._util.format import format_function_call +from inspect_ai._util.transcript import transcript_markdown +from inspect_ai.tool._tool_call import ToolCall + +from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool + + +def messages_preceding_assistant(messages: list[ChatMessage]) -> list[ChatMessage]: + preceding: list[ChatMessage] = [] + for m in reversed(messages): + if not isinstance(m, ChatMessageTool | ChatMessageAssistant): + preceding.append(m) + else: + break + return list(reversed(preceding)) + + +def render_tool_calls(tool_calls: list[ToolCall]) -> RenderableType: + formatted_calls: list[str] = [] + for call in tool_calls: + formatted_calls.append(format_function_call(call.function, call.arguments)) + return transcript_markdown("```python\n" + "\n\n".join(formatted_calls) + "\n```\n") diff --git a/src/inspect_ai/model/_trace.py b/src/inspect_ai/model/_trace.py index f1ccc8c98..319545158 100644 --- a/src/inspect_ai/model/_trace.py +++ b/src/inspect_ai/model/_trace.py @@ -1,11 +1,12 @@ from rich.console import RenderableType -from rich.markdown import Markdown from rich.text import Text -from inspect_ai._util.format import format_function_call +from inspect_ai._util.rich import lines_display +from inspect_ai._util.transcript import transcript_markdown from inspect_ai.util._trace import trace_enabled, trace_panel from ._chat_message import ChatMessage, ChatMessageAssistant, ChatMessageTool +from ._render import messages_preceding_assistant, render_tool_calls MESSAGE_TITLE = "Message" @@ -13,19 +14,8 @@ def trace_tool_mesage(message: ChatMessageTool) -> None: if trace_enabled(): # truncate output to 100 lines - MAX_LINES = 100 output = message.error.message if message.error else message.text.strip() - lines = output.splitlines() - if len(lines) > MAX_LINES: - content: list[RenderableType] = ["\n".join(lines[0:MAX_LINES])] - content.append(Text()) - content.append( - Text.from_markup( - f"[italic]Output truncated ({len(lines) - MAX_LINES} additional lines)...[/italic]" - ) - ) - else: - content = [output] + content = lines_display(output, 100) trace_panel( title=f"Tool Output: {message.function}", @@ -38,31 +28,21 @@ def trace_assistant_message( ) -> None: if trace_enabled(): # print precding messages that aren't tool or assistant - preceding: list[ChatMessage] = [] - for m in reversed(input): - if not isinstance(m, ChatMessageTool | ChatMessageAssistant): - preceding.append(m) - else: - break - for m in reversed(preceding): + for m in messages_preceding_assistant(input): trace_panel( title=m.role.capitalize(), - content=m.text, + content=transcript_markdown(m.text), ) # start with assistant content - content: list[RenderableType] = [message.text] if message.text else [] + content: list[RenderableType] = ( + [transcript_markdown(message.text)] if message.text else [] + ) # print tool calls if message.tool_calls: - if content: - content.append(Text()) - tool_calls: list[str] = [] - for call in message.tool_calls: - tool_calls.append(format_function_call(call.function, call.arguments)) - content.append( - Markdown("```python\n" + "\n\n".join(tool_calls) + "\n```\n"), - ) + content.append(Text()) + content.append(render_tool_calls(message.tool_calls)) # print the assistant message trace_panel(title="Assistant", content=content) diff --git a/src/inspect_ai/solver/_task_state.py b/src/inspect_ai/solver/_task_state.py index 30dd40356..c64650585 100644 --- a/src/inspect_ai/solver/_task_state.py +++ b/src/inspect_ai/solver/_task_state.py @@ -297,8 +297,8 @@ def completed(self) -> bool: transcript()._event( SampleLimitEvent( type="message", - limit=self.message_limit, message=f"Sample completed: exceeded message limit ({self.message_limit})", + limit=self.message_limit, ) ) return True @@ -309,8 +309,8 @@ def completed(self) -> bool: transcript()._event( SampleLimitEvent( type="token", - limit=self.token_limit, message=f"Sample completed: exceeded token limit ({self.token_limit:,})", + limit=self.token_limit, ) ) return True diff --git a/src/inspect_ai/util/_concurrency.py b/src/inspect_ai/util/_concurrency.py index ddfbd6c2e..6d3eb5915 100644 --- a/src/inspect_ai/util/_concurrency.py +++ b/src/inspect_ai/util/_concurrency.py @@ -1,5 +1,4 @@ import asyncio -from contextvars import ContextVar from dataclasses import dataclass @@ -41,12 +40,12 @@ def concurrency( key = key if key else name # do we have an existing semaphore? if not create one and store it - semaphore = _concurrency_semaphores.get().get(key, None) + semaphore = _concurrency_semaphores.get(key, None) if semaphore is None: semaphore = ConcurencySempahore( name, concurrency, asyncio.Semaphore(concurrency) ) - _concurrency_semaphores.get()[key] = semaphore + _concurrency_semaphores[key] = semaphore # return the semaphore return semaphore.semaphore @@ -54,13 +53,13 @@ def concurrency( def concurrency_status() -> dict[str, tuple[int, int]]: status: dict[str, tuple[int, int]] = {} - for c in _concurrency_semaphores.get().values(): + for c in _concurrency_semaphores.values(): status[c.name] = (c.concurrency - c.semaphore._value, c.concurrency) return status def init_concurrency() -> None: - _concurrency_semaphores.set({}) + _concurrency_semaphores.clear() @dataclass @@ -70,6 +69,4 @@ class ConcurencySempahore: semaphore: asyncio.Semaphore -_concurrency_semaphores: ContextVar[dict[str, ConcurencySempahore]] = ContextVar( - "concurrency_semaphores", default={} -) +_concurrency_semaphores: dict[str, ConcurencySempahore] = {} diff --git a/src/inspect_ai/util/_console.py b/src/inspect_ai/util/_console.py index 4bd3c7238..7b80ae7d6 100644 --- a/src/inspect_ai/util/_console.py +++ b/src/inspect_ai/util/_console.py @@ -33,7 +33,7 @@ def input_screen( Returns: Console to use for input. """ - from inspect_ai._display._display import task_screen + from inspect_ai._display.core.active import task_screen with task_screen().input_screen( header=header, transient=transient, width=width diff --git a/src/inspect_ai/util/_sandbox/docker/compose.py b/src/inspect_ai/util/_sandbox/docker/compose.py index 66b5be35f..0ea7f99dd 100644 --- a/src/inspect_ai/util/_sandbox/docker/compose.py +++ b/src/inspect_ai/util/_sandbox/docker/compose.py @@ -8,7 +8,7 @@ import yaml from pydantic import BaseModel -from inspect_ai._util.ansi import no_ansi +from inspect_ai._util.display import display_type from inspect_ai._util.error import PrerequisiteError from inspect_ai.util._subprocess import ExecResult, subprocess @@ -250,11 +250,15 @@ async def compose_command( env = project.env if (project.env and forward_env) else {} # ansi (apply global override) - if no_ansi(): + if display_type() == "plain": ansi = "never" if ansi: compose_command = compose_command + ["--ansi", ansi] + # quiet if display is none + if display_type() == "none": + compose_command = compose_command + ["--progress", "quiet"] + # add project scope compose_command = compose_command + ["--project-name", project.name] diff --git a/src/inspect_ai/util/_sandbox/docker/internal.py b/src/inspect_ai/util/_sandbox/docker/internal.py index 816d12c42..3d5237eb4 100644 --- a/src/inspect_ai/util/_sandbox/docker/internal.py +++ b/src/inspect_ai/util/_sandbox/docker/internal.py @@ -1,5 +1,5 @@ -from inspect_ai._util.ansi import no_ansi from inspect_ai._util.constants import PKG_PATH +from inspect_ai._util.display import display_type from inspect_ai._util.error import PrerequisiteError from inspect_ai.util._subprocess import subprocess @@ -24,16 +24,18 @@ async def is_internal_image_built(image: str) -> bool: async def build_internal_image(image: str) -> None: + args = [ + "docker", + "build", + "--tag", + image, + "--progress", + "plain" if display_type() == "plain" else "auto", + ] + if display_type() == "none": + args.append("--quiet") result = await subprocess( - [ - "docker", - "build", - "--tag", - image, - "--progress", - "plain" if no_ansi() else "auto", - INTERNAL_IMAGES[image].as_posix(), - ], + args + [INTERNAL_IMAGES[image].as_posix()], capture_output=False, ) if not result.success: diff --git a/src/inspect_ai/util/_subtask.py b/src/inspect_ai/util/_subtask.py index 4a241f813..68024ad36 100644 --- a/src/inspect_ai/util/_subtask.py +++ b/src/inspect_ai/util/_subtask.py @@ -125,7 +125,7 @@ async def run() -> tuple[RT, SubtaskEvent]: name=subtask_name, input=log_input, result=result, - events=transcript().events, + events=list(transcript().events), type=type, ) @@ -154,11 +154,13 @@ def wrapper(func: Subtask) -> Subtask: return create_subtask_wrapper(name) -def init_subtask(name: str, store: Store) -> None: +def init_subtask(name: str, store: Store) -> Any: from inspect_ai.log._transcript import ( Transcript, init_transcript, ) init_subtask_store(store) - init_transcript(Transcript(name=name)) + transcript = Transcript(name=name) + init_transcript(transcript) + return transcript diff --git a/src/inspect_ai/util/_trace.py b/src/inspect_ai/util/_trace.py index e03058f4b..aa0a80cf5 100644 --- a/src/inspect_ai/util/_trace.py +++ b/src/inspect_ai/util/_trace.py @@ -1,11 +1,11 @@ from contextvars import ContextVar from rich import print -from rich.console import Group, RenderableType -from rich.markdown import Markdown -from rich.panel import Panel +from rich.console import RenderableType from rich.text import Text +from inspect_ai._util.transcript import transcript_panel + def trace_enabled() -> bool: """Is trace mode currently enabled.""" @@ -29,39 +29,11 @@ def trace_panel( content (RenderableType | list[RenderableType]): One or more Rich renderables. """ print( - TracePanel(title, subtitle, content), + transcript_panel(title, subtitle, content), Text(), ) -class TracePanel(Panel): - def __init__( - self, - title: str, - subtitle: str | None = None, - content: RenderableType | list[RenderableType] = [], - ) -> None: - # resolve content to list - content = content if isinstance(content, list) else [content] - - # inject subtitle - if subtitle: - content.insert(0, Text()) - content.insert(0, Text.from_markup(f"[bold]{subtitle}[/bold]")) - - # use vs theme for markdown code - for c in content: - if isinstance(c, Markdown): - c.code_theme = "xcode" - - super().__init__( - Group(*content), - title=f"[bold][blue]{title}[/blue][/bold]", - highlight=True, - expand=True, - ) - - def init_trace(trace: bool | None) -> None: _trace.set(trace) diff --git a/tests/util/test_images.py b/tests/util/test_images.py index 357e929d7..15cc6907c 100644 --- a/tests/util/test_images.py +++ b/tests/util/test_images.py @@ -40,7 +40,7 @@ def test_google_images(): @skip_if_no_openai def test_openai_images(): - check_images("openai/gpt-4") + check_images("openai/gpt-4o") @skip_if_no_vertex diff --git a/tools/vscode/src/@types/log.d.ts b/tools/vscode/src/@types/log.d.ts index a91bb9815..12cce32fd 100644 --- a/tools/vscode/src/@types/log.d.ts +++ b/tools/vscode/src/@types/log.d.ts @@ -187,6 +187,7 @@ export type Answer = string | null; export type Explanation = string | null; export type Metadata5 = {} | null; export type Timestamp = string; +export type Pending = boolean | null; export type Event = "sample_init"; export type Input1 = | string @@ -206,20 +207,24 @@ export type Files1 = { export type Setup1 = string | null; export type JsonValue = unknown; export type Timestamp1 = string; +export type Pending1 = boolean | null; export type Event1 = "sample_limit"; export type Type5 = "message" | "time" | "token" | "operator"; +export type Message2 = string; export type Limit1 = number | null; -export type Message2 = string | null; export type Timestamp2 = string; +export type Pending2 = boolean | null; export type Event2 = "state"; export type Op = "remove" | "add" | "replace" | "move" | "test" | "copy"; export type Path = string; export type From = string | null; export type Changes = JsonChange[]; export type Timestamp3 = string; +export type Pending3 = boolean | null; export type Event3 = "store"; export type Changes1 = JsonChange[]; export type Timestamp4 = string; +export type Pending4 = boolean | null; export type Event4 = "model"; export type Model2 = string; export type Input2 = ( @@ -253,16 +258,18 @@ export type ToolChoice = ("auto" | "any" | "none") | ToolFunction; export type Name6 = string; export type Cache = ("read" | "write") | null; export type Timestamp5 = string; +export type Pending5 = boolean | null; export type Event5 = "tool"; export type Type8 = "function"; export type Id3 = string; export type Function2 = string; -export type Result = string | number | boolean | (ContentText | ContentImage)[]; -export type Truncated = [unknown, unknown] | null; export type Title = string | null; export type Format = "text" | "markdown"; export type Content5 = string; +export type Result = string | number | boolean | (ContentText | ContentImage)[]; +export type Truncated = [unknown, unknown] | null; export type Timestamp6 = string; +export type Pending6 = boolean | null; export type Event6 = "approval"; export type Message3 = string; export type Approver = string; @@ -274,15 +281,19 @@ export type Decision = | "terminate"; export type Explanation1 = string | null; export type Timestamp7 = string; +export type Pending7 = boolean | null; export type Event7 = "input"; export type Input3 = string; export type InputAnsi = string; export type Timestamp8 = string; +export type Pending8 = boolean | null; export type Event8 = "score"; export type Target2 = string | string[] | null; export type Timestamp9 = string; +export type Pending9 = boolean | null; export type Event9 = "error"; export type Timestamp10 = string; +export type Pending10 = boolean | null; export type Event10 = "logger"; export type Name7 = string | null; export type Level = @@ -299,13 +310,16 @@ export type Filename = string; export type Module = string; export type Lineno = number; export type Timestamp11 = string; +export type Pending11 = boolean | null; export type Event11 = "info"; export type Timestamp12 = string; +export type Pending12 = boolean | null; export type Event12 = "step"; export type Action = "begin" | "end"; export type Type9 = string | null; export type Name8 = string; export type Timestamp13 = string; +export type Pending13 = boolean | null; export type Event13 = "subtask"; export type Name9 = string; export type Type10 = string | null; @@ -674,6 +688,7 @@ export interface Store {} */ export interface SampleInitEvent { timestamp: Timestamp; + pending: Pending; event: Event; sample: Sample; state: JsonValue; @@ -693,16 +708,18 @@ export interface Sample { */ export interface SampleLimitEvent { timestamp: Timestamp1; + pending: Pending1; event: Event1; type: Type5; - limit: Limit1; message: Message2; + limit: Limit1; } /** * Change to the current `TaskState` */ export interface StateEvent { timestamp: Timestamp2; + pending: Pending2; event: Event2; changes: Changes; } @@ -725,6 +742,7 @@ export interface JsonChange { */ export interface StoreEvent { timestamp: Timestamp3; + pending: Pending3; event: Event3; changes: Changes1; } @@ -733,6 +751,7 @@ export interface StoreEvent { */ export interface ModelEvent { timestamp: Timestamp4; + pending: Pending4; event: Event4; model: Model2; input: Input2; @@ -849,14 +868,15 @@ export interface Response { */ export interface ToolEvent { timestamp: Timestamp5; + pending: Pending5; event: Event5; type: Type8; id: Id3; function: Function2; arguments: Arguments1; + view: ToolCallContent | null; result: Result; truncated: Truncated; - view: ToolCallContent | null; error: ToolCallError | null; events: Events1; } @@ -876,6 +896,7 @@ export interface ToolCallContent { */ export interface ApprovalEvent { timestamp: Timestamp6; + pending: Pending6; event: Event6; message: Message3; call: ToolCall; @@ -900,6 +921,7 @@ export interface ToolCallView { */ export interface InputEvent { timestamp: Timestamp7; + pending: Pending7; event: Event7; input: Input3; input_ansi: InputAnsi; @@ -909,6 +931,7 @@ export interface InputEvent { */ export interface ScoreEvent { timestamp: Timestamp8; + pending: Pending8; event: Event8; score: Score; target: Target2; @@ -918,6 +941,7 @@ export interface ScoreEvent { */ export interface ErrorEvent { timestamp: Timestamp9; + pending: Pending9; event: Event9; error: EvalError; } @@ -926,6 +950,7 @@ export interface ErrorEvent { */ export interface LoggerEvent { timestamp: Timestamp10; + pending: Pending10; event: Event10; message: LoggingMessage; } @@ -943,6 +968,7 @@ export interface LoggingMessage { */ export interface InfoEvent { timestamp: Timestamp11; + pending: Pending11; event: Event11; data: JsonValue; } @@ -951,6 +977,7 @@ export interface InfoEvent { */ export interface StepEvent { timestamp: Timestamp12; + pending: Pending12; event: Event12; action: Action; type: Type9; @@ -961,6 +988,7 @@ export interface StepEvent { */ export interface SubtaskEvent { timestamp: Timestamp13; + pending: Pending13; event: Event13; name: Name9; type: Type10;