Skip to content

Commit

Permalink
feat(mlp): support mlp layer fusion (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
SolenoidWGT authored Apr 1, 2024
1 parent 892862e commit c1a1936
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 31 deletions.
3 changes: 3 additions & 0 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ def args_sanity_check():
check_megablock_installed()
check_stk_installed()

if "mlp_layer_fusion" not in model:
model._add_item("mlp_layer_fusion", False)

# process the parallel config
if "sequence_parallel" not in gpc.config.parallel:
gpc.config.parallel._add_item("sequence_parallel", False)
Expand Down
14 changes: 14 additions & 0 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def __init__(
use_flash_attn: bool = True,
tp_mode: str = "mtp",
rope_base: int = 10000,
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
super().__init__()
self.checkpoint = checkpoint
Expand Down Expand Up @@ -125,6 +127,9 @@ def __init__(
bias=False,
device=device,
dtype=dtype,
mlp_layer_fusion=mlp_layer_fusion,
sequence_parallel=gpc.config.parallel.sequence_parallel,
multiple_of=multiple_of,
)
else:
assert gpc.config.use_cuda_flash_attn is True
Expand Down Expand Up @@ -171,6 +176,7 @@ def reset_parameters(self):
if self.use_scaled_init and "w2" in name:
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
else:
# candidate: w1, w3, fused_w1_w3
normal_(std=0.006 if "w1" in name or "w3" in name else 0.0015)(param.data)
else:
if self.use_scaled_init and "fc1" not in name:
Expand Down Expand Up @@ -297,6 +303,8 @@ def __init__(
use_swiglu: bool = True,
use_flash_attn: bool = True,
rope_base: int = 10000,
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
super().__init__()

Expand Down Expand Up @@ -352,6 +360,8 @@ def __init__(
use_flash_attn=use_flash_attn,
tp_mode=self.tp_mode,
rope_base=rope_base,
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
)
for lid in range(num_layers)
]
Expand Down Expand Up @@ -488,6 +498,8 @@ def build_model_with_cfg(
use_swiglu: bool = True,
use_flash_attn: bool = True,
rope_base: int = 10000,
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
"""
Build model with config.
Expand Down Expand Up @@ -545,6 +557,8 @@ def build_model_with_cfg(
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
rope_base=rope_base,
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
)

return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
14 changes: 14 additions & 0 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ def __init__(
ffn_other_init_std: float = 0.02,
init_type: str = "normal",
rope_base: int = 10000,
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
super().__init__()
self.checkpoint = checkpoint
Expand Down Expand Up @@ -595,6 +597,9 @@ def __init__(
bias=False,
device=device,
dtype=dtype,
mlp_layer_fusion=mlp_layer_fusion,
sequence_parallel=sequence_parallel,
multiple_of=multiple_of,
)
else:
from flash_attn.modules.mlp import ParallelFusedMLP
Expand Down Expand Up @@ -646,6 +651,7 @@ def reset_parameters(self):
if self.use_scaled_init and "w2" in name:
self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
else:
# candidate: w1, w3, fused_w1_w3
self.init_func(
std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std
)(param.data)
Expand Down Expand Up @@ -835,6 +841,8 @@ def __init__(
rope_base: int = 10000,
norm_head: bool = False,
tp_mode: str = "mtp",
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
super().__init__()

Expand Down Expand Up @@ -914,6 +922,8 @@ def __init__(
init_type=init_type,
tp_mode=self.tp_mode,
rope_base=rope_base,
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
)
for lid in range(num_layers)
]
Expand Down Expand Up @@ -1077,6 +1087,8 @@ def build_model_with_cfg(
norm_head: bool = False,
max_position_embeddings=2048,
use_dynamic_ntk_rope=False,
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
"""
Builde model with config
Expand Down Expand Up @@ -1157,6 +1169,8 @@ def build_model_with_cfg(
norm_head=norm_head,
max_position_embeddings=max_position_embeddings,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
)

return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
14 changes: 14 additions & 0 deletions internlm/model/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ def __init__(
init_type: str = "normal",
rope_base: int = 10000,
tp_mode: str = "mtp",
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
super().__init__()
self.checkpoint = checkpoint
Expand Down Expand Up @@ -582,6 +584,9 @@ def __init__(
bias=False,
device=device,
dtype=dtype,
mlp_layer_fusion=mlp_layer_fusion,
sequence_parallel=gpc.config.parallel.sequence_parallel,
multiple_of=multiple_of,
)
else:
from flash_attn.modules.mlp import ParallelFusedMLP
Expand Down Expand Up @@ -633,6 +638,7 @@ def reset_parameters(self):
if self.use_scaled_init and "w2" in name:
self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data)
else:
# candidate: w1, w3, fused_w1_w3
self.init_func(
std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std
)(param.data)
Expand Down Expand Up @@ -813,6 +819,8 @@ def __init__(
out_head_init_std: float = 0.02,
init_type: str = "normal",
rope_base: int = 10000,
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
super().__init__()

Expand Down Expand Up @@ -886,6 +894,8 @@ def __init__(
init_type=init_type,
rope_base=rope_base,
tp_mode=self.tp_mode,
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
)
for lid in range(num_layers)
]
Expand Down Expand Up @@ -1042,6 +1052,8 @@ def build_model_with_cfg(
out_head_init_std: float = 0.02,
init_type: str = "normal",
rope_base: int = 10000,
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
"""
Builde model with config
Expand Down Expand Up @@ -1117,6 +1129,8 @@ def build_model_with_cfg(
out_head_init_std=out_head_init_std,
init_type=init_type,
rope_base=rope_base,
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
)

return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
14 changes: 14 additions & 0 deletions internlm/model/modeling_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(
use_flash_attn: bool = True,
num_experts: int = 1,
tp_mode: str = "mtp",
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
super().__init__()
self.checkpoint = checkpoint
Expand Down Expand Up @@ -130,6 +132,9 @@ def __init__(
bias=False,
device=device,
dtype=dtype,
mlp_layer_fusion=mlp_layer_fusion,
sequence_parallel=gpc.config.parallel.sequence_parallel,
multiple_of=multiple_of,
)
else:
from flash_attn.modules.mlp import ParallelFusedMLP
Expand Down Expand Up @@ -190,6 +195,7 @@ def reset_parameters(self):
if self.use_scaled_init and "w2" in name:
scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data)
else:
# candidate: w1, w3, fused_w1_w3
normal_(std=0.006 if "w1" in name or "w3" in name else 0.0015)(param.data)
else:
if self.use_scaled_init and "fc1" not in name:
Expand Down Expand Up @@ -323,6 +329,8 @@ def __init__(
use_swiglu: bool = True,
use_flash_attn: bool = True,
num_experts: bool = 1,
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
super().__init__()

Expand Down Expand Up @@ -378,6 +386,8 @@ def __init__(
use_flash_attn=use_flash_attn,
num_experts=num_experts,
tp_mode=self.tp_mode,
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
)
for lid in range(num_layers)
]
Expand Down Expand Up @@ -517,6 +527,8 @@ def build_model_with_moe_cfg(
num_experts: int = 1,
moe_use_residual: bool = False, # pylint: disable=W0613
moe_type: str = None, # pylint: disable=W0613
mlp_layer_fusion: bool = False,
multiple_of: int = 256,
):
"""
Build model with config.
Expand Down Expand Up @@ -576,6 +588,8 @@ def build_model_with_moe_cfg(
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
num_experts=num_experts,
mlp_layer_fusion=mlp_layer_fusion,
multiple_of=multiple_of,
)

return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
Loading

0 comments on commit c1a1936

Please sign in to comment.