diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 24677c09..23a92980 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -309,6 +309,7 @@ def __init__( process_group: dist.ProcessGroup = None, is_moe: bool = False, selective_ckpt_offload: bool = False, + early_reduce_scatter_release: bool = True, ) -> None: self.process_group = process_group self.overlap = overlap @@ -317,6 +318,11 @@ def __init__( self.is_forward = True self.reduce_scatter_handlers = {} self._forward_prefetch_prerequisites = [] + self._zero_const_pool = {} + + self._enable_early_reduce_scatter_release = early_reduce_scatter_release + self._early_prev_layer_rs_handles = [] + self._early_curr_layer_rs_handles = [] self._forward_overlap_per = self._get_forward_overlap_granularity() self._launch_before_module = self._get_launch_before_module() # As an optimization, do not release weight after forward for the last @@ -595,6 +601,13 @@ def _post_backward_hook_for_module(self, module, *args): # pylint: disable=W061 self._clear_handle(module) self._clear_weight(module) + def _early_reduce_scatter_release_hook(self, *args): # pylint: disable=W0613 + for handle in self._early_prev_layer_rs_handles: + handle.wait() + + self._early_prev_layer_rs_handles = self._early_curr_layer_rs_handles + self._early_curr_layer_rs_handles = [] + def _register_sync_parameters_hook(self) -> None: """ register forward hooks and backward hooks for isp modules. @@ -625,12 +638,18 @@ def _register_sync_parameters_hook(self) -> None: for module in self._isp_modules: module.register_full_backward_hook(self._post_backward_hook_for_module) + if self._enable_early_reduce_scatter_release: + for block_idx in range(self._num_blocks): + block = self._index_to_block[block_idx] + block.register_full_backward_hook(self._early_reduce_scatter_release_hook) + def _get_constant_zero(self, size: tuple) -> torch.Tensor: - return torch.zeros( - *size, - dtype=self.model_conf.dtype, - device=self.model_conf.device, - ).contiguous() + if size not in self._zero_const_pool: + self._zero_const_pool[size] = torch.zeros( + *size, dtype=self.model_conf.dtype, device=self.model_conf.device + ).contiguous() + + return self._zero_const_pool[size] def communication_mode(self) -> str: return "wp" @@ -717,13 +736,18 @@ def grad_hook( assert hasattr(module.weight, "isp_reduce_scatter_name") key = getattr(module.weight, "isp_reduce_scatter_name") - self.reduce_scatter_handlers[key] = reduce_scatter_raw( + output, handle = reduce_scatter_raw( tensor, self.process_group, op=reduce_op, async_op=async_op, ) + if self._enable_early_reduce_scatter_release: + self._early_curr_layer_rs_handles.append(handle) + + self.reduce_scatter_handlers[key] = (output, handle) + result, handle = ( self._get_constant_zero( ( @@ -778,6 +802,10 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06 ): self._zero_optim.reduce_left_grads_after_backward() + if self._isp_communicator and self._isp_communicator._enable_early_reduce_scatter_release: + self._isp_communicator._early_prev_layer_rs_handles = [] + self._isp_communicator._early_curr_layer_rs_handles = [] + def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613 pass diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index b9e8e41b..d6038d18 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -456,11 +456,19 @@ def args_sanity_check(): gpc.config.parallel["weight"]["overlap"] = False if gpc.config.parallel["tensor"]["mode"] != TensorParallelMode.isp.name: assert gpc.config.parallel["weight"]["size"] <= 1, "weight parallel is only supported with isp" + + if "early_reduce_scatter_release" not in gpc.config.parallel.weight: + gpc.config.parallel.weight["early_reduce_scatter_release"] = True + # set default value for expert_weight parallel if gpc.config.parallel["expert_weight"].get("overlap", None) is None: gpc.config.parallel["expert_weight"]["overlap"] = False if gpc.config.parallel["expert"].get("no_tp", None) is None: gpc.config.parallel["expert"]["no_tp"] = False + + if "early_reduce_scatter_release" not in gpc.config.parallel.expert_weight: + gpc.config.parallel.expert_weight["early_reduce_scatter_release"] = True + # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: assert ( diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index ca11e689..784a5305 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -364,6 +364,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.get_group(ParallelMode.WEIGHT), is_moe=False, selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False), + early_reduce_scatter_release=gpc.config.parallel.weight.early_reduce_scatter_release, ) # register communicator for isp column parallel linear. ColumnParallelLinear.register_cls_communicator(isp_communicator) @@ -389,6 +390,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.expert_weight.overlap, gpc.get_group(ParallelMode.EXPERT_WEIGHT), is_moe=True, + early_reduce_scatter_release=gpc.config.parallel.expert_weight.early_reduce_scatter_release, ) for moe in _submodule_filter(model, Experts): for column_linear in _submodule_filter(moe, (ColumnParallelLinear, GroupedWPLinear)):