-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat] Heterogeneous Code Part 1: Add Model and Module Code for Chameleon Lumina #377
Open
zhhsplendid
wants to merge
5
commits into
InternLM:develop
Choose a base branch
from
zhhsplendid:hetero_merge_1
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
264a642
Add model and module code for Chameleon Lumina
zhhsplendid 910032b
Modify based on pre-commit
zhhsplendid 841ed0a
Modify norm more reuseable
zhhsplendid 99a3c58
Fix model name
zhhsplendid 692a97a
Modify based on reviewer's comment 1
zhhsplendid File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,12 @@ | |
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
from internlm.core.context import ParallelMode | ||
from internlm.core.context import global_context as gpc | ||
from internlm.core.parallel.comm.utils import ( | ||
gather_forward_split_backward, | ||
split_forward_gather_backward, | ||
) | ||
from internlm.model.modules.embedding import new_rotary_embedding | ||
from internlm.model.modules.linear import new_linear | ||
from internlm.model.modules.utils import update_kv_cache | ||
|
@@ -373,6 +378,30 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 | |
return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) | ||
|
||
|
||
class ChameleonLayerNorm(nn.LayerNorm): | ||
""" | ||
LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta | ||
from each shard separately to each head, instead of reducing. We can apply each head's own | ||
gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed | ||
in the last dimension. This module applies gamma/beta manually to fulfill this requirement. | ||
""" | ||
|
||
def __init__(self, hidden_size, head_group_num, n_heads_per_group, *args, **kwargs): | ||
if isinstance(hidden_size, int): | ||
hidden_size = (hidden_size,) | ||
super().__init__([head_group_num, *hidden_size], *args, **kwargs) | ||
self.normalized_shape = (hidden_size[-1],) | ||
self.n_heads_per_group = n_heads_per_group | ||
|
||
def repeat_param(self, param): | ||
return param.repeat_interleave(self.n_heads_per_group, dim=0) | ||
|
||
def forward(self, hidden_states): | ||
hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5) | ||
hidden_states = hidden_states * self.repeat_param(self.weight) + self.repeat_param(self.bias) | ||
return hidden_states | ||
|
||
|
||
class GQA(nn.Module): | ||
""" | ||
Multi-head self-attention and cross-attention. | ||
|
@@ -397,6 +426,8 @@ class GQA(nn.Module): | |
dtype (Optional[torch.dtype]): The type of data. | ||
qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default. | ||
enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default. | ||
qk_norm (Optional[bool]): if set, the query and key will be applied by layer norm after qk_linear. | ||
False by default. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -419,6 +450,8 @@ def __init__( | |
dtype: Optional[torch.dtype] = None, | ||
qk_interleaved: Optional[bool] = True, | ||
enable_qkv_fusion: bool = True, | ||
qk_norm: bool = False, | ||
chameleon_mp_size: int = 1, | ||
) -> None: | ||
super().__init__() | ||
self.layer_idx = layer_idx | ||
|
@@ -459,6 +492,10 @@ def __init__( | |
rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native", | ||
) | ||
|
||
self.qk_norm = qk_norm | ||
if qk_norm: | ||
assert enable_qkv_fusion is False, "qk_norm cannot be applied when fused wqkv" | ||
|
||
if enable_qkv_fusion: | ||
assert bias is False, "Fuesd wqkv only support bias is False." | ||
self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs) | ||
|
@@ -470,6 +507,11 @@ def __init__( | |
self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs) | ||
self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) | ||
self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) | ||
if qk_norm: | ||
assert num_heads % chameleon_mp_size == 0, "num_heads%chameleon_mp_size != 0 in GQA" | ||
assert num_kv_heads % chameleon_mp_size == 0, "num_kv_heads%chameleon_mp_size != 0 in GQA" | ||
self.q_norm = ChameleonLayerNorm(self.head_dim, chameleon_mp_size, num_heads // chameleon_mp_size) | ||
self.k_norm = ChameleonLayerNorm(self.head_dim, chameleon_mp_size, num_kv_heads // chameleon_mp_size) | ||
|
||
self.inner_attn = SelfAttention( | ||
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx | ||
|
@@ -508,10 +550,21 @@ def _training(self, x, **kwargs): | |
q = rearrange(q, "b s h gs d -> b s (h gs) d") | ||
else: | ||
q, k, v = self.wq(x), self.wk(x), self.wv(x) | ||
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) | ||
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) | ||
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) | ||
|
||
if self.qk_norm: | ||
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) | ||
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) | ||
q_all = gather_forward_split_backward(q, ParallelMode.TENSOR, dim=-2) | ||
q_norm_out = self.q_norm(q_all) | ||
q = split_forward_gather_backward(q_norm_out, ParallelMode.TENSOR, dim=-2) | ||
k_all = gather_forward_split_backward(k, ParallelMode.TENSOR, dim=-2) | ||
k_norm_out = self.k_norm(k_all) | ||
k = split_forward_gather_backward(k_norm_out, ParallelMode.TENSOR, dim=-2) | ||
|
||
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) | ||
else: | ||
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) | ||
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) | ||
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) | ||
kwargs = _convert_cu_seqlens_for_qksplited(kwargs) | ||
|
||
# rotary embedding | ||
|
@@ -584,9 +637,22 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 | |
q = rearrange(q, "b s h gs d -> b s (h gs) d") | ||
else: | ||
q, k, v = self.wq(x), self.wk(x), self.wv(x) | ||
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) | ||
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) | ||
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) | ||
if self.qk_norm: | ||
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) | ||
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) | ||
|
||
q_all = gather_forward_split_backward(q, ParallelMode.TENSOR, dim=-2) | ||
q_norm_out = self.q_norm(q_all) | ||
q = split_forward_gather_backward(q_norm_out, ParallelMode.TENSOR, dim=-2) | ||
k_all = gather_forward_split_backward(k, ParallelMode.TENSOR, dim=-2) | ||
k_norm_out = self.k_norm(k_all) | ||
k = split_forward_gather_backward(k_norm_out, ParallelMode.TENSOR, dim=-2) | ||
|
||
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块同上 |
||
else: | ||
q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) | ||
k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) | ||
v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) | ||
|
||
# rotary embedding, output: q, kv | ||
assert self.rotary_emb_dim > 0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qk_norm这块可能要额外区分一下,如果is_using_isp()wp并行算法的话,就不需要走gather的逻辑了,因为isp算法forward时权重是完整的,算出来的qkv head也是完整的,不需要gather