Skip to content

Commit

Permalink
move changes after orig_grad_data
Browse files Browse the repository at this point in the history
  • Loading branch information
vedanuj committed Oct 4, 2023
1 parent 60fa4f0 commit 0f2229e
Showing 1 changed file with 16 additions and 22 deletions.
38 changes: 16 additions & 22 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,39 +1714,33 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# Switch to FP32 shard after backward.
self._use_fp32_param_shard([param])

if self.fp32_reduce_scatter:
if param.grad is not None:
if param.main_grad is not None:
param.main_grad.add_(param.grad.data.float())
else:
param.main_grad = param.grad.data.float()
param.grad = None

if not self._require_backward_grad_sync:
return

# Wait for all work in the current stream to finish, then start the
# reductions in post_backward stream.
self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._streams["post_backward"]):
# orig_grad_data = param.main_grad.data
if param.main_grad is not None:
orig_grad_data = param.main_grad
else:
orig_grad_data = param.grad

if self.fp32_reduce_scatter:
# Cast grad to FP32. with .main_grad params are already in FP32.
if param.main_grad is not None:
orig_grad_data = param.main_grad.data
else:
orig_grad_data = param.grad.data.to(torch.float32)
else:
orig_grad_data = param.grad.data
if param.grad is not None:
if param.main_grad is not None:
param.main_grad.copy_(param.grad.float())
else:
param.main_grad = param.grad.float()
param.grad = None

if self.gradient_predivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
# param.grad.data.div_(self.gradient_predivide_factor)
if param.main_grad is not None:
param.main_grad.data.div_(self.gradient_predivide_factor)
param.main_grad.div_(self.gradient_predivide_factor)
else:
param.grad.data.div_(self.gradient_predivide_factor)
param.grad.div_(self.gradient_predivide_factor)

if param._is_sharded:
assert self._reducer is not None
Expand All @@ -1755,10 +1749,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
if param.main_grad is not None:
grad = param.main_grad.data
grad = param.main_grad
param.main_grad = None
else:
grad = param.grad.data
grad = param.grad
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
Expand All @@ -1781,9 +1775,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
# case grads should be all-reduced here.
assert self.world_size == 1
if param.main_grad is not None:
self._post_reduction_hook(param, param.main_grad.data)
self._post_reduction_hook(param, param.main_grad)
else:
self._post_reduction_hook(param, param.grad.data)
self._post_reduction_hook(param, param.grad)

# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
Expand Down

0 comments on commit 0f2229e

Please sign in to comment.