diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 24677c09..cffb898d 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -5,16 +5,20 @@ """ from abc import ABC, abstractmethod +from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import distributed as dist from torch import nn +from torch._prims_common import make_contiguous_strides_for +from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.naive_amp import unwrap_naive_amp +from internlm.core.parallel.comm.attn_offload import get_offload_manager from internlm.core.parallel.comm.utils import ( DUMMY_HANDLE_CONST, AsyncCommHandle, @@ -28,7 +32,7 @@ from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import ParallelLinearWithCommExt from internlm.model.modules.utils import is_moe_param -from internlm.utils.common import SchedulerHook, UniqueChainMap, get_current_device +from internlm.utils.common import SchedulerHook, get_current_device from internlm.utils.utils import ( CuSeqlenType, QKVPackType, @@ -37,7 +41,7 @@ params_dispatch_with_condition, ) -from .attn_offload import get_offload_manager +internlm_accelerator = get_accelerator() # not really useful, only for code hint. @@ -296,6 +300,453 @@ def __init__(self) -> None: self.bias_global_output: Dict[str, torch.Tensor] = {} +class ISPCommunicationContext(ABC): + """ + Common communication context interface for isp communication overlap. + """ + + @abstractmethod + def register_overlap_hooks(self, model: nn.Module) -> None: + """ + register hooks for communication. + """ + pass + + @abstractmethod + def switch_forward_backward_phase(self, is_forward: bool) -> None: + """switch forward/backward phase.""" + pass + + @abstractmethod + def switch_current_overlap_state(self, overlap_state: ISPOverlapState) -> None: + """switch current overlap state.""" + pass + + @abstractmethod + def all_gather( + self, module: nn.Module, tensor: torch.Tensor, async_op: bool = False, is_bias: bool = False + ) -> torch.Tensor: + """ + all gather proxy. + TODO: The interface should not have an is_bias parameter, but it is temporarily there. + """ + + @abstractmethod + def reduce_scatter( + self, key: str, tensor: torch.Tensor, reduce_op: dist.ReduceOp, async_op: bool = False + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + reduce scatter proxy. + """ + pass + + def pop_reduced_grad(self, key: str) -> torch.Tensor: + """ + return reduce scatter results + """ + pass + + +class _WaitOrCommitHandle(AsyncCommHandle): + """ + commit or wait handle + """ + + def __init__(self, commit_func: Callable, real_handle: AsyncCommHandle = None): + self._handle = real_handle + self._commit_func = commit_func + + def set_handle(self, real_handle: AsyncCommHandle): + self._handle = real_handle + + def wait(self, stream=None): + if self._handle is None: + self._commit_func() + assert self._handle is not None, "should not happend" + + self._handle.wait(stream) + + +@dataclass +class ReduceScatterResult: + handle: _WaitOrCommitHandle + result: Optional[torch.Tensor] = None + + +@dataclass +class ReduceScatterOperation: + key: str + grad: torch.Tensor + reduce_op: dist.ReduceOp + + +# Leveraged the implementation of FSDP2 +# https://github.com/pytorch/pytorch/issues/114299 +class LayerAsyncCommContext(ISPCommunicationContext): + """ + layer level async communcation context. + """ + + def __init__(self, dtype, device, process_group) -> None: + self.dtype = dtype + self.device = device + self.process_group = process_group + + # streams for communication overlap + self._allgather_copy_in_stream = internlm_accelerator.Stream(priority=-1) + self._allgather_comm_stream = internlm_accelerator.Stream(priority=-1) + self._reduce_scatter_comm_stream = internlm_accelerator.Stream(priority=-1) + + self._is_forward: bool = True + self._overlap_state: Optional[ISPOverlapState] = None + + self._allgather_result = None + self._allgather_buffer = None + + self._reduce_scatter_state = None + self._reduce_scatter_ops: List[ReduceScatterOperation] = [] + self._reduce_scatter_results: Dict[str, ReduceScatterResult] = {} + + def switch_forward_backward_phase(self, is_forward: bool) -> None: + self._is_forward = is_forward + + def switch_current_overlap_state(self, overlap_state: ISPOverlapState) -> None: + self._overlap_state = overlap_state + + # Possible future support for communication between embedding and head layers. + # def parse_model_structure( + # self, chunk_id: int, state: ISPOverlapState, model: nn.Module, is_moe: bool = False + # ) -> None: + # """Rewrite the data structures needed for some LayerAsyncCommContext.""" + + # state.index_to_block = {} + # state.index_to_isp_modules = {} + # state.module_to_index = {} + + # idx = 0 + # for name, children in model.named_children(): + # if isinstance(children, (Embedding1D, ParallelLinearWithCommExt)): + # # embedding layer and head layer. + # if is_moe: + # continue + + # state.index_to_block[idx] = children + # state.module_to_index[children] = idx + # state.index_to_isp_modules[idx] = [] + + # full_name = f"{chunk_id}.{idx}.{name}" + # setattr(children.weight, "isp_reduce_scatter_name", f"{full_name}.weight") + # if getattr(children, "bias", None) is not None: + # setattr(children.weight, "isp_reduce_scatter_name", f"{full_name}.bias") + # idx += 1 + # elif isinstance(children, nn.ModuleList): + # # decoder layers. + # for block in children: + # state.index_to_isp_modules[idx] = [] + # for name, child in block.named_modules(): + # if isinstance(child, (ParallelLinearWithCommExt)): + # if is_moe_param(child.weight) != is_moe: + # continue + # state.index_to_isp_modules[idx].append(child) + + # if len(state.index_to_isp_modules[idx]) > 0: + # state.index_to_block[idx] = block + # state.module_to_index[block] = idx + # idx += 1 + + # state.num_blocks = len(state.index_to_block) + + def register_overlap_hooks(self, model: nn.Module) -> None: + def _clear_all_gather_buffer(module, *args): # pylint: disable=W0613 + self._overlap_state.bias_global_output.clear() + self._overlap_state.weight_global_output.clear() + + def _clear_all_gather_result(module: nn.Module, *args): # pylint: disable=W0613 + self._allgather_result = None + + def _clear_reduce_scatter_result(module, *args): # pylint: disable=W0613 + self._allgather_result = None + self._reduce_scatter_ops = [] + + # Pre-fetch parameters for the first layer. + num_blocks = self._overlap_state.num_blocks + + first_block = self._overlap_state.index_to_block[0] + last_block = self._overlap_state.index_to_block[num_blocks - 1] + + # Pull parameters for the first layer during the forward phase. + first_block.register_forward_pre_hook(partial(self._pre_forward_for_block, -1)) + # Pull parameters for the first layer during the backward phase. + last_block.register_full_backward_pre_hook(partial(self._pre_backward_for_block, num_blocks)) + + for _block_idx in range(num_blocks): + _block = self._overlap_state.index_to_block[_block_idx] + # Pre-fetch parameters for the next layer. + _block.register_forward_pre_hook(self._pre_forward_for_block) + _block.register_full_backward_pre_hook(self._pre_backward_for_block) + # Clean up the parameters that have been used. + _block.register_forward_hook(_clear_all_gather_buffer) + _block.register_full_backward_hook(_clear_all_gather_buffer) + # Reduce scatter gradients + _block.register_full_backward_hook(self._post_backward_for_block) + + last_block.register_forward_hook(_clear_all_gather_result) + first_block.register_full_backward_hook(_clear_reduce_scatter_result) + + def all_gather( + self, module: nn.Module, tensor: torch.Tensor, async_op: bool = False, is_bias: bool = False + ) -> torch.Tensor: + """ + all gather proxy. + """ + + already_gathered = ( + self._overlap_state.bias_global_output if is_bias else self._overlap_state.weight_global_output + ) + + if module not in already_gathered: + result, _ = all_gather_raw(tensor, self.process_group, async_op) + else: + result = already_gathered[module] + + return result + + def reduce_scatter( + self, key: str, tensor: torch.Tensor, reduce_op: dist.ReduceOp, async_op: bool = False + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + reduce scatter proxy. + """ + if not async_op: + result, handle = reduce_scatter_raw(tensor, self.process_group, op=reduce_op, async_op=async_op) + else: + self._reduce_scatter_ops.append(ReduceScatterOperation(key, tensor, reduce_op)) + result, handle = None, _WaitOrCommitHandle(self._post_backward_for_block) + self._reduce_scatter_results[key] = ReduceScatterResult(handle, result) + + result, handle = ( + torch.zeros( + *( + tensor.shape[0] // dist.get_world_size(self.process_group), + *tensor.shape[1:], + ), + dtype=self.dtype, + device=self.device, + ).contiguous(), + DUMMY_HANDLE_CONST, + ) + + return result, handle + + def pop_reduced_grad(self, key: str) -> torch.Tensor: + # Be cautious here not to directly pop, as _WaitOrCommitHandle might trigger a commit, + # update the corresponding reduce scatter result. + rs_result = self._reduce_scatter_results[key] + rs_result.handle.wait() + + _ = self._reduce_scatter_results.pop(key) + + return rs_result.result + + def _check_reduce_op(self, reduce_ops: List[dist.ReduceOp]) -> dist.ReduceOp: + _check_reduce_ops = set(reduce_ops) + assert len(_check_reduce_ops) == 1, f"cannot fuse reduce scatter with different reduce_op {_check_reduce_ops}" + + return _check_reduce_ops.pop() + + # Copied from FSDP2. + def _all_gather_copy_in( + self, + all_gather_inputs: List[torch.Tensor], + inp_split_sizes: List[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + all_gather_output = torch.empty((all_gather_input_numel * world_size,), dtype=dtype, device=device) + all_gather_input = all_gather_output.narrow(0, all_gather_input_numel * rank, all_gather_input_numel) + foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) + + all_gather_inputs = [t.view(-1) for t in all_gather_inputs] + torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) + + return all_gather_input, all_gather_output + + # Copied from FSDP2. + def _split_with_sizes_copy( + self, + all_gather_output: torch.Tensor, + all_gather_input_split_sizes: List[int], + dim: int, + out: List[torch.Tensor], + ) -> None: + torch.split_with_sizes_copy(all_gather_output, all_gather_input_split_sizes, dim=dim, out=out) + + def _all_gather_block_params(self, block_idx): + all_gather_inputs = [] + + for module in self._overlap_state.index_to_isp_modules[block_idx]: + all_gather_inputs.append(module.weight) + if module.bias is not None: + all_gather_inputs.append(module.bias) + inp_split_sizes = [t.numel() for t in all_gather_inputs] + all_gather_input_numel = sum(inp_split_sizes) + all_gather_input_shapes = [t.shape for t in all_gather_inputs] + + with internlm_accelerator.stream(self._allgather_copy_in_stream): + all_gather_input, all_gather_output = self._all_gather_copy_in( + all_gather_inputs, + inp_split_sizes, + all_gather_input_numel, + dist.get_world_size(self.process_group), + dist.get_rank(self.process_group), + self.dtype, + self.device, + ) + + # 提交allgather通信 + self._allgather_comm_stream.wait_stream(self._allgather_copy_in_stream) + + with internlm_accelerator.stream(self._allgather_comm_stream): + dist.all_gather_into_tensor( + output_tensor=all_gather_output, + input_tensor=all_gather_input, + group=self.process_group, + async_op=False, + ) + all_gather_event = self._allgather_comm_stream.record_event() + + return (all_gather_event, all_gather_output, all_gather_input_shapes) + + def _wait_and_copy_out_params(self, block_index: int) -> None: + cur_allgather_event, cur_allgather_output, cur_input_shapes = self._allgather_result + + internlm_accelerator.current_stream().wait_event(cur_allgather_event) + + world_size = dist.get_world_size(self.process_group) + cur_inp_split_sizes = [t.numel() for t in cur_input_shapes] + + allgather_outputs = [ + torch.empty(torch.Size([numel * world_size]), dtype=self.dtype, device=self.device) + for numel in cur_inp_split_sizes + ] + + cur_allgather_output = cur_allgather_output.view(world_size, -1) + out = [t.view(world_size, -1) for t in allgather_outputs] + self._split_with_sizes_copy(cur_allgather_output, cur_inp_split_sizes, dim=1, out=out) + + _idx = 0 + for module in self._overlap_state.index_to_isp_modules[block_index]: + self._overlap_state.weight_global_output[module] = out[_idx].view(-1, *cur_input_shapes[_idx][1:]) + _idx += 1 + + if module.bias is not None: + self._overlap_state.bias_global_output[module] = out[_idx].view(-1, *cur_input_shapes[_idx][1:]) + _idx += 1 + + @torch.no_grad() + def _pre_forward_for_block(self, block_or_idx: Union[int, nn.Module], *args): # pylint: disable=W0613 + if isinstance(block_or_idx, int): + block_index = block_or_idx + else: + block_index = block_or_idx.layer_idx + + self._allgather_copy_in_stream.wait_stream(internlm_accelerator.current_stream()) + + # Check if the communication for this layer's parameters is complete and unpack the communication results. + if self._allgather_result is not None: + self._wait_and_copy_out_params(block_index) + + # Pre-fetch parameters for the "next layer". + if self._is_forward and block_index + 1 < self._overlap_state.num_blocks: + # start the all-gather for next block + next_all_gather_result = self._all_gather_block_params(block_index + 1) + else: + next_all_gather_result = None + + self._allgather_result = next_all_gather_result + + @torch.no_grad() + def _pre_backward_for_block(self, block_or_idx: Union[int, nn.Module], *args): # pylint: disable=W0613 + if isinstance(block_or_idx, int): + block_index = block_or_idx + else: + block_index = block_or_idx.layer_idx + + self._allgather_copy_in_stream.wait_stream(internlm_accelerator.current_stream()) + + # Check if the communication for this layer's parameters is complete and unpack the communication results. + if self._allgather_result is not None: + self._wait_and_copy_out_params(block_index) + + # Pre-fetch parameters for the "next layer". + if block_index - 1 >= 0: + next_all_gather_result = self._all_gather_block_params(block_index - 1) + else: + next_all_gather_result = None + + self._allgather_result = next_all_gather_result + + @torch.no_grad() + def _post_backward_for_block(self, *args): # pylint: disable=W0613 + if len(self._reduce_scatter_ops) == 0: + return + + if self._reduce_scatter_state is not None: + internlm_accelerator.current_stream().wait_event(self._reduce_scatter_state[1]) + self._reduce_scatter_state = None + + # Aggregate parameters for reduce scatter. + world_size = dist.get_world_size(self.process_group) + + reduce_ops = [_i.reduce_op for _i in self._reduce_scatter_ops] + reduce_op = self._check_reduce_op(reduce_ops) + + unshard_grads = [_i.grad for _i in self._reduce_scatter_ops] + unshard_grad_sizes = [_grad.size() for _grad in unshard_grads] + reduce_scatter_input_numel = sum(s.numel() for s in unshard_grad_sizes) + + # wait for compute stream + self._reduce_scatter_comm_stream.wait_stream(internlm_accelerator.current_stream()) + + with internlm_accelerator.stream(self._reduce_scatter_comm_stream): + + reduce_scatter_input = torch.empty((reduce_scatter_input_numel,), dtype=self.dtype, device=self.device) + reduce_scatter_input = reduce_scatter_input.view(world_size, -1) + torch._chunk_cat(unshard_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input) + + reduce_output, _ = reduce_scatter_raw(reduce_scatter_input, self.process_group, reduce_op) + + # unack reduce scatter result + flat_grad_offset = 0 + + for _idx, _unshard_size in enumerate(unshard_grad_sizes): + _shard_size = (_unshard_size[0] // world_size, *_unshard_size[1:]) + _strides = make_contiguous_strides_for(_shard_size) + + _new_sharded_grad = torch.as_strided( + reduce_output, + size=_shard_size, + stride=_strides, + storage_offset=flat_grad_offset, + ) + + _key = self._reduce_scatter_ops[_idx].key + _event = self._reduce_scatter_comm_stream.record_event() + + self._reduce_scatter_results[_key].result = _new_sharded_grad + self._reduce_scatter_results[_key].handle.set_handle(_event) + + flat_grad_offset += _unshard_size.numel() // world_size + + reduce_scatter_event = self._reduce_scatter_comm_stream.record_event() + + self._reduce_scatter_state = (unshard_grads, reduce_scatter_event) + self._reduce_scatter_ops = [] + + class ISPCommunicator(WPCommunicator): """ ISP Communicator for managing the all-gather and reduce_scatter of Intern Sequence Parallel. @@ -309,13 +760,14 @@ def __init__( process_group: dist.ProcessGroup = None, is_moe: bool = False, selective_ckpt_offload: bool = False, + enable_layer_fuse_isp_comm: bool = False, ) -> None: self.process_group = process_group self.overlap = overlap self.model_conf = model_conf self.is_moe = is_moe self.is_forward = True - self.reduce_scatter_handlers = {} + self._reduce_scatter_handlers = {} self._forward_prefetch_prerequisites = [] self._forward_overlap_per = self._get_forward_overlap_granularity() self._launch_before_module = self._get_launch_before_module() @@ -351,6 +803,16 @@ def __init__( # key: transformer block index; value: transformer block self._index_to_block = None + enable_layer_fuse_isp_comm = overlap and enable_layer_fuse_isp_comm + if enable_layer_fuse_isp_comm: + self._layer_level_comm_context = LayerAsyncCommContext( + dtype=self.model_conf.dtype, + device=self.model_conf.device, + process_group=self.process_group, + ) + else: + self._layer_level_comm_context = None + # init overlap states if necessary. if self.overlap: # build overlap states for every chunk. @@ -358,7 +820,7 @@ def __init__( self._parse_model_structure(chunk_id, chunk) self.switch_current_model_chunk(chunk_id) # register overlap hooks for every chunk. - self._register_sync_parameters_hook() + self._register_sync_parameters_hook(chunk) # switch to chunk 0 at first. self.switch_current_model_chunk(0) @@ -388,6 +850,14 @@ def _get_forward_overlap_granularity(self): assert _overlap_granularity in ["module", "layer"] return _overlap_granularity + def pop_reduced_grad(self, key: str) -> torch.Tensor: + if self._layer_level_comm_context is not None: + return self._layer_level_comm_context.pop_reduced_grad(key) + + result, handle = self._reduce_scatter_handlers.pop(key) + handle.wait() + return result + def _parse_model_structure(self, cid: int, model: nn.Module) -> None: self._overlap_states[cid] = ISPOverlapState() @@ -408,6 +878,9 @@ def is_allgather_launch_module(name, module): self._overlap_states[cid].ckpt_block_num = int(self.model_conf.activation_checkpointing * len(children)) for idx, block in enumerate(children): + if not hasattr(block, "layer_idx"): + setattr(block, "layer_idx", idx) + self._overlap_states[cid].index_to_isp_modules[idx] = [] self._overlap_states[cid].index_to_block[idx] = block for name, child in block.named_modules(): @@ -595,10 +1068,15 @@ def _post_backward_hook_for_module(self, module, *args): # pylint: disable=W061 self._clear_handle(module) self._clear_weight(module) - def _register_sync_parameters_hook(self) -> None: + def _register_sync_parameters_hook(self, model) -> None: """ register forward hooks and backward hooks for isp modules. """ + + if self._layer_level_comm_context is not None: + self._layer_level_comm_context.register_overlap_hooks(model) + return + # register forward hooks # 1. register pre_forward_hook @block_0 to prefetch weight for block 0. # 2. register pre_forward_hook @prefetch_launch_module to prefetch weight for next block, @@ -648,6 +1126,15 @@ def switch_current_model_chunk(self, chunk_id: int) -> None: self._ckpt_block_num = self._overlap_states[chunk_id].ckpt_block_num self._num_blocks = self._overlap_states[chunk_id].num_blocks + if self._layer_level_comm_context is not None: + self._layer_level_comm_context.switch_current_overlap_state(self._overlap_states[chunk_id]) + + def switch_forward_backward_phase(self, is_forward: int) -> None: + self.is_forward = is_forward + + if self._layer_level_comm_context is not None: + self._layer_level_comm_context.switch_forward_backward_phase(is_forward) + def register_prerequisite_for_forward_prefetch_hooks(self, prerequisite_func: Callable) -> None: """ Registers a callback function that specifies a prerequisite condition for @@ -685,6 +1172,11 @@ def weight_hook( if not self.overlap: result, _ = all_gather_raw(tensor, self.process_group, async_op=async_op) + return result + + if self._layer_level_comm_context is not None: + result = self._layer_level_comm_context.all_gather(module, tensor, async_op, is_bias) + return result elif is_bias: assert module is not None, "The module parameter must be specified" result = self._bias_global_output[module] @@ -717,22 +1209,25 @@ def grad_hook( assert hasattr(module.weight, "isp_reduce_scatter_name") key = getattr(module.weight, "isp_reduce_scatter_name") - self.reduce_scatter_handlers[key] = reduce_scatter_raw( - tensor, - self.process_group, - op=reduce_op, - async_op=async_op, - ) + if self._layer_level_comm_context is not None: + result, handle = self._layer_level_comm_context.reduce_scatter(key, tensor, reduce_op, async_op) + else: + self._reduce_scatter_handlers[key] = reduce_scatter_raw( + tensor, + self.process_group, + op=reduce_op, + async_op=async_op, + ) - result, handle = ( - self._get_constant_zero( - ( - tensor.shape[0] // dist.get_world_size(self.process_group), - *tensor.shape[1:], - ) - ), - DUMMY_HANDLE_CONST, - ) + result, handle = ( + self._get_constant_zero( + ( + tensor.shape[0] // dist.get_world_size(self.process_group), + *tensor.shape[1:], + ) + ), + DUMMY_HANDLE_CONST, + ) return result, handle @@ -747,7 +1242,7 @@ def __init__(self, overlap_handler: ISPCommunicator, zero_optim) -> None: self._zero_optim = zero_optim def before_forward(self, scheduler, inputs) -> None: # pylint: disable=W0613 - self._isp_communicator.is_forward = True + self._isp_communicator.switch_forward_backward_phase(is_forward=True) # switch model chunk before forward chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank self._isp_communicator.switch_current_model_chunk(chunk_id) @@ -762,7 +1257,7 @@ def after_criterion(self, scheduler, loss) -> None: # pylint: disable=W0613 pass def before_backward(self, scheduler, outputs, outputs_grad) -> None: # pylint: disable=W0613 - self._isp_communicator.is_forward = False + self._isp_communicator.switch_forward_backward_phase(is_forward=False) # switch model chunk before backward chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank self._isp_communicator.switch_current_model_chunk(chunk_id) @@ -793,11 +1288,16 @@ def __init__( isp_communicators: List[ISPCommunicator], ) -> None: self.isp_communicators = isp_communicators - self.reduce_scatter_handlers = {} - self.reduce_scatter_handlers = UniqueChainMap( - *(isp_communicator.reduce_scatter_handlers for isp_communicator in self.isp_communicators) - ) + def pop_reduced_grad(self, key) -> dict: + for communicator in self.isp_communicators: + try: + return communicator.pop_reduced_grad(key) + except KeyError: + continue + + # key is not in any communicator + raise KeyError(f"key {key} is not found") def register_prerequisite_for_forward_prefetch_hooks(self, prerequisite_func: Callable) -> None: for isp_communicator in self.isp_communicators: diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py index 5cd8cb79..e2f73b3b 100644 --- a/internlm/core/parallel/comm/utils.py +++ b/internlm/core/parallel/comm/utils.py @@ -15,14 +15,14 @@ class AsyncCommHandle(ABC): """A interface for asynchronous communication handles.""" @abstractmethod - def wait(self) -> None: + def wait(self, stream=None) -> None: """wait asynchronous communication to complete.""" class DummyAsyncCommHandle(AsyncCommHandle): """A fake communication handle used to maintain consistency in code writing""" - def wait(self) -> None: + def wait(self, stream=None) -> None: pass diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index b9e8e41b..d83cf9c5 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -456,11 +456,17 @@ def args_sanity_check(): gpc.config.parallel["weight"]["overlap"] = False if gpc.config.parallel["tensor"]["mode"] != TensorParallelMode.isp.name: assert gpc.config.parallel["weight"]["size"] <= 1, "weight parallel is only supported with isp" + + if gpc.config.parallel["weight"].get("layer_fuse_isp_comm", None) is None: + gpc.config.parallel["weight"]["layer_fuse_isp_comm"] = False # set default value for expert_weight parallel if gpc.config.parallel["expert_weight"].get("overlap", None) is None: gpc.config.parallel["expert_weight"]["overlap"] = False if gpc.config.parallel["expert"].get("no_tp", None) is None: gpc.config.parallel["expert"]["no_tp"] = False + + if gpc.config.parallel["expert_weight"].get("layer_fuse_isp_comm", None) is None: + gpc.config.parallel["expert_weight"]["layer_fuse_isp_comm"] = False # currently only interleaved pipeline scheduler with overlap can guarantee loss accuracy if hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1: assert ( diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 49f3fbcf..94da0649 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -449,13 +449,11 @@ def _accum_grads_store_in_bucket(self, bucket: BucketStore, reduce_rank: Optiona # wait and accumulate gardient. _key = getattr(_param, "isp_reduce_scatter_name") - _grad, _comm_handle = self._isp_communicator.reduce_scatter_handlers[_key] - _comm_handle.wait() + _grad = self._isp_communicator.pop_reduced_grad(_key) _param.grad.add_(_grad) # release cuda memory. _grad = None - self._isp_communicator.reduce_scatter_handlers[_key] = None bucket.reset_by_rank(reduce_rank) diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py index 36e5f073..11158aba 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim_v2.py +++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py @@ -231,13 +231,11 @@ def _accum_grads_store_in_bucket(self, bucket: BucketStore_v2) -> None: # wait and accumulate gardient. _key = getattr(_param, "isp_reduce_scatter_name") - _grad, _comm_handle = self._isp_communicator.reduce_scatter_handlers[_key] - _comm_handle.wait() + _grad = self._isp_communicator.pop_reduced_grad(_key) _param.grad.add_(_grad) # release cuda memory. _grad = None - self._isp_communicator.reduce_scatter_handlers[_key] = None bucket.reset_all() diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index ca11e689..a57c6526 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -364,6 +364,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.get_group(ParallelMode.WEIGHT), is_moe=False, selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False), + enable_layer_fuse_isp_comm=gpc.config.parallel.weight.get("layer_fuse_isp_comm", False), ) # register communicator for isp column parallel linear. ColumnParallelLinear.register_cls_communicator(isp_communicator) @@ -389,6 +390,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.expert_weight.overlap, gpc.get_group(ParallelMode.EXPERT_WEIGHT), is_moe=True, + enable_layer_fuse_isp_comm=gpc.config.parallel.expert_weight.get("layer_fuse_isp_comm", False), ) for moe in _submodule_filter(model, Experts): for column_linear in _submodule_filter(moe, (ColumnParallelLinear, GroupedWPLinear)):