From c7fd85a342e1976445a8b4b1a2e167733f18c4aa Mon Sep 17 00:00:00 2001 From: Jiecao Yu Date: Wed, 6 Dec 2023 15:51:52 -0800 Subject: [PATCH] fix .grad=None issue when param is not sharded (#1153) --- fairscale/nn/data_parallel/fully_sharded_data_parallel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index cc994d1a8..65eed6098 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1819,6 +1819,8 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> ), f"{param._saved_grad_shard.shape} vs {reduced_grad.shape}" param._saved_grad_shard.data += reduced_grad.data reduced_grad = param._saved_grad_shard.data + elif (param.grad is None) and self.fp32_reduce_scatter: + param.main_grad = reduced_grad.data # Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full # backwards pass completes, we will set `.grad` to the CPU copy.