From c5f8df99478f449dd999550372b65d5824a562c2 Mon Sep 17 00:00:00 2001 From: Margaret Li Date: Tue, 21 Nov 2023 14:28:00 -0800 Subject: [PATCH] Fix _free_full_params() I've been getting `AttributeError: 'FlatParameter' object has no attribute '_full_param_padded'` triggered by `p._full_param_padded.record_stream(current_stream)` Adding a check to not free full params if none have been added. --- fairscale/nn/data_parallel/fully_sharded_data_parallel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 0815b86c1..5f62c14f4 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -2096,6 +2096,8 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None: """Free up storage for full parameters.""" if params is None: params = self.params + if not self.has_full_params: + return self.has_full_params = False current_stream = torch.cuda.current_stream() for p in params: