Skip to content

Commit

Permalink
feat(isp): add early_reduce_scatter_release support (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiacx authored Jan 2, 2025
1 parent 1ea1555 commit d03c6f9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 6 deletions.
40 changes: 34 additions & 6 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def __init__(
process_group: dist.ProcessGroup = None,
is_moe: bool = False,
selective_ckpt_offload: bool = False,
early_reduce_scatter_release: bool = True,
) -> None:
self.process_group = process_group
self.overlap = overlap
Expand All @@ -317,6 +318,11 @@ def __init__(
self.is_forward = True
self.reduce_scatter_handlers = {}
self._forward_prefetch_prerequisites = []
self._zero_const_pool = {}

self._enable_early_reduce_scatter_release = early_reduce_scatter_release
self._early_prev_layer_rs_handles = []
self._early_curr_layer_rs_handles = []
self._forward_overlap_per = self._get_forward_overlap_granularity()
self._launch_before_module = self._get_launch_before_module()
# As an optimization, do not release weight after forward for the last
Expand Down Expand Up @@ -595,6 +601,13 @@ def _post_backward_hook_for_module(self, module, *args): # pylint: disable=W061
self._clear_handle(module)
self._clear_weight(module)

def _early_reduce_scatter_release_hook(self, *args): # pylint: disable=W0613
for handle in self._early_prev_layer_rs_handles:
handle.wait()

self._early_prev_layer_rs_handles = self._early_curr_layer_rs_handles
self._early_curr_layer_rs_handles = []

def _register_sync_parameters_hook(self) -> None:
"""
register forward hooks and backward hooks for isp modules.
Expand Down Expand Up @@ -625,12 +638,18 @@ def _register_sync_parameters_hook(self) -> None:
for module in self._isp_modules:
module.register_full_backward_hook(self._post_backward_hook_for_module)

if self._enable_early_reduce_scatter_release:
for block_idx in range(self._num_blocks):
block = self._index_to_block[block_idx]
block.register_full_backward_hook(self._early_reduce_scatter_release_hook)

def _get_constant_zero(self, size: tuple) -> torch.Tensor:
return torch.zeros(
*size,
dtype=self.model_conf.dtype,
device=self.model_conf.device,
).contiguous()
if size not in self._zero_const_pool:
self._zero_const_pool[size] = torch.zeros(
*size, dtype=self.model_conf.dtype, device=self.model_conf.device
).contiguous()

return self._zero_const_pool[size]

def communication_mode(self) -> str:
return "wp"
Expand Down Expand Up @@ -717,13 +736,18 @@ def grad_hook(
assert hasattr(module.weight, "isp_reduce_scatter_name")
key = getattr(module.weight, "isp_reduce_scatter_name")

self.reduce_scatter_handlers[key] = reduce_scatter_raw(
output, handle = reduce_scatter_raw(
tensor,
self.process_group,
op=reduce_op,
async_op=async_op,
)

if self._enable_early_reduce_scatter_release:
self._early_curr_layer_rs_handles.append(handle)

self.reduce_scatter_handlers[key] = (output, handle)

result, handle = (
self._get_constant_zero(
(
Expand Down Expand Up @@ -778,6 +802,10 @@ def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W06
):
self._zero_optim.reduce_left_grads_after_backward()

if self._isp_communicator and self._isp_communicator._enable_early_reduce_scatter_release:
self._isp_communicator._early_prev_layer_rs_handles = []
self._isp_communicator._early_curr_layer_rs_handles = []

def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613
pass

Expand Down
8 changes: 8 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,19 @@ def args_sanity_check():
gpc.config.parallel["weight"]["overlap"] = False
if gpc.config.parallel["tensor"]["mode"] != TensorParallelMode.isp.name:
assert gpc.config.parallel["weight"]["size"] <= 1, "weight parallel is only supported with isp"

if "early_reduce_scatter_release" not in gpc.config.parallel.weight:
gpc.config.parallel.weight["early_reduce_scatter_release"] = True

# set default value for expert_weight parallel
if gpc.config.parallel["expert_weight"].get("overlap", None) is None:
gpc.config.parallel["expert_weight"]["overlap"] = False
if gpc.config.parallel["expert"].get("no_tp", None) is None:
gpc.config.parallel["expert"]["no_tp"] = False

if "early_reduce_scatter_release" not in gpc.config.parallel.expert_weight:
gpc.config.parallel.expert_weight["early_reduce_scatter_release"] = True

# currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy
if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1:
assert (
Expand Down
2 changes: 2 additions & 0 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
gpc.get_group(ParallelMode.WEIGHT),
is_moe=False,
selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False),
early_reduce_scatter_release=gpc.config.parallel.weight.early_reduce_scatter_release,
)
# register communicator for isp column parallel linear.
ColumnParallelLinear.register_cls_communicator(isp_communicator)
Expand All @@ -389,6 +390,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
gpc.config.parallel.expert_weight.overlap,
gpc.get_group(ParallelMode.EXPERT_WEIGHT),
is_moe=True,
early_reduce_scatter_release=gpc.config.parallel.expert_weight.early_reduce_scatter_release,
)
for moe in _submodule_filter(model, Experts):
for column_linear in _submodule_filter(moe, (ColumnParallelLinear, GroupedWPLinear)):
Expand Down

0 comments on commit d03c6f9

Please sign in to comment.