Skip to content

Commit

Permalink
refactor(moe): refactor moe for extensibility (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
blankde authored Jan 30, 2024
1 parent fbff756 commit d28d204
Show file tree
Hide file tree
Showing 10 changed files with 243 additions and 204 deletions.
13 changes: 12 additions & 1 deletion configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
num_experts=4,
moe_use_residual=False,
moe_gate_k=2,
moe_type="GShard",
)

# zero1 parallel:
Expand Down Expand Up @@ -176,6 +176,17 @@
),
)

# custom moe impl configs
moe = dict(
top_k=2,
capacity_factor=1.0,
eval_capacity_factor=1.0,
min_capacity=4,
noisy_gate_policy=None,
drop_tokens=True,
use_rts=True,
)

model_type = "INTERNLM_MoE"

# metric_dtype can be "fp32" or other string
Expand Down
4 changes: 2 additions & 2 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def args_sanity_check():
model._add_item("num_experts", 1)
if "moe_use_residual" not in model:
model._add_item("moe_use_residual", False)
if "moe_gate_k" not in model:
model._add_item("moe_gate_k", 2)
if "moe_type" not in model:
model._add_item("moe_type", "GShard")
# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:
gpc.config.parallel._add_item("sequence_parallel", False)
Expand Down
86 changes: 6 additions & 80 deletions internlm/model/modeling_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,9 @@ class PackedFlashBaseLayer1D(nn.Module):
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to
infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
moe_type (str): determine which moe impl will be used, default is GShardMoE
"""

def __init__(
Expand All @@ -86,14 +79,6 @@ def __init__(
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: int = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
super().__init__()
self.checkpoint = checkpoint
Expand Down Expand Up @@ -131,14 +116,6 @@ def __init__(
set_fp32_attr_to_module(self.norm2)

self.num_experts = num_experts
self.moe_gate_k = moe_gate_k
self.moe_capacity_factor = moe_capacity_factor
self.moe_eval_capacity_factor = moe_eval_capacity_factor
self.moe_min_capacity = moe_min_capacity
self.moe_noisy_gate_policy = moe_noisy_gate_policy
self.moe_drop_tokens = moe_drop_tokens
self.moe_use_rts = moe_use_rts
self.moe_use_residual = moe_use_residual
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
if num_experts <= 1: # dense, not MoE
if use_swiglu:
Expand Down Expand Up @@ -174,15 +151,8 @@ def __init__(
self.mlp = MoE(
hidden_size=hidden_size,
num_experts=num_experts,
ep_group=gpc.get_group(ParallelMode.EXPERT),
ep_size=ep_size,
k=moe_gate_k,
capacity_factor=moe_capacity_factor,
eval_capacity_factor=moe_eval_capacity_factor,
min_capacity=moe_min_capacity,
noisy_gate_policy=moe_noisy_gate_policy,
drop_tokens=moe_drop_tokens,
use_rts=moe_use_rts,
use_residual=moe_use_residual,
device=device,
dtype=dtype,
)
Expand Down Expand Up @@ -316,16 +286,9 @@ class PackedFlashInternLm1D(nn.Module):
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
moe_type (str): determine which moe impl will be used, default is GShardMoE
"""

