Skip to content

Commit

Permalink
resolve tmux issues in fullscreen display (#876)
Browse files Browse the repository at this point in the history
* initial work on tmux fixes for textual/realtime

* explicitly enumerate known_fields for approval policy

work w/ more versions of pydantic

* improved theming across terminal environments

* ruff check

* fix openai image test

---------

Co-authored-by: aisi-inspect <[email protected]>
  • Loading branch information
jjallaire and aisi-inspect authored Nov 21, 2024
1 parent b62005b commit 0f342f7
Show file tree
Hide file tree
Showing 62 changed files with 3,738 additions and 1,229 deletions.
2 changes: 1 addition & 1 deletion docs/tutorial.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ computer security and provide a short response in a few words.

### Eval {.unlisted}

Discerning whether the correct security guidance was provided by the model might provide difficult using only text matching algorithms. Here we use a model to read the response and assess the quality of the answer.
Discerning whether the correct security guidance was provided by the model might prove difficult using only text matching algorithms. Here we use a model to read the response and assess the quality of the answer.

```{python}
@task
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ dev = [
"pytest-dotenv",
"pytest-xdist",
"ruff==0.7.4", # match version specified in .pre-commit-config.yaml
"textual-dev>=0.86.2",
"types-PyYAML",
"types-aiofiles",
"types-beautifulsoup4",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ s3fs>=2023
semver>=3.0.0
shortuuid
tenacity
textual>=0.86.2
typing_extensions>=4.9.0
zipp>=3.19.1 # not directly required, pinned by Snyk to avoid a vulnerability
2 changes: 1 addition & 1 deletion src/inspect_ai/_cli/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from rich import print
from rich.table import Table

from inspect_ai._display.logger import init_logger
from inspect_ai._util.logger import init_logger
from inspect_ai.model import (
ModelName,
cache_clear,
Expand Down
21 changes: 17 additions & 4 deletions src/inspect_ai/_cli/common.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import functools
import os
from typing import Any, Callable, cast
from typing import Any, Callable, Literal, cast

import click
from typing_extensions import TypedDict

from inspect_ai._util.constants import (
ALL_LOG_LEVELS,
DEFAULT_DISPLAY,
DEFAULT_LOG_LEVEL,
DEFAULT_LOG_LEVEL_TRANSCRIPT,
)
from inspect_ai._util.display import init_display_type


class CommonOptions(TypedDict):
log_level: str
log_level_transcript: str
log_dir: str
display: Literal["full", "rich", "plain", "none"]
no_ansi: bool | None
debug: bool
debug_port: int
Expand Down Expand Up @@ -60,10 +62,18 @@ def common_options(func: Callable[..., Any]) -> Callable[..., click.Context]:
envvar="INSPECT_LOG_DIR",
help="Directory for log files.",
)
@click.option(
"--display",
type=click.Choice(["full", "rich", "plain", "none"], case_sensitive=False),
default=DEFAULT_DISPLAY,
envvar="INSPECT_DISPLAY",
help="Set the display type (defaults to 'full')",
)
@click.option(
"--no-ansi",
type=bool,
is_flag=True,
hidden=True,
help="Do not print ANSI control characters.",
envvar="INSPECT_NO_ANSI",
)
Expand Down Expand Up @@ -91,9 +101,12 @@ def wrapper(*args: Any, **kwargs: Any) -> click.Context:


def process_common_options(options: CommonOptions) -> None:
# disable ansi if requested
# propagate display
if options["no_ansi"]:
os.environ["INSPECT_NO_ANSI"] = "1"
display = "plain"
else:
display = options["display"].lower().strip()
init_display_type(display)

# attach debugger if requested
if options["debug"]:
Expand Down
2 changes: 1 addition & 1 deletion src/inspect_ai/_cli/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def score(
log_level_transcript: str | None,
) -> None:
# init eval context
init_eval_context(None, log_level, log_level_transcript)
init_eval_context(log_level, log_level_transcript)

# read the eval log
recorder = create_recorder_for_location(log_file, log_dir)
Expand Down
29 changes: 24 additions & 5 deletions src/inspect_ai/_display/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
from ._display import Display
from .rich import rich_display
from .core.active import display
from .core.display import (
Display,
Progress,
TaskCancelled,
TaskError,
TaskProfile,
TaskResult,
TaskScreen,
TaskSuccess,
TaskWithResult,
)


def display() -> Display:
return rich_display()
__all__ = [
"display",
"Display",
"Progress",
"TaskCancelled",
"TaskError",
"TaskProfile",
"TaskResult",
"TaskScreen",
"TaskWithResult",
"TaskSuccess",
]
52 changes: 52 additions & 0 deletions src/inspect_ai/_display/core/active.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import sys
from contextvars import ContextVar

import rich

from inspect_ai._util.display import display_type
from inspect_ai.util._trace import trace_enabled

from ..rich.display import RichDisplay
from ..textual.display import TextualDisplay
from .display import Display, TaskScreen


def display() -> Display:
global _active_display
if _active_display is None:
if (
display_type() == "full"
and sys.stdout.isatty()
and not trace_enabled()
and not rich.get_console().is_jupyter
):
_active_display = TextualDisplay()
else:
_active_display = RichDisplay()

return _active_display


_active_display: Display | None = None


def task_screen() -> TaskScreen:
screen = _active_task_screen.get(None)
if screen is None:
raise RuntimeError(
"console input function called outside of running evaluation."
)
return screen


def init_task_screen(screen: TaskScreen) -> None:
_active_task_screen.set(screen)


def clear_task_screen() -> None:
_active_task_screen.set(None)


_active_task_screen: ContextVar[TaskScreen | None] = ContextVar(
"task_screen", default=None
)
43 changes: 43 additions & 0 deletions src/inspect_ai/_display/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from inspect_ai._util.registry import is_registry_dict

from .display import TaskProfile


def task_config(
profile: TaskProfile, generate_config: bool = True, style: str = ""
) -> str:
# merge config
# wind params back for display
task_args = dict(profile.task_args)
for key in task_args.keys():
value = task_args[key]
if is_registry_dict(value):
task_args[key] = value["name"]
config = task_args | dict(profile.eval_config.model_dump(exclude_none=True))
if generate_config:
config = config | dict(profile.generate_config.model_dump(exclude_none=True))
if profile.tags:
config["tags"] = ",".join(profile.tags)
config_print: list[str] = []
for name, value in config.items():
if name == "approval":
config_print.append(
f"{name}: {','.join([approver['name'] for approver in value['approvers']])}"
)
elif name not in ["limit", "model"]:
config_print.append(f"{name}: {value}")
values = ", ".join(config_print)
if values:
if style:
return f"[{style}]{values}[/{style}]"
else:
return values
else:
return ""


def task_dict(d: dict[str, str], bold_value: bool = False) -> str:
slot1, slot2 = ("", "[/bold]") if bold_value else ("[/bold]", "")
return " ".join(
[f"[bold]{key}:{slot1} {value}{slot2}" for key, value in d.items()]
)
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
import abc
import contextlib
from contextvars import ContextVar
from dataclasses import dataclass
from types import TracebackType
from typing import Any, Iterator, Type, Union

from typing import (
Any,
AsyncIterator,
Coroutine,
Iterator,
Protocol,
Type,
TypeVar,
Union,
runtime_checkable,
)

import rich
from rich.console import Console

from inspect_ai.log import EvalConfig, EvalResults, EvalStats
from inspect_ai.model import GenerateConfig, ModelName


class Progress(abc.ABC):
@abc.abstractmethod
@runtime_checkable
class Progress(Protocol):
def update(self, n: int = 1) -> None: ...

@abc.abstractmethod
def complete(self) -> None: ...


@dataclass
class TaskSpec:
name: str
model: ModelName


@dataclass
class TaskProfile:
name: str
Expand Down Expand Up @@ -59,58 +73,54 @@ class TaskSuccess:
TaskResult = Union[TaskError, TaskCancelled, TaskSuccess]


@dataclass
class TaskWithResult:
profile: TaskProfile
result: TaskResult | None


TR = TypeVar("TR")


class TaskScreen(contextlib.AbstractContextManager["TaskScreen"]):
@abc.abstractmethod
def __exit__(self, *excinfo: Any) -> None:
pass

@contextlib.contextmanager
def input_screen(
self,
header: str | None = None,
transient: bool | None = None,
width: int | None = None,
) -> Iterator[Console]: ...
) -> Iterator[Console]:
yield rich.get_console()


class TaskDisplay(abc.ABC):
@abc.abstractmethod
@runtime_checkable
class TaskDisplay(Protocol):
@contextlib.contextmanager
def progress(self) -> Iterator[Progress]: ...

@abc.abstractmethod
def complete(self, result: TaskResult) -> None: ...


class Display(abc.ABC):
@abc.abstractmethod
@runtime_checkable
class Display(Protocol):
def print(self, message: str) -> None: ...

@abc.abstractmethod
@contextlib.contextmanager
def progress(self, total: int) -> Iterator[Progress]: ...

@abc.abstractmethod
@contextlib.contextmanager
def task_screen(self, total_tasks: int, parallel: bool) -> Iterator[TaskScreen]: ...
def run_task_app(self, main: Coroutine[Any, Any, TR]) -> TR: ...

@abc.abstractmethod
@contextlib.contextmanager
def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]: ...


def task_screen() -> TaskScreen:
screen = _task_screen.get(None)
if screen is None:
raise RuntimeError(
"console input function called outside of running evaluation."
)
return screen

def suspend_task_app(self) -> Iterator[None]: ...

def init_task_screen(screen: TaskScreen) -> None:
_task_screen.set(screen)
@contextlib.asynccontextmanager
async def task_screen(
self, tasks: list[TaskSpec], parallel: bool
) -> AsyncIterator[TaskScreen]:
yield TaskScreen()


def clear_task_screen() -> None:
_task_screen.set(None)


_task_screen: ContextVar[TaskScreen | None] = ContextVar("task_screen", default=None)
@contextlib.contextmanager
def task(self, profile: TaskProfile) -> Iterator[TaskDisplay]: ...
27 changes: 27 additions & 0 deletions src/inspect_ai/_display/core/footer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from rich.console import RenderableType
from rich.text import Text

from inspect_ai._util.logger import http_rate_limit_count
from inspect_ai._util.throttle import throttle
from inspect_ai.util._concurrency import concurrency_status

from .config import task_dict


@throttle(1)
def task_footer(style: str = "") -> tuple[RenderableType, RenderableType]:
return (
Text.from_markup(task_resources(), style=style),
Text.from_markup(task_http_rate_limits(), style=style),
)


def task_resources() -> str:
resources: dict[str, str] = {}
for model, resource in concurrency_status().items():
resources[model] = f"{resource[0]}/{resource[1]}"
return task_dict(resources)


def task_http_rate_limits() -> str:
return f"HTTP rate limits: {http_rate_limit_count():,}"
Loading

0 comments on commit 0f342f7

Please sign in to comment.