Skip to content

Commit

Permalink
sync 11-07-24
Browse files Browse the repository at this point in the history
  • Loading branch information
aisi-inspect committed Jul 11, 2024
1 parent 44e432f commit 7f68746
Show file tree
Hide file tree
Showing 20 changed files with 231 additions and 83 deletions.
8 changes: 8 additions & 0 deletions docs/scorers.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,14 @@ Note also we use the `input_text` property of the `TaskState` to access a string

## Multiple Scorers {#sec-multiple-scorers}

::: {.callout-note appearance="simple"}
The multiple scorers feature described below is available in only the development version of Inspect (it is not yet published to PyPI). You can install the development version with:

```bash
$ pip install git+https://github.com/UKGovernmentBEIS/inspect_ai
```
:::

There are several ways to use multiple scorers in an evaluation:

1. You can provide a list of scorers in a `Task` definition (this is the best option when scorers are entirely independent)
Expand Down
4 changes: 2 additions & 2 deletions src/inspect_ai/_eval/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from inspect_ai._util.telemetry import init_telemetry
from inspect_ai._util.hooks import init_hooks
from inspect_ai.model import Model
from inspect_ai.model._model import init_active_model, init_model_usage
from inspect_ai.util._concurrency import init_concurrency
Expand All @@ -9,7 +9,7 @@
def init_eval_context(max_subprocesses: int | None = None) -> None:
init_concurrency()
init_max_subprocesses(max_subprocesses)
init_telemetry()
init_hooks()


def init_task_context(model: Model) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/inspect_ai/_eval/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
from inspect_ai._util.datetime import iso_now
from inspect_ai._util.error import exception_message
from inspect_ai._util.file import file, filesystem
from inspect_ai._util.hooks import send_telemetry
from inspect_ai._util.registry import (
is_registry_object,
registry_log_name,
)
from inspect_ai._util.telemetry import send_telemetry
from inspect_ai._util.url import data_uri_to_base64, is_data_uri
from inspect_ai._view.view import view_notify_eval
from inspect_ai.dataset import Dataset, Sample
Expand Down
118 changes: 118 additions & 0 deletions src/inspect_ai/_util/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import importlib
import os
from typing import Any, Awaitable, Callable, Literal, cast

from rich import print

from .constants import PKG_NAME
from .error import PrerequisiteError

# Hooks are functions inside packages that are installed with an
# environment variable (e.g. INSPECT_TELEMETRY='mypackage.send_telemetry')
# If one or more hooks are enabled a message will be printed at startup
# indicating this, as well as which package/function implements each hook


# Telemetry (INSPECT_TELEMETRY)
#
# Telemetry can be optionally enabled by setting an INSPECT_TELEMETRY
# environment variable that points to a function in a package which
# conforms to the TelemetrySend signature below.

# There are currently two types of telemetry sent:
# - model_usage (type ModelUsage)
# - eval_log (type EvalLog)

TelemetrySend = Callable[[str, str], Awaitable[None]]


async def send_telemetry(type: Literal["model_usage", "eval_log"], json: str) -> None:
global _send_telemetry
if _send_telemetry:
await _send_telemetry(type, json)


_send_telemetry: TelemetrySend | None = None

# API Key Override (INSPECT_API_KEY_OVERRIDE)
#
# API Key overrides can be optionally enabled by setting an
# INSPECT_API_KEY_OVERRIDE environment variable which conforms to the
# ApiKeyOverride signature below.
#
# The api key override function will be called with the name and value
# of provider specified environment variables that contain api keys,
# and it can optionally return an override value.

ApiKeyOverride = Callable[[str, str], str | None]


def override_api_key(var: str, value: str) -> str | None:
global _override_api_key
if _override_api_key:
return _override_api_key(var, value)
else:
return None


_override_api_key: ApiKeyOverride | None = None


def init_hooks() -> None:
# messages we'll print for hooks if we have them
messages: list[str] = []

# telemetry
global _send_telemetry
if not _send_telemetry:
result = init_hook(
"telemetry",
"INSPECT_TELEMETRY",
"(eval logs and token usage will be recorded by the provider)",
)
if result:
_send_telemetry, message = result
messages.append(message)

# api key override
global _override_api_key
if not _override_api_key:
result = init_hook(
"api key override",
"INSPECT_API_KEY_OVERRIDE",
"(api keys will be read and modified by the provider)",
)
if result:
_override_api_key, message = result
messages.append(message)

# if any hooks are enabled, let the user know
if len(messages) > 0:
version = importlib.metadata.version(PKG_NAME)
all_messages = "\n".join([f"- {message}" for message in messages])
print(
f"[blue][bold]inspect_ai v{version}[/bold][/blue]\n[bright_black]{all_messages}[/bright_black]\n"
)


def init_hook(
name: str, env: str, message: str
) -> tuple[Callable[..., Any], str] | None:
hook = os.environ.get(env, "")
if hook:
# parse module/function
module_name, function_name = hook.strip().rsplit(".", 1)
# load (fail gracefully w/ clear error)
try:
module = importlib.import_module(module_name)
return (
cast(Callable[..., Any], getattr(module, function_name)),
f"[bold]{name} enabled: {hook}[/bold]\n {message}",
)
except (AttributeError, ModuleNotFoundError):
raise PrerequisiteError(
f"{env} provider not found: {hook}\n"
+ "Please correct (or undefine) this environment variable before proceeding.\n"
)
else:
return None
53 changes: 0 additions & 53 deletions src/inspect_ai/_util/telemetry.py

This file was deleted.

