From 81ee78d60824a7092b0566804fa418a2bbed8396 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Sun, 1 Oct 2023 14:57:52 -0700 Subject: [PATCH 1/4] Fix fsdp+pp+te WPS decreasing issue --- .../fully_sharded_data_parallel.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 759b9f445..de88c4d8b 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -557,6 +557,7 @@ def __init__( self.dont_wait_current_stream_for_post_all_gather = False self._all_gather_free_event_queue = _FreeEventQueue() if limit_all_gather_events else None self._reduce_scatter_free_event_queue = _FreeEventQueue() if limit_reduce_scatter_events else None + self._module_fqn = None def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 @@ -1220,6 +1221,9 @@ def _lazy_init(self) -> None: self._set_is_root() self._setup_streams() self._setup_output_hook_list() + for module_name, module in self.named_modules(): + if isinstance(module, FullyShardedDataParallel): + module._module_fqn = module_name if self._is_root: # Buffers stay on GPU, and don't get sharded. Since _cast_buffers @@ -1713,6 +1717,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Switch to FP32 shard after backward. self._use_fp32_param_shard([param]) + if self.mixed_precision and self.fp32_reduce_scatter: + if getattr(param, "main_grad", None) is None: + param.main_grad = param.grad.to(torch.float32) + else: + param.main_grad.add_(param.grad.data) + + param.grad = None if not self._require_backward_grad_sync: return @@ -1721,15 +1732,19 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): - orig_grad_data = param.grad.data if self.fp32_reduce_scatter: # Cast grad to FP32. - param.grad.data = param.grad.data.float() + orig_grad_data = param.grad.data.float() + else: + orig_grad_data = param.grad.data if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.gradient_predivide_factor) + if getattr(param, "main_grad", None) is not None: + param.main_grad.data.div_(self.gradient_predivide_factor) + else: + param.grad.data.div_(self.gradient_predivide_factor) if param._is_sharded: assert self._reducer is not None @@ -1737,7 +1752,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - grad = param.grad.data + if getattr(param, "main_grad", None) is not None: + grad = param.main_grad.data + param.main_grad = None + else: + grad = param.grad.data # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. # # The effect on memory consumption is not usually significant. No extra memory is allocated if this From 71495bad6fea73bb92b2c13664f8bf1f966387c5 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Sun, 1 Oct 2023 17:16:34 -0700 Subject: [PATCH 2/4] Address comment; remove unused stuff --- .../nn/data_parallel/fully_sharded_data_parallel.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index de88c4d8b..2f6eb6d54 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -557,7 +557,6 @@ def __init__( self.dont_wait_current_stream_for_post_all_gather = False self._all_gather_free_event_queue = _FreeEventQueue() if limit_all_gather_events else None self._reduce_scatter_free_event_queue = _FreeEventQueue() if limit_reduce_scatter_events else None - self._module_fqn = None def _get_gradient_predivide_factor(self, world_size: int) -> float: factor: int = 1 @@ -1221,9 +1220,6 @@ def _lazy_init(self) -> None: self._set_is_root() self._setup_streams() self._setup_output_hook_list() - for module_name, module in self.named_modules(): - if isinstance(module, FullyShardedDataParallel): - module._module_fqn = module_name if self._is_root: # Buffers stay on GPU, and don't get sharded. Since _cast_buffers @@ -1735,9 +1731,9 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: if self.fp32_reduce_scatter: # Cast grad to FP32. - orig_grad_data = param.grad.data.float() - else: - orig_grad_data = param.grad.data + param.grad.data = param.grad.data.float() + + orig_grad_data = param.grad.data if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. From f3ae46e1ce7690ebab23378fec2f3d5441569214 Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Sun, 1 Oct 2023 18:05:58 -0700 Subject: [PATCH 3/4] split into wps fix P841842878 only and main_grad fix --- .../fully_sharded_data_parallel.py | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 2f6eb6d54..759b9f445 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1713,13 +1713,6 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Switch to FP32 shard after backward. self._use_fp32_param_shard([param]) - if self.mixed_precision and self.fp32_reduce_scatter: - if getattr(param, "main_grad", None) is None: - param.main_grad = param.grad.to(torch.float32) - else: - param.main_grad.add_(param.grad.data) - - param.grad = None if not self._require_backward_grad_sync: return @@ -1728,19 +1721,15 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): + orig_grad_data = param.grad.data if self.fp32_reduce_scatter: # Cast grad to FP32. param.grad.data = param.grad.data.float() - orig_grad_data = param.grad.data - if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - if getattr(param, "main_grad", None) is not None: - param.main_grad.data.div_(self.gradient_predivide_factor) - else: - param.grad.data.div_(self.gradient_predivide_factor) + param.grad.data.div_(self.gradient_predivide_factor) if param._is_sharded: assert self._reducer is not None @@ -1748,11 +1737,7 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - if getattr(param, "main_grad", None) is not None: - grad = param.main_grad.data - param.main_grad = None - else: - grad = param.grad.data + grad = param.grad.data # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. # # The effect on memory consumption is not usually significant. No extra memory is allocated if this From ad54660bb54b90ef224019513f0bcd7e7998561a Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Sun, 1 Oct 2023 18:07:28 -0700 Subject: [PATCH 4/4] Add main_grad --- .../fully_sharded_data_parallel.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py index 759b9f445..0eaa454f2 100644 --- a/fairscale/nn/data_parallel/fully_sharded_data_parallel.py +++ b/fairscale/nn/data_parallel/fully_sharded_data_parallel.py @@ -1713,6 +1713,13 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # Switch to FP32 shard after backward. self._use_fp32_param_shard([param]) + if self.fp32_reduce_scatter: + if getattr(param, "main_grad", None) is None: + param.main_grad = param.grad.to(torch.float32) + else: + param.main_grad.add_(param.grad.data) + + param.grad = None if not self._require_backward_grad_sync: return @@ -1721,15 +1728,19 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): - orig_grad_data = param.grad.data if self.fp32_reduce_scatter: # Cast grad to FP32. param.grad.data = param.grad.data.float() + orig_grad_data = param.grad.data + if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. - param.grad.data.div_(self.gradient_predivide_factor) + if getattr(param, "main_grad", None) is not None: + param.main_grad.data.div_(self.gradient_predivide_factor) + else: + param.grad.data.div_(self.gradient_predivide_factor) if param._is_sharded: assert self._reducer is not None @@ -1737,7 +1748,11 @@ def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: # param._saved_grad_shard. If this FSDP module was called multiple times it's possible that multiple # gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't # matter, neglecting rounding. - grad = param.grad.data + if getattr(param, "main_grad", None) is not None: + grad = param.main_grad.data + param.main_grad = None + else: + grad = param.grad.data # Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction. # # The effect on memory consumption is not usually significant. No extra memory is allocated if this