Skip to content

Commit

Permalink
changes to keep reduced grad in fp32 (#1152)
Browse files Browse the repository at this point in the history
  • Loading branch information
vedanuj authored Dec 4, 2023
1 parent 9791920 commit d346146
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,10 +1715,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
self._use_fp32_param_shard([param])

if self.fp32_reduce_scatter:
if getattr(param, "main_grad", None) is None:
param.main_grad = param.grad.to(torch.float32)
if getattr(param, "unsharded_main_grad", None) is None:
param.unsharded_main_grad = param.grad.to(torch.float32)
else:
param.main_grad.add_(param.grad.data)
param.unsharded_main_grad.add_(param.grad.data)

param.grad = None

Expand All @@ -1732,14 +1732,14 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:

if self.fp32_reduce_scatter:
# Cast grad to FP32.
orig_grad_data = param.main_grad.data
orig_grad_data = param.unsharded_main_grad.data
else:
orig_grad_data = param.grad.data

if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
if getattr(param, "main_grad", None) is not None:
param.main_grad.data.div_(self.gradient_predivide_factor)
if getattr(param, "unsharded_main_grad", None) is not None:
param.unsharded_main_grad.data.div_(self.gradient_predivide_factor)
else:
param.grad.data.div_(self.gradient_predivide_factor)

Expand All @@ -1749,9 +1749,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
if getattr(param, "main_grad", None) is not None:
grad = param.main_grad.data
param.main_grad = None
if getattr(param, "unsharded_main_grad", None) is not None:
grad = param.unsharded_main_grad.data
param.unsharded_main_grad = None
else:
grad = param.grad.data
param.grad = None
Expand All @@ -1776,8 +1776,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
assert self.world_size == 1
if getattr(param, "main_grad", None) is not None:
self._post_reduction_hook(param, param.main_grad)
if getattr(param, "unsharded_main_grad", None) is not None:
self._post_reduction_hook(param, param.unsharded_main_grad)
else:
self._post_reduction_hook(param, param.grad)

Expand Down Expand Up @@ -1805,7 +1805,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
# non-blocking. The downside is a bit more D2H transfer in that case.
if self.fp32_reduce_scatter:
orig_param_grad_data = reduced_grad.data
reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype)
# reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data.record_stream(torch.cuda.current_stream())

Expand Down Expand Up @@ -1907,7 +1907,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None:
if p.shape != p._saved_grad_shard.shape:
self._use_fp32_param_shard([p])
if p._saved_grad_shard.dtype != p.dtype:
p.grad = p._saved_grad_shard.to(p.dtype)
p.main_grad = p._saved_grad_shard
else:
p.grad = p._saved_grad_shard

Expand Down

0 comments on commit d346146

Please sign in to comment.