Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Goodfire API Provider Support #1161

Open
wants to merge 60 commits into
base: main
Choose a base branch
from

Conversation

menhguin
Copy link

@menhguin menhguin commented Jan 20, 2025

Add Goodfire API Provider Support

Overview

This PR introduces support for the Goodfire API, enabling the use of Meta's Llama models through Goodfire's inference service. The implementation provides basic chat completion functionality while maintaining compatibility with the existing evaluation framework.

Over the next few weeks, I expect to add complex mechanistic interpretability techniques (feature search, inspect, feature steering) as shown in the Goodfire AI documentation. For now, this PR seeks to cover basic chat completion and standardisation in line with other model providers (since I have to keep merging new commits every day)

Critical Implementation Details

  1. Core Provider Implementation (inspect_ai/src/inspect_ai/model/_providers/goodfire.py):

    • Implements GoodfireAPI class with synchronous API handling
    • Key methods: generate(), _to_goodfire_message(), connection management
    • Constants for defaults: DEFAULT_MAX_TOKENS=4096, DEFAULT_TEMPERATURE=0.7
    • Model mapping for supported variants in MODEL_MAP
  2. Known Limitations:

    • MMLU Few-shot Evaluation Issue ():
      • Zero-shot works correctly (~0.57 accuracy)
      • Few-shot fails (~0.1 accuracy) due to strict format following - we note that this is more due to using Llama 3 instruct models, rather than a Goodfire-specific issue
      • Model outputs bare letters instead of "Answer: A" format
      • Affects inspect_evals/src/inspect_evals/mmlu/mmlu_5_shot.py
    • Synchronous API in async framework:
      • Blocks event loop during generation
      • Affects progress bar updates
      • Located in generate() method
  3. API Differences (vs OpenAI/Anthropic):

    • Parameter naming:
      • Uses max_completion_tokens vs max_tokens
      • Different default values
    • Response handling:
      • Dictionary-based vs object-based responses
      • Manual extraction required for content/usage
      • No finish_reason field
    • Message handling:
      • Tool messages converted to user messages
      • Limited role support

Required Configuration

  1. Environment Setup:

    GOODFIRE_API_KEY=<key>
    GOODFIRE_BASE_URL=<optional>
    
    pip install goodfire
    
  2. Model Support:

    • Currently supports:
      • meta-llama/Meta-Llama-3.1-8B-Instruct
      • meta-llama/Llama-3.3-70B-Instruct

Pending Improvements (Prioritized)

  1. Critical:

    • Fix few-shot evaluation format handling
    • Implement proper async operation
    • Add progress tracking solution
  2. Important:

    • Add streaming support when available
    • Implement tool calls support
    • Enhance error handling
  3. Nice to Have:

    • Add feature analysis support
    • Expand model support
    • Add caching strategy

Testing Status

  1. Verified:

    • Basic chat completion
    • Zero-shot evaluations
    • Usage statistics collection
    • Parameter validation
  2. Known Issues:

    • Few-shot format compliance
    • Progress tracking during long runs
    • Type hints causing linter errors

Breaking Changes

