From b7dac034d1461101c314482b7dceeeed9b5293cc Mon Sep 17 00:00:00 2001 From: aisi-inspect <166920645+aisi-inspect@users.noreply.github.com> Date: Wed, 1 May 2024 17:45:16 +0000 Subject: [PATCH] sync 01-05-24 --- CHANGELOG.md | 7 ++- docs/_quarto.yml | 4 +- docs/theme.scss | 9 +++- src/inspect_ai/_eval/list.py | 54 +++++++++++-------- src/inspect_ai/_util/notebook.py | 16 ++++-- src/inspect_ai/_util/registry.py | 2 +- src/inspect_ai/_view/www/App.mjs | 38 ++++++++----- src/inspect_ai/_view/www/index.html | 2 +- src/inspect_ai/_view/www/src/Constants.mjs | 2 +- .../_view/www/src/components/CopyButton.mjs | 22 ++++++++ .../_view/www/src/title/TitleBlock.mjs | 16 ++---- src/inspect_ai/_view/www/src/utils/Format.mjs | 5 ++ .../_view/www/src/workspace/SampleFilter.mjs | 22 ++++++++ .../www/src/workspace/SamplesDescriptor.mjs | 14 +++++ .../_view/www/src/workspace/WorkSpace.mjs | 10 ++-- src/inspect_ai/scorer/_metric.py | 14 ++--- src/inspect_ai/scorer/_scorer.py | 11 ++-- src/inspect_ai/solver/_solver.py | 17 +++--- tests/test_metric.py | 8 +-- 19 files changed, 178 insertions(+), 95 deletions(-) create mode 100644 src/inspect_ai/_view/www/src/components/CopyButton.mjs diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c2b9487f..3f757e858 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,14 @@ # Changelog -## v0.3.4 (Unreleased) +## v0.3.4 (01 May 2024) - `write_eval_log()` now ignores unserializable objects in metadata fields. - `read_eval_log()` now takes a `str` or `FileInfo` (for compatibility w/ list returned from `list_eval_logs()`). +- Registry name looks are now case sensitive (fixes issue w/ loading tasks w/ mixed case names). +- Resiliancy to Python syntax errors that occur when enumerating tasks in a directory. - Do not throw error if unable to parse or load `.ipynb` file due to lack of dependencies (e.g. `nbformat`). -- Several small improvements to markdown rendering in log viewer (don't render intraword underscores, escape html tags). +- Various additions to log viewer display (log file name, dataset/scorer in listing, filter by complex score types). +- Improvements to markdown rendering in log viewer (don't render intraword underscores, escape html tags). ## v0.3.3 (28 April 2024) diff --git a/docs/_quarto.yml b/docs/_quarto.yml index d3f4d1e66..981f22fa7 100644 --- a/docs/_quarto.yml +++ b/docs/_quarto.yml @@ -12,9 +12,9 @@ book: repo-actions: [issue] downloads: [pdf, epub, docx] twitter-card: - description: "A framework for large language model evaluations" + description: "Open-source framework for large language model evaluations" open-graph: - description: "A framework for large language model evaluations" + description: "Open-source framework for large language model evaluations" sidebar: header: > [![](images/aisi-logo.png)](https://www.gov.uk/government/organisations/ai-safety-institute) diff --git a/docs/theme.scss b/docs/theme.scss index 5ab87312d..e36b65b63 100644 --- a/docs/theme.scss +++ b/docs/theme.scss @@ -38,4 +38,11 @@ .splash ul { padding-inline-start: 1rem; -} \ No newline at end of file +} + +@media(max-width: 991.98px) { + .sidebar-header-item .img-fluid { + max-width: 195px; + } +} + diff --git a/src/inspect_ai/_eval/list.py b/src/inspect_ai/_eval/list.py index ae1b1a950..aab23166f 100644 --- a/src/inspect_ai/_eval/list.py +++ b/src/inspect_ai/_eval/list.py @@ -248,17 +248,21 @@ def exec_filter(cells: list[str]) -> bool: def code_has_task(code: str) -> bool: - tree = ast.parse(code) - for node in ast.iter_child_nodes(tree): - if isinstance(node, ast.FunctionDef): - for decorator in node.decorator_list: - if isinstance(decorator, ast.Name): - if str(decorator.id) == "task": - return True - elif isinstance(decorator, ast.Call): - if isinstance(decorator.func, ast.Name): - if str(decorator.func.id) == "task": + try: + tree = ast.parse(code) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.FunctionDef): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name): + if str(decorator.id) == "task": return True + elif isinstance(decorator, ast.Call): + if isinstance(decorator.func, ast.Name): + if str(decorator.func.id) == "task": + return True + except SyntaxError: + pass + return False @@ -283,20 +287,24 @@ def parse_tasks(path: Path, root_dir: Path, absolute: bool) -> list[TaskInfo]: # parse the top level tasks out of the code tasks: list[TaskInfo] = [] - tree = ast.parse(code) - for node in ast.iter_child_nodes(tree): - if isinstance(node, ast.FunctionDef): - for decorator in node.decorator_list: - result = parse_decorator(node, decorator) - if result: - name, attribs = result - tasks.append( - TaskInfo( - file=task_path(path, root_dir, absolute), - name=name, - attribs=attribs, + try: + tree = ast.parse(code) + for node in ast.iter_child_nodes(tree): + if isinstance(node, ast.FunctionDef): + for decorator in node.decorator_list: + result = parse_decorator(node, decorator) + if result: + name, attribs = result + tasks.append( + TaskInfo( + file=task_path(path, root_dir, absolute), + name=name, + attribs=attribs, + ) ) - ) + except SyntaxError: + pass + return tasks diff --git a/src/inspect_ai/_util/notebook.py b/src/inspect_ai/_util/notebook.py index 9deea38bf..2a9305e85 100644 --- a/src/inspect_ai/_util/notebook.py +++ b/src/inspect_ai/_util/notebook.py @@ -6,7 +6,8 @@ from IPython import get_ipython # type: ignore from IPython.core.interactiveshell import InteractiveShell -from nbformat import read +from nbformat import NBFormatError, ValidationError, read +from nbformat.reader import NotJSONError # from https://jupyter-notebook.readthedocs.io/en/stable/examples/Notebook/Importing%20Notebooks.html @@ -64,9 +65,16 @@ def load_module(self, fullname: str) -> types.ModuleType: def read_notebook_code(path: Path) -> str: - # load the notebook object - with io.open(path, "r", encoding="utf-8") as f: - nb = read(f, 4) # type: ignore + try: + # load the notebook object + with io.open(path, "r", encoding="utf-8") as f: + nb = read(f, 4) # type: ignore + except NotJSONError: + return "" + except ValidationError: + return "" + except NBFormatError: + return "" # for dealing w/ magics shell = InteractiveShell.instance() diff --git a/src/inspect_ai/_util/registry.py b/src/inspect_ai/_util/registry.py index 0ad6aeae7..6d2c3ea36 100644 --- a/src/inspect_ai/_util/registry.py +++ b/src/inspect_ai/_util/registry.py @@ -93,7 +93,7 @@ def registry_name(o: object, name: str) -> str: and if it is, preprends the package name as a namespace """ package = get_package_name(o) - return (f"{package}/{name}" if package else name).lower() + return f"{package}/{name}" if package else name def registry_lookup(type: RegistryType, name: str) -> object | None: diff --git a/src/inspect_ai/_view/www/App.mjs b/src/inspect_ai/_view/www/App.mjs index e7e6a4ca2..3aa285421 100644 --- a/src/inspect_ai/_view/www/App.mjs +++ b/src/inspect_ai/_view/www/App.mjs @@ -10,6 +10,7 @@ import "./src/Register.mjs"; import { icons } from "./src/Constants.mjs"; import { WorkSpace } from "./src/workspace/WorkSpace.mjs"; import { eval_log } from "./api.mjs"; +import { CopyButton } from "./src/components/CopyButton.mjs"; export function App() { const [selected, setSelected] = useState(0); @@ -89,7 +90,7 @@ export function App() { const appEnvelope = fullScreen ? "" : html` - <${Header} logs=${logs} offcanvas=${offcanvas} /> + <${Header} logs=${logs} selected=${selected} offcanvas=${offcanvas} /> <${Sidebar} logs=${logs} logHeaders=${logHeaders} @@ -123,6 +124,14 @@ export function App() { const Header = (props) => { const toggleOffCanClass = props.offcanvas ? "" : " d-md-none"; const gearOffCanClass = props.offcanvas ? "" : " d-md-inline"; + + const logFiles = props.logs.files || []; + const logSelected = props.selected || 0; + const logUri = logFiles.length > logSelected ? logFiles[logSelected].name : ""; + const logName =logUri.split('/').pop(); + + + return html` `; @@ -218,6 +227,9 @@ const Sidebar = (props) => { ...logHeader.eval?.task_args, } : undefined; + const model = logHeader?.eval?.model; + const dataset = logHeader?.eval?.dataset; + const scorer = logHeader?.results?.scorer?.name; return html`
  • {
    @@ -249,11 +261,7 @@ const Sidebar = (props) => { })} - ${logHeader?.eval?.model - ? html`
    - ${logHeader?.eval.model} -
    ` - : ""} + ${model ? html`
    ${model}
    `: ""}
    ${logHeader?.results?.metrics ? html`
    @@ -274,7 +282,7 @@ const Sidebar = (props) => { > ${logHeader?.results.metrics[metric].name}
    -
    +
    ${formatPrettyDecimal( logHeader?.results.metrics[metric].value )} @@ -286,13 +294,17 @@ const Sidebar = (props) => {
    ` : logHeader?.status === "error" ? html`
    Eval Error
    ` : ""}
    - +
    + ${ hyperparameters ? Object.keys((hyperparameters)).map((key) => { return `${key}: ${hyperparameters[key]}` }).join(", ") : "" } +
    + ${dataset || scorer ? html`
    dataset: ${dataset.name || "(samples)"}scorer: ${scorer}
    ` : ""} +
  • `; })} diff --git a/src/inspect_ai/_view/www/index.html b/src/inspect_ai/_view/www/index.html index dbb006294..f8231be56 100644 --- a/src/inspect_ai/_view/www/index.html +++ b/src/inspect_ai/_view/www/index.html @@ -55,7 +55,7 @@ diff --git a/src/inspect_ai/_view/www/src/Constants.mjs b/src/inspect_ai/_view/www/src/Constants.mjs index 2803b6391..8e52fb6ef 100644 --- a/src/inspect_ai/_view/www/src/Constants.mjs +++ b/src/inspect_ai/_view/www/src/Constants.mjs @@ -9,7 +9,7 @@ export const icons = { "close": "bi bi-x", config: "bi bi-gear", confirm: "bi bi-check", - copy: "bi bi-clipboard", + copy: "bi bi-copy", epoch: (epoch) => { return `bi bi-${epoch}-circle`; diff --git a/src/inspect_ai/_view/www/src/components/CopyButton.mjs b/src/inspect_ai/_view/www/src/components/CopyButton.mjs new file mode 100644 index 000000000..fee02c27a --- /dev/null +++ b/src/inspect_ai/_view/www/src/components/CopyButton.mjs @@ -0,0 +1,22 @@ +import { html } from "htm/preact"; +import { icons } from "../Constants.mjs"; + +export const CopyButton = ({ value }) => { + return html``; +}; diff --git a/src/inspect_ai/_view/www/src/title/TitleBlock.mjs b/src/inspect_ai/_view/www/src/title/TitleBlock.mjs index 3ccde92a7..b9e09ef56 100644 --- a/src/inspect_ai/_view/www/src/title/TitleBlock.mjs +++ b/src/inspect_ai/_view/www/src/title/TitleBlock.mjs @@ -2,7 +2,7 @@ import { html } from "htm/preact"; import { icons } from "../Constants.mjs"; import { LabeledValue } from "../components/LabeledValue.mjs"; -import { formatPrettyDecimal } from "../utils/Format.mjs"; +import { formatPrettyDecimal, formatDataset } from "../utils/Format.mjs"; export const TitleBlock = ({ title, @@ -160,20 +160,12 @@ const DatasetSummary = ({ dataset, samples, epochs, style }) => { return ""; } - const sampleCount = epochs > 0 ? samples.length / epochs : samples; - console - return html`
    ${dataset.name}${samples?.length - ? html` - ${dataset.name ? "— " : ""}${sampleCount + " "}${epochs > 1 - ? `x ${epochs} ` - : ""} - ${samples.length === 1 ? "sample" : "samples"}` + ? html` + ${formatDataset(dataset.name, samples.length, epochs)} + ` : ""}
    `; diff --git a/src/inspect_ai/_view/www/src/utils/Format.mjs b/src/inspect_ai/_view/www/src/utils/Format.mjs index ea2d7e4dd..83d3abf13 100644 --- a/src/inspect_ai/_view/www/src/utils/Format.mjs +++ b/src/inspect_ai/_view/www/src/utils/Format.mjs @@ -50,6 +50,11 @@ export const answerForSample = (sample) => { } }; +export const formatDataset = (name, samples, epochs) => { + const perEpochSamples = epochs > 0 ? samples / epochs : samples; + return `${name ? "— " : ""}${perEpochSamples + " "}${epochs > 1 ? `x ${epochs} ` : ""}${samples === 1 ? "sample" : "samples"}`; +} + export const userPromptForSample = (sample) => { if (sample) { if (typeof (sample.input) == "string") { diff --git a/src/inspect_ai/_view/www/src/workspace/SampleFilter.mjs b/src/inspect_ai/_view/www/src/workspace/SampleFilter.mjs index 51c19ed77..bdb5860f3 100644 --- a/src/inspect_ai/_view/www/src/workspace/SampleFilter.mjs +++ b/src/inspect_ai/_view/www/src/workspace/SampleFilter.mjs @@ -4,6 +4,7 @@ import { isNumeric } from "../utils/Type.mjs"; import { kScoreTypeCategorical, kScoreTypeNumeric, + kScoreTypeObject, kScoreTypePassFail, } from "./SamplesDescriptor.mjs"; @@ -21,6 +22,8 @@ export const SampleFilter = ({ descriptor, filter, filterChanged }) => { filterFn: (sample, value) => { if (typeof sample.score.value === "string") { return sample.score.value.toLowerCase() === value?.toLowerCase(); + } else if (typeof sample.score.value === "object") { + return JSON.stringify(sample.score.value) == value; } else { return sample.score.value === value; } @@ -76,6 +79,25 @@ export const SampleFilter = ({ descriptor, filter, filterChanged }) => { `; } + case kScoreTypeObject: { + if (!descriptor.scoreDescriptor.categories) { + return ""; + } + const options = [{ text: "All", value: "all" }]; + options.push( + ...descriptor.scoreDescriptor.categories.map((cat) => { + return { text: cat.text, value: cat.value}; + }) + ); + + + + return html`<${SelectFilter} + options=${options} + filterFn=${filterCategory} + />`; + } + default: { return undefined; } diff --git a/src/inspect_ai/_view/www/src/workspace/SamplesDescriptor.mjs b/src/inspect_ai/_view/www/src/workspace/SamplesDescriptor.mjs index b3ceedc87..a330b18d2 100644 --- a/src/inspect_ai/_view/www/src/workspace/SamplesDescriptor.mjs +++ b/src/inspect_ai/_view/www/src/workspace/SamplesDescriptor.mjs @@ -116,8 +116,22 @@ const scoreCategorizers = [ { describe: (values, types) => { if (types.length !== 0 && types[0] === "object") { + + const buckets = values.map((val) => { return JSON.stringify(val); }); + const vals = new Set(buckets); + let categories = undefined; + if (vals.size < 10) { + categories = Array.from(vals).map((val) => { + return { + val, + text: val + } + }); + } + return { scoreType: kScoreTypeObject, + categories, render: (score) => { if (score === null) { return "[null]"; diff --git a/src/inspect_ai/_view/www/src/workspace/WorkSpace.mjs b/src/inspect_ai/_view/www/src/workspace/WorkSpace.mjs index 54d593487..2d88ff413 100644 --- a/src/inspect_ai/_view/www/src/workspace/WorkSpace.mjs +++ b/src/inspect_ai/_view/www/src/workspace/WorkSpace.mjs @@ -365,10 +365,6 @@ export const WorkSpace = (props) => { [state] ); - const selectTab = (event) => { - const id = event.currentTarget.id; - setSelectedTab(state, id); - }; /** * @@ -443,6 +439,11 @@ export const WorkSpace = (props) => { } }); + const selectTab = (event) => { + const id = event.currentTarget.id; + setSelectedTab(state, id); + }; + return html`<${WorkspaceDisplay} divRef=${divRef} tabs=${tabs} @@ -451,6 +452,7 @@ export const WorkSpace = (props) => { fullScreen=${props.fullScreen} offcanvas=${props.offcanvas} context=${context} + selectTab=${selectTab} afterBodyElements=${afterBodyElements} />`; }; diff --git a/src/inspect_ai/scorer/_metric.py b/src/inspect_ai/scorer/_metric.py index b213c0529..d0ee69f86 100644 --- a/src/inspect_ai/scorer/_metric.py +++ b/src/inspect_ai/scorer/_metric.py @@ -158,8 +158,7 @@ class Metric(Protocol): Metric value """ - def __call__(self, scores: list[Score]) -> int | float: - ... + def __call__(self, scores: list[Score]) -> int | float: ... MetricType = TypeVar("MetricType", Callable[..., Metric], type[Metric]) @@ -182,7 +181,7 @@ def metric_register(metric: MetricType, name: str = "") -> MetricType: Returns: Metric type with registry attributes. """ - metric_name = (name if name else getattr(metric, "__name__")).lower() + metric_name = name if name else getattr(metric, "__name__") registry_add(metric, RegistryInfo(type="metric", name=metric_name)) return metric @@ -204,19 +203,16 @@ def metric_create(name: str, **kwargs: Any) -> Metric: @overload -def metric(name: str) -> Callable[..., MetricType]: - ... +def metric(name: str) -> Callable[..., MetricType]: ... @overload # type: ignore -def metric(name: Callable[..., Metric]) -> Callable[..., Metric]: - ... +def metric(name: Callable[..., Metric]) -> Callable[..., Metric]: ... @overload -def metric(name: type[Metric]) -> type[Metric]: - ... +def metric(name: type[Metric]) -> type[Metric]: ... def metric(name: str | MetricType) -> Callable[..., MetricType] | MetricType: diff --git a/src/inspect_ai/scorer/_scorer.py b/src/inspect_ai/scorer/_scorer.py index c06f848fa..7644a8e61 100644 --- a/src/inspect_ai/scorer/_scorer.py +++ b/src/inspect_ai/scorer/_scorer.py @@ -34,12 +34,10 @@ def __init__(self, target: str | list[str]) -> None: self.target = target if isinstance(target, list) else [target] @overload - def __getitem__(self, index: int) -> str: - ... + def __getitem__(self, index: int) -> str: ... @overload - def __getitem__(self, index: slice) -> Sequence[str]: - ... + def __getitem__(self, index: slice) -> Sequence[str]: ... def __getitem__(self, index: Union[int, slice]) -> Union[str, Sequence[str]]: return self.target[index] @@ -64,8 +62,7 @@ class Scorer(Protocol): target (Target): Ideal target for the output. """ - async def __call__(self, state: TaskState, target: Target) -> Score: - ... + async def __call__(self, state: TaskState, target: Target) -> Score: ... ScorerType = TypeVar("ScorerType", Callable[..., Scorer], type[Scorer]) @@ -89,7 +86,7 @@ def scorer_register(scorer: ScorerType, name: str = "") -> ScorerType: Returns: Scorer with registry attributes. """ - scorer_name = (name if name else getattr(scorer, "__name__")).lower() + scorer_name = name if name else getattr(scorer, "__name__") registry_add(scorer, RegistryInfo(type="scorer", name=scorer_name)) return scorer diff --git a/src/inspect_ai/solver/_solver.py b/src/inspect_ai/solver/_solver.py index bcf6d0665..c59ca61a0 100644 --- a/src/inspect_ai/solver/_solver.py +++ b/src/inspect_ai/solver/_solver.py @@ -133,8 +133,7 @@ class Generate(Protocol): async def __call__( self, state: TaskState, **kwargs: Unpack[GenerateConfigArgs] - ) -> TaskState: - ... + ) -> TaskState: ... @runtime_checkable @@ -160,8 +159,7 @@ async def __call__( self, state: TaskState, generate: Generate, - ) -> TaskState: - ... + ) -> TaskState: ... SolverType = TypeVar("SolverType", Callable[..., Solver], type[Solver]) @@ -184,7 +182,7 @@ def solver_register(solver: SolverType, name: str = "") -> SolverType: Returns: Solver with registry attributes. """ - solver_name = (name if name else getattr(solver, "__name__")).lower() + solver_name = name if name else getattr(solver, "__name__") registry_add(solver, RegistryInfo(type="solver", name=solver_name)) return solver @@ -203,19 +201,16 @@ def solver_create(name: str, **kwargs: Any) -> Solver: @overload -def solver(name: str) -> Callable[..., SolverType]: - ... +def solver(name: str) -> Callable[..., SolverType]: ... @overload # type: ignore -def solver(name: Callable[..., Solver]) -> Callable[..., Solver]: - ... +def solver(name: Callable[..., Solver]) -> Callable[..., Solver]: ... @overload -def solver(name: type[Solver]) -> type[Solver]: - ... +def solver(name: type[Solver]) -> type[Solver]: ... def solver(name: str | SolverType) -> Callable[..., SolverType] | SolverType: diff --git a/tests/test_metric.py b/tests/test_metric.py index de96ab52c..42a4c55f5 100644 --- a/tests/test_metric.py +++ b/tests/test_metric.py @@ -51,21 +51,21 @@ def __call__(self, scores: list[Score]) -> int | float: def test_metric_registry() -> None: registry_assert(accuracy1, "accuracy1") registry_assert(acc_fn, "accuracy2") - registry_assert(Accuracy3, "accuracy3") + registry_assert(Accuracy3, "Accuracy3") registry_assert(AccuracyNamedCls, "accuracy4") def test_metric_call() -> None: registry_assert(accuracy1(), "accuracy1") registry_assert(acc_fn(), "accuracy2") - registry_assert(Accuracy3(), "accuracy3") + registry_assert(Accuracy3(), "Accuracy3") registry_assert(AccuracyNamedCls(), "accuracy4") def test_metric_create() -> None: metric_create_assert("accuracy1", correct="C") metric_create_assert("accuracy1", correct="C") - metric_create_assert("accuracy3", correct="C") + metric_create_assert("Accuracy3", correct="C") metric_create_assert("accuracy4", correct="C") @@ -84,7 +84,7 @@ def check_log(log): "accuracy", "bootstrap_std", "accuracy1", - "accuracy3", + "Accuracy3", ] )