-
Notifications
You must be signed in to change notification settings - Fork 283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add main_grad #1140
Open
jianyuh
wants to merge
4
commits into
ngoyal_changes_for_pp_fp8
Choose a base branch
from
ngoyal_changes_for_pp_fp8_fix_handle_grad_main
base: ngoyal_changes_for_pp_fp8
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add main_grad #1140
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1713,6 +1713,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |
|
||
# Switch to FP32 shard after backward. | ||
self._use_fp32_param_shard([param]) | ||
if self.fp32_reduce_scatter: | ||
if getattr(param, "main_grad", None) is None: | ||
param.main_grad = param.grad.to(torch.float32) | ||
else: | ||
param.main_grad.add_(param.grad.data) | ||
|
||
param.grad = None | ||
|
||
if not self._require_backward_grad_sync: | ||
return | ||
|
@@ -1721,23 +1728,31 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: | |
# reductions in post_backward stream. | ||
self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) | ||
with torch.cuda.stream(self._streams["post_backward"]): | ||
orig_grad_data = param.grad.data | ||
|
||
if self.fp32_reduce_scatter: | ||
# Cast grad to FP32. | ||
param.grad.data = param.grad.data.float() | ||
|
||
orig_grad_data = param.grad.data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move here to make |
||
|
||
if self.gradient_predivide_factor > 1: | ||
# Average grad by world_size for consistency with PyTorch DDP. | ||
param.grad.data.div_(self.gradient_predivide_factor) | ||
if getattr(param, "main_grad", None) is not None: | ||
param.main_grad.data.div_(self.gradient_predivide_factor) | ||
else: | ||
param.grad.data.div_(self.gradient_predivide_factor) | ||
|
||
if param._is_sharded: | ||
assert self._reducer is not None | ||
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into | ||
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple | ||
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't | ||
# matter, neglecting rounding. | ||
grad = param.grad.data | ||
if getattr(param, "main_grad", None) is not None: | ||
grad = param.main_grad.data | ||
param.main_grad = None | ||
else: | ||
grad = param.grad.data | ||
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. | ||
# | ||
# The effect on memory consumption is not usually significant. No extra memory is allocated if this | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel this is right since param.grad will be None from L1722.
Overall, this PR creates main_grad for flat parameters while what we need to do is main_grad visible to TE modules. So probably we need to change FlatParameter as well?
Is this based on one of Naman's branches?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a branch where i am adding param.main_grad to FlatParams to enable fuse wgrad accumulation. here is the PR : #1142
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Feel free to ignore the changes in this PR. Still learning about FlatParams etc.