diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 74c819efea23b..c4d33711cc9a2 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -1,9 +1,9 @@ # pylint: disable=protected-access -import pytest import random from typing import Tuple from unittest.mock import patch +import pytest import torch from vllm.model_executor.layers.sampler import Sampler @@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int): input_metadata=input_metadata) expected = torch.argmax(fake_logits, dim=-1) for i, sequence_output in enumerate(sampler_output): - for nth_output in sequence_output: + for nth_output in sequence_output.samples: assert nth_output.output_token == expected[i].item() @@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int): hidden_states=input_tensor, input_metadata=input_metadata) for i, sequence_output in enumerate(sampler_output): - for nth_output in sequence_output: + for nth_output in sequence_output.samples: assert nth_output.output_token == i @@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int): for i, sequence_output in enumerate(sampler_output): if seq_group_metadata_list[i].sampling_params.use_beam_search: continue - for nth_output in sequence_output: + for nth_output in sequence_output.samples: assert nth_output.output_token in expected_tokens