diff --git a/hfppl/distributions/lmcontext.py b/hfppl/distributions/lmcontext.py index ba0626a..233892a 100644 --- a/hfppl/distributions/lmcontext.py +++ b/hfppl/distributions/lmcontext.py @@ -78,8 +78,13 @@ async def log_prob(self, v): if v else self.ctx.model_mask - self.mask ) + if len(good_tokens) == 0: + # If there are no good tokens, the log probability of v under the mask is -inf + logprob_good = float("-inf") + else: + logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)]) + bad_tokens = [i for i in self.ctx.model_mask if i not in good_tokens] - logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)]) self.ctx.next_token_logprobs[bad_tokens] = float("-inf") self.ctx.next_token_logprobs -= logprob_good self.ctx.model_mask = good_tokens