-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Goodfire API Provider Support #1161
base: main
Are you sure you want to change the base?
Conversation
- Introduced `GoodfireConfig` dataclass for Goodfire-specific settings in `_generate_config.py`. - Implemented `GoodfireAPI` class in a new file `_providers/goodfire.py` to handle interactions with the Goodfire API. - Registered the Goodfire API provider in `_providers/providers.py`, including error handling for dependency imports. - Updated `GenerateConfig` to include Goodfire configuration options.
…th version verification and improved error handling - Added support for minimum version requirement for the Goodfire API. - Introduced supported model literals and updated model name handling. - Improved API key retrieval logic with environment variable checks. - Enhanced client initialization to include base URL handling. - Updated maximum completion tokens to 4096 for better performance. - Refined message conversion to handle tool messages appropriately. - Removed unsupported feature analysis configuration. This commit improves the robustness and usability of the Goodfire API integration.
…odfireAPI generate method for improved error handling and parameter management - Enhanced the generate method to use a try-except block for better error logging. - Consolidated API request parameters into a dictionary for cleaner code. - Added handling for usage statistics in the output if available. - Improved message conversion process for better clarity and maintainability. This update increases the robustness of the Goodfire API integration and enhances error reporting.
OH YES almost forgot: You need to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fantastic! So happy to see this and excited to see it built out further. Left some feedback in the review. Some additional comments:
-
Saw there was a note on streaming support -- currently do don't use streaming in our model interfaces so I don't think this will be required (but perhaps there is a scenario I'm not thinking of?)
-
Saw your note on caching -- would the built in caching work (we cache
ModelOutput
instances based on a key that hopefully reflects the full range of possible inputs).
In terms of adding mech interp stuff, we've had initial discussions with a few others in the field on how to do this. At some point I think we'd like to define some common data structures that can go in ModelOutput
but we aren't there yet. In the meantime, you should add any mech interp data to the metadata
field of ModelOutput
(using whatever schema you want). Later we can try to bring some of this back into something that is shared by multiple mech interp back ends.
You can add this to the |
The proposed changes here seem reasonable. I will attempt to implement all of them by Friday morning-ish UK time. Some of the more ... awkward design choices were me trying to patch dozens of little mismatches between Inspect and Goodfire API (different function names, output formats, Goodfire's special functions). I tried to clean it up + standardise with the rest of Inspect but evidently i missed a few things. I'll try the metadata approach afterwards. Figuring out which mech interp function to allow and how is gonna be ... tricky. Do you have any reference examples where a model provider supports more than just text generation via Inspect? Even logits/logprob view might be a helpful reference. |
Great, thanks! (and feel free to ping me w/ any questions in the meantime)
Yes, several model providers (OpenAI, Grok, TogetherAI, Huggingface, llama-cpp-python, and vLLM) support Logprobs: inspect_ai/src/inspect_ai/model/_model_output.py Lines 39 to 72 in 124d837
Eventually I'd like to have some standard fields like this for mech interp payloads (so that readers of logs can benefit form some uniformity). Absent working out these schemas I would put your own data structures in |
…e package dependency and remove GoodfireConfig class from GenerateConfig. Enhance goodfire provider with version verification.
…se runtime-safe string. Add note and TODO for potential issue in Goodfire's repo.
- Remove hardcoded MODEL_MAP and variant validation - Directly use model name for Variant initialization - Standardize max_tokens to use default value - Enhance generate method to include additional model arguments - Streamline model configuration and parameter handling
- Improve error handling by using specific Goodfire exception types - Refactor rate limit and context length error detection - Enhance parameter configuration with more flexible temperature and top_p handling - Add type casting and improve type hints - Simplify client initialization and method calls
- Simplify model argument collection and storage - Update generate method to incorporate model arguments more flexibly - Remove separate tracking of temperature and top_p - Ensure all model-specific arguments are passed to generation parameters
- Update type hints for model arguments and parameters - Improve parameter configuration in generate method - Simplify base model selection and parameter passing - Enhance code readability and type consistency
- Add default values for temperature and top_p when not specified in model arguments - Prioritize model_args over config parameters - Ensure consistent parameter configuration when generating completions
Newest updates as of this weekend. Basically improving robustness, handling and minimising hardcodes:
So text gen should be settled now? I am moving onto feature implementations this week, finally. |
Thanks again for your diligent work here! Noted the changes and PR is looking good. I did a scan of the code as-is and did find a couple more things we should tweak (will post those a new review shortly). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some additional comments (all of them quite small).
Could you also add a simple test (and related skip_if
function) along these lines: https://github.com/UKGovernmentBEIS/inspect_ai/blob/main/tests/model/providers/test_groq.py
I noticed there were also some ruff
errors when running the checks. You can clean this up by running make check
locally.
|
||
# Defer importing model api classes until they are actually used | ||
# (this allows the package to load without the optional deps) | ||
# Note that some api providers (e.g. Cloudflare, AzureAI) don't | ||
# strictly require this treatment but we do it anyway for uniformity, | ||
|
||
logger = logging.getLogger(__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't use the logger anymore so we can remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
|
||
from inspect_ai._util.error import pip_dependency_error | ||
from inspect_ai._util.version import verify_required_version | ||
|
||
from .._model import ModelAPI | ||
from .._registry import modelapi | ||
from .._registry import modelapi, modelapi_register | ||
from .goodfire import GoodfireAPI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import needs to be moved down into the goodfire()
function (otherwise all Inspect users will need this package installed). See below for recommended implementation of this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shifted. first few lines are just this (unchanged)
import os
from inspect_ai._util.error import pip_dependency_error
from inspect_ai._util.version import verify_required_version
from .._model import ModelAPI
from .._registry import modelapi
@@ -239,6 +243,21 @@ def mockllm() -> type[ModelAPI]: | |||
return MockLLM | |||
|
|||
|
|||
@modelapi(name="goodfire") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommended implementation based on other providers:
@modelapi(name="anthropic")
def anthropic() -> type[ModelAPI]:
FEATURE = "Goodfire API"
PACKAGE = "goodfire"
MIN_VERSION = "0.2.5"
# verify we have the package
try:
import goodfire # noqa: F401
except ImportError:
raise pip_dependency_error(FEATURE, [PACKAGE])
# verify version
verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
# in the clear
from .goodfire import GoodfireAPI
return GoodfireAPI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implemented in latest version!
@modelapi("goodfire")
def goodfire() -> type[ModelAPI]:
"""Get the Goodfire API provider."""
FEATURE = "Goodfire API"
PACKAGE = "goodfire"
MIN_VERSION = "0.3.4" # Support for newer Llama models and OpenAI compatibility
# verify we have the package
try:
import goodfire # noqa: F401
except ImportError:
raise pip_dependency_error(FEATURE, [PACKAGE])
# verify version
verify_required_version(FEATURE, PACKAGE, MIN_VERSION)
# in the clear
from .goodfire import GoodfireAPI
return GoodfireAPI
- Remove redundant rate limit exception handling in handle_error method - Simplify import of Goodfire API exceptions - Maintain existing InvalidRequestException handling
- Remove unnecessary type imports - Simplify type casting in generate method - Update type hints for parameters - Remove redundant type casting for return values
- Remove unused logging import - Refactor Goodfire provider initialization to improve error handling - Streamline package import and version verification - Remove unnecessary logger definition
- Create new test file for Goodfire model provider - Add skip decorator for Goodfire API key requirement - Implement basic test for model generation with sample configuration - Verify response generation for Goodfire model
- Remove unnecessary imports and unused variables - Simplify error handling and type conversion logic - Streamline code by removing commented-out and redundant code - Update providers.py to remove unused import
… test case - Modify GoodfireAPI to filter out non-API parameters like api_key and base_url - Update test case to use a specific Llama 3.1 model from SUPPORTED_MODELS - Simplify test configuration and add tool_choice parameter - Add RateLimitException to imports for potential future error handling
- Update test_goodfire.py to use GoodfireAPI directly - Remove skip decorator and unnecessary imports - Simplify test configuration - Update providers.py to use deferred import for GoodfireAPI - Improve code organization and import management
Thanks! I noticed that there are 3 more small issues to resolve (all related to the providers.py file). We should also name that function |
- Add version check for goodfire package (minimum 0.3.4) - Modify provider function to include package verification - Remove version constraint in pyproject.toml - Rename get_goodfire() to goodfire() for consistency
I actually did implement those, I just forgot to reply mentioning that. anyway, |
- Remove unused logging import - Refactor Goodfire provider initialization to improve error handling - Streamline package import and version verification - Remove unnecessary logger definition
Add Goodfire API Provider Support
Overview
This PR introduces support for the Goodfire API, enabling the use of Meta's Llama models through Goodfire's inference service. The implementation provides basic chat completion functionality while maintaining compatibility with the existing evaluation framework.
Over the next few weeks, I expect to add complex mechanistic interpretability techniques (feature search, inspect, feature steering) as shown in the Goodfire AI documentation. For now, this PR seeks to cover basic chat completion and standardisation in line with other model providers (since I have to keep merging new commits every day)
Critical Implementation Details
Core Provider Implementation (
inspect_ai/src/inspect_ai/model/_providers/goodfire.py
):GoodfireAPI
class with synchronous API handlinggenerate()
,_to_goodfire_message()
, connection managementDEFAULT_MAX_TOKENS=4096
,DEFAULT_TEMPERATURE=0.7
MODEL_MAP
Known Limitations:
inspect_evals/src/inspect_evals/mmlu/mmlu_5_shot.py
generate()
methodAPI Differences (vs OpenAI/Anthropic):
max_completion_tokens
vsmax_tokens
finish_reason
fieldRequired Configuration
Environment Setup:
Model Support:
meta-llama/Meta-Llama-3.1-8B-Instruct
meta-llama/Llama-3.3-70B-Instruct
Pending Improvements (Prioritized)
Critical:
Important:
Nice to Have:
Testing Status
Verified:
Known Issues:
Breaking Changes
None. This should not affect the use of other model providers, and effort has been taken to ensure standardisation. Code changes have been isolated to:
/src/inspect_ai/model/_providers/goodfire.py
for core implementation scriptsrc/inspect_ai/model/_providers/providers.py
to register Goodfire as a model providersrc/inspect_ai/model/_generate_config.py
for certain Goodfire-specific generation functions (tho do feel free to test this to make sure it doesn't affect any other providers)So far, the model seems to generate and score similarly to VLLM-hosted Llama 8B-Instruct on GPQA, GSM8K and MMLU.
Conclusion
Once again, I will be improving and building on this initial chat generation implementation in the coming weeks with more advanced mech interp functions. If you come across issues in other evals, do let me know.