None. This should not affect the use of other model providers, and effort has been taken to ensure standardisation. Code changes have been isolated to:

  • /src/inspect_ai/model/_providers/goodfire.py for core implementation script
  • src/inspect_ai/model/_providers/providers.py to register Goodfire as a model provider
  • src/inspect_ai/model/_generate_config.py for certain Goodfire-specific generation functions (tho do feel free to test this to make sure it doesn't affect any other providers)

So far, the model seems to generate and score similarly to VLLM-hosted Llama 8B-Instruct on GPQA, GSM8K and MMLU.

Conclusion

Once again, I will be improving and building on this initial chat generation implementation in the coming weeks with more advanced mech interp functions. If you come across issues in other evals, do let me know.

- Introduced `GoodfireConfig` dataclass for Goodfire-specific settings in `_generate_config.py`.
- Implemented `GoodfireAPI` class in a new file `_providers/goodfire.py` to handle interactions with the Goodfire API.
- Registered the Goodfire API provider in `_providers/providers.py`, including error handling for dependency imports.
- Updated `GenerateConfig` to include Goodfire configuration options.
…th version verification and improved error handling

- Added support for minimum version requirement for the Goodfire API.
- Introduced supported model literals and updated model name handling.
- Improved API key retrieval logic with environment variable checks.
- Enhanced client initialization to include base URL handling.
- Updated maximum completion tokens to 4096 for better performance.
- Refined message conversion to handle tool messages appropriately.
- Removed unsupported feature analysis configuration.

This commit improves the robustness and usability of the Goodfire API integration.
…odfireAPI generate method for improved error handling and parameter management

- Enhanced the generate method to use a try-except block for better error logging.
- Consolidated API request parameters into a dictionary for cleaner code.
- Added handling for usage statistics in the output if available.
- Improved message conversion process for better clarity and maintainability.

This update increases the robustness of the Goodfire API integration and enhances error reporting.
@menhguin
Copy link
Author

menhguin commented Jan 20, 2025

OH YES almost forgot: You need to pip install goodfire as well, but i'm unsure where to add this.

Copy link
Collaborator

@jjallaire jjallaire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic! So happy to see this and excited to see it built out further. Left some feedback in the review. Some additional comments:

  1. Saw there was a note on streaming support -- currently do don't use streaming in our model interfaces so I don't think this will be required (but perhaps there is a scenario I'm not thinking of?)

  2. Saw your note on caching -- would the built in caching work (we cache ModelOutput instances based on a key that hopefully reflects the full range of possible inputs).

In terms of adding mech interp stuff, we've had initial discussions with a few others in the field on how to do this. At some point I think we'd like to define some common data structures that can go in ModelOutput but we aren't there yet. In the meantime, you should add any mech interp data to the metadata field of ModelOutput (using whatever schema you want). Later we can try to bring some of this back into something that is shared by multiple mech interp back ends.

src/inspect_ai/model/_generate_config.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/providers.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/providers.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/providers.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
@jjallaire
Copy link
Collaborator

OH YES almost forgot: You need to pip install goodfire as well, but i'm unsure where to add this.

You can add this to the dev config of [project.optional-dependencies] in pyproject.toml

@menhguin
Copy link
Author

The proposed changes here seem reasonable. I will attempt to implement all of them by Friday morning-ish UK time.

Some of the more ... awkward design choices were me trying to patch dozens of little mismatches between Inspect and Goodfire API (different function names, output formats, Goodfire's special functions). I tried to clean it up + standardise with the rest of Inspect but evidently i missed a few things.

I'll try the metadata approach afterwards. Figuring out which mech interp function to allow and how is gonna be ... tricky. Do you have any reference examples where a model provider supports more than just text generation via Inspect? Even logits/logprob view might be a helpful reference.

@jjallaire
Copy link
Collaborator

The proposed changes here seem reasonable. I will attempt to implement all of them by Friday morning-ish UK time.

Great, thanks! (and feel free to ping me w/ any questions in the meantime)

I'll try the metadata approach afterwards. Figuring out which mech interp function to allow and how is gonna be ... tricky. Do you have any reference examples where a model provider supports more than just text generation via Inspect? Even logits/logprob view might be a helpful reference.

Yes, several model providers (OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM) support Logprobs:

class TopLogprob(BaseModel):
"""List of the most likely tokens and their log probability, at this token position."""
token: str
"""The top-kth token represented as a string."""
logprob: float
"""The log probability value of the model for the top-kth token."""
bytes: list[int] | None = Field(default=None)
"""The top-kth token represented as a byte array (a list of integers)."""
class Logprob(BaseModel):
"""Log probability for a token."""
token: str
"""The predicted token represented as a string."""
logprob: float
"""The log probability value of the model for the predicted token."""
bytes: list[int] | None = Field(default=None)
"""The predicted token represented as a byte array (a list of integers)."""
top_logprobs: list[TopLogprob] | None = Field(default=None)
"""If the `top_logprobs` argument is greater than 0, this will contain an ordered list of the top K most likely tokens and their log probabilities."""
class Logprobs(BaseModel):
"""Log probability information for a completion choice."""
content: list[Logprob]
"""a (num_generated_tokens,) length list containing the individual log probabilities for each generated token."""

Eventually I'd like to have some standard fields like this for mech interp payloads (so that readers of logs can benefit form some uniformity). Absent working out these schemas I would put your own data structures in ModelOutput.metadata then we can ideally learn from them and work towards standardization over time.

…e package dependency and remove GoodfireConfig class from GenerateConfig. Enhance goodfire provider with version verification.
…se runtime-safe string. Add note and TODO for potential issue in Goodfire's repo.
- Remove hardcoded MODEL_MAP and variant validation
- Directly use model name for Variant initialization
- Standardize max_tokens to use default value
- Enhance generate method to include additional model arguments
- Streamline model configuration and parameter handling
- Improve error handling by using specific Goodfire exception types
- Refactor rate limit and context length error detection
- Enhance parameter configuration with more flexible temperature and top_p handling
- Add type casting and improve type hints
- Simplify client initialization and method calls
- Simplify model argument collection and storage
- Update generate method to incorporate model arguments more flexibly
- Remove separate tracking of temperature and top_p
- Ensure all model-specific arguments are passed to generation parameters
- Update type hints for model arguments and parameters
- Improve parameter configuration in generate method
- Simplify base model selection and parameter passing
- Enhance code readability and type consistency
- Add default values for temperature and top_p when not specified in model arguments
- Prioritize model_args over config parameters
- Ensure consistent parameter configuration when generating completions
@menhguin
Copy link
Author

menhguin commented Jan 25, 2025

Newest updates as of this weekend. Basically improving robustness, handling and minimising hardcodes:

  • Replaced synchronous with async (I found it deep in the docs). This also now enables progress tracking.
  • Refined the logic of passing model args, including any model args that should be added in the future.
  • Streamlined error handling to pass max tokens, invalid request, ratelimit and connection key error specifically since i'm told that lets Inspect do handling on its own, but otherwise errors are passed as-is. For example, invalid model args are passed and the specific error is propagated as expected. I tested an invalid top_p value of 1.1 and it gave the relevant error output.
│ RuntimeError:                                                                             │
│ RequestFailedException('{"detail":[{"type":"less_than_equal","loc":["body","top_p"],"msg… │
│ should be less than or equal to 1","input":1.1,"ctx":{"le":1.0}}]}') 
  • Model names no longer hardcoded, simply passed to the Goodfire API. This was needed early on for setup since it took a while to get model name passing right due to issues with "variant", Literal, prefixing and the actual mode names being weird. This is robust to model names being edited/added, is more standardised and eliminates the need to update specific model params.
  • General standardisation and logic improvements to iron out hardcoded values, model args precedence etc to make all the new functions play nice with one another, especially when something unexpected happens. Though I may have missed something.

So text gen should be settled now? I am moving onto feature implementations this week, finally.

@jjallaire
Copy link
Collaborator

Thanks again for your diligent work here! Noted the changes and PR is looking good. I did a scan of the code as-is and did find a couple more things we should tweak (will post those a new review shortly).

Copy link
Collaborator

@jjallaire jjallaire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some additional comments (all of them quite small).

Could you also add a simple test (and related skip_if function) along these lines: https://github.com/UKGovernmentBEIS/inspect_ai/blob/main/tests/model/providers/test_groq.py

I noticed there were also some ruff errors when running the checks. You can clean this up by running make check locally.

src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved
src/inspect_ai/model/_providers/goodfire.py Outdated Show resolved Hide resolved

# Defer importing model api classes until they are actually used
# (this allows the package to load without the optional deps)
# Note that some api providers (e.g. Cloudflare, AzureAI) don't
# strictly require this treatment but we do it anyway for uniformity,

logger = logging.getLogger(__name__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use the logger anymore so we can remove it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.


from inspect_ai._util.error import pip_dependency_error
from inspect_ai._util.version import verify_required_version

from .._model import ModelAPI
from .._registry import modelapi
from .._registry import modelapi, modelapi_register
from .goodfire import GoodfireAPI
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import needs to be moved down into the goodfire() function (otherwise all Inspect users will need this package installed). See below for recommended implementation of this function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shifted. first few lines are just this (unchanged)

import os

from inspect_ai._util.error import pip_dependency_error
from inspect_ai._util.version import verify_required_version

from .._model import ModelAPI
from .._registry import modelapi

@@ -239,6 +243,21 @@ def mockllm() -> type[ModelAPI]:
return MockLLM


@modelapi(name="goodfire")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommended implementation based on other providers:

@modelapi(name="anthropic")
def anthropic() -> type[ModelAPI]:
    FEATURE = "Goodfire API"
    PACKAGE = "goodfire"
    MIN_VERSION = "0.2.5"

    # verify we have the package
    try:
        import goodfire  # noqa: F401
    except ImportError:
        raise pip_dependency_error(FEATURE, [PACKAGE])

    # verify version
    verify_required_version(FEATURE, PACKAGE, MIN_VERSION)

    # in the clear
    from .goodfire import GoodfireAPI

    return GoodfireAPI

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implemented in latest version!

@modelapi("goodfire")
def goodfire() -> type[ModelAPI]:
    """Get the Goodfire API provider."""
    FEATURE = "Goodfire API"
    PACKAGE = "goodfire"
    MIN_VERSION = "0.3.4"  # Support for newer Llama models and OpenAI compatibility

    # verify we have the package
    try:
        import goodfire  # noqa: F401
    except ImportError:
        raise pip_dependency_error(FEATURE, [PACKAGE])

    # verify version
    verify_required_version(FEATURE, PACKAGE, MIN_VERSION)

    # in the clear
    from .goodfire import GoodfireAPI
    return GoodfireAPI

- Remove redundant rate limit exception handling in handle_error method
- Simplify import of Goodfire API exceptions
- Maintain existing InvalidRequestException handling
- Remove unnecessary type imports
- Simplify type casting in generate method
- Update type hints for parameters
- Remove redundant type casting for return values
- Remove unused logging import
- Refactor Goodfire provider initialization to improve error handling
- Streamline package import and version verification
- Remove unnecessary logger definition
- Create new test file for Goodfire model provider
- Add skip decorator for Goodfire API key requirement
- Implement basic test for model generation with sample configuration
- Verify response generation for Goodfire model
- Remove unnecessary imports and unused variables
- Simplify error handling and type conversion logic
- Streamline code by removing commented-out and redundant code
- Update providers.py to remove unused import
… test case

- Modify GoodfireAPI to filter out non-API parameters like api_key and base_url
- Update test case to use a specific Llama 3.1 model from SUPPORTED_MODELS
- Simplify test configuration and add tool_choice parameter
- Add RateLimitException to imports for potential future error handling
- Update test_goodfire.py to use GoodfireAPI directly
- Remove skip decorator and unnecessary imports
- Simplify test configuration
- Update providers.py to use deferred import for GoodfireAPI
- Improve code organization and import management
@menhguin
Copy link
Author

hello! each of the requested changes should be applied now

image

@jjallaire
Copy link
Collaborator

hello! each of the requested changes should be applied now

Thanks! I noticed that there are 3 more small issues to resolve (all related to the providers.py file). We should also name that function goodfire() rather than get_goodfire().

- Add version check for goodfire package (minimum 0.3.4)
- Modify provider function to include package verification
- Remove version constraint in pyproject.toml
- Rename get_goodfire() to goodfire() for consistency
@menhguin
Copy link
Author

hello! each of the requested changes should be applied now

Thanks! I noticed that there are 3 more small issues to resolve (all related to the providers.py file). We should also name that function goodfire() rather than get_goodfire().

I actually did implement those, I just forgot to reply mentioning that. anyway, goodfire() should be get_goodfire() now!

- Remove unused logging import
- Refactor Goodfire provider initialization to improve error handling
- Streamline package import and version verification
- Remove unnecessary logger definition
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants