Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(moe): add gshard token rearrange optim #352

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 121 additions & 28 deletions internlm/model/moe/gshard_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Git commit hash: f3943cf9109226ed3ecf2d5dbb639a11cd925555
We retain the following license from the original files:
"""

from collections import namedtuple
from typing import Callable, Dict, Optional, Tuple

import torch
Expand All @@ -25,10 +25,26 @@
# global llm logger
logger = get_logger(__file__)

try:
# To enable Tutel MoE optimizations:
# python3 -m pip install --user --upgrade git+https://github.com/microsoft/[email protected]
from tutel import moe as tutel_moe

TUTEL_INSTALLED = True
except (ModuleNotFoundError, ImportError):
# Fail silently so we don't spam logs unnecessarily if user isn't using tutel
TUTEL_INSTALLED = False
logger.warning("from tutel import moe failed")
pass

uniform_map: Dict[torch.device, Callable] = {}
gumbel_map: Dict[torch.device, Callable] = {}
exp_selection_uniform_map: Dict[torch.device, Callable] = {}

GatingTokenRearrangeInfo = namedtuple(
"GatingTokenRearrangeInfo", ["token_rearranged_ec_idx", "token_exp_weights", "expert_select_token_idx"]
)


def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
"""
Expand Down Expand Up @@ -223,7 +239,7 @@ def top1gating(

dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts
return l_aux, combine_weights, dispatch_mask


def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Expand Down Expand Up @@ -253,7 +269,7 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
locations2 += torch.sum(mask1, dim=0, keepdim=True)

# gating decisions
exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")
# exp_counts = torch.sum(mask1, dim=0).detach().to("cpu")

# Compute l_aux
me = torch.mean(gates, dim=0)
Expand Down Expand Up @@ -289,14 +305,16 @@ def top2gating(logits: Tensor, capacity_factor: float, min_capacity: int) -> Tup
combine_weights = combine1_sec + combine2_sec
dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts
return l_aux, combine_weights, dispatch_mask


def fused_topkgating(
logits: Tensor,
k: int,
capacity_factor: float,
min_capacity: int,
enable_token_rearrange_opt: bool = True,
use_tutel: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Implements TopKGating on logits."""
# everything is in fp32 in this function
Expand All @@ -306,19 +324,21 @@ def fused_topkgating(
capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity))

# Create a mask by top-k experts
indices_s = torch.topk(gates, k, dim=1).indices
indices_s = indices_s.permute(1, 0).reshape(-1)
masks = F.one_hot(indices_s, num_classes=num_experts)
indices_s = torch.topk(gates, k, dim=1).indices.t()
masks = F.one_hot(indices_s.reshape(-1), num_classes=num_experts)

# Compute locations in capacity buffer
locations = torch.cumsum(masks, dim=0) - 1
if use_tutel and TUTEL_INSTALLED:
locations = tutel_moe.fast_cumsum_sub_one(masks)
else:
locations = torch.cumsum(masks, dim=0) - 1

# reshape (s,e) to (k,s,e)
masks = masks.reshape(-1, gates.shape[0], num_experts)
locations = locations.reshape(-1, gates.shape[0], num_experts)

# gating decisions
exp_counts = torch.sum(masks[0], dim=0).detach().to("cpu")
# exp_counts = torch.sum(masks[0], dim=0).detach()

# Compute l_aux
me = torch.mean(gates, dim=0)
Expand All @@ -333,20 +353,39 @@ def fused_topkgating(

# Normalize gate probabilities
mask_float = masks.type_as(logits)
gate_s = einsum("se,kse->ks", gates, mask_float)
# gate_s = einsum("se,kse->ks", gates, mask_float)
gate_s, indices_s = torch.max(gates * mask_float, dim=2)
denom_s = torch.sum(gate_s, dim=0)
# Avoid divide-by-zero
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps)
gate_s /= denom_s

# Calculate combine_weights and dispatch_mask
gate_all = einsum("ks,kse->kse", gate_s, mask_float)
locations_sc = F.one_hot(locations_s, num_classes=capacity).type_as(logits)
combine_sec = einsum("kse,ksc->ksec", gate_all, locations_sc)
combine_weights = torch.sum(combine_sec, dim=0)
dispatch_mask = combine_weights.bool()
if enable_token_rearrange_opt:
token_rearranged_ec_idx = indices_s.int() * capacity + locations_s.int()
# shape:[S, E]->[C, E]->[E, C]->[E*C]
token_sel_exp_int_mask = masks * torch.arange(k, 0, -1, device=masks.device).reshape(k, 1, 1)
expert_sel_top_c_token_idx = torch.topk(
torch.sum(token_sel_exp_int_mask, dim=0), k=capacity, dim=0, sorted=True
)[1]
expert_select_token_idx = expert_sel_top_c_token_idx.t().reshape(num_experts * capacity)
token_rearranged_ec_idx = token_rearranged_ec_idx.reshape(-1)
token_exp_weights = gate_s.reshape(-1)

top2_gating_token_infos = GatingTokenRearrangeInfo(
token_rearranged_ec_idx=token_rearranged_ec_idx,
token_exp_weights=token_exp_weights,
expert_select_token_idx=expert_select_token_idx,
)
return l_aux, top2_gating_token_infos
else:
# Calculate combine_weights and dispatch_mask
gate_all = einsum("ks,kse->kse", gate_s, mask_float)
locations_sc = F.one_hot(locations_s, num_classes=capacity).type_as(logits)
combine_sec = einsum("kse,ksc->ksec", gate_all, locations_sc)
combine_weights = torch.sum(combine_sec, dim=0)
dispatch_mask = combine_weights.bool()

return l_aux, combine_weights, dispatch_mask, exp_counts
return l_aux, combine_weights, dispatch_mask


class TopKGate(Module):
Expand Down Expand Up @@ -378,10 +417,13 @@ def __init__(
noisy_gate_policy: Optional[str] = None,
drop_tokens: bool = True,
use_rts: bool = True,
use_fused_gating: bool = False,
use_fused_gating: bool = True,
enable_token_rearrange_opt: bool = True,
use_tutel: bool = True,
) -> None:
super().__init__()
# alway use fp32

# Deepspeed's mechisms, alway use fp32
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False)
self.k = topk
self.capacity_factor = capacity_factor
Expand All @@ -393,6 +435,8 @@ def __init__(
self.drop_tokens = drop_tokens
self.use_rts = use_rts
self.use_fused_gating = use_fused_gating
self.enable_token_rearrange_opt = enable_token_rearrange_opt
self.use_tutel = use_tutel

def forward(
self, inputs: torch.Tensor, used_token: torch.Tensor = None
Expand All @@ -408,7 +452,12 @@ def forward(
if self.use_fused_gating or self.k > 2:
assert self.noisy_gate_policy != "RSample", "RSample noisy is not supported by fused_gating policy"
gate_output = fused_topkgating(
logits, self.k, self.capacity_factor if self.training else self.eval_capacity_factor, self.min_capacity
logits,
self.k,
self.capacity_factor if self.training else self.eval_capacity_factor,
self.min_capacity,
self.enable_token_rearrange_opt,
self.use_tutel,
)
# deepspeed-style code
elif self.k == 1:
Expand Down Expand Up @@ -437,11 +486,11 @@ def forward(


class GShardMoELayer(BaseMoELayer):
"""MOELayer module which implements MixtureOfExperts as described in Gshard_.
"""MoELayer module which implements MixtureOfExperts as described in Gshard_.
::

gate = TopKGate(model_dim, num_experts)
moe = MOELayer(gate, expert)
moe = MoELayer(gate, expert)
output = moe(inputs)
l_aux = moe.l_aux

Expand Down Expand Up @@ -475,6 +524,8 @@ def __init__(
drop_tokens: bool = True,
use_rts: bool = True,
use_fused_gating: bool = True,
enable_token_rearrange_opt: bool = True,
use_tutel: bool = True,
use_grouped_mlp: bool = True,
) -> None:
assert noisy_gate_policy is None or noisy_gate_policy in ["None", "Jitter", "RSample"], (
Expand All @@ -483,6 +534,12 @@ def __init__(
assert (
num_experts % ep_size == 0
), f"Number of experts ({num_experts}) should be divisible by expert parallel size ({ep_size})"

if enable_token_rearrange_opt:
assert (
use_fused_gating or top_k > 2
), "enable_token_rearrange_opt only can be used when use_fused_gating or top_k>2"

if use_grouped_mlp:
experts = new_feed_forward(
in_features,
Expand Down Expand Up @@ -529,6 +586,8 @@ def __init__(
drop_tokens,
use_rts,
use_fused_gating,
enable_token_rearrange_opt,
use_tutel,
),
experts,
ep_group,
Expand All @@ -542,6 +601,9 @@ def __init__(
self.time_salltoall = 0.0
self.time_moe = 0.0
self.wall_clock_breakdown = False
self.enable_token_rearrange_opt = enable_token_rearrange_opt
self.num_experts = num_experts
self.topk = top_k

def forward(self, *inputs: Tensor) -> Tensor:
if self.wall_clock_breakdown:
Expand All @@ -555,11 +617,24 @@ def forward(self, *inputs: Tensor) -> Tensor:
# group_size = kwargs['group_size'] if 'group_size' in kwargs.keys() else 1
reshaped_inputs = inputs[0].reshape(-1, d_model)

self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_inputs, inputs[1])
dispatched_inputs = einsum(
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
) # TODO: heavy memory usage due to long sequence length

if not self.enable_token_rearrange_opt:
self.l_aux, combine_weights, dispatch_mask = self.gate(reshaped_inputs, inputs[1])
dispatched_inputs = einsum(
"sec,sm->ecm", dispatch_mask.type_as(inputs[0]), reshaped_inputs
) # TODO: heavy memory usage due to long sequence length
else:
self.l_aux, token_rearrange_infos = self.gate(reshaped_inputs)
org_dtype = reshaped_inputs.dtype
if org_dtype == torch.bfloat16: # avoid precision missing
rearranged_input = torch.index_select(
reshaped_inputs.to(torch.float32), dim=0, index=token_rearrange_infos.expert_select_token_idx
).to(org_dtype)
else:
rearranged_input = torch.index_select(
reshaped_inputs, dim=0, index=token_rearrange_infos.expert_select_token_idx
)
capacity = token_rearrange_infos.expert_select_token_idx.size(0) // self.num_experts
dispatched_inputs = rearranged_input.reshape(self.num_experts, capacity, d_model).contiguous()
if self.wall_clock_breakdown:
timer("falltoall").start()

Expand Down Expand Up @@ -600,7 +675,25 @@ def forward(self, *inputs: Tensor) -> Tensor:
# Re-shape back: gecm -> ecm
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)

combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
if not self.enable_token_rearrange_opt:
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(inputs[0]), expert_output)
else:
E, C, M = expert_output.shape
org_dtype = expert_output.dtype
if org_dtype == torch.bfloat16:
valid_expert_out = torch.index_select(
expert_output.view(E * C, M).to(torch.float32),
dim=0,
index=token_rearrange_infos.token_rearranged_ec_idx,
).to(org_dtype)
else:
valid_expert_out = torch.index_select(
expert_output.view(E * C, M), dim=0, index=token_rearrange_infos.token_rearranged_ec_idx
)
combined_output = valid_expert_out * token_rearrange_infos.token_exp_weights.unsqueeze(1).type_as(inputs[0])
if self.topk > 1:
combined_output = combined_output.reshape(self.topk, -1, M)
combined_output = torch.sum(combined_output, dim=0)

out = combined_output.reshape(inputs[0].shape)

Expand Down
Loading