Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add main_grad #1140

Open
wants to merge 4 commits into
base: ngoyal_changes_for_pp_fp8
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:

# Switch to FP32 shard after backward.
self._use_fp32_param_shard([param])
if self.fp32_reduce_scatter:
if getattr(param, "main_grad", None) is None:
param.main_grad = param.grad.to(torch.float32)
else:
param.main_grad.add_(param.grad.data)

param.grad = None

if not self._require_backward_grad_sync:
return
Expand All @@ -1721,23 +1728,31 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
orig_grad_data = param.grad.data

if self.fp32_reduce_scatter:
# Cast grad to FP32.
param.grad.data = param.grad.data.float()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel this is right since param.grad will be None from L1722.

Overall, this PR creates main_grad for flat parameters while what we need to do is main_grad visible to TE modules. So probably we need to change FlatParameter as well?

Is this based on one of Naman's branches?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a branch where i am adding param.main_grad to FlatParams to enable fuse wgrad accumulation. here is the PR : #1142

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Feel free to ignore the changes in this PR. Still learning about FlatParams etc.


orig_grad_data = param.grad.data
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move here to make orig_grad_data FP32. This was from #1139 (comment)


if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
param.grad.data.div_(self.gradient_predivide_factor)
if getattr(param, "main_grad", None) is not None:
param.main_grad.data.div_(self.gradient_predivide_factor)
else:
param.grad.data.div_(self.gradient_predivide_factor)

if param._is_sharded:
assert self._reducer is not None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
grad = param.grad.data
if getattr(param, "main_grad", None) is not None:
grad = param.main_grad.data
param.main_grad = None
else:
grad = param.grad.data
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
Expand Down