Skip to content

Commit

Permalink
bug fix: variable number of max decode tokens within batch (#73)
Browse files Browse the repository at this point in the history
This PR fixes a previously unidentified bug and adds pytests for
validation.

**Changes**: 
- addressing the logic error described below by introducing
`SpyreCausalLM.indices` containing a mask indicating the unfinished
sequences in the current batch. ->
[commit](3f087a7)
- adapting the generation functions in
[tests/spyre/spyre_util.py](main...ysc-fix-variable-max-tokens#diff-d232e0cf89b92b0ec7da17e322bb2ca675af8a704099e5ae0c54995ddb4a3f9a)
for `hf` and `vllm` to accept different number of max decoding token for
sequences within the same batch ->
[commit](f632e8e)
- adding
[tests/spyre/test_spyre_max_new_tokens.py](main...ysc-fix-variable-max-tokens#diff-82d9214a22b1db2e524795c8a649a40c115fd95a40b279e4d3245c7820e6ddf8)
to validate functionality when sequences in a batch finish decoding
before others. ->
[commit](f632e8e)

**Bug description**:

Having a different number of requested output tokens within the same
batch will lead to some sequences being removed from the batch while
others are still decoding. Previously the code did not take into account
the offset a removed sequence introduces in the `positions` (ids) and
(attention) `masks`. This error remains undetected if all prompts are of
the same length (they will have the same position ids and attention
masks) or if always the last sequence in a batch finishes early (the
offset at the end will not affect sequences with smaller indices within
the same batch).

_bug example_: 
<img width="1392" alt="Screenshot 2025-01-31 at 12 39 26"
src="https://github.com/user-attachments/assets/b19deee5-af32-48cd-9b1a-051e9f074737"
/>

---------

Signed-off-by: Yannick Schnider <[email protected]>
  • Loading branch information
yannicks1 authored Feb 4, 2025
1 parent e7dc638 commit 938fea3
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 33 deletions.
24 changes: 15 additions & 9 deletions tests/spyre/spyre_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import os
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union

import numpy as np
from sentence_transformers import SentenceTransformer, util
Expand All @@ -18,7 +18,8 @@
def generate_spyre_vllm_output(model: str, prompts: List[str],
warmup_shapes: List[Tuple[int, int, int]],
max_model_len: int, block_size: int,
sampling_params: SamplingParams,
sampling_params: Union[SamplingParams,
List[SamplingParams]],
tensor_parallel_size: int,
backend: str) -> List[Dict[str, Any]]:

Expand Down Expand Up @@ -62,20 +63,25 @@ def generate_spyre_vllm_output(model: str, prompts: List[str],


# Hugging Face
def generate_hf_output(model: str, prompts: List[str],
max_new_tokens: int) -> List[Dict[str, Any]]:
def generate_hf_output(
model: str, prompts: List[str],
max_new_tokens: Union[int, List[int]]) -> List[Dict[str, Any]]:

if not isinstance(max_new_tokens, list):
max_new_tokens = [max_new_tokens] * len(prompts)

hf_model = AutoModelForCausalLM.from_pretrained(model)
hf_tokenizer = AutoTokenizer.from_pretrained(model)

results = []
for prompt_index, prompt in enumerate(prompts):
hf_input_tokens = hf_tokenizer(prompt, return_tensors="pt").input_ids
hf_output = hf_model.generate(hf_input_tokens,
do_sample=False,
max_new_tokens=max_new_tokens,
return_dict_in_generate=True,
output_scores=True)
hf_output = hf_model.generate(
hf_input_tokens,
do_sample=False,
max_new_tokens=max_new_tokens[prompt_index],
return_dict_in_generate=True,
output_scores=True)

# decode output tokens after first removing input tokens (prompt)
hf_generated_text = hf_tokenizer.batch_decode(
Expand Down
102 changes: 102 additions & 0 deletions tests/spyre/test_spyre_max_new_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""Verification of vLLM output by comparing with HF
Run `python -m pytest tests/spyre/test_spyre_max_new_tokens.py`.
"""

from typing import List, Tuple

import pytest
from spyre_util import (compare_results, generate_hf_output,
generate_spyre_vllm_output)

from vllm import SamplingParams

template = (
"Below is an instruction that describes a task. Write a response that "
"appropriately completes the request. Be polite in your response to the "
"user.\n\n### Instruction:\n{}\n\n### Response:")

prompt1 = template.format("Provide a recipe for chicken soup.")
prompt2 = template.format("Provide a list of instructions for preparing "
"chicken soup for a family of four.")


@pytest.mark.parametrize("model", ["/models/llama-194m"])
@pytest.mark.parametrize("prompts", [[prompt1, prompt2, prompt2, prompt2],
[prompt2, prompt2, prompt2, prompt1],
[prompt2, prompt2, prompt2, prompt2]])
@pytest.mark.parametrize("stop_last", [True, False])
@pytest.mark.parametrize("warmup_shape", [(64, 10, 4)]
) # (prompt_length/new_tokens/batch_size)
@pytest.mark.parametrize("backend",
["eager"]) #, "inductor", "sendnn_decoder"])
def test_output(
model: str,
prompts: List[str],
stop_last: bool,
warmup_shape: Tuple[int, int, int],
backend: str,
) -> None:
'''
The warmup is based on a single shape. After the warmup,
one request with the provided prompts is input to vLLM.
The same prompts are also input to HF. The generated output
including text, token ids, and logprobs, is verified to be
identical for vLLM and HF.
If errors occur, these can be analyzed/debugged by setting
'DISABLE_ASSERTS = True' in spyre_util.py and by rerunning the
test using 'pytest --capture=no tests/spyre/test_spyre_max_new_tokens.py'
After debugging, DISABLE_ASSERTS should be reset to 'False'.
'''

max_new_tokens_warmup = warmup_shape[1]
max_new_tokens_early_stop = 1

vllm_sampling_params_normal = SamplingParams(
max_tokens=max_new_tokens_warmup,
temperature=0,
logprobs=0, # return logprobs of generated tokens only
ignore_eos=False)

vllm_sampling_params_early_stop = SamplingParams(
max_tokens=max_new_tokens_early_stop,
temperature=0,
logprobs=0, # return logprobs of generated tokens only
ignore_eos=False)

vllm_sampling_params = [vllm_sampling_params_normal] * 3
hf_max_new_tokens = [max_new_tokens_warmup] * 3

# stop last or first sequence in batch early
if stop_last:
vllm_sampling_params = vllm_sampling_params + [
vllm_sampling_params_early_stop
]
hf_max_new_tokens = hf_max_new_tokens + [max_new_tokens_early_stop]
else:
vllm_sampling_params = [vllm_sampling_params_early_stop
] + vllm_sampling_params
hf_max_new_tokens = [max_new_tokens_early_stop] + hf_max_new_tokens

vllm_results = generate_spyre_vllm_output(
model=model,
prompts=prompts,
warmup_shapes=[warmup_shape],
max_model_len=2048,
block_size=2048,
sampling_params=vllm_sampling_params,
tensor_parallel_size=1,
backend=backend)

hf_results = generate_hf_output(model=model,
prompts=prompts,
max_new_tokens=hf_max_new_tokens)

compare_results(model=model,
prompts=prompts,
warmup_shapes=[warmup_shape],
tensor_parallel_size=1,
backend=backend,
vllm_results=vllm_results,
hf_results=hf_results)
13 changes: 6 additions & 7 deletions vllm/model_executor/model_loader/spyre.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ def __init__(
self.past_key_value_states = None
self.dtype = torch.float16 if envs.VLLM_SPYRE_DYNAMO_BACKEND == \
'sendnn_decoder' else torch.float32
# number of added padding sequences to fill
# batch to warmed up batch size
self.num_padded_sequences = 0
# boolean tensor of length batch size with indices:
# True for unfinished sequences and
# False for finished or padded sequences
self.indices = None

# Lazy initialized
self.model: nn.Module
Expand Down Expand Up @@ -89,10 +90,8 @@ def forward(
for tensor in layer:
torch._dynamo.mark_dynamic(tensor, 2)

# removing batch padding sequences to compute logits
batch_size = input_ids.shape[0]

logits = logits[:batch_size - self.num_padded_sequences]
# removing finished or padded sequences
logits = logits[self.indices]

return logits

Expand Down
39 changes: 22 additions & 17 deletions vllm/worker/spyre_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def __init__(
self._position_ids: torch.Tensor = None
# attention masks of all the sequences in current batch
self._mask: torch.Tensor = None
# mapping: request id to index in batch
self._req_ids2idx: dict = {}
# Lazy initialization: after load_model.
self.model: nn.Module

Expand Down Expand Up @@ -148,8 +150,10 @@ def _prepare_prompt(
'prompt_length']
padded_batch_size = applicable_spyre_warmup_shapes[0]['batch_size']

for seq_group_metadata in seq_group_metadata_list:
self._req_ids2idx = {}
for idx, seq_group_metadata in enumerate(seq_group_metadata_list):
assert seq_group_metadata.is_prompt
self._req_ids2idx[seq_group_metadata.request_id] = idx
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
Expand All @@ -163,9 +167,13 @@ def _prepare_prompt(
dtype=torch.long,
device=torch.device("cpu")))

# set number of added padding sequences used for computing logits
self.model.num_padded_sequences = padded_batch_size - len(
input_token_list)
actual_batch_size = len(input_token_list)
self.model.indices = torch.cat([
torch.ones(actual_batch_size, dtype=torch.bool, device='cpu'),
torch.zeros(padded_batch_size - actual_batch_size,
dtype=torch.bool,
device='cpu')
])

# padding to compiled batch size
while len(input_token_list) < padded_batch_size:
Expand All @@ -187,7 +195,9 @@ def _prepare_decode(
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_tokens: List[List[int]] = [
[0] for _ in range(self._position_ids.shape[0])
]

for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
Expand All @@ -197,18 +207,9 @@ def _prepare_decode(

seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])

# padding to compiled batch size
actual_batch_size = len(seq_group_metadata_list)
padded_batch_size = self._position_ids.shape[0]

# set number of added padding sequences used for computing logits
self.model.num_padded_sequences = padded_batch_size - actual_batch_size

while actual_batch_size < padded_batch_size:
input_tokens.append([0])
actual_batch_size += 1
input_tokens[self._req_ids2idx[seq_group_metadata.request_id]] = [
generation_token
]

# update position ids and attention mask
self._update_position_ids()
Expand Down Expand Up @@ -274,6 +275,10 @@ def prepare_model_input(
input_tokens.shape[1] for i in range(input_tokens.shape[0])
]
else:
# updating indices: set indices of newly finished sequences False
if finished_requests_ids:
for seq_id in finished_requests_ids:
self.model.indices[self._req_ids2idx[seq_id]] = False
(input_tokens, input_positions,
input_masks) = self._prepare_decode(seq_group_metadata_list)
seq_lens = []
Expand Down

0 comments on commit 938fea3

Please sign in to comment.