Skip to content

Commit

Permalink
Fix sampler test (vllm-project#1379)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Oct 16, 2023
1 parent e8ef4c0 commit d3a5bd9
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# pylint: disable=protected-access
import pytest
import random
from typing import Tuple
from unittest.mock import patch

import pytest
import torch

from vllm.model_executor.layers.sampler import Sampler
Expand Down Expand Up @@ -69,7 +69,7 @@ def test_sampler_all_greedy(seed: int):
input_metadata=input_metadata)
expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item()


Expand Down Expand Up @@ -101,7 +101,7 @@ def test_sampler_all_random(seed: int):
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output:
for nth_output in sequence_output.samples:
assert nth_output.output_token == i


Expand Down Expand Up @@ -181,5 +181,5 @@ def test_sampler_mixed(seed: int):
for i, sequence_output in enumerate(sampler_output):
if seq_group_metadata_list[i].sampling_params.use_beam_search:
continue
for nth_output in sequence_output:
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens

0 comments on commit d3a5bd9

Please sign in to comment.