Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ValueError for LMTokenMask.log_prob() when the mask has no support #21

Merged
merged 2 commits into from
Jan 15, 2025

Conversation

gabegrand
Copy link
Collaborator

There are various scenarios in which token masking results in a LMTokenMask distribution that has a null support. This can occur when a mask rules out all tokens in the vocab, or when multiple sequential observations of mask_dist() are mutually incompatible (i.e., their set intersection is empty).

Currently, this corner case is not well accounted for and results in a fairly cryptic ValueError: zero-size array to reduction operation maximum which has no identity:

  File "hfppl/hfppl/modeling.py", line 222, in observe
    p = await dist.log_prob(x)
        ^^^^^^^^^^^^^^^^^^^^^^
  File "hfppl/hfppl/distributions/lmcontext.py", line 81, in log_prob
    logprob_good = logsumexp(self.ctx.next_token_logprobs[list(good_tokens)])
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "hfppl/hfppl/util.py", line 7, in logsumexp
    m = np.max(nums)
        ^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/numpy/core/fromnumeric.py", line 2810, in max
    return _wrapreduction(a, np.maximum, 'max', axis, None, out,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".venv/lib/python3.11/site-packages/numpy/core/fromnumeric.py", line 88, in _wrapreduction
    return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: zero-size array to reduction operation maximum which has no identity

This PR proposes to address this issue by defining LMTokenMask.log_prob(v) := -inf when the mask has no support under v (i.e., there are no good_tokens for v). From a practical standpoint, this fix is useful for addressing the zero-size array error above. However, it is still possible to instantiate "degenerate" LMTokenMask distributions (we simply define their density to be 0 everywhere).

From a more theoretical standpoint, this issue is a bit tricky because LMTokenMask objects are closely intertwined with LMContext.model_mask. Other partial/complementary solutions include:

  • Raising an error when LMTokenMask is instantiated with null support
  • Monitoring LMContext.model_mask to ensure that it never becomes the empty set

@alex-lew let me know what you think, happy to discuss.

@gabegrand gabegrand requested a review from alex-lew January 3, 2025 22:15
@alex-lew
Copy link
Contributor

alex-lew commented Jan 4, 2025

Thanks, Gabe!

I think ideally, the model_mask gets set to the empty set and we return -inf. Because you shouldn't be able to observe a mask to be true, then sample something inconsistent with the mask.

I think the error should come when attempting to sample or observe the next_token distribution, if the model_mask is empty.

What do you think?

@gabegrand
Copy link
Collaborator Author

@alex-lew Thanks for the sanity check. Yes I agree; I've updated the code so that observing a mask to be True always updates self.ctx.model_mask even when the mask is degenerate.

In terms of when an error should be raised, I'm a bit torn. In principle, I suppose we could allow programs to observe() masks that rule out all tokens. However, I can't think of any case where this is intentional/desirable. I feel like in practice, punting the error check to the next sample/observation of LMNextToken is just going to make debugging these corner cases more difficult. So I think my inclination would be to keep error checking associated with null masks encapsulated within LMTokenMask, but I'm not 100% sold either way. What do you think?

@alex-lew alex-lew merged commit 7c68968 into main Jan 15, 2025
@gabegrand gabegrand deleted the gg/good_tokens branch January 15, 2025 15:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants