Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for stop sequences to HF models #1188

Merged
merged 5 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- [Agent Bridge](https://inspect.ai-safety-institute.org.uk/agent-bridge.html) for integrating external agent frameworks with Inspect.
- Add `@wraps` to functions wrapped by Inspect decorators to preserve type information.
- Hugging Face: Add support for stop sequences for HF models.
- Docker: More robust parsing of version strings (handle development versions).

## v0.3.59 (24 January 2025)
Expand Down
6 changes: 6 additions & 0 deletions src/inspect_ai/model/_providers/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ async def generate(
kwargs["output_logits"] = config.logprobs
if "return_dict_in_generate" in kwargs:
assert kwargs["return_dict_in_generate"]
if config.stop_seqs is not None:
from transformers.generation import StopStringCriteria # type: ignore

stopping_criteria = [StopStringCriteria(self.tokenizer, config.stop_seqs)]
kwargs["stopping_criteria"] = stopping_criteria

kwargs["return_dict_in_generate"] = True
generator = functools.partial(self.model.generate, **kwargs)

Expand Down
33 changes: 33 additions & 0 deletions tests/model/providers/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,28 @@ def model():
)


@pytest.fixture
def model_with_stop_seqs():
DEFAULT_CHAT_TEMPLATE = (
"{% for message in messages %}{{ message.content }}{% endfor %}"
)
model = get_model(
"hf/EleutherAI/pythia-70m",
config=GenerateConfig(
max_tokens=5,
seed=42,
temperature=0.001,
stop_seqs=["w3"],
),
# this allows us to run base models with the chat message scaffolding:
chat_template=DEFAULT_CHAT_TEMPLATE,
tokenizer_call_args={"truncation": True, "max_length": 10},
)
# Chat template is not propagated by default from get_model to the model's tokenizer.
model.api.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
return model


@pytest.mark.asyncio
@skip_if_github_action
@skip_if_no_transformers
Expand All @@ -38,6 +60,17 @@ async def test_hf_api(model) -> None:
assert len(response.completion) >= 1


@pytest.mark.asyncio
@skip_if_github_action
@skip_if_no_transformers
@skip_if_no_accelerate
async def test_hf_api_with_stop_seqs(model_with_stop_seqs) -> None:
# This generates "https://www.w3.org" with pythia-70m greedy decoding
message = ChatMessageUser(content="https://")
response = await model_with_stop_seqs.generate(input=[message])
assert response.completion == "www.w3"


@pytest.mark.asyncio
@skip_if_github_action
@skip_if_no_transformers
Expand Down
Loading