Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Heterogeneous Code Part 1: Add Model and Module Code for Chameleon Lumina #377

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
772 changes: 772 additions & 0 deletions internlm/model/modeling_chameleon.py

Large diffs are not rendered by default.

80 changes: 73 additions & 7 deletions internlm/model/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from torch import nn
from torch.nn import functional as F

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.parallel.comm.utils import (
gather_forward_split_backward,
split_forward_gather_backward,
)
from internlm.model.modules.embedding import new_rotary_embedding
from internlm.model.modules.linear import new_linear
from internlm.model.modules.utils import update_kv_cache
Expand Down Expand Up @@ -373,6 +378,30 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
return self.out_proj(rearrange(context, "b s h d -> b s (h d)"))


class ChameleonLayerNorm(nn.LayerNorm):
"""
LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta
from each shard separately to each head, instead of reducing. We can apply each head's own
gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed
in the last dimension. This module applies gamma/beta manually to fulfill this requirement.
"""

def __init__(self, hidden_size, head_group_num, n_heads_per_group, *args, **kwargs):
if isinstance(hidden_size, int):
hidden_size = (hidden_size,)
super().__init__([head_group_num, *hidden_size], *args, **kwargs)
self.normalized_shape = (hidden_size[-1],)
self.n_heads_per_group = n_heads_per_group

def repeat_param(self, param):
return param.repeat_interleave(self.n_heads_per_group, dim=0)

def forward(self, hidden_states):
hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5)
hidden_states = hidden_states * self.repeat_param(self.weight) + self.repeat_param(self.bias)
return hidden_states


class GQA(nn.Module):
"""
Multi-head self-attention and cross-attention.
Expand All @@ -397,6 +426,8 @@ class GQA(nn.Module):
dtype (Optional[torch.dtype]): The type of data.
qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default.
enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default.
qk_norm (Optional[bool]): if set, the query and key will be applied by layer norm after qk_linear.
False by default.
"""

def __init__(
Expand All @@ -419,6 +450,8 @@ def __init__(
dtype: Optional[torch.dtype] = None,
qk_interleaved: Optional[bool] = True,
enable_qkv_fusion: bool = True,
qk_norm: bool = False,
chameleon_mp_size: int = 1,
) -> None:
super().__init__()
self.layer_idx = layer_idx
Expand Down Expand Up @@ -459,6 +492,10 @@ def __init__(
rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native",
)

self.qk_norm = qk_norm
if qk_norm:
assert enable_qkv_fusion is False, "qk_norm cannot be applied when fused wqkv"

if enable_qkv_fusion:
assert bias is False, "Fuesd wqkv only support bias is False."
self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs)
Expand All @@ -470,6 +507,11 @@ def __init__(
self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs)
self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs)
self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs)
if qk_norm:
assert num_heads % chameleon_mp_size == 0, "num_heads%chameleon_mp_size != 0 in GQA"
assert num_kv_heads % chameleon_mp_size == 0, "num_kv_heads%chameleon_mp_size != 0 in GQA"
self.q_norm = ChameleonLayerNorm(self.head_dim, chameleon_mp_size, num_heads // chameleon_mp_size)
self.k_norm = ChameleonLayerNorm(self.head_dim, chameleon_mp_size, num_kv_heads // chameleon_mp_size)

self.inner_attn = SelfAttention(
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx
Expand Down Expand Up @@ -508,10 +550,21 @@ def _training(self, x, **kwargs):
q = rearrange(q, "b s h gs d -> b s (h gs) d")
else:
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)

if self.qk_norm:
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
q_all = gather_forward_split_backward(q, ParallelMode.TENSOR, dim=-2)
q_norm_out = self.q_norm(q_all)
q = split_forward_gather_backward(q_norm_out, ParallelMode.TENSOR, dim=-2)
k_all = gather_forward_split_backward(k, ParallelMode.TENSOR, dim=-2)
k_norm_out = self.k_norm(k_all)
k = split_forward_gather_backward(k_norm_out, ParallelMode.TENSOR, dim=-2)

v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qk_norm这块可能要额外区分一下,如果is_using_isp()wp并行算法的话,就不需要走gather的逻辑了,因为isp算法forward时权重是完整的,算出来的qkv head也是完整的,不需要gather

else:
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
kwargs = _convert_cu_seqlens_for_qksplited(kwargs)

# rotary embedding
Expand Down Expand Up @@ -584,9 +637,22 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613
q = rearrange(q, "b s h gs d -> b s (h gs) d")
else:
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
if self.qk_norm:
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)

q_all = gather_forward_split_backward(q, ParallelMode.TENSOR, dim=-2)
q_norm_out = self.q_norm(q_all)
q = split_forward_gather_backward(q_norm_out, ParallelMode.TENSOR, dim=-2)
k_all = gather_forward_split_backward(k, ParallelMode.TENSOR, dim=-2)
k_norm_out = self.k_norm(k_all)
k = split_forward_gather_backward(k_norm_out, ParallelMode.TENSOR, dim=-2)

v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块同上

else:
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim)
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim)
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim)

# rotary embedding, output: q, kv
assert self.rotary_emb_dim > 0
Expand Down
6 changes: 4 additions & 2 deletions internlm/model/modules/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
Shape = Union[int, List[int], torch.Size]


def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False):
def new_layer_norm(
norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False, convert_to_input_dtype=False
):
if norm_type == "rmsnorm":
rmsnorm_params = inspect.signature(RMSNorm).parameters
if "add_unit_offset" in rmsnorm_params:
return RMSNorm(normalized_shape, eps, add_unit_offset)
return RMSNorm(normalized_shape, eps, add_unit_offset, convert_to_input_dtype)
else:
return RMSNorm(normalized_shape, eps)
else: # default: layernorm
Expand Down
6 changes: 3 additions & 3 deletions internlm/model/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_p
attn_type, op = _select_attn_op(AttnOpType.FixedLenKVPacked)

dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p
extra_args = (key_padding_mask) if attn_type is AttnType.Torch else ()
extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else ()

extra_kwargs = {}
if attn_type is AttnType.SlidingWindowZigZagFlash:
Expand Down Expand Up @@ -1007,7 +1007,7 @@ def _q_kv_with_cu_seqlens(
attn_type, op = _select_attn_op(AttnOpType.VarLenKVPacked)

dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p
extra_args = (key_padding_mask) if attn_type is AttnType.Torch else ()
extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else ()

return op(
q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_p
attn_type, op = _select_attn_op(AttnOpType.FixedLenKVPacked)

dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p
extra_args = (key_padding_mask) if attn_type is AttnType.Torch else ()
extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else ()

return op(q, kv, dropout, softmax_scale, causal, *extra_args)

Expand Down
16 changes: 11 additions & 5 deletions internlm/model/ops/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
torchnpu_rmsnorm_impl = False


def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=False):
def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=False, convert_to_input_dtype=False):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
Expand All @@ -44,8 +44,11 @@ def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=Fal
if weight is None:
return my_input

# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
if convert_to_input_dtype:
input_dtype = my_input.dtype
my_input = my_input.to(input_dtype)
huangting4201 marked this conversation as resolved.
Show resolved Hide resolved
elif weight.dtype in [torch.float16, torch.bfloat16]:
# convert into half-precision if necessary
my_input = my_input.to(weight.dtype)

if add_unit_offset:
Expand All @@ -57,7 +60,7 @@ def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=Fal
class _RMSNorm(torch.nn.Module):
"""A generic module for RMS normalization."""

def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False):
def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False, convert_to_input_dtype=False):
super().__init__()

if isinstance(normalized_shape, numbers.Integral):
Expand All @@ -67,14 +70,17 @@ def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False):
self.weight = Parameter(torch.empty(*normalized_shape))
self.add_unit_offset = add_unit_offset
self.reset_parameters()
self.convert_to_input_dtype = convert_to_input_dtype

def forward(self, _input: torch.Tensor):
if apex_rmsnorm_impl:
_norm_func = mixed_dtype_fused_rms_norm_affine
return _norm_func(_input, self.weight, self.normalized_shape, self.eps)
else:
_norm_func = manual_rms_norm
return _norm_func(_input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset)
return _norm_func(
_input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset, self.convert_to_input_dtype
)

def reset_parameters(self):
if self.add_unit_offset:
Expand Down
2 changes: 2 additions & 0 deletions internlm/model/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable

from internlm.model.modeling_baichuan2 import Baichuan2
from internlm.model.modeling_chameleon import ChameleonModel
from internlm.model.modeling_gemma import Gemma
from internlm.model.modeling_internlm import InternLM1
from internlm.model.modeling_internlm2 import InternLM2
Expand Down Expand Up @@ -93,6 +94,7 @@ def register_model_initializer() -> None:
model_initializer.register_module(ModelType.GEMMA.name, Gemma)
model_initializer.register_module(ModelType.QWEN2MOE.name, Qwen2Moe)
model_initializer.register_module(ModelType.MIXTRALMOE.name, MixtralMoE)
model_initializer.register_module(ModelType.CHAMELEON.name, ChameleonModel)


register_model_initializer()
2 changes: 2 additions & 0 deletions internlm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,15 @@ class ModelType(Enum):
GEMMA = 8
QWEN2MOE = 9
MIXTRALMOE = 10
CHAMELEON = 11


class DataType(Enum):
streaming = 1
tokenized = 2
megatron = 3
mocked = 4
lumina_pickle = 5


class TensorParallelMode(Enum):
Expand Down
23 changes: 20 additions & 3 deletions tests/test_model/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def check_norm(args):
# init
rank, world_size, free_port = args
rank, world_size, free_port, convert_to_input_dtype = args
build_environment(rank, world_size, free_port)
device = get_current_device()
rtol, atol = (1e-3, 5e-3)
Expand All @@ -22,7 +22,12 @@ def check_norm(args):
seed_all(1024)

# define norm
norm = new_layer_norm(norm_type="rmsnorm", normalized_shape=hidden_size, eps=layer_norm_epsilon)
norm = new_layer_norm(
norm_type="rmsnorm",
normalized_shape=hidden_size,
eps=layer_norm_epsilon,
convert_to_input_dtype=convert_to_input_dtype,
)
norm = norm.to(device)

# create input
Expand Down Expand Up @@ -83,8 +88,20 @@ def check_norm(args):
def test_norm():
ctx = mp.get_context("spawn")
free_port = str(find_free_port())
convert_input_dtype = False
with ctx.Pool(processes=8) as pool:
pool.map(check_norm, [[rank, 8, free_port] for rank in range(8)])
pool.map(check_norm, [[rank, 8, free_port, convert_input_dtype] for rank in range(8)])
pool.close()
pool.join()


@pytest.mark.norm_convert_to_input_dtype
def test_norm_convert_to_input_dtype():
ctx = mp.get_context("spawn")
free_port = str(find_free_port())
convert_input_dtype = True
with ctx.Pool(processes=8) as pool:
pool.map(check_norm, [[rank, 8, free_port, convert_input_dtype] for rank in range(8)])
pool.close()
pool.join()

Expand Down
Loading