Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(hf): fix convert_inetrnevo2hf for internlm2 model #401

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional

import torch
from einops import rearrange
from torch import nn
from tqdm import tqdm

Expand Down Expand Up @@ -771,18 +770,6 @@ def unique_kv_index(i):

@staticmethod
def convert_internevo2hf_weights(src: str, tgt: str) -> None:
def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True):
if adapt_hf:
return qkv
q_per_kv = num_heads // num_kv_heads
qkv = rearrange(qkv.T, "o (g n i) -> o g n i", n=q_per_kv + 2, i=head_dim)
q, k, v = qkv[..., :q_per_kv, :], qkv[..., -2:-1, :], qkv[..., -1:, :]
q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1)
k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)
qkv = torch.cat((q, k, v), dim=2)
qkv = rearrange(qkv, "o g n i -> o (g n i)").T
return qkv

model_config = gpc.config.model
tp_mode = gpc.config.parallel.tensor["mode"]
row_dim = 0 if tp_mode == "isp" else 1
Expand All @@ -808,12 +795,14 @@ def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True):
}
)
# attn
state_dict[f"model.layers.{layer_i}.attention.wqkv.weight"] = permute(
torch.cat([states[i][f"layers.{layer_i}.attention.wqkv.weight"] for i in range(num_shards)], dim=0),
num_heads=model_config["num_attention_heads"],
num_kv_heads=model_config["num_kv_attention_heads"],
head_dim=model_config["hidden_size"] // model_config["num_attention_heads"],
adapt_hf=model_config.get("adapt_hf", True),
state_dict[f"model.layers.{layer_i}.attention.wq.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.attention.wk.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.attention.wv.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.attention.wo.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim
Expand Down
Loading