Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add llama-cpp-python model #666

Merged
merged 4 commits into from
Nov 3, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
- Prune `sample_reductions` when returning eval logs with `header_only=True`.
- Improved error message for undecorated solvers.
- For simple matching scorers, only include explanation if it differs from answer.
- Add `llama-cpp-python` local model provider.
jjallaire marked this conversation as resolved.
Show resolved Hide resolved

## v0.3.39 (3 October 2024)

Expand Down
2 changes: 1 addition & 1 deletion docs/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ $ inspect eval arc.py --model vllm/meta-llama/Llama-2-7b-chat-hf
```
:::

In addition to the model providers shown above, Inspect also supports models hosted on AWS Bedrock, Azure AI, Grok, TogetherAI, Groq, and Cloudflare, as well as local models with Ollama.
In addition to the model providers shown above, Inspect also supports models hosted on AWS Bedrock, Azure AI, Grok, TogetherAI, Groq, and Cloudflare, as well as local models with Ollama or llama-cpp-python.

## Hello, Inspect {#sec-hello-inspect}

Expand Down
60 changes: 32 additions & 28 deletions docs/models.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ Inspect has built in support for a variety of language model API providers and c
| Hugging Face | `pip install transformers` | None required |
| vLLM | `pip install vllm` | None required |
| Ollama | `pip install openai` | None required |
| llama-cpp-python | `pip install openai` | None required |
| Vertex | `pip install google-cloud-aiplatform` | None required |

: {tbl-colwidths="\[18,45,37\]"}

::: {.callout-note appearance="minimal"}
Note that some providers ([Grok](https://docs.x.ai/api/integrations#openai-sdk), [Ollama](https://github.com/ollama/ollama/blob/main/docs/openai.md) and [TogetherAI](https://docs.together.ai/docs/openai-api-compatibility)) support the OpenAI Python package as a client, which is why you need to `pip install openai` for these providers even though you aren't actually interacting with the OpenAI service when you use them.
Note that some providers ([Grok](https://docs.x.ai/api/integrations#openai-sdk), [Ollama](https://github.com/ollama/ollama/blob/main/docs/openai.md), [llama-cpp-python](https://llama-cpp-python.readthedocs.io/en/latest/server/) and [TogetherAI](https://docs.together.ai/docs/openai-api-compatibility)) support the OpenAI Python package as a client, which is why you need to `pip install openai` for these providers even though you aren't actually interacting with the OpenAI service when you use them.
:::

## Using Models
Expand All @@ -43,6 +44,7 @@ To select a model for use in an evaluation task you specify it using a *model na
| Hugging Face | `hf/openai-community/gpt2` | [Hugging Face Models](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) |
| vLLM | `vllm/openai-community/gpt2` | [vLLM Models](https://docs.vllm.ai/en/latest/models/supported_models.html) |
| Ollama | `ollama/llama3` | [Ollama Models](https://ollama.com/library) |
| llama-cpp-python | `llama-cpp-python/llama3` | [llama-cpp-python Models](https://llama-cpp-python.readthedocs.io/en/latest/#openai-compatible-web-server) |
| TogetherAI | `together/google/gemma-7b-it` | [TogetherAI Models](https://docs.together.ai/docs/inference-models#chat-models) |
| AWS Bedrock | `bedrock/meta.llama2-70b-chat-v1` | [AWS Bedrock Models](https://aws.amazon.com/bedrock/) |
| Azure AI | `azureai/azure-deployment-name` | [Azure AI Models](https://ai.azure.com/explore/models) |
Expand Down Expand Up @@ -80,19 +82,20 @@ If are using Google, Azure AI, AWS Bedrock, Hugging Face, or vLLM you should add

Each model also can use a different base URL than the default (e.g. if running through a proxy server). The base URL can be specified with the same prefix as the `API_KEY`, for example, the following are all valid base URLs:

| Provider | Environment Variable |
|-------------|-----------------------|
| OpenAI | `OPENAI_BASE_URL` |
| Anthropic | `ANTHROPIC_BASE_URL` |
| Google | `GOOGLE_BASE_URL` |
| Mistral | `MISTRAL_BASE_URL` |
| Grok | `GROK_BASE_URL` |
| TogetherAI | `TOGETHER_BASE_URL` |
| Ollama | `OLLAMA_BASE_URL` |
| AWS Bedrock | `BEDROCK_BASE_URL` |
| Azure AI | `AZUREAI_BASE_URL` |
| Groq | `GROQ_BASE_URL` |
| Cloudflare | `CLOUDFLARE_BASE_URL` |
| Provider | Environment Variable |
|------------------|-----------------------------|
| OpenAI | `OPENAI_BASE_URL` |
| Anthropic | `ANTHROPIC_BASE_URL` |
| Google | `GOOGLE_BASE_URL` |
| Mistral | `MISTRAL_BASE_URL` |
| Grok | `GROK_BASE_URL` |
| TogetherAI | `TOGETHER_BASE_URL` |
| Ollama | `OLLAMA_BASE_URL` |
| llama-cpp-python | `LLAMA_CPP_PYTHON_BASE_URL` |
| AWS Bedrock | `BEDROCK_BASE_URL` |
| Azure AI | `AZUREAI_BASE_URL` |
| Groq | `GROQ_BASE_URL` |
| Cloudflare | `CLOUDFLARE_BASE_URL` |

: {tbl-colwidths="\[50,50\]"}

Expand Down Expand Up @@ -349,7 +352,7 @@ Similar to the Hugging Face provider, you can also use local models with the vLL
$ inspect eval popularity --model vllm/local -M model_path=./my-model
```

#### vLLM Server
#### vLLM Server {#sec-vllm-server}

vLLM provides an HTTP server that implements OpenAI’s Chat API. To use this with Inspect, use the OpenAI provider rather than the vLLM provider, setting the model base URL to point to the vLLM server rather than OpenAI. For example:

Expand Down Expand Up @@ -413,19 +416,20 @@ inspect eval popularity --model google/gemini-1.0-pro -M transport:grpc

The additional `model_args` are forwarded as follows for the various providers:

| Provider | Forwarded to |
|--------------|----------------------------------------|
| OpenAI | `AsyncOpenAI` |
| Anthropic | `AsyncAnthropic` |
| Google | `genai.configure` |
| Mistral | `Mistral` |
| Hugging Face | `AutoModelForCausalLM.from_pretrained` |
| vLLM | `SamplingParams` |
| Ollama | `AsyncOpenAI` |
| TogetherAI | `AsyncOpenAI` |
| Groq | `AsyncGroq` |
| AzureAI | Chat HTTP Post Body |
| Cloudflare | Chat HTTP Post Body |
| Provider | Forwarded to |
|------------------|----------------------------------------|
| OpenAI | `AsyncOpenAI` |
| Anthropic | `AsyncAnthropic` |
| Google | `genai.configure` |
| Mistral | `Mistral` |
| Hugging Face | `AutoModelForCausalLM.from_pretrained` |
| vLLM | `SamplingParams` |
| Ollama | `AsyncOpenAI` |
| llama-cpp-python | `AsyncOpenAI` |
| TogetherAI | `AsyncOpenAI` |
| Groq | `AsyncGroq` |
| AzureAI | Chat HTTP Post Body |
| Cloudflare | Chat HTTP Post Body |

: {tbl-colwidths="\[30,70\]"}

