Skip to content

Commit

Permalink
Renormalize LMNextToken.sample() probs to fix floating point errors
Browse files Browse the repository at this point in the history
  • Loading branch information
gabegrand committed Jan 3, 2025
1 parent f172d8b commit f40ee1b
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions hfppl/distributions/lmcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ async def log_prob(self, x):

async def sample(self):
probs = np.exp(self.ctx.next_token_logprobs)
probs /= np.sum(probs) # Renormalize to fix floating point errors
token_id = np.random.choice(len(probs), p=(probs))
self.ctx.tokens.append(token_id)
logprob = self.ctx.next_token_logprobs[token_id]
Expand Down

0 comments on commit f40ee1b

Please sign in to comment.