From d346146b3c4d9e17b0381751054e15f5f651a57f Mon Sep 17 00:00:00 2001 From: Vedanuj Goswami Date: Tue, 5 Dec 2023 03:46:22 +0530 Subject: [PATCH] changes to keep reduced grad in fp32 (#1152) --- .../fully_sharded_data_parallel.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 3960bbbae..cc994d1a8 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1715,10 +1715,10 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: self._use_fp32_param_shard([param]) if self.fp32_reduce_scatter: - if getattr(param, "main_grad", None) is None: - param.main_grad = param.grad.to(torch.float32) + if getattr(param, "unsharded_main_grad", None) is None: + param.unsharded_main_grad = param.grad.to(torch.float32) else: - param.main_grad.add_(param.grad.data) + param.unsharded_main_grad.add_(param.grad.data) param.grad = None @@ -1732,14 +1732,14 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if self.fp32_reduce_scatter: # Cast grad to FP32. - orig_grad_data = param.main_grad.data + orig_grad_data = param.unsharded_main_grad.data else: orig_grad_data = param.grad.data if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - if getattr(param, "main_grad", None) is not None: - param.main_grad.data.div_(self.gradient_predivide_factor) + if getattr(param, "unsharded_main_grad", None) is not None: + param.unsharded_main_grad.data.div_(self.gradient_predivide_factor) else: param.grad.data.div_(self.gradient_predivide_factor) @@ -1749,9 +1749,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - if getattr(param, "main_grad", None) is not None: - grad = param.main_grad.data - param.main_grad = None + if getattr(param, "unsharded_main_grad", None) is not None: + grad = param.unsharded_main_grad.data + param.unsharded_main_grad = None else: grad = param.grad.data param.grad = None @@ -1776,8 +1776,8 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # world_size == 1. This could be relaxed in the future, in which # case grads should be all-reduced here. assert self.world_size == 1 - if getattr(param, "main_grad", None) is not None: - self._post_reduction_hook(param, param.main_grad) + if getattr(param, "unsharded_main_grad", None) is not None: + self._post_reduction_hook(param, param.unsharded_main_grad) else: self._post_reduction_hook(param, param.grad) @@ -1805,7 +1805,7 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> # non-blocking. The downside is a bit more D2H transfer in that case. if self.fp32_reduce_scatter: orig_param_grad_data = reduced_grad.data - reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype) + # reduced_grad.data = reduced_grad.data.to(dtype=param.data.dtype) # Don't let this memory get reused until after the transfer. orig_param_grad_data.record_stream(torch.cuda.current_stream()) @@ -1907,7 +1907,7 @@ def _finalize_parameters(fsdp_module: FullyShardedDataParallel) -> None: if p.shape != p._saved_grad_shard.shape: self._use_fp32_param_shard([p]) if p._saved_grad_shard.dtype != p.dtype: - p.grad = p._saved_grad_shard.to(p.dtype) + p.main_grad = p._saved_grad_shard else: p.grad = p._saved_grad_shard