From 83bef93195b664bbffa5a763f106c8e773dc0108 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 13 Sep 2024 18:13:30 +0800 Subject: [PATCH 1/3] add moe async param handle --- internlm/core/parallel/comm/isp.py | 12 +++++- internlm/core/parallel/comm/utils.py | 7 ++++ internlm/core/parallel/comm/zero.py | 61 ++++++++++++++++++++++++---- internlm/initialize/launch.py | 3 -- internlm/train/pipeline.py | 32 +++++++++++---- 5 files changed, 93 insertions(+), 22 deletions(-) diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 8ca107bf..c36327fd 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -18,6 +18,7 @@ from internlm.core.parallel.comm.utils import ( DUMMY_HANDLE_CONST, AsyncCommHandle, + CommunicatorType, _gather, _split, all_gather_raw, @@ -832,9 +833,8 @@ class ISPCommunicatorWrapper: def __init__( self, - isp_communicators: List[ISPCommunicator], ) -> None: - self.isp_communicators = isp_communicators + self.isp_communicators = [None for _ in range(len(CommunicatorType))] self.reduce_scatter_handlers = {} self.memory_pools = [ @@ -851,6 +851,14 @@ def __init__( else: self.enable_memory_pool = False + def set_communicator(self, index, communicator): + assert index < len(CommunicatorType) + self.isp_communicators[index] = communicator + + def get_communicator(self, index): + assert index < len(CommunicatorType) + return self.isp_communicators[index] + def free_reduce_scatter_memory(self, key, index): for memory_pool in self.memory_pools: if key in memory_pool._reduce_scatter_memory_pool: diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py index a7f93c3b..fbb1c434 100644 --- a/internlm/core/parallel/comm/utils.py +++ b/internlm/core/parallel/comm/utils.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- from abc import ABC, abstractmethod +from enum import IntEnum from typing import Callable import torch @@ -359,3 +360,9 @@ def backward(ctx, grad_output): expandKVPacked = _ExpandKVPackedFunction.apply + + +# used in isp and zero +class CommunicatorType(IntEnum): + Non_MoE = 0 + MoE = 1 diff --git a/internlm/core/parallel/comm/zero.py b/internlm/core/parallel/comm/zero.py index 58929290..7b446e50 100644 --- a/internlm/core/parallel/comm/zero.py +++ b/internlm/core/parallel/comm/zero.py @@ -12,8 +12,10 @@ from internlm.core.context import global_context as gpc from internlm.core.naive_amp import unwrap_naive_amp from internlm.core.parallel.comm.isp import ISPCommunicatorWrapper +from internlm.core.parallel.comm.utils import CommunicatorType from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import ScaleColumnParallelLinear +from internlm.model.modules.utils import is_moe_param from internlm.solver.optimizer.utils import flatten @@ -27,6 +29,7 @@ def __init__( zero1_mode: ParallelMode, model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicatorWrapper = None, + is_moe: bool = False, ) -> None: self._block_to_param: Dict[nn.Module, List[nn.Parameter]] = OrderedDict() self._param_to_rank: Dict[nn.Parameter, int] = {} @@ -35,8 +38,7 @@ def __init__( self._block_to_name: Dict[nn.Module, str] = {} zero1_size = gpc.get_world_size(zero1_mode) - total_param_num = sum(p.numel() for p in model.parameters()) - avg_param_num = total_param_num * 1.0 // zero1_size + total_param_num = 0 # initialize an empty list for _bcast_handles of each rank self._bcast_handles = {rank: [] for rank in range(zero1_size)} @@ -56,16 +58,25 @@ def __init__( for idx, block in enumerate(children): block_name = name + f"_{idx}" # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) - self._block_to_param[block] = list(block.parameters()) - self._block_to_name[block] = block_name + self._block_to_param[block] = [] + for param in block.parameters(): + if is_moe_param(param) == is_moe: + total_param_num += param.numel() + self._block_to_param[block].append(param) + self._block_to_name[block] = block_name else: # record the block that a parameter belongs to # self._block_to_param[name] = list(children.parameters()) - self._block_to_param[children] = list(children.parameters()) - self._block_to_name[children] = name + self._block_to_param[children] = [] + for param in children.parameters(): + if is_moe_param(param) == is_moe: + total_param_num += param.numel() + self._block_to_param[children].append(param) + self._block_to_name[children] = name alloc_num = 0 rank_to_go = 0 + avg_param_num = total_param_num * 1.0 // zero1_size # process the parameters in block_to_param sequencially, # allocate each parameter to a local rank of ParallelMode.ZERO1, @@ -76,14 +87,14 @@ def __init__( # allocate a model block to a local rank of ParallelMode.ZERO1 self._block_to_rank[block] = [rank_to_go] for p in params: + # allocate a parameter to a local rank of ParallelMode.ZERO1 + self._param_to_rank[p] = rank_to_go alloc_num = alloc_num + p.numel() # in this case, allocate the param to next rank if possible if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1: rank_to_go = rank_to_go + 1 alloc_num = 0 self._block_to_rank[block].append(rank_to_go) - # allocate a parameter to a local rank of ParallelMode.ZERO1 - self._param_to_rank[p] = rank_to_go for block_name in self._block_to_name.values(): self._block_allgather_handles[block_name] = None @@ -188,3 +199,37 @@ def add_allgather_handle(self, handle, master_param, working_param, gatherd_para self._block_working_params[block_name] = working_param self._block_gathered_params[block_name] = gatherd_param self._block_allgather_order[block_name] = 1 + + +class ParamAsyncBcastHandlerWrapper: + """ + Wrapper for multiple ISPCommunicators. + TODO: check all isp communicator external interfaces and wrap them. + """ + + def __init__( + self, + ) -> None: + self.param_bcast_sync_handlers = [None for _ in range(len(CommunicatorType))] + + def set_handle(self, index, handler): + assert index < len(CommunicatorType) + self.param_bcast_sync_handlers[index] = handler + + def get_handle(self, index): + assert index < len(CommunicatorType) + return self.param_bcast_sync_handlers[index] + + def get_rank_by_param(self, param) -> int: + idx = CommunicatorType.MoE if is_moe_param(param) else CommunicatorType.Non_MoE + return self.get_handle(idx).get_rank_by_param(param) + + def add_bcast_handle(self, rank, handle, is_moe_group=False) -> None: + idx = CommunicatorType.MoE if is_moe_group else CommunicatorType.Non_MoE + self.get_handle(idx).add_bcast_handle(rank, handle) + + def add_allgather_handle( + self, handle, master_param, working_param, gatherd_param, block_name, is_moe_group=False + ) -> None: + idx = CommunicatorType.MoE if is_moe_group else CommunicatorType.Non_MoE + self.get_handle(idx).add_allgather_handle(handle, master_param, working_param, gatherd_param, block_name) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 52db23cf..b07f7c62 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -508,9 +508,6 @@ def args_sanity_check(): # moe not support overlap and zero1.5 for now if gpc.config.model.get("num_experts", 1) > 1: assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support num_experts > 1" - assert ( - not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param - ), "not support overlap and moe at the same time" assert gpc.config.parallel.zero1.size in ( -1, gpc.get_world_size(ParallelMode.DATA), diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 05e5c30f..a9e4b83e 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -44,7 +44,11 @@ SequenceParallelCommunicator, TensorParallelCommunicator, ) -from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler +from internlm.core.parallel.comm.utils import CommunicatorType +from internlm.core.parallel.comm.zero import ( + ParamAsyncBcastHandler, + ParamAsyncBcastHandlerWrapper, +) from internlm.core.trainer import TrainState from internlm.data.utils import unpack_type_ids from internlm.model.builder import create_model @@ -293,6 +297,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): _retain_out_sharded = gpc.config.model.get("parallel_output", True) if is_using_isp(): + isp_communicator_wrapper = ISPCommunicatorWrapper() isp_communicator = ISPCommunicator( model, ISPCommModelConfig( @@ -304,6 +309,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.weight.memory_pool, gpc.get_group(ParallelMode.WEIGHT), ) + isp_communicator_wrapper.set_communicator(CommunicatorType.Non_MoE, isp_communicator) # register communicator for isp column parallel linear. ColumnParallelLinear.register_cls_communicator(isp_communicator) # row parallel linear will not be used. @@ -329,16 +335,13 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.expert_weight.memory_pool, gpc.get_group(ParallelMode.EXPERT_WEIGHT), ) + isp_communicator_wrapper.set_communicator(CommunicatorType.MoE, moe_isp_communicator) for moe in _submodule_filter(model, Experts): for column_linear in _submodule_filter(moe, (ColumnParallelLinear)): column_linear.register_communicator(moe_isp_communicator) for row_linear in _submodule_filter(moe, RowParallelLinear): row_linear.register_communicator(None) - isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator, moe_isp_communicator]) - else: - isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator]) - # register communictor for mtp/msp/fsp linear. # tensor parallel @@ -460,9 +463,20 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato zero_cfg.overlap_sync_grad = False if zero_cfg.overlap_sync_param: - param_bcast_sync_handler = ParamAsyncBcastHandler(ParallelMode.ZERO1, model, isp_communicator) + param_bcast_sync_handle_wrapper = ParamAsyncBcastHandlerWrapper() + non_moe_isp_communicator = ( + isp_communicator.get_communicator(CommunicatorType.Non_MoE) if isp_communicator else None + ) + param_bcast_sync_handler = ParamAsyncBcastHandler(ParallelMode.ZERO1, model, non_moe_isp_communicator) + param_bcast_sync_handle_wrapper.set_handle(CommunicatorType.Non_MoE, param_bcast_sync_handler) + if gpc.config.model.get("num_experts", 1) > 1: + moe_isp_communicator = isp_communicator.get_communicator(CommunicatorType.MoE) if isp_communicator else None + moe_param_bcast_sync_handler = ParamAsyncBcastHandler( + ParallelMode.EXPERT_DATA, model, moe_isp_communicator, is_moe=True + ) + param_bcast_sync_handle_wrapper.set_handle(CommunicatorType.MoE, moe_param_bcast_sync_handler) else: - param_bcast_sync_handler = None + param_bcast_sync_handle_wrapper = None if not gpc.config.parallel.zero1.fsdp: if ( @@ -473,7 +487,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato naive_optimizer, grad_scal_cfg=grad_scal_cfg, zero_cfg=zero_cfg, - param_bcast_sync_handler=param_bcast_sync_handler, + param_bcast_sync_handler=param_bcast_sync_handle_wrapper, isp_communicator=isp_communicator, ) else: @@ -481,7 +495,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato naive_optimizer, grad_scal_cfg=grad_scal_cfg, zero_cfg=zero_cfg, - param_bcast_sync_handler=param_bcast_sync_handler, + param_bcast_sync_handler=param_bcast_sync_handle_wrapper, isp_communicator=isp_communicator, ) else: From 5fb554ba551f0250ae5276fdf447509237649409 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 13 Sep 2024 18:55:47 +0800 Subject: [PATCH 2/3] fix isp --- internlm/core/parallel/comm/isp.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index c36327fd..9b4a2a5a 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -835,25 +835,19 @@ def __init__( self, ) -> None: self.isp_communicators = [None for _ in range(len(CommunicatorType))] - self.reduce_scatter_handlers = {} + self.memory_pools = [None for _ in range(len(CommunicatorType))] - self.memory_pools = [ - isp_communicator.memory_pool - for isp_communicator in self.isp_communicators - if isp_communicator.enable_memory_pool - ] - self.reduce_scatter_handlers = UniqueChainMap( - *(isp_communicator.reduce_scatter_handlers for isp_communicator in self.isp_communicators) - ) + self.reduce_scatter_handlers = UniqueChainMap() - if self.memory_pools: - self.enable_memory_pool = True - else: - self.enable_memory_pool = False + self.enable_memory_pool = False def set_communicator(self, index, communicator): assert index < len(CommunicatorType) self.isp_communicators[index] = communicator + self.reduce_scatter_handlers = self.reduce_scatter_handlers.new_child(communicator.reduce_scatter_handlers) + if communicator.enable_memory_pool: + self.memory_pools[index] = communicator.memory_pool + self.enable_memory_pool = True def get_communicator(self, index): assert index < len(CommunicatorType) @@ -861,12 +855,13 @@ def get_communicator(self, index): def free_reduce_scatter_memory(self, key, index): for memory_pool in self.memory_pools: - if key in memory_pool._reduce_scatter_memory_pool: + if memory_pool is not None and key in memory_pool._reduce_scatter_memory_pool: memory_pool.free_reduce_scatter_memory(key, index) def reset_lazy_pools(self) -> None: for memory_pool in self.memory_pools: - memory_pool.reset_lazy_pools() + if memory_pool is not None: + memory_pool.reset_lazy_pools() def register_prerequisite_for_forward_prefetch_hooks(self, prerequisite_func: Callable) -> None: for isp_communicator in self.isp_communicators: From d5f68a6696fb65dcdcab39cec601114b56b845ff Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 13 Sep 2024 19:12:37 +0800 Subject: [PATCH 3/3] fix little bugs --- internlm/core/parallel/comm/zero.py | 10 ++++------ internlm/solver/optimizer/hybrid_zero_optim.py | 2 +- internlm/solver/optimizer/hybrid_zero_optim_v2.py | 1 + 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/internlm/core/parallel/comm/zero.py b/internlm/core/parallel/comm/zero.py index 7b446e50..78037c59 100644 --- a/internlm/core/parallel/comm/zero.py +++ b/internlm/core/parallel/comm/zero.py @@ -224,12 +224,10 @@ def get_rank_by_param(self, param) -> int: idx = CommunicatorType.MoE if is_moe_param(param) else CommunicatorType.Non_MoE return self.get_handle(idx).get_rank_by_param(param) - def add_bcast_handle(self, rank, handle, is_moe_group=False) -> None: - idx = CommunicatorType.MoE if is_moe_group else CommunicatorType.Non_MoE + def add_bcast_handle(self, rank, handle, bcast_mode) -> None: + idx = CommunicatorType.MoE if bcast_mode == ParallelMode.EXPERT_DATA else CommunicatorType.Non_MoE self.get_handle(idx).add_bcast_handle(rank, handle) - def add_allgather_handle( - self, handle, master_param, working_param, gatherd_param, block_name, is_moe_group=False - ) -> None: - idx = CommunicatorType.MoE if is_moe_group else CommunicatorType.Non_MoE + def add_allgather_handle(self, handle, master_param, working_param, gatherd_param, block_name, bcast_mode) -> None: + idx = CommunicatorType.MoE if bcast_mode == ParallelMode.EXPERT_DATA else CommunicatorType.Non_MoE self.get_handle(idx).add_allgather_handle(handle, master_param, working_param, gatherd_param, block_name) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index b2d2063d..c5e40269 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -884,7 +884,7 @@ def broadcast_params(self): ) if self._overlap_sync_param: - self._param_bcast_sync_handler.add_bcast_handle(rank, handle) + self._param_bcast_sync_handler.add_bcast_handle(rank, handle, self._broadcast_parallel_mode[group_id]) else: handles.append(handle) diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py index 6f8d3206..63ab29ae 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim_v2.py +++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py @@ -710,6 +710,7 @@ def all_gather_params( all_gather_working_params, gathered_params, all_gather_working_params[0].block_name, + self._zero_parallel_mode[group_id], ) else: gathered_params_list.append(gathered_params)