Expand Down
4 changes: 2 additions & 2 deletions src/inspect_ai/_cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,13 @@ def eval_options(func: Callable[..., Any]) -> Callable[..., click.Context]:
@click.option(
"--frequency-penalty",
type=float,
help="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. OpenAI, Grok, Groq, and vLLM only.",
help="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. OpenAI, Grok, Groq, llama-cpp-python and vLLM only.",
envvar="INSPECT_EVAL_FREQUENCY_PENALTY",
)
@click.option(
"--presence-penalty",
type=float,
help="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. OpenAI, Grok, Groq, and vLLM only.",
help="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. OpenAI, Grok, Groq, llama-cpp-python and vLLM only.",
envvar="INSPECT_EVAL_PRESENCE_PENALTY",
)
@click.option(
Expand Down
21 changes: 21 additions & 0 deletions src/inspect_ai/model/_providers/llama_cpp_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from inspect_ai.model._providers.util import model_base_url

from .._generate_config import GenerateConfig
from .openai import OpenAIAPI


class LlamaCppPythonAPI(OpenAIAPI):
def __init__(
self,
model_name: str,
base_url: str | None = None,
api_key: str | None = None,
config: GenerateConfig = GenerateConfig(),
) -> None:
base_url = model_base_url(base_url, "LLAMA_CPP_PYTHON_BASE_URL")
base_url = base_url if base_url else "http://localhost:8000/v1"
if not api_key:
api_key = "llama-cpp-python"
super().__init__(
model_name=model_name, base_url=base_url, api_key=api_key, config=config
)
11 changes: 11 additions & 0 deletions src/inspect_ai/model/_providers/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ def ollama() -> type[ModelAPI]:
return OllamaAPI


@modelapi(name="llama-cpp-python")
def llama_cpp_python() -> type[ModelAPI]:
# validate
validate_openai_client("llama-cpp-python API")

# in the clear
from .llama_cpp_python import LlamaCppPythonAPI

return LlamaCppPythonAPI


@modelapi(name="azureai")
def azureai() -> type[ModelAPI]:
from .azureai import AzureAIAPI
Expand Down
30 changes: 30 additions & 0 deletions tests/model/providers/test_llama_cpp_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from test_helpers.utils import skip_if_no_llama_cpp_python

from inspect_ai.model import (
ChatMessageUser,
GenerateConfig,
get_model,
)


@pytest.mark.asyncio
@skip_if_no_llama_cpp_python
async def test_llama_cpp_python_api() -> None:
model = get_model(
"llama-cpp-python/default",
config=GenerateConfig(
frequency_penalty=0.0,
stop_seqs=None,
max_tokens=50,
presence_penalty=0.0,
logit_bias=dict([(42, 10), (43, -10)]),
seed=None,
temperature=0.0,
top_p=1.0,
),
)

message = ChatMessageUser(content="This is a test string. What are you?")
response = await model.generate(input=[message])
assert len(response.completion) >= 1
13 changes: 13 additions & 0 deletions tests/model/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
skip_if_github_action,
skip_if_no_accelerate,
skip_if_no_grok,
skip_if_no_llama_cpp_python,
skip_if_no_openai,
skip_if_no_together,
skip_if_no_transformers,
Expand Down Expand Up @@ -84,3 +85,15 @@ async def test_vllm_logprobs() -> None:
and response.choices[0].logprobs.content[0].top_logprobs is not None
)
assert len(response.choices[0].logprobs.content[0].top_logprobs) == 2


@pytest.mark.asyncio
@skip_if_github_action
@skip_if_no_llama_cpp_python
async def test_llama_cpp_python_logprobs() -> None:
response = await generate_with_logprobs("llama-cpp-python/default")
assert (
response.choices[0].logprobs
and response.choices[0].logprobs.content[0].top_logprobs is not None
)
assert len(response.choices[0].logprobs.content[0].top_logprobs) == 2
6 changes: 6 additions & 0 deletions tests/test_helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def skip_if_no_azureai(func):
return pytest.mark.api(skip_if_env_var("AZURE_API_KEY", exists=False)(func))


def skip_if_no_llama_cpp_python(func):
return pytest.mark.api(
skip_if_env_var("ENABLE_LLAMA_CPP_PYTHON_TESTS", exists=False)(func)
)


def skip_if_no_vertex(func):
return pytest.mark.api(skip_if_env_var("ENABLE_VERTEX_TESTS", exists=False)(func))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ const kInspectModels: Record<string, string> = {
"bedrock": "0.3.8",
"ollama": "0.3.9",
"azureai": "0.3.8",
"cf": "0.3.8"
"cf": "0.3.8",
"llama-cpp-python": "0.3.39"
};

const inspectModels = () => {
Expand Down
Loading