30 changes: 26 additions & 4 deletions src/inspect_ai/model/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS
from inspect_ai._util.content import Content, ContentText
from inspect_ai._util.entrypoints import ensure_entry_points
from inspect_ai._util.hooks import init_hooks, override_api_key, send_telemetry
from inspect_ai._util.platform import platform_init
from inspect_ai._util.registry import (
RegistryInfo,
Expand All @@ -28,7 +29,6 @@
registry_unqualified_name,
)
from inspect_ai._util.retry import log_rate_limit_retry
from inspect_ai._util.telemetry import init_telemetry, send_telemetry
from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo
from inspect_ai.util import concurrency

Expand All @@ -54,6 +54,7 @@ def __init__(
model_name: str,
base_url: str | None = None,
api_key: str | None = None,
api_key_vars: list[str] = [],
config: GenerateConfig = GenerateConfig(),
) -> None:
"""Create a model API provider.
Expand All @@ -62,13 +63,34 @@ def __init__(
model_name (str): Model name.
base_url (str | None): Alternate base URL for model.
api_key (str | None): API key for model.
api_key_vars (list[str]): Environment variables that
may contain keys for this provider (used for override)
config (GenerateConfig): Model configuration.
"""
self.model_name = model_name
self.base_url = base_url
self.api_key = api_key
self.config = config

# apply api key override
for key in api_key_vars:
# if there is an explicit api_key passed then it
# overrides anything in the environment so use it
if api_key is not None:
override = override_api_key(key, api_key)
if override is not None:
api_key = override
# otherwise look it up in the environment and
# override it if it has a value
else:
value = os.environ.get(key, None)
if value is not None:
override = override_api_key(key, value)
if override is not None:
os.environ[key] = override

# set any explicitly specified api key
self.api_key = api_key

@abc.abstractmethod
async def generate(
self,
Expand Down Expand Up @@ -459,9 +481,9 @@ def match_modelapi_type(info: RegistryInfo) -> bool:
# find a matching model type
modelapi_types = registry_find(match_modelapi_type)
if len(modelapi_types) > 0:
# create the model (init_telemetry here in case the model api
# create the model (init_hooks here in case the model api
# is being used as a stadalone model interface outside of evals)
init_telemetry()
init_hooks()
modelapi_type = cast(type[ModelAPI], modelapi_types[0])
modelapi_instance = modelapi_type(
model_name=model,
Expand Down
6 changes: 5 additions & 1 deletion src/inspect_ai/model/_providers/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def __init__(
**model_args: Any,
):
super().__init__(
model_name=model_name, base_url=base_url, api_key=api_key, config=config
model_name=model_name,
base_url=base_url,
api_key=api_key,
api_key_vars=[ANTHROPIC_API_KEY],
config=config,
)

# create client
Expand Down
6 changes: 5 additions & 1 deletion src/inspect_ai/model/_providers/azureai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def __init__(
**model_args: Any,
):
super().__init__(
model_name=model_name, base_url=base_url, api_key=api_key, config=config
model_name=model_name,
base_url=base_url,
api_key=api_key,
api_key_vars=[AZURE_API_KEY],
config=config,
)

# required for some deployments
Expand Down
8 changes: 7 additions & 1 deletion src/inspect_ai/model/_providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .._generate_config import GenerateConfig
from .._model import ModelAPI, simple_input_messages
from .._model_output import ChatCompletionChoice, ModelOutput, ModelUsage
from .anthropic import ANTHROPIC_API_KEY
from .util import as_stop_reason, model_base_url


Expand All @@ -35,7 +36,12 @@ def __init__(
config: GenerateConfig = GenerateConfig(),
**model_args: Any,
):
super().__init__(model_name=model_name, base_url=base_url, config=config)
super().__init__(
model_name=model_name,
base_url=base_url,
api_key_vars=[ANTHROPIC_API_KEY],
config=config,
)

# we can optionally proxy to another ModelAPI
self.model_api: ModelAPI | None = None
Expand Down
15 changes: 12 additions & 3 deletions src/inspect_ai/model/_providers/cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
# https://developers.cloudflare.com/workers-ai/models/#text-generation


CLOUDFLARE_API_TOKEN = "CLOUDFLARE_API_TOKEN"


class CloudFlareAPI(ModelAPI):
def __init__(
self,
Expand All @@ -29,15 +32,21 @@ def __init__(
**model_args: Any,
):
super().__init__(
model_name=model_name, base_url=base_url, api_key=api_key, config=config
model_name=model_name,
base_url=base_url,
api_key=api_key,
api_key_vars=[CLOUDFLARE_API_TOKEN],
config=config,
)
self.account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID")
if not self.account_id:
raise RuntimeError("CLOUDFLARE_ACCOUNT_ID environment variable not set")
if not self.api_key:
self.api_key = os.getenv("CLOUDFLARE_API_TOKEN")
self.api_key = os.getenv(CLOUDFLARE_API_TOKEN)
if not self.api_key:
raise RuntimeError("CLOUDFLARE_API_TOKEN environment variable not set")
raise RuntimeError(
f"{CLOUDFLARE_API_TOKEN} environment variable not set"
)
self.client = httpx.AsyncClient()
base_url = model_base_url(base_url, "CLOUDFLARE_BASE_URL")
self.base_url = (
Expand Down
8 changes: 7 additions & 1 deletion src/inspect_ai/model/_providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}

GOOGLE_API_KEY = "GOOGLE_API_KEY"


class GoogleAPI(ModelAPI):
def __init__(
Expand All @@ -70,7 +72,11 @@ def __init__(
**model_args: Any,
) -> None:
super().__init__(
model_name=model_name, base_url=base_url, api_key=api_key, config=config
model_name=model_name,
base_url=base_url,
api_key=api_key,
api_key_vars=[GOOGLE_API_KEY],
config=config,
)

# configure genai client
Expand Down
Loading

0 comments on commit 7f68746

Please sign in to comment.