Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(moe): add moe async param handler #332

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from internlm.core.parallel.comm.utils import (
DUMMY_HANDLE_CONST,
AsyncCommHandle,
CommunicatorType,
_gather,
_split,
all_gather_raw,
Expand Down Expand Up @@ -832,33 +833,35 @@ class ISPCommunicatorWrapper:

def __init__(
self,
isp_communicators: List[ISPCommunicator],
) -> None:
self.isp_communicators = isp_communicators
self.reduce_scatter_handlers = {}
self.isp_communicators = [None for _ in range(len(CommunicatorType))]
self.memory_pools = [None for _ in range(len(CommunicatorType))]

self.memory_pools = [
isp_communicator.memory_pool
for isp_communicator in self.isp_communicators
if isp_communicator.enable_memory_pool
]
self.reduce_scatter_handlers = UniqueChainMap(
*(isp_communicator.reduce_scatter_handlers for isp_communicator in self.isp_communicators)
)
self.reduce_scatter_handlers = UniqueChainMap()

if self.memory_pools:
self.enable_memory_pool = False

def set_communicator(self, index, communicator):
assert index < len(CommunicatorType)
self.isp_communicators[index] = communicator
self.reduce_scatter_handlers = self.reduce_scatter_handlers.new_child(communicator.reduce_scatter_handlers)
if communicator.enable_memory_pool:
self.memory_pools[index] = communicator.memory_pool
self.enable_memory_pool = True
else:
self.enable_memory_pool = False

def get_communicator(self, index):
assert index < len(CommunicatorType)
return self.isp_communicators[index]

def free_reduce_scatter_memory(self, key, index):
for memory_pool in self.memory_pools:
if key in memory_pool._reduce_scatter_memory_pool:
if memory_pool is not None and key in memory_pool._reduce_scatter_memory_pool:
memory_pool.free_reduce_scatter_memory(key, index)

def reset_lazy_pools(self) -> None:
for memory_pool in self.memory_pools:
memory_pool.reset_lazy_pools()
if memory_pool is not None:
memory_pool.reset_lazy_pools()

def register_prerequisite_for_forward_prefetch_hooks(self, prerequisite_func: Callable) -> None:
for isp_communicator in self.isp_communicators:
Expand Down
7 changes: 7 additions & 0 deletions internlm/core/parallel/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# -*- encoding: utf-8 -*-

from abc import ABC, abstractmethod
from enum import IntEnum
from typing import Callable

import torch
Expand Down Expand Up @@ -359,3 +360,9 @@ def backward(ctx, grad_output):


expandKVPacked = _ExpandKVPackedFunction.apply


# used in isp and zero
class CommunicatorType(IntEnum):
Non_MoE = 0
MoE = 1
59 changes: 51 additions & 8 deletions internlm/core/parallel/comm/zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import unwrap_naive_amp
from internlm.core.parallel.comm.isp import ISPCommunicatorWrapper
from internlm.core.parallel.comm.utils import CommunicatorType
from internlm.model.modules.embedding import Embedding1D
from internlm.model.modules.linear import ScaleColumnParallelLinear
from internlm.model.modules.utils import is_moe_param
from internlm.solver.optimizer.utils import flatten


Expand All @@ -27,6 +29,7 @@ def __init__(
zero1_mode: ParallelMode,
model: Union[nn.Module, nn.ModuleList],
isp_communicator: ISPCommunicatorWrapper = None,
is_moe: bool = False,
) -> None:
self._block_to_param: Dict[nn.Module, List[nn.Parameter]] = OrderedDict()
self._param_to_rank: Dict[nn.Parameter, int] = {}
Expand All @@ -35,8 +38,7 @@ def __init__(
self._block_to_name: Dict[nn.Module, str] = {}

zero1_size = gpc.get_world_size(zero1_mode)
total_param_num = sum(p.numel() for p in model.parameters())
avg_param_num = total_param_num * 1.0 // zero1_size
total_param_num = 0

# initialize an empty list for _bcast_handles of each rank
self._bcast_handles = {rank: [] for rank in range(zero1_size)}
Expand All @@ -56,16 +58,25 @@ def __init__(
for idx, block in enumerate(children):
block_name = name + f"_{idx}"
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
self._block_to_param[block] = list(block.parameters())
self._block_to_name[block] = block_name
self._block_to_param[block] = []
for param in block.parameters():
if is_moe_param(param) == is_moe:
total_param_num += param.numel()
self._block_to_param[block].append(param)
self._block_to_name[block] = block_name
else:
# record the block that a parameter belongs to
# self._block_to_param[name] = list(children.parameters())
self._block_to_param[children] = list(children.parameters())
self._block_to_name[children] = name
self._block_to_param[children] = []
for param in children.parameters():
if is_moe_param(param) == is_moe:
total_param_num += param.numel()
self._block_to_param[children].append(param)
self._block_to_name[children] = name

alloc_num = 0
rank_to_go = 0
avg_param_num = total_param_num * 1.0 // zero1_size

# process the parameters in block_to_param sequencially,
# allocate each parameter to a local rank of ParallelMode.ZERO1,
Expand All @@ -76,14 +87,14 @@ def __init__(
# allocate a model block to a local rank of ParallelMode.ZERO1
self._block_to_rank[block] = [rank_to_go]
for p in params:
# allocate a parameter to a local rank of ParallelMode.ZERO1
self._param_to_rank[p] = rank_to_go
alloc_num = alloc_num + p.numel()
# in this case, allocate the param to next rank if possible
if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1:
rank_to_go = rank_to_go + 1
alloc_num = 0
self._block_to_rank[block].append(rank_to_go)
# allocate a parameter to a local rank of ParallelMode.ZERO1
self._param_to_rank[p] = rank_to_go

for block_name in self._block_to_name.values():
self._block_allgather_handles[block_name] = None
Expand Down Expand Up @@ -188,3 +199,35 @@ def add_allgather_handle(self, handle, master_param, working_param, gatherd_para
self._block_working_params[block_name] = working_param
self._block_gathered_params[block_name] = gatherd_param
self._block_allgather_order[block_name] = 1


class ParamAsyncBcastHandlerWrapper:
"""
Wrapper for multiple ISPCommunicators.
TODO: check all isp communicator external interfaces and wrap them.
"""

def __init__(
self,
) -> None:
self.param_bcast_sync_handlers = [None for _ in range(len(CommunicatorType))]

def set_handle(self, index, handler):
assert index < len(CommunicatorType)
self.param_bcast_sync_handlers[index] = handler

def get_handle(self, index):
assert index < len(CommunicatorType)
return self.param_bcast_sync_handlers[index]

def get_rank_by_param(self, param) -> int:
idx = CommunicatorType.MoE if is_moe_param(param) else CommunicatorType.Non_MoE
return self.get_handle(idx).get_rank_by_param(param)

def add_bcast_handle(self, rank, handle, bcast_mode) -> None:
idx = CommunicatorType.MoE if bcast_mode == ParallelMode.EXPERT_DATA else CommunicatorType.Non_MoE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bcast mode 需要考虑 expert zero吗

self.get_handle(idx).add_bcast_handle(rank, handle)

def add_allgather_handle(self, handle, master_param, working_param, gatherd_param, block_name, bcast_mode) -> None:
idx = CommunicatorType.MoE if bcast_mode == ParallelMode.EXPERT_DATA else CommunicatorType.Non_MoE
self.get_handle(idx).add_allgather_handle(handle, master_param, working_param, gatherd_param, block_name)
3 changes: 0 additions & 3 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,6 @@ def args_sanity_check():
# moe not support overlap and zero1.5 for now
if gpc.config.model.get("num_experts", 1) > 1:
assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support num_experts > 1"
assert (
not optim_ckpt.overlap_sync_grad & optim_ckpt.overlap_sync_param
), "not support overlap and moe at the same time"
assert gpc.config.parallel.zero1.size in (
-1,
gpc.get_world_size(ParallelMode.DATA),
Expand Down
2 changes: 1 addition & 1 deletion internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def broadcast_params(self):
)

if self._overlap_sync_param:
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
self._param_bcast_sync_handler.add_bcast_handle(rank, handle, self._broadcast_parallel_mode[group_id])
else:
handles.append(handle)

Expand Down
1 change: 1 addition & 0 deletions internlm/solver/optimizer/hybrid_zero_optim_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ def all_gather_params(
all_gather_working_params,
gathered_params,
all_gather_working_params[0].block_name,
self._zero_parallel_mode[group_id],
)
else:
gathered_params_list.append(gathered_params)
Expand Down
32 changes: 23 additions & 9 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@
SequenceParallelCommunicator,
TensorParallelCommunicator,
)
from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler
from internlm.core.parallel.comm.utils import CommunicatorType
from internlm.core.parallel.comm.zero import (
ParamAsyncBcastHandler,
ParamAsyncBcastHandlerWrapper,
)
from internlm.core.trainer import TrainState
from internlm.data.utils import unpack_type_ids
from internlm.model.builder import create_model
Expand Down Expand Up @@ -293,6 +297,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
_retain_out_sharded = gpc.config.model.get("parallel_output", True)

if is_using_isp():
isp_communicator_wrapper = ISPCommunicatorWrapper()
isp_communicator = ISPCommunicator(
model,
ISPCommModelConfig(
Expand All @@ -304,6 +309,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
gpc.config.parallel.weight.memory_pool,
gpc.get_group(ParallelMode.WEIGHT),
)
isp_communicator_wrapper.set_communicator(CommunicatorType.Non_MoE, isp_communicator)
# register communicator for isp column parallel linear.
ColumnParallelLinear.register_cls_communicator(isp_communicator)
# row parallel linear will not be used.
Expand All @@ -329,16 +335,13 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]):
gpc.config.parallel.expert_weight.memory_pool,
gpc.get_group(ParallelMode.EXPERT_WEIGHT),
)
isp_communicator_wrapper.set_communicator(CommunicatorType.MoE, moe_isp_communicator)
for moe in _submodule_filter(model, Experts):
for column_linear in _submodule_filter(moe, (ColumnParallelLinear)):
column_linear.register_communicator(moe_isp_communicator)
for row_linear in _submodule_filter(moe, RowParallelLinear):
row_linear.register_communicator(None)

isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator, moe_isp_communicator])
else:
isp_communicator_wrapper = ISPCommunicatorWrapper([isp_communicator])

# register communictor for mtp/msp/fsp linear.

# tensor parallel
Expand Down Expand Up @@ -460,9 +463,20 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato
zero_cfg.overlap_sync_grad = False

if zero_cfg.overlap_sync_param:
param_bcast_sync_handler = ParamAsyncBcastHandler(ParallelMode.ZERO1, model, isp_communicator)
param_bcast_sync_handle_wrapper = ParamAsyncBcastHandlerWrapper()
non_moe_isp_communicator = (
isp_communicator.get_communicator(CommunicatorType.Non_MoE) if isp_communicator else None
)
param_bcast_sync_handler = ParamAsyncBcastHandler(ParallelMode.ZERO1, model, non_moe_isp_communicator)
param_bcast_sync_handle_wrapper.set_handle(CommunicatorType.Non_MoE, param_bcast_sync_handler)
if gpc.config.model.get("num_experts", 1) > 1:
moe_isp_communicator = isp_communicator.get_communicator(CommunicatorType.MoE) if isp_communicator else None
moe_param_bcast_sync_handler = ParamAsyncBcastHandler(
ParallelMode.EXPERT_DATA, model, moe_isp_communicator, is_moe=True
)
param_bcast_sync_handle_wrapper.set_handle(CommunicatorType.MoE, moe_param_bcast_sync_handler)
else:
param_bcast_sync_handler = None
param_bcast_sync_handle_wrapper = None

if not gpc.config.parallel.zero1.fsdp:
if (
Expand All @@ -473,15 +487,15 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato
naive_optimizer,
grad_scal_cfg=grad_scal_cfg,
zero_cfg=zero_cfg,
param_bcast_sync_handler=param_bcast_sync_handler,
param_bcast_sync_handler=param_bcast_sync_handle_wrapper,
isp_communicator=isp_communicator,
)
else:
optimizer = HybridZeroOptimizer_v2(
naive_optimizer,
grad_scal_cfg=grad_scal_cfg,
zero_cfg=zero_cfg,
param_bcast_sync_handler=param_bcast_sync_handler,
param_bcast_sync_handler=param_bcast_sync_handle_wrapper,
isp_communicator=isp_communicator,
)
else:
Expand Down
Loading