Skip to content

Commit

Permalink
fix .grad=None issue when param is not sharded (#1153)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiecaoyu authored Dec 6, 2023
1 parent d346146 commit c7fd85a
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,8 @@ def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) ->
), f"{param._saved_grad_shard.shape} vs {reduced_grad.shape}"
param._saved_grad_shard.data += reduced_grad.data
reduced_grad = param._saved_grad_shard.data
elif (param.grad is None) and self.fp32_reduce_scatter:
param.main_grad = reduced_grad.data

# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
Expand Down

0 comments on commit c7fd85a

Please sign in to comment.