Skip to content

Commit

Permalink
Fix Qwen2 MoE Loss Convergence Issue (#275)
Browse files Browse the repository at this point in the history
Co-authored-by: 同润 <[email protected]>
  • Loading branch information
jerryli1981 and 同润 authored Jun 23, 2024
1 parent 4c94f2e commit 0ac48a0
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 3 deletions.
2 changes: 1 addition & 1 deletion megatron_patch/model/qwen1_5/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(
self.router = TopKRouter(config=self.config)
self.enable_shared_experts = config.enable_shared_expert
if config.enable_shared_expert:
self.shared_expert = MLP(self.config, submodules, is_expert=True, is_shared_expert=True)
self.shared_expert = MLP(self.config, submodules, is_expert=False, is_shared_expert=True)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)

if self.config.moe_grouped_gemm:
Expand Down
2 changes: 1 addition & 1 deletion megatron_patch/model/qwen2/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
self.router = TopKRouter(config=self.config)
self.enable_shared_experts = config.enable_shared_expert
if config.enable_shared_expert:
self.shared_expert = MLP(self.config, submodules, is_expert=True, is_shared_expert=True)
self.shared_expert = MLP(self.config, submodules, is_expert=False, is_shared_expert=True)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)

if self.config.moe_grouped_gemm:
Expand Down
78 changes: 77 additions & 1 deletion megatron_patch/model/qwen2/moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,89 @@
MoEAuxLossAutoScaler,
save_to_aux_losses_tracker,
sinkhorn,
get_capacity,
switch_load_balancing_loss_func,
topk_softmax_with_capacity,
z_loss_func,
)
from megatron.core.transformer.transformer_config import TransformerConfig


def topk_softmax_with_capacity(
logits: torch.Tensor,
topk: int,
capacity_factor: float = None,
pad_to_capacity: bool = False,
drop_policy: str = "probs",
):
"""Apply capacity and padding to the top-k selection.
Args:
logits (torch.Tensor): Logits tensor.
topk (int): The number of experts to select for each token.
capacity_factor (int): The capacity factor of each expert. Will drop tokens if the number of tokens exceeds the capacity.
pad_to_capacity (bool): Whether to need padding in token drop mode.
drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". If "prob", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Probs, indices and tokens_per_expert tensor.
(1) If there's no token padding, the shape of probs and indices is [tokens, top_k], indicating the selected experts for each token.
(2) If there's token padding, the shape of probs and indices is [num_expert, capacity], indicating the tokens selected for each expert.
"""
# TODO: Add Pre softmax.
assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens = logits.shape[0]
num_experts = logits.shape[1]

#scores, top_indices = torch.topk(logits, k=topk, dim=1)
#probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)

routing_weights = torch.softmax(logits, dim=1, dtype=torch.float32).type_as(logits)
probs, top_indices = torch.topk(routing_weights, k=topk, dim=-1)

if capacity_factor is None:
# TopK without capacity
tokens_per_expert = torch.histc(top_indices, bins=num_experts, min=0, max=num_experts)
return probs, top_indices, tokens_per_expert
else:
# TopK with capacity
expert_capacity = get_capacity(
num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor,
)
# TopK selection, Maskout unused experts
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_mask = torch.zeros_like(logits).scatter(1, top_indices, 1)

# Maskout exceeded tokens
if drop_policy == "probs":
capacity_probs, capacity_indices = torch.topk(
topk_masked_gates, k=expert_capacity, dim=0, sorted=False
)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
elif drop_policy == "position":
_, capacity_indices = torch.topk(topk_mask, k=expert_capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1)
capacity_probs = torch.gather(topk_masked_gates, 0, capacity_indices)
else:
raise ValueError(f"Invalid drop_policy: {drop_policy}")

if pad_to_capacity:
final_probs, final_indices = (
capacity_probs.T.contiguous(),
capacity_indices.T.contiguous(),
)
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
else:
# Get exceed mask and maskout exceeded probs and indices
final_mask = torch.logical_and(topk_mask, capacity_mask)
drop_mask = torch.logical_not(final_mask)
exceed_mask = torch.gather(drop_mask, 1, top_indices)
final_probs = probs * torch.logical_not(exceed_mask)
final_indices = top_indices.clone().masked_fill_(
exceed_mask, torch.iinfo(torch.long).max
)
tokens_per_expert_before_capacity = topk_mask.sum(dim=0)
return final_probs, final_indices, tokens_per_expert_before_capacity

class Router(ABC, MegatronModule):
"""Base Router class"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ moe_options=" \
--target-expert-model-parallel-size ${EP}\
--moe-ffn-hidden-size ${MOE_INTERMEDIATE_SIZE} \
--shared-moe-ffn-hidden-size ${SHARED_EXPERT_INTERMEDIATE_SIZE} \
--moe-router-load-balancing-type aux_loss \
--moe-aux-loss-coeff 1e-2 \
--enable-shared-expert"

cpu_options=" \
Expand Down Expand Up @@ -189,6 +191,7 @@ torchrun ${DISTRIBUTED_ARGS} hf2mcore_qwen2_dense_and_moe_gqa.py \
--num-query-groups ${NUM_KEY_VALUE_HEADS} \
--normalization RMSNorm \
--norm-epsilon ${RMS_NORM_EPS} \
--use-mcore-models \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--rotary-base ${ROPE_THETA} \
Expand Down

0 comments on commit 0ac48a0

Please sign in to comment.