diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index c43ae53fd..3960bbbae 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1777,9 +1777,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # case grads should be all-reduced here. assert self.world_size == 1 if getattr(param, "main_grad", None) is not None: - self._post_reduction_hook(param, param.main_grad.data) + self._post_reduction_hook(param, param.main_grad) else: - self._post_reduction_hook(param, param.grad.data) + self._post_reduction_hook(param, param.grad) # After _post_backward_hook returns, orig_grad_data will eventually # go out of scope, at which point it could otherwise be freed for