-
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
bug fix: variable number of max decode tokens within batch (#73)
This PR fixes a previously unidentified bug and adds pytests for validation. **Changes**: - addressing the logic error described below by introducing `SpyreCausalLM.indices` containing a mask indicating the unfinished sequences in the current batch. -> [commit](3f087a7) - adapting the generation functions in [tests/spyre/spyre_util.py](main...ysc-fix-variable-max-tokens#diff-d232e0cf89b92b0ec7da17e322bb2ca675af8a704099e5ae0c54995ddb4a3f9a) for `hf` and `vllm` to accept different number of max decoding token for sequences within the same batch -> [commit](f632e8e) - adding [tests/spyre/test_spyre_max_new_tokens.py](main...ysc-fix-variable-max-tokens#diff-82d9214a22b1db2e524795c8a649a40c115fd95a40b279e4d3245c7820e6ddf8) to validate functionality when sequences in a batch finish decoding before others. -> [commit](f632e8e) **Bug description**: Having a different number of requested output tokens within the same batch will lead to some sequences being removed from the batch while others are still decoding. Previously the code did not take into account the offset a removed sequence introduces in the `positions` (ids) and (attention) `masks`. This error remains undetected if all prompts are of the same length (they will have the same position ids and attention masks) or if always the last sequence in a batch finishes early (the offset at the end will not affect sequences with smaller indices within the same batch). _bug example_: <img width="1392" alt="Screenshot 2025-01-31 at 12 39 26" src="https://github.com/user-attachments/assets/b19deee5-af32-48cd-9b1a-051e9f074737" /> --------- Signed-off-by: Yannick Schnider <[email protected]>
- Loading branch information
Showing
4 changed files
with
145 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
"""Verification of vLLM output by comparing with HF | ||
Run `python -m pytest tests/spyre/test_spyre_max_new_tokens.py`. | ||
""" | ||
|
||
from typing import List, Tuple | ||
|
||
import pytest | ||
from spyre_util import (compare_results, generate_hf_output, | ||
generate_spyre_vllm_output) | ||
|
||
from vllm import SamplingParams | ||
|
||
template = ( | ||
"Below is an instruction that describes a task. Write a response that " | ||
"appropriately completes the request. Be polite in your response to the " | ||
"user.\n\n### Instruction:\n{}\n\n### Response:") | ||
|
||
prompt1 = template.format("Provide a recipe for chicken soup.") | ||
prompt2 = template.format("Provide a list of instructions for preparing " | ||
"chicken soup for a family of four.") | ||
|
||
|
||
@pytest.mark.parametrize("model", ["/models/llama-194m"]) | ||
@pytest.mark.parametrize("prompts", [[prompt1, prompt2, prompt2, prompt2], | ||
[prompt2, prompt2, prompt2, prompt1], | ||
[prompt2, prompt2, prompt2, prompt2]]) | ||
@pytest.mark.parametrize("stop_last", [True, False]) | ||
@pytest.mark.parametrize("warmup_shape", [(64, 10, 4)] | ||
) # (prompt_length/new_tokens/batch_size) | ||
@pytest.mark.parametrize("backend", | ||
["eager"]) #, "inductor", "sendnn_decoder"]) | ||
def test_output( | ||
model: str, | ||
prompts: List[str], | ||
stop_last: bool, | ||
warmup_shape: Tuple[int, int, int], | ||
backend: str, | ||
) -> None: | ||
''' | ||
The warmup is based on a single shape. After the warmup, | ||
one request with the provided prompts is input to vLLM. | ||
The same prompts are also input to HF. The generated output | ||
including text, token ids, and logprobs, is verified to be | ||
identical for vLLM and HF. | ||
If errors occur, these can be analyzed/debugged by setting | ||
'DISABLE_ASSERTS = True' in spyre_util.py and by rerunning the | ||
test using 'pytest --capture=no tests/spyre/test_spyre_max_new_tokens.py' | ||
After debugging, DISABLE_ASSERTS should be reset to 'False'. | ||
''' | ||
|
||
max_new_tokens_warmup = warmup_shape[1] | ||
max_new_tokens_early_stop = 1 | ||
|
||
vllm_sampling_params_normal = SamplingParams( | ||
max_tokens=max_new_tokens_warmup, | ||
temperature=0, | ||
logprobs=0, # return logprobs of generated tokens only | ||
ignore_eos=False) | ||
|
||
vllm_sampling_params_early_stop = SamplingParams( | ||
max_tokens=max_new_tokens_early_stop, | ||
temperature=0, | ||
logprobs=0, # return logprobs of generated tokens only | ||
ignore_eos=False) | ||
|
||
vllm_sampling_params = [vllm_sampling_params_normal] * 3 | ||
hf_max_new_tokens = [max_new_tokens_warmup] * 3 | ||
|
||
# stop last or first sequence in batch early | ||
if stop_last: | ||
vllm_sampling_params = vllm_sampling_params + [ | ||
vllm_sampling_params_early_stop | ||
] | ||
hf_max_new_tokens = hf_max_new_tokens + [max_new_tokens_early_stop] | ||
else: | ||
vllm_sampling_params = [vllm_sampling_params_early_stop | ||
] + vllm_sampling_params | ||
hf_max_new_tokens = [max_new_tokens_early_stop] + hf_max_new_tokens | ||
|
||
vllm_results = generate_spyre_vllm_output( | ||
model=model, | ||
prompts=prompts, | ||
warmup_shapes=[warmup_shape], | ||
max_model_len=2048, | ||
block_size=2048, | ||
sampling_params=vllm_sampling_params, | ||
tensor_parallel_size=1, | ||
backend=backend) | ||
|
||
hf_results = generate_hf_output(model=model, | ||
prompts=prompts, | ||
max_new_tokens=hf_max_new_tokens) | ||
|
||
compare_results(model=model, | ||
prompts=prompts, | ||
warmup_shapes=[warmup_shape], | ||
tensor_parallel_size=1, | ||
backend=backend, | ||
vllm_results=vllm_results, | ||
hf_results=hf_results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters