Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about potentially incomplete usage of 'logits_mask' #131

Open
ruohuali opened this issue Apr 13, 2023 · 1 comment
Open

Question about potentially incomplete usage of 'logits_mask' #131

ruohuali opened this issue Apr 13, 2023 · 1 comment

Comments

@ruohuali
Copy link

Hi @HobbitLong,
Thanks for this great work!
My question is that in these two lines of loss

exp_logits = torch.exp(logits) * logits_mask

log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

the logit_mask are only applied to exp_logits and not logits themselves.
I cannot figure out the reason for this mathematically, can you please shed some light?

@forgotton-wind
Copy link

I have also been studying the code for this loss today, and I have come to understand that it works like this:
logits_mask is used to get the denominator(positives and negatives).
mask is used to get the numerator(positives).
You can see that in line 92, mask is applied to log_prob.

mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants