-
Notifications
You must be signed in to change notification settings - Fork 200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements to learned round #1107
Conversation
00f496c
to
acee4d3
Compare
loss, loss_components = block_loss(quant_outs, fp_outs) | ||
else: | ||
# Run block forward to obtain quant outputs | ||
quant_outs = block_forward(block, inputs) | ||
fp_outs = send_to_device(fp_outs, quant_outs.device) | ||
loss, loss_components = block_loss(quant_outs.to(torch.float32), fp_outs.to(torch.float32)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code for each condition is almost exactly the same. Maybe we could have autocast(enabled=use_amp, ...)
and just have a conditional for upcasting the outputs to float32 before computing the loss to avoid repeting block_forward/send_to_device/block_loss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand it's less code, but I think it will be more confusing. I am leaving as it is. Extra verbosity for clarity, I am happy with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I've included a couple of minor comments.
a15163b
to
a51322e
Compare
a51322e
to
1ab9a0e
Compare
Reason for this PR
Fix entrypoint for learned scale
Fix training with float32 + amp
Testing Summary
NA
Risk Highlight
Checklist
dev
branch.