def __init__(
Expand Down Expand Up @@ -357,14 +320,6 @@ def __init__(
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: bool = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -415,14 +370,6 @@ def __init__(
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
num_experts=num_experts,
moe_gate_k=moe_gate_k,
moe_capacity_factor=moe_capacity_factor,
moe_eval_capacity_factor=moe_eval_capacity_factor,
moe_min_capacity=moe_min_capacity,
moe_noisy_gate_policy=moe_noisy_gate_policy,
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
)
for lid in range(num_layers)
]
Expand Down Expand Up @@ -559,14 +506,8 @@ def build_model_with_moe_cfg(
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: int = 1,
moe_gate_k: int = 1,
moe_capacity_factor: float = 1.0,
moe_eval_capacity_factor: float = 1.0,
moe_min_capacity: int = 4,
moe_noisy_gate_policy: str = None,
moe_drop_tokens: bool = True,
moe_use_rts: bool = True,
moe_use_residual: bool = False,
moe_use_residual: bool = False, # pylint: disable=W0613
moe_type: str = None, # pylint: disable=W0613
):
"""
Build model with config.
Expand Down Expand Up @@ -598,16 +539,9 @@ def build_model_with_moe_cfg(
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default.
moe_gate_k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
moe_capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
moe_eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
moe_min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
moe_noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample'.
moe_drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent
to infinite capacity).
moe_use_rts (bool, optional): default=True, whether to use Random Token Selection.
moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE
(https://arxiv.org/abs/2201.05596) layer.
moe_type (str): determine which moe impl will be used, default is GShardMoE
"""

cfg = dict(
Expand All @@ -633,14 +567,6 @@ def build_model_with_moe_cfg(
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
num_experts=num_experts,
moe_gate_k=moe_gate_k,
moe_capacity_factor=moe_capacity_factor,
moe_eval_capacity_factor=moe_eval_capacity_factor,
moe_min_capacity=moe_min_capacity,
moe_noisy_gate_policy=moe_noisy_gate_policy,
moe_drop_tokens=moe_drop_tokens,
moe_use_rts=moe_use_rts,
moe_use_residual=moe_use_residual,
)

return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
79 changes: 16 additions & 63 deletions internlm/model/moe.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import typing

import torch

import internlm.moe # noqa # pylint: disable=W0611
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.linear import FeedForward
from internlm.moe.experts import Experts
from internlm.moe.sharded_moe import MOELayer, TopKGate
from internlm.utils.logger import get_logger
from internlm.utils.registry import MODEL_INITIALIZER

# global llm logger
logger = get_logger(__file__)
Expand Down Expand Up @@ -39,75 +37,30 @@ def __init__(
self,
hidden_size,
num_experts=1,
ep_group=None,
ep_size=1,
k=1,
capacity_factor=1.0,
eval_capacity_factor=1.0,
min_capacity=4,
noisy_gate_policy: typing.Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
using_default_moe: bool = True,
use_residual=False,
device=None,
dtype=None,
):

super().__init__()

assert (
num_experts % ep_size == 0
), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"
self.ep_size = ep_size
self.num_experts = num_experts
self.num_local_experts = num_experts // self.ep_size

assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
"Unsupported noisy_gate_policy: " + noisy_gate_policy
)

# for elastic expert paralle, experts may have multiple groups
expert_group_name = f"moe_ep_size_{self.ep_size}"
if expert_group_name not in gpc.expert_parallel_group_names:
gpc.expert_parallel_group_names.append(expert_group_name)
experts = torch.nn.ModuleList(
[
FeedForward(
hidden_size,
int(hidden_size * gpc.config.model.mlp_ratio),
out_features=hidden_size,
process_group=gpc.get_group(ParallelMode.TENSOR),
bias=False,
device=device,
dtype=dtype,
)
for _ in range(self.num_local_experts)
]
if not hasattr(gpc.config, "moe"):
gpc.config.moe = dict()

self.moe_layer = MODEL_INITIALIZER.get_module(module_name=gpc.config.model.moe_type)(
hidden_size=hidden_size,
num_experts=num_experts,
ep_group=ep_group,
ep_size=ep_size,
device=device,
dtype=dtype,
**(gpc.config.moe)
)
experts = Experts(experts, self.num_local_experts, expert_group_name)

if using_default_moe:
self.moe_layer = MOELayer(
TopKGate(
hidden_size,
num_experts,
k,
capacity_factor,
eval_capacity_factor,
min_capacity,
noisy_gate_policy,
drop_tokens,
use_rts,
),
experts,
gpc.get_group(ParallelMode.EXPERT),
self.ep_size,
self.num_local_experts,
)

# residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence
self.use_residual = use_residual
if use_residual:
self.use_residual = gpc.config.model.moe_use_residual
if self.use_residual:
self.residual_mlp = FeedForward(
hidden_size,
int(hidden_size * gpc.config.model.mlp_ratio),
Expand Down
3 changes: 3 additions & 0 deletions internlm/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .gshard_moe import GShardMOELayer

__all__ = ["GShardMOELayer"]
35 changes: 35 additions & 0 deletions internlm/moe/base_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import TYPE_CHECKING, Union

import torch
from torch import Tensor
from torch.nn import Module, ModuleList

from internlm.core.context import global_context as gpc
from internlm.moe.experts import Experts

if TYPE_CHECKING:
Base = Module[Tensor]
else:
Base = Module


class BaseMoELayer(Base):
"""
Base MoE Layer.
"""

def __init__(
self, gate: Module, experts: Union[Module, ModuleList], ep_group, ep_size: int, num_local_experts: int
) -> None:
super().__init__()
# for elastic expert paralle, experts may have multiple groups
expert_group_name = f"moe_ep_size_{ep_size}"
if expert_group_name not in gpc.expert_parallel_group_names:
gpc.expert_parallel_group_names.append(expert_group_name)
self.gate = gate
self.experts = Experts(experts, num_local_experts, expert_group_name)
self.ep_group = ep_group
self.ep_size = ep_size
self.num_local_experts = num_local_experts
self.l_aux = torch.tensor(0.0, device=torch.cuda.current_device(), dtype=gpc.config.model.get("dtype"))
self.exp_counts = None
Loading

0 comments on commit d28d204

Please sign in to comment.