From 582de087c3bb09cb281264a157e229f647779789 Mon Sep 17 00:00:00 2001 From: Gabe Grand Date: Fri, 3 Jan 2025 16:42:28 -0500 Subject: [PATCH 1/2] Fix ValueError for when the mask has no support --- hfppl/distributions/lmcontext.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/hfppl/distributions/lmcontext.py b/hfppl/distributions/lmcontext.py index ba0626a..9fea4f3 100644 --- a/hfppl/distributions/lmcontext.py +++ b/hfppl/distributions/lmcontext.py @@ -78,6 +78,9 @@ async def log_prob(self, v): if v else self.ctx.model_mask - self.mask ) + if len(good_tokens) == 0: + # If there are no good tokens, the log probability of v under the mask is -inf + return float("-inf") bad_tokens = [i for i in self.ctx.model_mask if i not in good_tokens] logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)]) self.ctx.next_token_logprobs[bad_tokens] = float("-inf") From ec6b35a14b5369b8056f1e3977f8311ac14d5c0d Mon Sep 17 00:00:00 2001 From: Gabe Grand Date: Mon, 6 Jan 2025 13:30:02 -0500 Subject: [PATCH 2/2] set self.ctx.model_mask to empty set when len(good_tokens) == 0 --- hfppl/distributions/lmcontext.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/hfppl/distributions/lmcontext.py b/hfppl/distributions/lmcontext.py index 9fea4f3..233892a 100644 --- a/hfppl/distributions/lmcontext.py +++ b/hfppl/distributions/lmcontext.py @@ -80,9 +80,11 @@ async def log_prob(self, v): ) if len(good_tokens) == 0: # If there are no good tokens, the log probability of v under the mask is -inf - return float("-inf") + logprob_good = float("-inf") + else: + logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)]) + bad_tokens = [i for i in self.ctx.model_mask if i not in good_tokens] - logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)]) self.ctx.next_token_logprobs[bad_tokens] = float("-inf") self.ctx.next_token_logprobs -= logprob_good self.ctx.model_mask = good_tokens