Skip to content

Commit

Permalink
refactor(model): refactor model architecture (#126)
Browse files Browse the repository at this point in the history
Co-authored-by: lijiaxing <[email protected]>
Co-authored-by: huangting4201 <[email protected]>
  • Loading branch information
3 people authored May 10, 2024
1 parent 9a2a30e commit 6dfdb34
Show file tree
Hide file tree
Showing 85 changed files with 4,848 additions and 5,607 deletions.
8 changes: 8 additions & 0 deletions configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
# Whether the odd and even columns of the query and key in the model are normally interleaved.
# If it's True, the model's odd and even columns are normally ordered; if it's False,
# it means that the model has prematurely concatenated all odd columns and even columns in front
# and back, in order to improve the RoPE's computational efficiency.
# Example:
# qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
qk_interleaved=False,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
num_experts=4,
moe_use_residual=False,
Expand Down
8 changes: 8 additions & 0 deletions configs/7B_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@
layer_norm_epsilon=1e-5,
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
use_flash_attn=True,
# Whether the odd and even columns of the query and key in the model are normally interleaved.
# If it's True, the model's odd and even columns are normally ordered; if it's False,
# it means that the model has prematurely concatenated all odd columns and even columns in front
# and back, in order to improve the RoPE's computational efficiency.
# Example:
# qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
qk_interleaved=False,
)

"""
Expand Down
8 changes: 8 additions & 0 deletions configs/7B_isp_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
# Whether the odd and even columns of the query and key in the model are normally interleaved.
# If it's True, the model's odd and even columns are normally ordered; if it's False,
# it means that the model has prematurely concatenated all odd columns and even columns in front
# and back, in order to improve the RoPE's computational efficiency.
# Example:
# qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
qk_interleaved=False,
)
"""
zero1 parallel (dict):
Expand Down
8 changes: 8 additions & 0 deletions configs/7B_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@
layer_norm_epsilon=1e-5,
num_kv_attention_heads=NUM_KV_ATTENTION_HEAD,
use_flash_attn=True,
# Whether the odd and even columns of the query and key in the model are normally interleaved.
# If it's True, the model's odd and even columns are normally ordered; if it's False,
# it means that the model has prematurely concatenated all odd columns and even columns in front
# and back, in order to improve the RoPE's computational efficiency.
# Example:
# qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
qk_interleaved=False,
)

"""
Expand Down
8 changes: 8 additions & 0 deletions configs/7B_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
# Whether the odd and even columns of the query and key in the model are normally interleaved.
# If it's True, the model's odd and even columns are normally ordered; if it's False,
# it means that the model has prematurely concatenated all odd columns and even columns in front
# and back, in order to improve the RoPE's computational efficiency.
# Example:
# qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...]
# qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...]
qk_interleaved=False,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
)
"""
Expand Down
2 changes: 1 addition & 1 deletion internlm/checkpoint/load_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from internlm.accelerator import get_accelerator
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.solver.pipeline_utils import partition_uniform
from internlm.core.parallel.shard import partition_uniform
from internlm.utils.logger import get_logger
from internlm.utils.storage_manager import get_fns, llm_load

Expand Down
231 changes: 0 additions & 231 deletions internlm/core/communication/utils.py

This file was deleted.

9 changes: 8 additions & 1 deletion internlm/core/naive_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp

from functools import partial
from typing import Any, Union
from typing import Any, List, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -206,3 +206,10 @@ def _post_forward_hook_for_fp32(
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32))


def unwrap_naive_amp(model: Union[nn.Module, nn.ModuleList]) -> List[nn.Module]:
if not isinstance(model, nn.ModuleList):
model = [model]

return [_chunk.model if isinstance(_chunk, NaiveAMPModel) else _chunk for _chunk in model]
File renamed without changes.
Loading

0 comments on commit 6dfdb34

Please sign in to comment.