Skip to content

Commit

Permalink
WIP: bug in beam search kv cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Sep 9, 2024
1 parent 803fb58 commit 8fd079d
Show file tree
Hide file tree
Showing 6 changed files with 473 additions and 15 deletions.
30 changes: 18 additions & 12 deletions mammoth/tests/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def test_advance_with_all_repeats_gets_blocked(self):
word_probs[:, repeat_idx] = 0

attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
beam.set_cache(attns)
beam.advance(word_probs)

if i < ngram_repeat:
# before repeat, scores are either 0 or -inf
Expand Down Expand Up @@ -130,7 +131,8 @@ def test_advance_with_some_repeats_gets_blocked(self):
# continue pushing around what beam 1 predicts
word_probs[1::beam_sz, repeat_idx + i + 1] = 0
attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
beam.set_cache(attns)
beam.advance(word_probs)
if i < ngram_repeat:
self.assertFalse(beam.topk_log_probs[:, 0].eq(self.BLOCKED_SCORE).any())
self.assertFalse(beam.topk_log_probs[:, 1].eq(self.BLOCKED_SCORE).any())
Expand Down Expand Up @@ -196,7 +198,8 @@ def test_repeating_excluded_index_does_not_die(self):
# predict the allowed-repeat again in beam 2
word_probs[2::beam_sz, repeat_idx_ignored] = 0
attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
beam.set_cache(attns)
beam.advance(word_probs)
if i < ngram_repeat:
self.assertFalse(beam.topk_log_probs[:, 0].eq(self.BLOCKED_SCORE).any())
self.assertFalse(beam.topk_log_probs[:, 1].eq(self.BLOCKED_SCORE).any())
Expand Down Expand Up @@ -245,7 +248,6 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
device=device_init.device,
)
beam.initialize()
all_attns = []
for i in range(min_length + 4):
# non-interesting beams are going to get dummy values
word_probs = torch.full((batch_sz * beam_sz, n_words), -float('inf'))
Expand All @@ -265,16 +267,18 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
word_probs[beam_idx::beam_sz, j] = score

attns = torch.randn(1, batch_sz * beam_sz, 53)
all_attns.append(attns)
beam.advance(word_probs, attns)
beam.set_cache(attns)
beam.advance(word_probs)
if i < min_length:
expected_score_dist = (i + 1) * valid_score_dist[1:].unsqueeze(0)
# Note that when batch_sz is > 1, expected is broadcast across the batch
self.assertTrue(beam.topk_log_probs.allclose(expected_score_dist))
self.assertTrue(beam.cache.shape == torch.Size([1, batch_sz * beam_sz, 53]))
elif i == min_length:
# now the top beam has ended and no others have
self.assertTrue(beam.is_finished[:, 0].eq(1).all())
self.assertTrue(beam.is_finished[:, 1:].eq(0).all())
self.assertTrue(beam.cache.shape == torch.Size([1, batch_sz * (beam_sz - 1), 53]))
else: # i > min_length
# not of interest, but want to make sure it keeps running
# since only beam 0 terminates and n_best = 2
Expand Down Expand Up @@ -337,7 +341,8 @@ def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
word_probs[beam_idx::beam_sz, j] = score

attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
beam.set_cache(attns)
beam.advance(word_probs)
if i < min_length:
self.assertFalse(beam.done)
elif i == min_length:
Expand Down Expand Up @@ -408,7 +413,8 @@ def test_beam_returns_attn_with_correct_length(self):
word_probs[beam_idx::beam_sz, j] = score

attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
beam.set_cache(attns)
beam.advance(word_probs)
if i < min_length:
self.assertFalse(beam.done)
# no top beams are finished yet
Expand Down Expand Up @@ -465,7 +471,7 @@ def init_step(self, beam, expected_len_pen):
expected_beam_scores, expected_preds_0 = new_scores.view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS).topk(
self.BEAM_SZ, dim=-1
)
beam.advance(deepcopy(init_scores), self.random_attn())
beam.advance(deepcopy(init_scores))
self.assertTrue(beam.topk_log_probs.allclose(expected_beam_scores))
self.assertTrue(beam.topk_ids.equal(expected_preds_0))
self.assertFalse(beam.is_finished.any())
Expand All @@ -489,7 +495,7 @@ def first_step(self, beam, expected_beam_scores, expected_len_pen):
)
scores_1 = scores_1.repeat(self.BATCH_SZ, 1)

beam.advance(deepcopy(scores_1), self.random_attn())
beam.advance(deepcopy(scores_1))

new_scores = scores_1 + expected_beam_scores.view(-1).unsqueeze(1)
expected_beam_scores, unreduced_preds = new_scores.view(self.BATCH_SZ, self.BEAM_SZ * self.N_WORDS).topk(
Expand Down Expand Up @@ -525,7 +531,7 @@ def second_step(self, beam, expected_beam_scores, expected_len_pen):
)
scores_2 = scores_2.repeat(self.BATCH_SZ, 1)

beam.advance(deepcopy(scores_2), self.random_attn())
beam.advance(deepcopy(scores_2))

# ended beam 2 shouldn't continue
expected_beam_scores[:, 2::self.BEAM_SZ] = self.DEAD_SCORE
Expand Down Expand Up @@ -568,7 +574,7 @@ def third_step(self, beam, expected_beam_scores, expected_len_pen):
)
scores_3 = scores_3.repeat(self.BATCH_SZ, 1)

beam.advance(deepcopy(scores_3), self.random_attn())
beam.advance(deepcopy(scores_3))

expected_beam_scores[:, 0::self.BEAM_SZ] = self.DEAD_SCORE
new_scores = scores_3 + expected_beam_scores.view(-1).unsqueeze(1)
Expand Down
5 changes: 4 additions & 1 deletion mammoth/translate/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,9 @@ def update_finished(self):

_B_new = non_finished.shape[0]
self.remove_finished_batches(_B_new, _B_old, non_finished, predictions, attention, step)
if self.cache is not None:
# FIXME: self.cache is a list of LayerIntermediates. Reach in and manipulate it?
self.cache = None

def remove_finished_batches(self, _B_new, _B_old, non_finished, predictions, attention, step):
# Remove finished batches for the next step.
Expand All @@ -253,7 +256,7 @@ def remove_finished_batches(self, _B_new, _B_old, non_finished, predictions, att
step - 1, _B_new * self.beam_size, inp_seq_len
)

def advance(self, log_probs, new_cache):
def advance(self, log_probs):
vocab_size = log_probs.size(-1)

# using integer division to get an integer _B without casting
Expand Down
2 changes: 1 addition & 1 deletion mammoth/translate/decode_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def maybe_update_target_prefix(self, select_index):
return
self.target_prefix = self.target_prefix.index_select(0, select_index)

def advance(self, logits, new_cache):
def advance(self, logits):
"""DecodeStrategy subclasses should override :func:`advance()`.
Advance is used to update ``self.alive_seq``, ``self.is_finished``,
Expand Down
2 changes: 1 addition & 1 deletion mammoth/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy):
logits = logits[:, -1]
log_probs = torch.log_softmax(logits, dim=-1)

decode_strategy.advance(log_probs, new_cache)
decode_strategy.advance(log_probs)
any_finished = decode_strategy.is_finished.any()
if any_finished:
decode_strategy.update_finished()
Expand Down
Loading

0 comments on commit 8fd079d

Please sign in to comment.