From 264a64231bcf5e10c66d457ef6afee51449b28ff Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Thu, 28 Nov 2024 16:58:39 +0800 Subject: [PATCH 1/5] Add model and module code for Chameleon Lumina --- internlm/model/modeling_chameleon.py | 884 +++++++++++++++++++++++++++ internlm/model/modules/mha.py | 74 ++- internlm/model/modules/norm.py | 4 +- internlm/model/ops/attention.py | 6 +- internlm/model/ops/norm.py | 25 +- internlm/model/registry.py | 2 + 6 files changed, 973 insertions(+), 22 deletions(-) create mode 100644 internlm/model/modeling_chameleon.py diff --git a/internlm/model/modeling_chameleon.py b/internlm/model/modeling_chameleon.py new file mode 100644 index 00000000..8797ce07 --- /dev/null +++ b/internlm/model/modeling_chameleon.py @@ -0,0 +1,884 @@ +import math +import os +from functools import cached_property +from typing import Optional + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from tqdm import tqdm + +# Should re-implement CausalLMOutputWithPast? +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from internlm.accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.naive_amp import set_output_attr_to_module +from internlm.initialize.initialize_tensor import ( + normal_, + scaled_init_method_normal, + scaled_init_method_uniform, + uniform_, +) +from internlm.model.base_model import BaseModel +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import GQA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm +from internlm.model.utils import ( + convert_attn_args_to_kwargs, + convert_attn_kwargs_to_args, +) +from internlm.solver.activation_checkpoint import activation_checkpoint +from internlm.utils.logger import get_logger +from internlm.utils.storage_manager import get_fns, llm_load, llm_save +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + shard_checkpoint, +) + +internlm_accelerator = get_accelerator() +logger = get_logger(__file__) + + +class ChameleonDecoderLayer(nn.Module): + """ + Chameleon Decoder Layer. + + Args: + hidden_size (int): The hidden size of model. 768 by default. + num_attention_heads (int): The number of attention heads. 12 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 12. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + dtype (torch.dtype): Type of data. torch.float by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + layer_idx (int): The index of current layer. 0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + device (Optional[Union[str, torch.device]]): The device will be used. + apply_post_layer_norm (bool): Whether to apply layer normalization after the attention and mlp. + Defaults to False. + fused_dropout_add_ln (bool): Whether to fuse dropout, residual addition, and layer normalization. + Defaults to True. + no_bias (bool): Whether to exclude bias in attention and feed-forward networks. Defaults to False. + norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. + qk_norm (bool): Support q norm and k norm. + """ + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + mlp_ratio: int = 4, + attn_drop_rate: float = 0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + layer_norm_epsilon: float = 1e-6, + checkpoint: bool = False, + layer_idx: int = 0, + residual_in_fp32: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm: bool = False, + fused_dropout_add_ln: bool = True, + no_bias: bool = False, + norm_type: str = "rmsnorm", + qk_interleaved: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + mlp_layer_fusion: bool = False, + multiple_of: int = 256, + qk_norm = True, + chameleon_mp_size = 1, + ): + super().__init__() + self.checkpoint = checkpoint + # dropout selective checkpoint can only be enabled when checkpoint is disabled. + self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False + self.layer_idx = layer_idx + self.prenorm = not apply_post_layer_norm + assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" + self.fused_dropout_add_ln = fused_dropout_add_ln + self.attn_wqkv_init_std = attn_wqkv_init_std + self.attn_other_init_std = attn_other_init_std + self.ffn_uplayer_init_std = ffn_uplayer_init_std + self.ffn_other_init_std = ffn_other_init_std + + head_dim = hidden_size // num_attention_heads + + self.attention = GQA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_kv_attention_heads, + dropout=attn_drop_rate, + softmax_scale=1 / math.sqrt(head_dim), + causal=True, + layer_idx=layer_idx, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + device=device, + dtype=dtype, + qk_interleaved=qk_interleaved, + bias=not no_bias, + rope_base=rope_base, + enable_qkv_fusion=False, + qk_norm=qk_norm, + chameleon_mp_size=chameleon_mp_size, + ) + + self.dropout = nn.Dropout(drop_rate) + self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=False, is_Chameleon=True) + self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=False, is_Chameleon=True) + + self.feed_forward = new_feed_forward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + bias=False, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + # TODO: to support more activation functions + activation_type="swiglu" if use_swiglu else "swiglu", + ) + + self.use_swiglu = use_swiglu + self.use_scaled_init = use_scaled_init + self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm + self.return_residual = False + + if init_type == "normal": + self.init_func = normal_ + self.scaled_init_func = scaled_init_method_normal + else: + self.init_func = uniform_ + self.scaled_init_func = scaled_init_method_uniform + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + for name, param in self.attention.named_parameters(): + if param.ndim == 1: + param.data.zero_() + elif "wq" in name or "wk" in name or "wv" in name: + self.init_func(std=self.attn_wqkv_init_std)(param.data) + elif self.use_scaled_init: # wo + self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.attn_other_init_std)(param.data) + + for name, param in self.feed_forward.named_parameters(): + if self.use_swiglu: + if self.use_scaled_init and "w2" in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + # candidate: w1, w3, fused_w1_w3 + self.init_func( + std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std + )(param.data) + else: + if self.use_scaled_init and "fc1" not in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( + param.data + ) + + def forward(self, hidden_states, residual=None, **kwargs): + if self.checkpoint and self.training: + # NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True. + args = convert_attn_kwargs_to_args(kwargs) + return activation_checkpoint(self._forward, False, hidden_states, residual, *args) + else: + return self._forward(hidden_states, residual, **kwargs) + + def _forward(self, hidden_states, residual, *args, **kwargs): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Attn/MLP(LN(residual)) + cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 + indexes: the length of index is same as hidden states, which stand for the current position + """ + if self.prenorm: + + residual = hidden_states + + hidden_states = self.attention_norm(hidden_states) + + # def _dropout_and_norm_attn(_residual, _hidden_states): + # _dropped = self.dropout1(_hidden_states) + # _residual = (_dropped + _residual) if _residual is not None else _dropped + # _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) + + # return _residual, _hidden_states + + # if self.dropout_selective_checkpoint: + # residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) + # else: + # residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) + + # if self.residual_in_fp32: + # residual = residual.to(torch.float32) + + attn_kwargs = convert_attn_args_to_kwargs(args, kwargs) + hidden_states = self.attention(hidden_states, **attn_kwargs) + dropped = self.dropout(hidden_states) + hidden_states = residual + dropped + + if not isinstance(self.feed_forward, nn.Identity): + residual = hidden_states + hidden_states = self.ffn_norm(hidden_states) + hidden_states = self.feed_forward(hidden_states) + dropped = self.dropout(hidden_states) + hidden_states = residual + dropped + return hidden_states + else: + assert residual is None + + residual = hidden_states + hidden_states = self.attention(hidden_states, **kwargs) + hidden_states = self.attention_norm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + if not isinstance(self.feed_forward, nn.Identity): + residual = hidden_states + hidden_states = self.feed_forward(hidden_states) + hidden_states = self.ffn_norm(hidden_states) + hidden_states = residual + self.dropout(hidden_states) + return hidden_states + + +class ChameleonImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + + Reference: + https://github.com/Alpha-VLLM/Lumina-mGPT/blob/104abe453ec1acca5863698629c4db2111b0b3fc/lumina_mgpt/model/chameleon/modeling_chameleon.py#L1036 + """ + + def __init__(self, vocab_map): + self.vocab_map = vocab_map + self.image_token_id = vocab_map.get("") + + @cached_property + def val2name(self): + return {v: k for k, v in self.vocab_map.items()} + + @cached_property + def image_tokens(self): + return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")]) + + @cached_property + def bpe2img(self): + img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} + + def remap(old_name: str) -> str: + return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1]) + + return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens} + + @cached_property + def img2bpe(self): + return {v: k for k, v in self.bpe2img.items()} + + @cached_property + def bpe2img_search_tensors(self): + return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values())) + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + + +class ChameleonModel(BaseModel): + """ + Chameleon Model. + + Args: + num_layers (int): The number of layer. 12 by default. + hidden_size (int): The size of hidden state. 768 by default. + num_attention_heads (int): The number of attention head. 12 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 12. + vocab_size (int): The size of vocabulary. 50304 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. + drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + dtype (torch.dtype): The type of data. torch.float by default. + checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 0.0 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + first (bool): Whether input embedding layer or not. False by default. + last (bool): Whether output embedding layer or not. False by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. None by default. + apply_post_layer_norm (bool): Whether to apply layer normalization after the attention and mlp. + Defaults to False. + no_bias (bool): Whether to exclude bias in attention and feed-forward networks. Defaults to False. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. + embedding_init_std (float): std used to init embedding weight. 0.02 by default, + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. + qk_norm (bool): Support q norm and k norm. + """ + + def __init__( + self, + num_layers: int = 48, + hidden_size: int = 2048, + num_attention_heads: int = 32, + num_kv_attention_heads: int = 32, + vocab_size: int = 50304, + mlp_ratio: float = 4.0, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + dtype: torch.dtype = torch.float, + checkpoint: float = 0.0, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False, + embed_grad_scale: float = 0.1, + parallel_output: bool = True, + start_layer_idx: int = 0, + device: Optional[torch.device] = None, + apply_post_layer_norm=False, + no_bias=False, + residual_in_fp32: bool = False, + norm_type: str = "rmsnorm", + qk_interleaved: bool = False, + is_reward: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_swiglu: bool = True, + embedding_init_std: float = 0.02, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + out_head_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + mlp_layer_fusion: bool = False, + multiple_of: int = 256, + qk_norm = True, + chameleon_mp_size = 1, + ): + super().__init__() + + checkpoint_layer_num = int(num_layers * checkpoint) + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output + if chameleon_mp_size == 4: + apply_post_layer_norm = True + + + if first: + self.padding_idx = None + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size, padding_idx=self.padding_idx) + for _, param in self.tok_embeddings.named_parameters(): + if init_type == "normal": + normal_(std=embedding_init_std)(param) + else: + uniform_(std=embedding_init_std)(param) + + self.layers = nn.ModuleList( + [ + ChameleonDecoderLayer( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_attention_heads=num_kv_attention_heads, + mlp_ratio=mlp_ratio, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + dtype=dtype, + layer_norm_epsilon=layer_norm_epsilon, + checkpoint=lid < checkpoint_layer_num, + layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation + residual_in_fp32=residual_in_fp32, + device=device, + apply_post_layer_norm=apply_post_layer_norm, + fused_dropout_add_ln=False, + no_bias=no_bias, + norm_type=norm_type, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_swiglu=use_swiglu, + qk_interleaved=qk_interleaved, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + init_type=init_type, + rope_base=rope_base, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + qk_norm = qk_norm, + chameleon_mp_size=chameleon_mp_size, + ) + for lid in range(num_layers) + ] + ) + + if last: + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + + self.output = new_linear( + name="output", + in_features=hidden_size, + out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, + bias=False, + device=device, + dtype=dtype, + is_reward=is_reward, + weight_scale=embed_grad_scale, + ) + set_output_attr_to_module(self.output) + for _, param in self.output.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + + def forward(self, hidden_states=None, input_ids=None, **kwargs): + # attention_mask: compute attention on the places where the value is 1 + if hasattr(self, "tok_embeddings") and input_ids is not None: + hidden_states = self.tok_embeddings(input_ids) + if self.embed_grad_scale != 1: + hidden_states = ( + self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() + ) + + for _, block in enumerate(self.layers): + hidden_states = block(hidden_states, residual=None, **kwargs) + + if hasattr(self, "norm"): + hidden_states = self.norm(hidden_states) + + if hasattr(self, "output"): + hidden_states = self.output(hidden_states).float() + + return hidden_states + + @staticmethod + def load_hf_weights(folder: str, model: nn.Module): + """NOTE: when loading huggingface's llama pretrained weights, you should set `adapt_hf=True` in your config.""" + """NOTE: specified for meta-llama/Llama-2-7b-hf""" + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [ + os.path.join(folder, fn) + for fn in fns + if (fn.endswith(".bin") and fn.startswith("pytorch_model")) + or (fn.endswith(".safetensors") and fn.startswith("model")) + ] + model_fns.sort() + + state_dict = {} + for model_fn in model_fns: + state_dict.update(llm_load(model_fn, map_location="cpu")) + + tp_size = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + wp_size = gpc.get_world_size(ParallelMode.WEIGHT) + wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) + tp_mode = gpc.config.parallel.tensor["mode"] + split_size = wp_size if tp_mode == "isp" else tp_size + local_rank = wp_rank if tp_mode == "isp" else tp_rank + row_dim = 0 if tp_mode == "isp" else 1 + if gpc.config.model.get("embed_split_hidden", True): + embed_concat_dim = 1 + else: + embed_concat_dim = 0 + + new_state_dict = {} + + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + layer_ids = i + + # attn + state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + state_dict[f"layers.{i}.attention.q_norm.weight"] = state_dict.pop( + f"model.layers.{i}.self_attn.q_norm.weight" + ) + state_dict[f"layers.{i}.attention.q_norm.bias"] = state_dict.pop( + f"model.layers.{i}.self_attn.q_norm.bias" + ) + state_dict[f"layers.{i}.attention.k_norm.weight"] = state_dict.pop( + f"model.layers.{i}.self_attn.k_norm.weight" + ) + state_dict[f"layers.{i}.attention.k_norm.bias"] = state_dict.pop( + f"model.layers.{i}.self_attn.k_norm.bias" + ) + + # ffn + state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # TODO(zhhsplendid): add qk norm and mapping parameter names? + + # attn norm + state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.input_layernorm.weight" + ) + + # ffn norm + state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.post_attention_layernorm.weight" + ) + + # skip rotary_emb inv_freq + if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in state_dict: + state_dict.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq") + + # replace value within decoder layer + for name in list(state_dict.keys()): + if name.startswith(f"model.vqmodel"): + state_dict.pop(name) + if name.startswith(f"layers.{i}"): + new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) + + # embedding + if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): + new_state_dict["tok_embeddings.weight"] = torch.chunk( + state_dict.pop("model.embed_tokens.weight"), + split_size, + dim=embed_concat_dim, + )[local_rank] + + # output + if gpc.is_last_rank(ParallelMode.PIPELINE): + new_state_dict["output.weight"] = torch.chunk( + state_dict.pop("lm_head.weight"), + split_size, + dim=0, + )[local_rank] + new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight") + + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + internlm_accelerator.empty_cache() + + @staticmethod + def load_llama_pretrained_weights(folder: str, model: nn.Module) -> None: + """NOTE: when loading huggingface's llama pretrained weights, you should set `adapt_hf=True` in your config.""" + """NOTE: specified for meta-llama/Llama-2-7b""" + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [] + for fn in fns: + if fn.startswith("model_t") and not fn.endswith("md5"): + model_fns.append(os.path.join(folder, fn)) + + if len(model_fns) == 0: + model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")] + + if len(model_fns) == 0: + raise FileNotFoundError(f"No checkpoint file found in {folder}") + + model_fns.sort() + + old_tp = len(model_fns) + cur_tp = gpc.get_world_size(ParallelMode.TENSOR) + # If the two tp are inconsistent, you need to consider the merge before splitting + if old_tp != cur_tp: + raise RuntimeError( + f"Your current tp is `{cur_tp}`, but the tp in folder:`{folder}` is `{old_tp}`, use `` to convert first" + ) + + states = llm_load(model_fns[gpc.get_local_rank(ParallelMode.TENSOR)], map_location="cpu") + + current_states = {} + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + for name in list(states.keys()): + if f".{i}." in name: + current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name) + + model_state_keys = set(list(model.state_dict().keys())) + + if "tok_embeddings.weight" in model_state_keys: + current_states["tok_embeddings.weight"] = states["tok_embeddings.weight"] + assert model.first_layer == 0, f"Expect model.NaiveAMPModel to be 0, but got {model.first_layer}" + if "output.weight" in model_state_keys: + current_states["norm.weight"] = states["norm.weight"] + current_states["output.weight"] = states["output.weight"] + missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + internlm_accelerator.empty_cache() + + @staticmethod + def convert_internevo2hf_weights(src: str, tgt: str) -> None: + model_config = gpc.config.model + tp_mode = gpc.config.parallel.tensor["mode"] + row_dim = 0 if tp_mode == "isp" else 1 + if model_config["embed_split_hidden"]: + embed_concat_dim = 1 + else: + embed_concat_dim = 0 + + # load states + states, num_shards = ChameleonModel.load_sharded_states(src) + + # convert state_dict + state_dict = {} + for layer_i in tqdm(range(model_config["num_layers"])): + # attn norm, ffn norm + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": states[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + ) + # attn + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim + ) + state_dict.update( + { + f"model.layers.{layer_i}.self_attn.q_norm.weight" : states[0][f"layers.{layer_i}.attention.q_norm.weight"].clone(), + f"model.layers.{layer_i}.self_attn.q_norm.bias" : states[0][f"layers.{layer_i}.attention.q_norm.bias"].clone(), + f"model.layers.{layer_i}.self_attn.k_norm.weight" : states[0][f"layers.{layer_i}.attention.k_norm.weight"].clone(), + f"model.layers.{layer_i}.self_attn.k_norm.bias" : states[0][f"layers.{layer_i}.attention.k_norm.bias"].clone(), + } + ) + + # ffn + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + # embedding, output + state_dict.update( + { + "model.norm.weight": states[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat( + [states[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=embed_concat_dim + ), + "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0), + }, + ) + + # save state_dict to hf format + shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME) + for shard_file, shard in shards.items(): + llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) + if index is not None: + llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) +''' + +class ChameleonForConditionalGeneration(BaseModel): + def __init__(self, + max_position_embeddings: int, + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + mask_image_logits: bool): + self.max_position_embeddings = max_position_embeddings + + self.model = ChameleonModel() + + self.output_attentions = output_attentions + self.output_hidden_states = output_hidden_states + self.return_dict = return_dict + self.mask_image_logits = mask_image_logits + + def forward(self, + input_ids=None, + labels=None, + pixel_values: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs): + + + + + # Data to torch.tensor + max_tokens = max([len(_) for _ in input_ids]) + max_tokens = min(max_tokens, self.max_position_embeddings) + input_ids = [_[:max_tokens] for _ in input_ids] + labels = [_[:max_tokens] for _ in labels] + input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids] + input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device) + labels = [label + [-100] * (max_tokens - len(label)) for label in labels] + labels = torch.tensor(labels, dtype=torch.int64, device=self.device) + + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + if self.mask_image_logits: + # Disallow image tokens which does not include special begin-image and end-image tokens + image_tokens = self.model.vocabulary_mapping.image_tokens + logits[:, :, image_tokens] = torch.finfo(logits.dtype).min + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not self.return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +''' + diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 42418a21..15e8eede 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -10,7 +10,9 @@ from torch import nn from torch.nn import functional as F +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm.utils import gather_forward_split_backward, split_forward_gather_backward from internlm.model.modules.embedding import new_rotary_embedding from internlm.model.modules.linear import new_linear from internlm.model.modules.utils import update_kv_cache @@ -372,6 +374,28 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 # wo return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) +class ChameleonLayerNorm(nn.LayerNorm): + """ + LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta + from each shard separately to each head, instead of reducing. We can apply each head's own + gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed + in the last dimension. This module applies gamma/beta manually to fulfill this requirement. + """ + + def __init__(self, hidden_size, head_group_num, n_heads_per_group, *args, **kwargs): + if isinstance(hidden_size, int): + hidden_size = (hidden_size,) + super().__init__([head_group_num, *hidden_size], *args, **kwargs) + self.normalized_shape = (hidden_size[-1],) + self.n_heads_per_group = n_heads_per_group + + def repeat_param(self, param): + return param.repeat_interleave(self.n_heads_per_group, dim=0) + + def forward(self, hidden_states): + hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5) + hidden_states = hidden_states * self.repeat_param(self.weight) + self.repeat_param(self.bias) + return hidden_states class GQA(nn.Module): """ @@ -397,6 +421,7 @@ class GQA(nn.Module): dtype (Optional[torch.dtype]): The type of data. qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default. enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default. + qk_norm (Optional[bool]): if set, the query and key will be applied by layer norm after qk_linear. False by default. """ def __init__( @@ -419,6 +444,8 @@ def __init__( dtype: Optional[torch.dtype] = None, qk_interleaved: Optional[bool] = True, enable_qkv_fusion: bool = True, + qk_norm: bool = False, + chameleon_mp_size: int = 1, ) -> None: super().__init__() self.layer_idx = layer_idx @@ -459,6 +486,10 @@ def __init__( rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native", ) + self.qk_norm = qk_norm + if qk_norm: + assert enable_qkv_fusion is False, "qk_norm cannot be applied when fused wqkv" + if enable_qkv_fusion: assert bias is False, "Fuesd wqkv only support bias is False." self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs) @@ -470,6 +501,11 @@ def __init__( self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) + if qk_norm: + assert num_heads%chameleon_mp_size == 0, "num_heads%chameleon_mp_size != 0 in GQA" + assert num_kv_heads%chameleon_mp_size == 0, "num_kv_heads%chameleon_mp_size != 0 in GQA" + self.q_norm = ChameleonLayerNorm(self.head_dim,chameleon_mp_size,num_heads//chameleon_mp_size) + self.k_norm = ChameleonLayerNorm(self.head_dim,chameleon_mp_size,num_kv_heads//chameleon_mp_size) self.inner_attn = SelfAttention( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx @@ -508,10 +544,21 @@ def _training(self, x, **kwargs): q = rearrange(q, "b s h gs d -> b s (h gs) d") else: q, k, v = self.wq(x), self.wk(x), self.wv(x) - q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) - k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) - v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) - + if self.qk_norm: + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + q_all = gather_forward_split_backward(q, ParallelMode.TENSOR, dim=-2) + q_norm_out = self.q_norm(q_all) + q = split_forward_gather_backward(q_norm_out, ParallelMode.TENSOR, dim=-2) + k_all = gather_forward_split_backward(k, ParallelMode.TENSOR, dim=-2) + k_norm_out = self.k_norm(k_all) + k = split_forward_gather_backward(k_norm_out, ParallelMode.TENSOR, dim=-2) + + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + else: + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) kwargs = _convert_cu_seqlens_for_qksplited(kwargs) # rotary embedding @@ -584,9 +631,22 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 q = rearrange(q, "b s h gs d -> b s (h gs) d") else: q, k, v = self.wq(x), self.wk(x), self.wv(x) - q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) - k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) - v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + if self.qk_norm: + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + # TODO: using repeat + (fwd: split, bwd: allgather or allreducesum) or (fwd: split + repeat, bwd: allreducesum) is better + q_all = gather_forward_split_backward(q, ParallelMode.TENSOR, dim=-2) + q_norm_out = self.q_norm(q_all) + q = split_forward_gather_backward(q_norm_out, ParallelMode.TENSOR, dim=-2) + k_all = gather_forward_split_backward(k, ParallelMode.TENSOR, dim=-2) + k_norm_out = self.k_norm(k_all) + k = split_forward_gather_backward(k_norm_out, ParallelMode.TENSOR, dim=-2) + + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + else: + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) # rotary embedding, output: q, kv assert self.rotary_emb_dim > 0 diff --git a/internlm/model/modules/norm.py b/internlm/model/modules/norm.py index 2a9700f8..3aaf356d 100644 --- a/internlm/model/modules/norm.py +++ b/internlm/model/modules/norm.py @@ -13,11 +13,11 @@ Shape = Union[int, List[int], torch.Size] -def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False): +def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False, is_Chameleon=False): if norm_type == "rmsnorm": rmsnorm_params = inspect.signature(RMSNorm).parameters if "add_unit_offset" in rmsnorm_params: - return RMSNorm(normalized_shape, eps, add_unit_offset) + return RMSNorm(normalized_shape, eps, add_unit_offset, is_Chameleon) else: return RMSNorm(normalized_shape, eps) else: # default: layernorm diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index d0a668c8..fd503668 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -944,7 +944,7 @@ def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_p attn_type, op = _select_attn_op(AttnOpType.FixedLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () extra_kwargs = {} if attn_type is AttnType.SlidingWindowZigZagFlash: @@ -1007,7 +1007,7 @@ def _q_kv_with_cu_seqlens( attn_type, op = _select_attn_op(AttnOpType.VarLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args @@ -1088,7 +1088,7 @@ def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_p attn_type, op = _select_attn_op(AttnOpType.FixedLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask, ) if attn_type is AttnType.Torch else () return op(q, kv, dropout, softmax_scale, causal, *extra_args) diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index 8565db4c..c7dade0e 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -35,8 +35,9 @@ torchnpu_rmsnorm_impl = False -def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=False): +def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=False, is_Chameleon=False): # layer norm should always be calculated in float32 + input_dtype = my_input.dtype dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True) my_input = my_input * torch.rsqrt(variance + eps) @@ -44,20 +45,23 @@ def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=Fal if weight is None: return my_input - # convert into half-precision if necessary - if weight.dtype in [torch.float16, torch.bfloat16]: - my_input = my_input.to(weight.dtype) - - if add_unit_offset: - return (1 + weight) * my_input + if is_Chameleon: + return weight * my_input.to(input_dtype) else: - return weight * my_input + # convert into half-precision if necessary + if weight.dtype in [torch.float16, torch.bfloat16]: + my_input = my_input.to(weight.dtype) + + if add_unit_offset: + return (1 + weight) * my_input + else: + return weight * my_input class _RMSNorm(torch.nn.Module): """A generic module for RMS normalization.""" - def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False): + def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False, is_Chameleon=False): super().__init__() if isinstance(normalized_shape, numbers.Integral): @@ -67,6 +71,7 @@ def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False): self.weight = Parameter(torch.empty(*normalized_shape)) self.add_unit_offset = add_unit_offset self.reset_parameters() + self.is_Chameleon = is_Chameleon def forward(self, _input: torch.Tensor): if apex_rmsnorm_impl: @@ -74,7 +79,7 @@ def forward(self, _input: torch.Tensor): return _norm_func(_input, self.weight, self.normalized_shape, self.eps) else: _norm_func = manual_rms_norm - return _norm_func(_input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset) + return _norm_func(_input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset, self.is_Chameleon) def reset_parameters(self): if self.add_unit_offset: diff --git a/internlm/model/registry.py b/internlm/model/registry.py index cd4f38fb..2c01fb6e 100644 --- a/internlm/model/registry.py +++ b/internlm/model/registry.py @@ -4,6 +4,7 @@ from typing import Callable from internlm.model.modeling_baichuan2 import Baichuan2 +from internlm.model.modeling_chameleon import ChameleonModel from internlm.model.modeling_gemma import Gemma from internlm.model.modeling_internlm import InternLM1 from internlm.model.modeling_internlm2 import InternLM2 @@ -93,6 +94,7 @@ def register_model_initializer() -> None: model_initializer.register_module(ModelType.GEMMA.name, Gemma) model_initializer.register_module(ModelType.QWEN2MOE.name, Qwen2Moe) model_initializer.register_module(ModelType.MIXTRALMOE.name, MixtralMoE) + model_initializer.register_module(ModelType.CHAMELEON.name, ChameleonModel) register_model_initializer() From 910032b44bee37eae7a8c71a4ec6a06ee78689f6 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Thu, 28 Nov 2024 18:13:38 +0800 Subject: [PATCH 2/5] Modify based on pre-commit --- internlm/model/modeling_chameleon.py | 176 +++++---------------------- internlm/model/modules/mha.py | 20 +-- internlm/model/modules/norm.py | 4 +- internlm/model/ops/attention.py | 2 +- internlm/model/ops/norm.py | 4 +- 5 files changed, 52 insertions(+), 154 deletions(-) diff --git a/internlm/model/modeling_chameleon.py b/internlm/model/modeling_chameleon.py index 8797ce07..9958aa99 100644 --- a/internlm/model/modeling_chameleon.py +++ b/internlm/model/modeling_chameleon.py @@ -5,12 +5,8 @@ import torch from torch import nn -from torch.nn import CrossEntropyLoss from tqdm import tqdm -# Should re-implement CausalLMOutputWithPast? -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast - from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc @@ -68,7 +64,6 @@ class ChameleonDecoderLayer(nn.Module): no_bias (bool): Whether to exclude bias in attention and feed-forward networks. Defaults to False. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. - dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. use_scaled_init (bool): Whether to use scaled initialization for weights. use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, @@ -102,7 +97,6 @@ def __init__( no_bias: bool = False, norm_type: str = "rmsnorm", qk_interleaved: bool = False, - dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, attn_wqkv_init_std: float = 0.02, @@ -113,13 +107,11 @@ def __init__( rope_base: int = 10000, mlp_layer_fusion: bool = False, multiple_of: int = 256, - qk_norm = True, - chameleon_mp_size = 1, + qk_norm=True, + chameleon_mp_size=1, ): super().__init__() self.checkpoint = checkpoint - # dropout selective checkpoint can only be enabled when checkpoint is disabled. - self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx self.prenorm = not apply_post_layer_norm assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" @@ -152,8 +144,12 @@ def __init__( ) self.dropout = nn.Dropout(drop_rate) - self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=False, is_Chameleon=True) - self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=False, is_Chameleon=True) + self.attention_norm = new_layer_norm( + norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=False, is_Chameleon=True + ) + self.ffn_norm = new_layer_norm( + norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=False, is_Chameleon=True + ) self.feed_forward = new_feed_forward( hidden_size, @@ -234,20 +230,8 @@ def _forward(self, hidden_states, residual, *args, **kwargs): hidden_states = self.attention_norm(hidden_states) - # def _dropout_and_norm_attn(_residual, _hidden_states): - # _dropped = self.dropout1(_hidden_states) - # _residual = (_dropped + _residual) if _residual is not None else _dropped - # _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) - - # return _residual, _hidden_states - - # if self.dropout_selective_checkpoint: - # residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) - # else: - # residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) - - # if self.residual_in_fp32: - # residual = residual.to(torch.float32) + if self.residual_in_fp32: + residual = residual.to(torch.float32) attn_kwargs = convert_attn_args_to_kwargs(args, kwargs) hidden_states = self.attention(hidden_states, **attn_kwargs) @@ -279,7 +263,7 @@ def _forward(self, hidden_states, residual, *args, **kwargs): class ChameleonImageVocabularyMapping: """ A class for mapping discrete image tokens from VQGAN to BPE tokens. - + Reference: https://github.com/Alpha-VLLM/Lumina-mGPT/blob/104abe453ec1acca5863698629c4db2111b0b3fc/lumina_mgpt/model/chameleon/modeling_chameleon.py#L1036 """ @@ -326,7 +310,6 @@ def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: return img_tokens.to(device) - class ChameleonModel(BaseModel): """ Chameleon Model. @@ -356,7 +339,6 @@ class ChameleonModel(BaseModel): residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. - dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. use_scaled_init (bool): Whether to use scaled initialization for weights. use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. embedding_init_std (float): std used to init embedding weight. 0.02 by default, @@ -398,7 +380,6 @@ def __init__( norm_type: str = "rmsnorm", qk_interleaved: bool = False, is_reward: bool = False, - dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, embedding_init_std: float = 0.02, @@ -411,8 +392,8 @@ def __init__( rope_base: int = 10000, mlp_layer_fusion: bool = False, multiple_of: int = 256, - qk_norm = True, - chameleon_mp_size = 1, + qk_norm=True, + chameleon_mp_size=1, ): super().__init__() @@ -422,10 +403,11 @@ def __init__( if chameleon_mp_size == 4: apply_post_layer_norm = True - if first: self.padding_idx = None - self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size, padding_idx=self.padding_idx) + self.tok_embeddings = Embedding1D( + num_embeddings=vocab_size, embedding_dim=hidden_size, padding_idx=self.padding_idx + ) for _, param in self.tok_embeddings.named_parameters(): if init_type == "normal": normal_(std=embedding_init_std)(param) @@ -451,7 +433,6 @@ def __init__( fused_dropout_add_ln=False, no_bias=no_bias, norm_type=norm_type, - dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, qk_interleaved=qk_interleaved, @@ -463,7 +444,7 @@ def __init__( rope_base=rope_base, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, - qk_norm = qk_norm, + qk_norm=qk_norm, chameleon_mp_size=chameleon_mp_size, ) for lid in range(num_layers) @@ -574,15 +555,11 @@ def load_hf_weights(folder: str, model: nn.Module): state_dict[f"layers.{i}.attention.q_norm.weight"] = state_dict.pop( f"model.layers.{i}.self_attn.q_norm.weight" ) - state_dict[f"layers.{i}.attention.q_norm.bias"] = state_dict.pop( - f"model.layers.{i}.self_attn.q_norm.bias" - ) + state_dict[f"layers.{i}.attention.q_norm.bias"] = state_dict.pop(f"model.layers.{i}.self_attn.q_norm.bias") state_dict[f"layers.{i}.attention.k_norm.weight"] = state_dict.pop( f"model.layers.{i}.self_attn.k_norm.weight" ) - state_dict[f"layers.{i}.attention.k_norm.bias"] = state_dict.pop( - f"model.layers.{i}.self_attn.k_norm.bias" - ) + state_dict[f"layers.{i}.attention.k_norm.bias"] = state_dict.pop(f"model.layers.{i}.self_attn.k_norm.bias") # ffn state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( @@ -619,7 +596,7 @@ def load_hf_weights(folder: str, model: nn.Module): # replace value within decoder layer for name in list(state_dict.keys()): - if name.startswith(f"model.vqmodel"): + if name.startswith("model.vqmodel"): state_dict.pop(name) if name.startswith(f"layers.{i}"): new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) @@ -751,10 +728,18 @@ def convert_internevo2hf_weights(src: str, tgt: str) -> None: ) state_dict.update( { - f"model.layers.{layer_i}.self_attn.q_norm.weight" : states[0][f"layers.{layer_i}.attention.q_norm.weight"].clone(), - f"model.layers.{layer_i}.self_attn.q_norm.bias" : states[0][f"layers.{layer_i}.attention.q_norm.bias"].clone(), - f"model.layers.{layer_i}.self_attn.k_norm.weight" : states[0][f"layers.{layer_i}.attention.k_norm.weight"].clone(), - f"model.layers.{layer_i}.self_attn.k_norm.bias" : states[0][f"layers.{layer_i}.attention.k_norm.bias"].clone(), + f"model.layers.{layer_i}.self_attn.q_norm.weight": states[0][ + f"layers.{layer_i}.attention.q_norm.weight" + ].clone(), + f"model.layers.{layer_i}.self_attn.q_norm.bias": states[0][ + f"layers.{layer_i}.attention.q_norm.bias" + ].clone(), + f"model.layers.{layer_i}.self_attn.k_norm.weight": states[0][ + f"layers.{layer_i}.attention.k_norm.weight" + ].clone(), + f"model.layers.{layer_i}.self_attn.k_norm.bias": states[0][ + f"layers.{layer_i}.attention.k_norm.bias" + ].clone(), } ) @@ -785,100 +770,3 @@ def convert_internevo2hf_weights(src: str, tgt: str) -> None: llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) if index is not None: llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) -''' - -class ChameleonForConditionalGeneration(BaseModel): - def __init__(self, - max_position_embeddings: int, - output_attentions: bool, - output_hidden_states: bool, - return_dict: bool, - mask_image_logits: bool): - self.max_position_embeddings = max_position_embeddings - - self.model = ChameleonModel() - - self.output_attentions = output_attentions - self.output_hidden_states = output_hidden_states - self.return_dict = return_dict - self.mask_image_logits = mask_image_logits - - def forward(self, - input_ids=None, - labels=None, - pixel_values: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs): - - - - - # Data to torch.tensor - max_tokens = max([len(_) for _ in input_ids]) - max_tokens = min(max_tokens, self.max_position_embeddings) - input_ids = [_[:max_tokens] for _ in input_ids] - labels = [_[:max_tokens] for _ in labels] - input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids] - input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device) - labels = [label + [-100] * (max_tokens - len(label)) for label in labels] - labels = torch.tensor(labels, dtype=torch.int64, device=self.device) - - - outputs = self.model( - input_ids=input_ids, - pixel_values=pixel_values, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - logits = logits.float() - - if self.mask_image_logits: - # Disallow image tokens which does not include special begin-image and end-image tokens - image_tokens = self.model.vocabulary_mapping.image_tokens - logits[:, :, image_tokens] = torch.finfo(logits.dtype).min - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not self.return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - -''' - diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 15e8eede..0094084d 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -12,7 +12,10 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.parallel.comm.utils import gather_forward_split_backward, split_forward_gather_backward +from internlm.core.parallel.comm.utils import ( + gather_forward_split_backward, + split_forward_gather_backward, +) from internlm.model.modules.embedding import new_rotary_embedding from internlm.model.modules.linear import new_linear from internlm.model.modules.utils import update_kv_cache @@ -374,6 +377,7 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 # wo return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) + class ChameleonLayerNorm(nn.LayerNorm): """ LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta @@ -397,6 +401,7 @@ def forward(self, hidden_states): hidden_states = hidden_states * self.repeat_param(self.weight) + self.repeat_param(self.bias) return hidden_states + class GQA(nn.Module): """ Multi-head self-attention and cross-attention. @@ -421,7 +426,8 @@ class GQA(nn.Module): dtype (Optional[torch.dtype]): The type of data. qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default. enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default. - qk_norm (Optional[bool]): if set, the query and key will be applied by layer norm after qk_linear. False by default. + qk_norm (Optional[bool]): if set, the query and key will be applied by layer norm after qk_linear. + False by default. """ def __init__( @@ -502,10 +508,10 @@ def __init__( self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) if qk_norm: - assert num_heads%chameleon_mp_size == 0, "num_heads%chameleon_mp_size != 0 in GQA" - assert num_kv_heads%chameleon_mp_size == 0, "num_kv_heads%chameleon_mp_size != 0 in GQA" - self.q_norm = ChameleonLayerNorm(self.head_dim,chameleon_mp_size,num_heads//chameleon_mp_size) - self.k_norm = ChameleonLayerNorm(self.head_dim,chameleon_mp_size,num_kv_heads//chameleon_mp_size) + assert num_heads % chameleon_mp_size == 0, "num_heads%chameleon_mp_size != 0 in GQA" + assert num_kv_heads % chameleon_mp_size == 0, "num_kv_heads%chameleon_mp_size != 0 in GQA" + self.q_norm = ChameleonLayerNorm(self.head_dim, chameleon_mp_size, num_heads // chameleon_mp_size) + self.k_norm = ChameleonLayerNorm(self.head_dim, chameleon_mp_size, num_kv_heads // chameleon_mp_size) self.inner_attn = SelfAttention( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx @@ -634,7 +640,7 @@ def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 if self.qk_norm: q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) - # TODO: using repeat + (fwd: split, bwd: allgather or allreducesum) or (fwd: split + repeat, bwd: allreducesum) is better + q_all = gather_forward_split_backward(q, ParallelMode.TENSOR, dim=-2) q_norm_out = self.q_norm(q_all) q = split_forward_gather_backward(q_norm_out, ParallelMode.TENSOR, dim=-2) diff --git a/internlm/model/modules/norm.py b/internlm/model/modules/norm.py index 3aaf356d..41c75822 100644 --- a/internlm/model/modules/norm.py +++ b/internlm/model/modules/norm.py @@ -13,7 +13,9 @@ Shape = Union[int, List[int], torch.Size] -def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False, is_Chameleon=False): +def new_layer_norm( + norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False, is_Chameleon=False +): if norm_type == "rmsnorm": rmsnorm_params = inspect.signature(RMSNorm).parameters if "add_unit_offset" in rmsnorm_params: diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index fd503668..f888070a 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -1088,7 +1088,7 @@ def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_p attn_type, op = _select_attn_op(AttnOpType.FixedLenKVPacked) dropout = self.dropout if attn_type is AttnType.Torch else self.dropout.p - extra_args = (key_padding_mask, ) if attn_type is AttnType.Torch else () + extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op(q, kv, dropout, softmax_scale, causal, *extra_args) diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index c7dade0e..d3dd5215 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -79,7 +79,9 @@ def forward(self, _input: torch.Tensor): return _norm_func(_input, self.weight, self.normalized_shape, self.eps) else: _norm_func = manual_rms_norm - return _norm_func(_input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset, self.is_Chameleon) + return _norm_func( + _input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset, self.is_Chameleon + ) def reset_parameters(self): if self.add_unit_offset: From 841ed0a4f8abd0dbfed0d72ca034188e2ef7c1ad Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Thu, 28 Nov 2024 18:34:13 +0800 Subject: [PATCH 3/5] Modify norm more reuseable --- internlm/model/modeling_chameleon.py | 4 ++-- internlm/model/modules/norm.py | 4 ++-- internlm/model/ops/norm.py | 27 +++++++++++++-------------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/internlm/model/modeling_chameleon.py b/internlm/model/modeling_chameleon.py index 9958aa99..851fbc2e 100644 --- a/internlm/model/modeling_chameleon.py +++ b/internlm/model/modeling_chameleon.py @@ -145,10 +145,10 @@ def __init__( self.dropout = nn.Dropout(drop_rate) self.attention_norm = new_layer_norm( - norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=False, is_Chameleon=True + norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=False, convert_to_input_dtype=True ) self.ffn_norm = new_layer_norm( - norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=False, is_Chameleon=True + norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=False, convert_to_input_dtype=True ) self.feed_forward = new_feed_forward( diff --git a/internlm/model/modules/norm.py b/internlm/model/modules/norm.py index 41c75822..7b24d67f 100644 --- a/internlm/model/modules/norm.py +++ b/internlm/model/modules/norm.py @@ -14,12 +14,12 @@ def new_layer_norm( - norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False, is_Chameleon=False + norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False, convert_to_input_dtype=False ): if norm_type == "rmsnorm": rmsnorm_params = inspect.signature(RMSNorm).parameters if "add_unit_offset" in rmsnorm_params: - return RMSNorm(normalized_shape, eps, add_unit_offset, is_Chameleon) + return RMSNorm(normalized_shape, eps, add_unit_offset, convert_to_input_dtype) else: return RMSNorm(normalized_shape, eps) else: # default: layernorm diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index d3dd5215..ea51ebe9 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -35,9 +35,8 @@ torchnpu_rmsnorm_impl = False -def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=False, is_Chameleon=False): +def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=False, convert_to_input_dtype=False): # layer norm should always be calculated in float32 - input_dtype = my_input.dtype dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True) my_input = my_input * torch.rsqrt(variance + eps) @@ -45,23 +44,23 @@ def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=Fal if weight is None: return my_input - if is_Chameleon: - return weight * my_input.to(input_dtype) - else: + if convert_to_input_dtype: + input_dtype = my_input.dtype + my_input = my_input.to(input_dtype) + elif weight.dtype in [torch.float16, torch.bfloat16]: # convert into half-precision if necessary - if weight.dtype in [torch.float16, torch.bfloat16]: - my_input = my_input.to(weight.dtype) + my_input = my_input.to(weight.dtype) - if add_unit_offset: - return (1 + weight) * my_input - else: - return weight * my_input + if add_unit_offset: + return (1 + weight) * my_input + else: + return weight * my_input class _RMSNorm(torch.nn.Module): """A generic module for RMS normalization.""" - def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False, is_Chameleon=False): + def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False, convert_to_input_dtype=False): super().__init__() if isinstance(normalized_shape, numbers.Integral): @@ -71,7 +70,7 @@ def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False, is_Chamele self.weight = Parameter(torch.empty(*normalized_shape)) self.add_unit_offset = add_unit_offset self.reset_parameters() - self.is_Chameleon = is_Chameleon + self.convert_to_input_dtype = convert_to_input_dtype def forward(self, _input: torch.Tensor): if apex_rmsnorm_impl: @@ -80,7 +79,7 @@ def forward(self, _input: torch.Tensor): else: _norm_func = manual_rms_norm return _norm_func( - _input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset, self.is_Chameleon + _input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset, self.convert_to_input_dtype ) def reset_parameters(self): From 99a3c580abdc4ddc2c5eabc597ed1c8553096b70 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Fri, 29 Nov 2024 05:26:03 +0000 Subject: [PATCH 4/5] Fix model name --- internlm/utils/utils.py | 2 ++ tests/test_model/test_norm.py | 23 ++++++++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py index d01d13f9..876169b7 100644 --- a/internlm/utils/utils.py +++ b/internlm/utils/utils.py @@ -56,6 +56,7 @@ class ModelType(Enum): GEMMA = 8 QWEN2MOE = 9 MIXTRALMOE = 10 + CHAMELEON = 11 class DataType(Enum): @@ -63,6 +64,7 @@ class DataType(Enum): tokenized = 2 megatron = 3 mocked = 4 + lumina_pickle = 5 class TensorParallelMode(Enum): diff --git a/tests/test_model/test_norm.py b/tests/test_model/test_norm.py index 83861b36..0e7fc8f5 100644 --- a/tests/test_model/test_norm.py +++ b/tests/test_model/test_norm.py @@ -11,7 +11,7 @@ def check_norm(args): # init - rank, world_size, free_port = args + rank, world_size, free_port, convert_to_input_dtype = args build_environment(rank, world_size, free_port) device = get_current_device() rtol, atol = (1e-3, 5e-3) @@ -22,7 +22,12 @@ def check_norm(args): seed_all(1024) # define norm - norm = new_layer_norm(norm_type="rmsnorm", normalized_shape=hidden_size, eps=layer_norm_epsilon) + norm = new_layer_norm( + norm_type="rmsnorm", + normalized_shape=hidden_size, + eps=layer_norm_epsilon, + convert_to_input_dtype=convert_to_input_dtype, + ) norm = norm.to(device) # create input @@ -83,8 +88,20 @@ def check_norm(args): def test_norm(): ctx = mp.get_context("spawn") free_port = str(find_free_port()) + convert_input_dtype = False with ctx.Pool(processes=8) as pool: - pool.map(check_norm, [[rank, 8, free_port] for rank in range(8)]) + pool.map(check_norm, [[rank, 8, free_port, convert_input_dtype] for rank in range(8)]) + pool.close() + pool.join() + + +@pytest.mark.norm_convert_to_input_dtype +def test_norm_convert_to_input_dtype(): + ctx = mp.get_context("spawn") + free_port = str(find_free_port()) + convert_input_dtype = True + with ctx.Pool(processes=8) as pool: + pool.map(check_norm, [[rank, 8, free_port, convert_input_dtype] for rank in range(8)]) pool.close() pool.join() From 692a97aa2a1fa2ed3c357a4622fd6c06579305b3 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Mon, 2 Dec 2024 09:11:22 +0000 Subject: [PATCH 5/5] Modify based on reviewer's comment 1 --- internlm/model/modeling_chameleon.py | 20 +++++++++++++------- internlm/model/ops/norm.py | 2 +- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/internlm/model/modeling_chameleon.py b/internlm/model/modeling_chameleon.py index 851fbc2e..d89e21d2 100644 --- a/internlm/model/modeling_chameleon.py +++ b/internlm/model/modeling_chameleon.py @@ -75,7 +75,10 @@ class ChameleonDecoderLayer(nn.Module): rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. - qk_norm (bool): Support q norm and k norm. + qk_norm (bool): Whether supports q norm and k norm. + chameleon_mp_size (int): For ChameleonLayerNorm, it applies gamma and beta from each shard separately + to each head, instead of reducing. The chameleon_mp_size means the number + of groups of ChameleonLayerNorm headers. It is 1 in 7B model and 4 in 34B model. """ def __init__( @@ -107,8 +110,8 @@ def __init__( rope_base: int = 10000, mlp_layer_fusion: bool = False, multiple_of: int = 256, - qk_norm=True, - chameleon_mp_size=1, + qk_norm: bool = True, + chameleon_mp_size: int = 1, ): super().__init__() self.checkpoint = checkpoint @@ -161,7 +164,7 @@ def __init__( mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, # TODO: to support more activation functions - activation_type="swiglu" if use_swiglu else "swiglu", + activation_type="swiglu" if use_swiglu else "gelu", ) self.use_swiglu = use_swiglu @@ -352,7 +355,10 @@ class ChameleonModel(BaseModel): rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. - qk_norm (bool): Support q norm and k norm. + qk_norm (bool): Whether supports q norm and k norm. + chameleon_mp_size (int): For ChameleonLayerNorm, it applies gamma and beta from each shard separately + to each head, instead of reducing. The chameleon_mp_size means the number + of groups of ChameleonLayerNorm headers. It is 1 in 7B model and 4 in 34B model. """ def __init__( @@ -392,8 +398,8 @@ def __init__( rope_base: int = 10000, mlp_layer_fusion: bool = False, multiple_of: int = 256, - qk_norm=True, - chameleon_mp_size=1, + qk_norm: bool = True, + chameleon_mp_size: int = 1, ): super().__init__() diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index ea51ebe9..e0113735 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -37,6 +37,7 @@ def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=False, convert_to_input_dtype=False): # layer norm should always be calculated in float32 + input_dtype = my_input.dtype dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True) my_input = my_input * torch.rsqrt(variance + eps) @@ -45,7 +46,6 @@ def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=Fal return my_input if convert_to_input_dtype: - input_dtype = my_input.dtype my_input = my_input.to(input_dtype) elif weight.dtype in [torch.float16, torch.bfloat16]: # convert into half-precision if necessary