From 6dfdb34d5b3569739d319c73690022f6ba723e84 Mon Sep 17 00:00:00 2001 From: cx <759046501@qq.com> Date: Fri, 10 May 2024 16:28:36 +0800 Subject: [PATCH] refactor(model): refactor model architecture (#126) Co-authored-by: lijiaxing Co-authored-by: huangting4201 <1538303371@qq.com> --- configs/7B_MoE4_sft.py | 8 + configs/7B_internlm2.py | 8 + configs/7B_isp_sft.py | 8 + configs/7B_llama2.py | 8 + configs/7B_sft.py | 8 + internlm/checkpoint/load_funcs.py | 2 +- internlm/core/communication/utils.py | 231 ----- internlm/core/naive_amp.py | 9 +- .../parallel}/__init__.py | 0 .../{communication => parallel/comm}/isp.py | 276 +++++- internlm/core/parallel/comm/tensor.py | 369 ++++++++ internlm/core/parallel/comm/utils.py | 226 +++++ internlm/core/parallel/comm/zero.py | 106 +++ internlm/core/parallel/shard.py | 119 +++ .../comm}/__init__.py | 0 .../{communication => scheduler/comm}/p2p.py | 0 internlm/core/scheduler/comm/utils.py | 125 +++ .../core/scheduler/no_pipeline_scheduler.py | 5 +- internlm/core/scheduler/pipeline_scheduler.py | 13 +- internlm/data/utils.py | 57 +- internlm/initialize/initialize_trainer.py | 9 +- internlm/initialize/launch.py | 23 +- internlm/model/__init__.py | 33 - internlm/model/builder.py | 36 + internlm/model/llava/__init__.py | 0 .../{llava_modules => llava}/clip_builder.py | 0 .../{llava_modules => llava}/clip_encoder.py | 0 .../projector_builder.py | 0 internlm/model/losses/ce_loss.py | 13 +- internlm/model/metrics.py | 43 +- internlm/model/modeling_internlm.py | 321 ++----- internlm/model/modeling_internlm2.py | 838 ++--------------- internlm/model/modeling_llama.py | 809 ++-------------- internlm/model/modeling_llava.py | 345 ++----- internlm/model/modeling_moe.py | 342 ++----- internlm/model/modules/embedding.py | 371 +++----- internlm/model/modules/linear.py | 605 ++++++++++++ internlm/model/modules/mha.py | 596 ++++++++++++ internlm/model/modules/mlp.py | 280 ++---- .../model/modules/multi_head_attention.py | 867 ------------------ internlm/model/modules/norm.py | 19 + internlm/model/modules/utils.py | 82 ++ internlm/model/moe/__init__.py | 28 - internlm/model/moe/gshard_layer.py | 7 +- .../model/moe/megablock/megablock_dmoe.py | 13 +- internlm/model/moe/megablock/megablock_moe.py | 8 +- internlm/model/moe/megablock/mlp.py | 2 +- internlm/model/moe/megablock/utils.py | 27 +- internlm/model/moe/moe.py | 29 +- internlm/model/ops/attention.py | 847 +++++++++++++++++ internlm/model/ops/cross_entropy.py | 60 ++ .../model/ops/fusion_ops_import_helper.py | 211 ----- internlm/model/ops/linear.py | 417 +-------- internlm/model/ops/norm.py | 44 +- internlm/model/ops/rotary_emb.py | 158 ++++ internlm/model/ops/utils.py | 48 + internlm/{utils => model}/registry.py | 28 +- internlm/model/utils.py | 723 +-------------- internlm/solver/optimizer/compatible_adamw.py | 52 ++ .../solver/optimizer/hybrid_zero_optim.py | 2 +- internlm/solver/pipeline_utils.py | 34 - internlm/train/__init__.py | 4 +- internlm/train/pipeline.py | 185 ++-- internlm/train/utils.py | 2 +- internlm/utils/common.py | 53 +- internlm/utils/parallel.py | 4 - internlm/utils/utils.py | 109 +++ tests/common_fixture.py | 4 +- tests/test_core/utils.py | 4 +- tests/test_model/test_feed_forward.py | 26 +- .../test_fused_precision.py | 6 +- tests/test_model/test_model_internlm.py | 37 +- tests/test_model/test_norm.py | 6 +- tests/test_model/test_npu_ops.py | 4 +- tests/test_solver/test_optimizer.py | 2 +- .../test_forward_output_no_fa.py | 8 +- tests/test_training/test_load_ckpt_loss.py | 4 +- tests/test_training/test_loss.py | 4 +- tests/test_training/test_no_fa_train_temp.py | 4 +- tests/test_training/test_norm_weight.py | 4 +- .../test_swap_nb_loss_and_gradnorm.py | 7 +- tests/test_training/train_CI.py | 9 +- tests/test_utils/common_fixture.py | 12 +- tools/load_internlm_model.py | 5 +- train.py | 4 +- 85 files changed, 4848 insertions(+), 5607 deletions(-) delete mode 100644 internlm/core/communication/utils.py rename internlm/{model/llava_modules => core/parallel}/__init__.py (100%) rename internlm/core/{communication => parallel/comm}/isp.py (70%) create mode 100644 internlm/core/parallel/comm/tensor.py create mode 100644 internlm/core/parallel/comm/utils.py create mode 100644 internlm/core/parallel/comm/zero.py create mode 100644 internlm/core/parallel/shard.py rename internlm/core/{communication => scheduler/comm}/__init__.py (100%) rename internlm/core/{communication => scheduler/comm}/p2p.py (100%) create mode 100644 internlm/core/scheduler/comm/utils.py create mode 100644 internlm/model/builder.py create mode 100644 internlm/model/llava/__init__.py rename internlm/model/{llava_modules => llava}/clip_builder.py (100%) rename internlm/model/{llava_modules => llava}/clip_encoder.py (100%) rename internlm/model/{llava_modules => llava}/projector_builder.py (100%) create mode 100644 internlm/model/modules/linear.py create mode 100644 internlm/model/modules/mha.py delete mode 100644 internlm/model/modules/multi_head_attention.py create mode 100644 internlm/model/modules/norm.py create mode 100644 internlm/model/modules/utils.py create mode 100644 internlm/model/ops/attention.py create mode 100644 internlm/model/ops/cross_entropy.py delete mode 100644 internlm/model/ops/fusion_ops_import_helper.py create mode 100644 internlm/model/ops/rotary_emb.py create mode 100644 internlm/model/ops/utils.py rename internlm/{utils => model}/registry.py (71%) create mode 100644 internlm/solver/optimizer/compatible_adamw.py delete mode 100644 internlm/solver/pipeline_utils.py diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 891e8ee3..ef20dc60 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -146,6 +146,14 @@ norm_type="rmsnorm", layer_norm_epsilon=1e-5, use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. num_experts=4, moe_use_residual=False, diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 9e0fc91d..a69896ce 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -144,6 +144,14 @@ layer_norm_epsilon=1e-5, num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, ) """ diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index dc6408cd..7e88772f 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -146,6 +146,14 @@ layer_norm_epsilon=1e-5, use_flash_attn=True, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, ) """ zero1 parallel (dict): diff --git a/configs/7B_llama2.py b/configs/7B_llama2.py index 9f464164..fb88be4a 100644 --- a/configs/7B_llama2.py +++ b/configs/7B_llama2.py @@ -144,6 +144,14 @@ layer_norm_epsilon=1e-5, num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, ) """ diff --git a/configs/7B_sft.py b/configs/7B_sft.py index c2ae7078..746b4867 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -145,6 +145,14 @@ norm_type="rmsnorm", layer_norm_epsilon=1e-5, use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. ) """ diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py index ee4ed472..6bdfd634 100644 --- a/internlm/checkpoint/load_funcs.py +++ b/internlm/checkpoint/load_funcs.py @@ -6,7 +6,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.solver.pipeline_utils import partition_uniform +from internlm.core.parallel.shard import partition_uniform from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load diff --git a/internlm/core/communication/utils.py b/internlm/core/communication/utils.py deleted file mode 100644 index 5d08327a..00000000 --- a/internlm/core/communication/utils.py +++ /dev/null @@ -1,231 +0,0 @@ -# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication - -from collections import OrderedDict -from typing import Dict, List, Tuple, Union - -import torch -import torch.distributed as dist -from torch import nn - -from internlm.core.communication.isp import ISPCommunicator -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.core.naive_amp import NaiveAMPModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.ops.linear import BaseScaleColumnParallelLinear -from internlm.utils.common import get_current_device - -TensorShape = Union[torch.Size, List[int], Tuple[int]] - - -def send_meta_helper(obj, next_rank, tensor_kwargs): - send_shape = torch.tensor(obj.size(), **tensor_kwargs) - send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs) - dist.send(send_ndims, next_rank) - dist.send(send_shape, next_rank) - - -def send_obj_meta(obj, next_rank=None): - """Sends obj meta information before sending a specific obj. - Since the recipient must know the shape of the obj in p2p communications, - meta information of the obj should be sent before communications. This function - synchronizes with :func:`recv_obj_meta`. - - Args: - obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent. - need_meta (bool, optional): If False, meta information won't be sent. - next_rank (int): The rank of the next member in pipeline parallel group. - - Returns: - bool: False - """ - if next_rank is None: - next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} - if isinstance(obj, torch.Tensor): - send_obj_nums = torch.tensor(1, **tensor_kwargs) - dist.send(send_obj_nums, next_rank) - send_meta_helper(obj, next_rank, tensor_kwargs) - else: - send_obj_nums = torch.tensor(len(obj), **tensor_kwargs) - dist.send(send_obj_nums, next_rank) - for tensor_to_send in obj: - send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) - - -def recv_meta_helper(prev_rank, tensor_kwargs): - recv_ndims = torch.empty((), **tensor_kwargs) - dist.recv(recv_ndims, prev_rank) - recv_shape = torch.empty(recv_ndims, **tensor_kwargs) - dist.recv(recv_shape, prev_rank) - return recv_shape - - -def recv_obj_meta(prev_rank=None) -> torch.Size: - """Receives obj meta information before receiving a specific obj. - Since the recipient must know the shape of the obj in p2p communications, - meta information of the obj should be received before communications. This function - synchronizes with :func:`send_obj_meta`. - - Args: - obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. - prev_rank (int): The rank of the source of the obj. - - Returns: - Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. - """ - if prev_rank is None: - prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} - recv_obj_nums = torch.empty((), **tensor_kwargs) - dist.recv(recv_obj_nums, prev_rank) - if recv_obj_nums.item() == 1: - recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) - obj_shape = torch.Size(recv_shape) - else: - obj_shape = [] - for _ in range(recv_obj_nums.item()): - recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) - obj_shape.append(torch.Size(recv_shape)) - - return obj_shape - - -def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: - """Break a tensor into equal 1D chunks. - - Args: - tensor (:class:`torch.Tensor`): Tensor to be split before communication. - new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. - - Returns: - :class:`torch.Tensor`: The split tensor - """ - partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR) - start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR) - end_index = start_index + partition_size - if new_buffer: - data = torch.empty(partition_size, dtype=tensor.dtype, device=get_current_device(), requires_grad=False) - data.copy_(tensor.view(-1)[start_index:end_index]) - else: - data = tensor.view(-1)[start_index:end_index] - return data - - -def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: - """Opposite of above function, gather values from model parallel ranks. - - Args: - tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. - Returns: - :class:`torch.Tensor`: The gathered tensor. - """ - world_size = gpc.get_world_size(ParallelMode.TENSOR) - numel = torch.numel(tensor) - numel_gathered = world_size * numel - gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=get_current_device(), requires_grad=False) - chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] - dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR)) - return gathered - - -class ParamAsyncBcastHandler: - """ - Model Partition Handler for overlap broadcast with forward - """ - - def __init__( - self, zero1_mode: ParallelMode, model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicator = None - ) -> None: - self._block_to_param: Dict[nn.Module, List[nn.Parameter]] = OrderedDict() - self._param_to_rank: Dict[nn.Parameter, int] = {} - self._block_to_rank: Dict[nn.Module, int] = {} - self._bcast_handles: Dict[int, List[dist.Work]] = {} - - zero1_size = gpc.get_world_size(zero1_mode) - total_param_num = sum(p.numel() for p in model.parameters()) - avg_param_num = total_param_num * 1.0 // zero1_size - - # initialize an empty list for _bcast_handles of each rank - self._bcast_handles = {rank: [] for rank in range(zero1_size)} - - # just want to share same for loop for ModuleList and Module - if not isinstance(model, nn.ModuleList): - model = [model] - - # record the parameters to transformer/embeding/head/norm block - for _chunk in model: - if isinstance(_chunk, NaiveAMPModel): - _chunk = _chunk.model - - for _, children in _chunk.named_children(): - # should be the transformer block definaton in modeling_xxx.py - if isinstance(children, nn.ModuleList): - # record the block that a parameter belongs to - for _, block in enumerate(children): - # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) - self._block_to_param[block] = list(block.parameters()) - else: - # record the block that a parameter belongs to - # self._block_to_param[name] = list(children.parameters()) - self._block_to_param[children] = list(children.parameters()) - - alloc_num = 0 - rank_to_go = 0 - - # process the parameters in block_to_param sequencially, - # allocate each parameter to a local rank of ParallelMode.ZERO1, - # NOTE that we do NOT consider following scenarios: - # 1) whether a parameter is trainable; - # 2) paramters maybe in different optimizer group - for block, params in self._block_to_param.items(): - # allocate a model block to a local rank of ParallelMode.ZERO1 - self._block_to_rank[block] = [rank_to_go] - for p in params: - alloc_num = alloc_num + p.numel() - # in this case, allocate the param to next rank if possible - if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1: - rank_to_go = rank_to_go + 1 - alloc_num = 0 - self._block_to_rank[block].append(rank_to_go) - # allocate a parameter to a local rank of ParallelMode.ZERO1 - self._param_to_rank[p] = rank_to_go - - # register_forward_pre_hook for transformer/embeding/norm/xxx block - self._register_sync_parameters_hook(isp_communicator) - - def _register_sync_parameters_hook(self, isp_communicator: ISPCommunicator = None) -> None: - def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W0613 - bcast_handles = [] - # gather all required broadcast hanles into a list - for rank in self._block_to_rank[model]: - bcast_handles.extend(self._bcast_handles[rank]) - # need to clear _bcast_handles since they would be processed later - self._bcast_handles[rank] = [] - # wait all required broadcast handles to be completed - for handle in bcast_handles: - handle.wait() - - # register_forward_pre_hook for transformer/embeding/norm/xxx block - for block, _ in self._block_to_rank.items(): - # TODO: remove special handling for embedding and head layers, - # instead implement support for weight parallelism of embedding and head layers within the ISP. - - # NOTE: Although the layernorm layer does not have explicit processing, - # both ISPCommunicator and ParamAsyncBcastHandler handle transformer blocks as granularity, - # so everything is fine. - - embedding_head_cls = (Embedding1D, BaseScaleColumnParallelLinear) - - if isp_communicator is None or isinstance(block, embedding_head_cls): - block.register_forward_pre_hook(_pre_forward_hook) - if isp_communicator: - isp_communicator.register_prerequisite_for_forward_prefetch_hooks(_pre_forward_hook) - - def get_rank_by_param(self, param) -> int: - return self._param_to_rank[param] - - def add_bcast_handle(self, rank, handle) -> None: - self._bcast_handles[rank].append(handle) diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 4e1427eb..7cac640d 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -4,7 +4,7 @@ # adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp from functools import partial -from typing import Any, Union +from typing import Any, List, Union import torch import torch.distributed as dist @@ -206,3 +206,10 @@ def _post_forward_hook_for_fp32( torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) + + +def unwrap_naive_amp(model: Union[nn.Module, nn.ModuleList]) -> List[nn.Module]: + if not isinstance(model, nn.ModuleList): + model = [model] + + return [_chunk.model if isinstance(_chunk, NaiveAMPModel) else _chunk for _chunk in model] diff --git a/internlm/model/llava_modules/__init__.py b/internlm/core/parallel/__init__.py similarity index 100% rename from internlm/model/llava_modules/__init__.py rename to internlm/core/parallel/__init__.py diff --git a/internlm/core/communication/isp.py b/internlm/core/parallel/comm/isp.py similarity index 70% rename from internlm/core/communication/isp.py rename to internlm/core/parallel/comm/isp.py index 2976899a..14637912 100644 --- a/internlm/core/communication/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -1,18 +1,62 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +""" +communication for isp parallel. +""" +from abc import ABC, abstractmethod from functools import partial -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Tuple, Union import torch from torch import distributed as dist from torch import nn +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.naive_amp import NaiveAMPModel -from internlm.model.ops.linear import ISPLinear -from internlm.model.utils import all_gather_raw, reduce_scatter_raw +from internlm.core.naive_amp import unwrap_naive_amp +from internlm.core.parallel.comm.utils import ( + DUMMY_HANDLE_CONST, + AsyncCommHandle, + all_gather_raw, + reduce_scatter_raw, +) +from internlm.model.modules.linear import ParallelLinearWithCommExt from internlm.utils.common import SchedulerHook, get_current_device +from internlm.utils.utils import ( + CuSeqlenType, + QKVPackType, + check_attention_argument, + params_dispatch_with_condition, +) + + +# not really useful, only for code hint. +class WPCommunicator(ABC): + """ + Common communicator interface for weight parallel + """ + + @abstractmethod + def communication_mode(self) -> str: + """ + communication mode of communictor + """ + pass + + @abstractmethod + def weight_hook(self, tensor: torch.Tensor, async_op: bool = False, **kwargs) -> torch.Tensor: + """ + communication for weight when forward/backward. + """ + pass + + @abstractmethod + def grad_hook(self, tensor: torch.Tensor, async_op: bool = False, **kwargs) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + communication for grad when backward. + """ + pass class ISPCommModelConfig: @@ -148,7 +192,7 @@ def __init__(self) -> None: self.bias_global_output: Dict[str, torch.Tensor] = {} -class ISPCommunicator: +class ISPCommunicator(WPCommunicator): """ ISP Communicator for managing the all-gather and reduce_scatter of Intern Sequence Parallel. """ @@ -195,16 +239,11 @@ def __init__( # init overlap states if necessary. if self.overlap: - # just want to share same for loop for modulelist and module. - model = model if isinstance(model, nn.ModuleList) else [model] # build overlap states for every chunk. - for chunk_id, chunk in enumerate(model): - if isinstance(chunk, NaiveAMPModel): - chunk = chunk.model + for chunk_id, chunk in enumerate(unwrap_naive_amp(model)): self._parse_model_structure(chunk_id, chunk) - # register overlap hooks for every chunk. - for chunk_id in range(len(model)): self.switch_current_model_chunk(chunk_id) + # register overlap hooks for every chunk. self._register_sync_parameters_hook() # switch to chunk 0 at first. self.switch_current_model_chunk(0) @@ -232,7 +271,7 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None: if name in ["out_proj", "wo"]: self._overlap_states[cid].isp_outs.append(child) self._overlap_states[cid].module_to_index[child] = idx - if isinstance(child, ISPLinear): + if isinstance(child, ParallelLinearWithCommExt): if name not in self._module_shapes: origin_shape = tuple( [child.weight.shape[0] * gpc.weight_parallel_size] @@ -436,6 +475,9 @@ def _get_constant_zero(self, size: tuple) -> torch.Tensor: device=self.model_conf.device, ).contiguous() + def communication_mode(self) -> str: + return "wp" + def switch_current_model_chunk(self, chunk_id: int) -> None: self._isp_outs = self._overlap_states[chunk_id].isp_outs self._isp_modules = self._overlap_states[chunk_id].isp_modules @@ -478,44 +520,51 @@ def register_prerequisite_for_forward_prefetch_hooks(self, prerequisite_func: Ca # communication operation interfaces - def all_gather(self, tensor: torch.Tensor, module: nn.Module, is_bias: bool = False): + def weight_hook( + self, tensor: torch.Tensor, async_op: bool = False, module: nn.Module = None, is_bias: bool = False + ) -> torch.Tensor: if dist.get_world_size(self.process_group) <= 1: return tensor if not self.overlap: - result, _ = all_gather_raw(tensor, self.process_group, async_op=False) + result, _ = all_gather_raw(tensor, self.process_group, async_op=async_op) elif is_bias: + assert module is not None, "The module parameter must be specified" result = self._bias_global_output[module] else: + assert module is not None, "The module parameter must be specified" result = self._weight_global_output[module] return result - def reduce_scatter( + def grad_hook( self, tensor: torch.Tensor, - model: nn.Module, - op: dist.ReduceOp, + async_op: bool = False, + module: nn.Module = None, + reduce_op: dist.ReduceOp = dist.ReduceOp.AVG, is_bias: bool = False, - ): + ) -> Tuple[torch.Tensor, AsyncCommHandle]: if dist.get_world_size(self.process_group) <= 1: - return tensor, None + return tensor, DUMMY_HANDLE_CONST if not self.overlap: - result, handle = reduce_scatter_raw(tensor, self.process_group, op=op, async_op=True) + result, handle = reduce_scatter_raw(tensor, self.process_group, op=reduce_op, async_op=async_op) else: + assert module is not None, "The module parameter must be specified" + if is_bias: - assert hasattr(model.bias, "isp_reduce_scatter_name") - key = getattr(model.bias, "isp_reduce_scatter_name") + assert hasattr(module.bias, "isp_reduce_scatter_name") + key = getattr(module.bias, "isp_reduce_scatter_name") else: - assert hasattr(model.weight, "isp_reduce_scatter_name") - key = getattr(model.weight, "isp_reduce_scatter_name") + assert hasattr(module.weight, "isp_reduce_scatter_name") + key = getattr(module.weight, "isp_reduce_scatter_name") self.reduce_scatter_handlers[key] = reduce_scatter_raw( tensor, self.process_group, - op=op, - async_op=True, + op=reduce_op, + async_op=async_op, memory_pool_allocator=( self.memory_pool.allocate_reduce_scatter_memory if self.enable_memory_pool else None ), @@ -528,7 +577,7 @@ def reduce_scatter( *tensor.shape[1:], ) ), - None, + DUMMY_HANDLE_CONST, ) return result, handle @@ -543,33 +592,190 @@ def __init__(self, overlap_handler: ISPCommunicator, zero_optim) -> None: self._isp_communicator = overlap_handler self._zero_optim = zero_optim - def before_forward(self, scheduler, inputs) -> None: + def before_forward(self, scheduler, inputs) -> None: # pylint: disable=W0613 self._isp_communicator.is_forward = True # switch model chunk before forward chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank self._isp_communicator.switch_current_model_chunk(chunk_id) - def after_forward(self, scheduler, outputs) -> None: + def after_forward(self, scheduler, outputs) -> None: # pylint: disable=W0613 pass - def before_criterion(self, scheduler, outputs, label) -> None: + def before_criterion(self, scheduler, outputs, label) -> None: # pylint: disable=W0613 pass - def after_criterion(self, scheduler, loss) -> None: + def after_criterion(self, scheduler, loss) -> None: # pylint: disable=W0613 pass - def before_backward(self, scheduler, outputs, outputs_grad) -> None: + def before_backward(self, scheduler, outputs, outputs_grad) -> None: # pylint: disable=W0613 self._isp_communicator.is_forward = False # switch model chunk before backward chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank self._isp_communicator.switch_current_model_chunk(chunk_id) - def after_backward(self, scheduler, inputs_grad) -> None: + def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W0613 # accumulate left gradients in last bucket after backward. self._zero_optim.accumulate_left_grads_after_backward() # reset lazy memory pools for reduce scatter after every micro step. if self._isp_communicator and self._isp_communicator.enable_memory_pool: self._isp_communicator.memory_pool.reset_lazy_pools() - def post_helper_func(self, scheduler, outputs, label) -> None: + def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613 pass + + +# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py +class _SeqAllToAll(torch.autograd.Function): + "sequence alltoall function" + + @staticmethod + def forward(ctx, group: dist.ProcessGroup, input_: torch.Tensor, scatter_idx: int, gather_idx: int) -> torch.Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + if dist.get_world_size(group) <= 1: + return input_ + + seq_world_size = dist.get_world_size(group) + + input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)] + output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] + # TODO: use all_to_all_single instead + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_idx).contiguous() + + @staticmethod + def backward(ctx, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]: + if dist.get_world_size(ctx.group) <= 1: + return (None, *grad_output, None, None) + + return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) + + +# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py +class DistributedAttention(nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local self-attention module + sequence_process_group (ProcessGroup): sequence parallel process group + """ + + def __init__( + self, + local_attention: nn.Module, + sequence_process_group: dist.ProcessGroup, + ) -> None: + super().__init__() + self.local_attn = local_attention + self.spg = sequence_process_group + + @params_dispatch_with_condition(condition=check_attention_argument) + def forward(self) -> torch.Tensor: + assert False, "Should never arrive" + + @forward.register(conditions=(str(QKVPackType.QKVPACKED), str(CuSeqlenType.With))) + @forward.register(conditions=(str(QKVPackType.QKVPACKED), str(CuSeqlenType.WithOut))) + def _(self, qkv: torch.Tensor, **kwargs) -> torch.Tensor: + """forward + + Arguments: + qkv (Tensor): packed qkv input to the layer + kwargs: other args + + Returns: + * output (Tensor): context output + """ + # qkv shape: [1, packlen, 3, n_head, head_dim] or [batch, seqlen, 3, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + qkv = _SeqAllToAll.apply(self.spg, qkv, 3, 1) + + context = self.local_attn(qkv, **kwargs) + + # context shape: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in seqlen(packlen) and gather in n_head + context = _SeqAllToAll.apply(self.spg, context, 1, 2) + + return context + + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.With))) + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.WithOut))) + def _(self, q: torch.Tensor, kv: torch.Tensor, **kwargs) -> torch.Tensor: + """forward + + Arguments: + q (Tensor): q input to the layer + kv (Tensor): packed kv input to the layer + kwargs: other args + + Returns: + output (Tensor): context output + """ + # q shpae: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + q = _SeqAllToAll.apply(self.spg, q, 2, 1) + # kv shape: [1, packlen, 2, n_head, head_dim] or [batch, seqlen, 2, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + kv = _SeqAllToAll.apply(self.spg, kv, 3, 1) + + context = self.local_attn(q, kv, **kwargs) + + # context shape: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in seqlen(packlen) and gather in n_head + context = _SeqAllToAll.apply(self.spg, context, 1, 2) + + return context + + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.With))) + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.WithOut))) + def _(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs) -> torch.Tensor: + """forward + + Arguments: + q (Tensor): q input to the layer + k (Tensor): k input to the layer + v (Tensor): v input to the layer + kwargs: other args + + Returns: + * output (Tensor): context output + """ + # self._scatter_gather_idx["q"] = [1, 0] # q/k/v shape: [sequence, head, head_dim] + # q shpae: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + q = _SeqAllToAll.apply(self.spg, q, 2, 1) + # k shpae: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + k = _SeqAllToAll.apply(self.spg, k, 2, 1) + # v shpae: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + v = _SeqAllToAll.apply(self.spg, v, 2, 1) + + context = self.local_attn(q, k, v, **kwargs) + + # context shape: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in seqlen(packlen) and gather in n_head + context = _SeqAllToAll.apply(self.spg, context, 1, 2) + + return context + + +def auto_wrap_distributed_attention(cls: nn.Module) -> Callable[[bool, Any, float], nn.Module]: + """ + Wrap a local attention module to a distributed one, which will be used in the ISP parallelism. + """ + + # should we impl distributed attention as a metaclass? + def _attetion_constructor( + local_attn_cls: type, causal=False, softmax_scale=None, attention_dropout=0.0 + ) -> nn.Module: + if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp": + return local_attn_cls(causal, softmax_scale, attention_dropout) + else: + return DistributedAttention( + local_attention=local_attn_cls(causal, softmax_scale, attention_dropout), + sequence_process_group=gpc.get_group(ParallelMode.TENSOR), + ) + + return partial(_attetion_constructor, local_attn_cls=cls) diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py new file mode 100644 index 00000000..47086ad9 --- /dev/null +++ b/internlm/core/parallel/comm/tensor.py @@ -0,0 +1,369 @@ +""" +communication for tensor/sequence parallel. +""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Tuple + +import torch +from torch import distributed as dist + +from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.comm.utils import ( + DUMMY_HANDLE_CONST, + AsyncCommHandle, + _gather, + _split, + all_gather_raw, + all_reduce_raw, + gather_forward_split_backward, + reduce_scatter_raw, + split_forward_gather_backward, +) +from internlm.model.modules.embedding import Embedding1D +from internlm.model.moe.moe import MoE + +# input gather dim +_GATHER_DIM = 1 # shape: [batch, seqlen, dim] or [1, packlen, dim] +_REDUCE_DIM = 1 # shape: [batch, seqlen, dim] or [1, packlen, dim] + + +class LinearRole(Enum): + COLUMN = "column" + ROW = "row" + + +# not really useful, only for code hint. +class TPCommunicator(ABC): + """ + Common communicator interafce for tensor/sequence parallel. + """ + + @abstractmethod + def save_total_input(self) -> bool: + """ + Should linear save total input after all gather as activation in sequence parallel. + """ + pass + + @abstractmethod + def communication_mode(self) -> str: + """ + communication mode of communictor + """ + pass + + @abstractmethod + def input_hook( + self, _input: torch.Tensor, async_op: bool = False, is_forward: bool = True + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + communication for input when forward/backward. + """ + pass + + @abstractmethod + def grad_output_hook( + self, grad_output: torch.Tensor, async_op: bool = False + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + communication for grad_output when backward. + """ + pass + + @abstractmethod + def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + communication for grad_input when backward. + """ + pass + + @abstractmethod + def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + communication for output when forward. + """ + pass + + +class TensorParallelCommunicator(TPCommunicator): + """ + tensor parallel communicator for linear + """ + + def __init__(self, process_group: dist.ProcessGroup, role: LinearRole) -> None: + assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}" + + self._process_group = process_group + self._role = role + + self._save_total_input = False + + def save_total_input(self) -> bool: + return self._save_total_input + + def communication_mode(self) -> str: + return "tp" + + def input_hook( + self, _input: torch.Tensor, async_op: bool = False, is_forward: bool = True # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + tensor parallel should do nothing for input. + """ + return _input, DUMMY_HANDLE_CONST + + def grad_output_hook( + self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + tensor parallel should do nothing for grad_output. + """ + return grad_output, DUMMY_HANDLE_CONST + + def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all reduce grad_input only for column parallel linear when backward. + """ + if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.ROW: + return grad_input, DUMMY_HANDLE_CONST + + return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op) + + def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all reduce output only for row parallel linear when forward. + """ + if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + return output, DUMMY_HANDLE_CONST + + return all_reduce_raw(output, process_group=self._process_group, async_op=async_op) + + +class SequenceParallelCommunicator(TPCommunicator): + """ + sequence parallel communicator for linear + """ + + def __init__( + self, process_group: dist.ProcessGroup, role: LinearRole, save_total_input_as_activation: bool = False + ) -> None: + assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}" + + self._process_group = process_group + self._role = role + + self._save_total_input = save_total_input_as_activation + + def save_total_input(self) -> bool: + return self._save_total_input + + def communication_mode(self) -> str: + return "sp" + + def input_hook( + self, _input: torch.Tensor, async_op: bool = False, is_forward: bool = True + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all gather input only for column parallel linear when forward/backward. + """ + # 1. world_size <= 1 + # 2. row parallel linear should not allgather input. + # 3. column parallel linear should not allgather input if save_total_input_as_activation and backward is True. + if ( + dist.get_world_size(self._process_group) <= 1 + or self._role == LinearRole.ROW + or (is_forward is False and self._save_total_input) + ): + return _input, DUMMY_HANDLE_CONST + + return all_gather_raw(_input, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) + + def grad_output_hook( + self, grad_output: torch.Tensor, async_op: bool = False + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all gather grad_output only for row parallel linear when backward. + """ + if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + return grad_output, DUMMY_HANDLE_CONST + + return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) + + def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + reduce scatter grad_input only for column parallel linear when backward. + """ + if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.ROW: + return grad_input, DUMMY_HANDLE_CONST + + return reduce_scatter_raw( + grad_input, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM + ) + + def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + reduce scatter output only for row parallel linear when forward. + """ + if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + return output, DUMMY_HANDLE_CONST + + return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM) + + +class HeadTensorParallelCommunicator(TensorParallelCommunicator): + """ + tensor parallel communicator for head linear + """ + + def __init__(self, parallel_mode: ParallelMode, retain_out_sharded: bool = True) -> None: + super().__init__(process_group=gpc.get_group(parallel_mode), role=LinearRole.COLUMN) + + self._parallel_mode = parallel_mode + self._retain_out_sharded = retain_out_sharded + + def grad_output_hook( + self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + split grad_output if retain_out_sharded is False. + """ + if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + return grad_output, DUMMY_HANDLE_CONST + + return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1) + + def output_hook( + self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all gather output for head layer if retain_out_sharded is False. + """ + if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + return output, DUMMY_HANDLE_CONST + + return _gather(output, parallel_mode=self._parallel_mode, dim=-1) + + +class HeadSequenceParallelCommunicator(SequenceParallelCommunicator): + """ + sequence parallel communicator for head linear + """ + + def __init__( + self, parallel_mode: ParallelMode, retain_out_sharded: bool = True, save_total_input_as_activation: bool = False + ) -> None: + super().__init__( + process_group=gpc.get_group(parallel_mode), + role=LinearRole.COLUMN, + save_total_input_as_activation=save_total_input_as_activation, + ) + + self._parallel_mode = parallel_mode + self._retain_out_sharded = retain_out_sharded + + # rewrite grad_output communication hook + def grad_output_hook( + self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + split grad_output if retain_out_sharded is False. + """ + if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + return grad_output, DUMMY_HANDLE_CONST + + return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1) + + # rewrite ouput communication hook + def output_hook( + self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all gather output for head layer if retain_out_sharded is False. + """ + if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + return output, DUMMY_HANDLE_CONST + + return _gather(output, parallel_mode=self._parallel_mode, dim=-1) + + +class MoESequenceParallelCommunicator: + """ + sequence parallel communicator for moe layer + """ + + def __init__(self, parallel_mode: ParallelMode) -> None: + self._parallel_mode = parallel_mode + + def register_module_hook(self, module: MoE) -> None: + assert isinstance(module, MoE), "MoE sequence parallel communicator is only support moe module" + + module.register_forward_pre_hook(self.input_hook, with_kwargs=True) + module.register_forward_hook(self.output_hook) + + def input_hook(self, module: MoE, args, kwargs) -> torch.Tensor: # pylint: disable=W0613 + """ + allgather input before forward and split grad_input after backward. + """ + _input = args[0] if len(args) > 0 else kwargs.pop("hidden_states") + _input = gather_forward_split_backward(_input, self._parallel_mode, dim=_GATHER_DIM) + + return (_input, *args), kwargs + + def output_hook(self, module: MoE, args: Any, output: Tuple[Any]) -> Tuple[Any]: # pylint: disable=W0613 + """ + split output after forward and allgather grad_output before backward. + """ + _output, *_others = output + _output = split_forward_gather_backward(_output, self._parallel_mode, dim=_REDUCE_DIM) + + return (_output, *_others) + + +class EmbbedingTensorParallelCommunicator: + """ + tensor parallel communicator for embbeding layer + """ + + def __init__(self, parallel_mode: ParallelMode) -> None: + self._parallel_mode = parallel_mode + + def register_module_hook(self, module: Embedding1D) -> None: + assert isinstance(module, Embedding1D), "Embbeding tensor parallel communicator is only support Embedding1D" + + module.register_forward_hook(self.output_hook) + + def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tuple[Any]: # pylint: disable=W0613 + """ + split output after forward and allgather grad_output before backward. + """ + _emb_dim = 2 # [bsz, seqlen, emb_dim] + + return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + + +class EmbbedingSequenceParallelCommunicator: + """ + sequence parallel communictor for embbeding layer + """ + + def __init__(self, parallel_mode: ParallelMode) -> None: + self._parallel_mode = parallel_mode + + def register_module_hook(self, module: Embedding1D) -> None: + assert isinstance(module, Embedding1D), "Embbeding sequence parallel communicator is only support Embedding1D" + + module.register_forward_hook(self.output_hook) + + def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tuple[Any]: # pylint: disable=W0613 + """ + split output after forward and allgather grad_output before backward. + """ + _emb_dim, _seq_dim = 2, 1 # [bsz, seqlen, emb_dim] + + output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + output = split_forward_gather_backward(output, self._parallel_mode, dim=_seq_dim) + + return output diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py new file mode 100644 index 00000000..dbfeb3fd --- /dev/null +++ b/internlm/core/parallel/comm/utils.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import Callable + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + +from internlm.core.context import global_context as gpc + + +class AsyncCommHandle(ABC): + """A interface for asynchronous communication handles.""" + + @abstractmethod + def wait(self) -> None: + """wait asynchronous communication to complete.""" + + +class DummyAsyncCommHandle(AsyncCommHandle): + """A fake communication handle used to maintain consistency in code writing""" + + def wait(self) -> None: + pass + + +DUMMY_HANDLE_CONST = DummyAsyncCommHandle() + + +# Raw operation, does not support autograd, but does support async +def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + input_ = input_.contiguous() + handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) + return input_, handle + + +class ReduceScatterFunc(torch.autograd.Function): + """Reduce scatter the input from the sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup, reduce_dim: int = 0) -> Tensor: + ctx.process_group = process_group + ctx.reduce_dim = reduce_dim + output, _ = reduce_scatter_raw(input_, process_group, reduce_dim=reduce_dim) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + gather_dim = ctx.reduce_dim + grad_input, _ = all_gather_raw(grad_output, ctx.process_group, gather_dim=gather_dim) + return grad_input, None, None + + +# Supports autograd, but does not support async +reduce_scatter = ReduceScatterFunc.apply + + +class AllReduceFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_reduce_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + _ = ctx # avoid lint warning W0613 + return grad_output, None + + +# Supports autograd, but does not support async +all_reduce = AllReduceFunc.apply + + +def _split(input_, parallel_mode, dim=-1): + # skip if only one rank involved + world_size = gpc.get_world_size(parallel_mode) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = gpc.get_local_rank(parallel_mode) + output = tensor_list[rank].contiguous() + output = output.detach().clone() + + return output + + +def _gather(input_, parallel_mode, dim=-1): + # skip if only one rank involved + world_size = gpc.get_world_size(parallel_mode) + if world_size == 1: + return input_ + + # all gather + rank = gpc.get_local_rank(parallel_mode) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + dist.all_gather(tensor_list, input_, group=group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(input_): + return _gather(input_, parallel_mode=None) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _gather(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.mode, ctx.dim), None, None + + +def gather_forward_split_backward(input_, parallel_mode, dim): + return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(input_): + return _split(input_, parallel_mode=None) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _split(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.mode, ctx.dim), None, None + + +def split_forward_gather_backward(input_, parallel_mode, dim): + return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) + + +def all_gather_raw( + input_: Tensor, + process_group: ProcessGroup, + async_op: bool = False, + gather_dim: int = 0, + memory_pool_allocator: Callable = None, +): + world_size = dist.get_world_size(process_group) + if world_size <= 1: + return input_, None + + if memory_pool_allocator is not None: + output = memory_pool_allocator() + else: + shape = list(input_.shape) + shape[gather_dim] = shape[gather_dim] * world_size + output = torch.empty(shape, dtype=input_.dtype, device=input_.device) + + handle = dist.all_gather_into_tensor(output, input_.contiguous(), group=process_group, async_op=async_op) + return output, handle + + +def reduce_scatter_raw( + input_: Tensor, + process_group: ProcessGroup, + op=dist.ReduceOp.SUM, + async_op: bool = False, + reduce_dim: int = 0, + memory_pool_allocator: Callable = None, +): + world_size = dist.get_world_size(process_group) + assert input_.shape[reduce_dim] % world_size == 0 + + if world_size <= 1: + return input_, None + + shape_list = list(input_.shape) + shape_list[reduce_dim] = shape_list[reduce_dim] // world_size + + if memory_pool_allocator is not None: + output = memory_pool_allocator(tuple(shape_list)) + else: + output = torch.empty( + shape_list, + dtype=input_.dtype, + device=input_.device, + ).contiguous() + + handle = dist.reduce_scatter_tensor(output, input_.contiguous(), op=op, group=process_group, async_op=async_op) + return output, handle diff --git a/internlm/core/parallel/comm/zero.py b/internlm/core/parallel/comm/zero.py new file mode 100644 index 00000000..db4ff4e2 --- /dev/null +++ b/internlm/core/parallel/comm/zero.py @@ -0,0 +1,106 @@ +""" +communication for zero parallel +""" + +from collections import OrderedDict +from typing import Dict, List, Union + +from torch import distributed as dist +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.naive_amp import unwrap_naive_amp +from internlm.core.parallel.comm.isp import ISPCommunicator +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import ScaleColumnParallelLinear + + +class ParamAsyncBcastHandler: + """ + Model Partition Handler for overlap broadcast with forward + """ + + def __init__( + self, zero1_mode: ParallelMode, model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicator = None + ) -> None: + self._block_to_param: Dict[nn.Module, List[nn.Parameter]] = OrderedDict() + self._param_to_rank: Dict[nn.Parameter, int] = {} + self._block_to_rank: Dict[nn.Module, int] = {} + self._bcast_handles: Dict[int, List[dist.Work]] = {} + + zero1_size = gpc.get_world_size(zero1_mode) + total_param_num = sum(p.numel() for p in model.parameters()) + avg_param_num = total_param_num * 1.0 // zero1_size + + # initialize an empty list for _bcast_handles of each rank + self._bcast_handles = {rank: [] for rank in range(zero1_size)} + + # record the parameters to transformer/embeding/head/norm block + for _chunk in unwrap_naive_amp(model): + for _, children in _chunk.named_children(): + # should be the transformer block definaton in modeling_xxx.py + if isinstance(children, nn.ModuleList): + # record the block that a parameter belongs to + for _, block in enumerate(children): + # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) + self._block_to_param[block] = list(block.parameters()) + else: + # record the block that a parameter belongs to + # self._block_to_param[name] = list(children.parameters()) + self._block_to_param[children] = list(children.parameters()) + + alloc_num = 0 + rank_to_go = 0 + + # process the parameters in block_to_param sequencially, + # allocate each parameter to a local rank of ParallelMode.ZERO1, + # NOTE that we do NOT consider following scenarios: + # 1) whether a parameter is trainable; + # 2) paramters maybe in different optimizer group + for block, params in self._block_to_param.items(): + # allocate a model block to a local rank of ParallelMode.ZERO1 + self._block_to_rank[block] = [rank_to_go] + for p in params: + alloc_num = alloc_num + p.numel() + # in this case, allocate the param to next rank if possible + if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1: + rank_to_go = rank_to_go + 1 + alloc_num = 0 + self._block_to_rank[block].append(rank_to_go) + # allocate a parameter to a local rank of ParallelMode.ZERO1 + self._param_to_rank[p] = rank_to_go + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + self._register_sync_parameters_hook(isp_communicator) + + def _register_sync_parameters_hook(self, isp_communicator: ISPCommunicator = None) -> None: + def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W0613 + bcast_handles = [] + # gather all required broadcast hanles into a list + for rank in self._block_to_rank[model]: + bcast_handles.extend(self._bcast_handles[rank]) + # need to clear _bcast_handles since they would be processed later + self._bcast_handles[rank] = [] + # wait all required broadcast handles to be completed + for handle in bcast_handles: + handle.wait() + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + for block, _ in self._block_to_rank.items(): + # TODO: remove special handling for embedding and head layers, + # instead implement support for weight parallelism of embedding and head layers within the ISP. + + # NOTE: Although the layernorm layer does not have explicit processing, + # both ISPCommunicator and ParamAsyncBcastHandler handle transformer blocks as granularity, + # so everything is fine. + if isp_communicator is None or isinstance(block, (Embedding1D, ScaleColumnParallelLinear)): + block.register_forward_pre_hook(_pre_forward_hook) + if isp_communicator: + isp_communicator.register_prerequisite_for_forward_prefetch_hooks(_pre_forward_hook) + + def get_rank_by_param(self, param) -> int: + return self._param_to_rank[param] + + def add_bcast_handle(self, rank, handle) -> None: + self._bcast_handles[rank].append(handle) diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py new file mode 100644 index 00000000..33c187ec --- /dev/null +++ b/internlm/core/parallel/shard.py @@ -0,0 +1,119 @@ +""" +shard strategies for parallel +""" + +from typing import Callable + +import torch +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +# The head layer in ISP mode is actually a special case, +# and we would prefer a unified segmentation and communication logic. +def get_tensor_split_parallel_mode(is_head: bool = False) -> ParallelMode: + tp_mode = gpc.config.parallel.tensor.mode + + if tp_mode == "isp" and is_head is False: + return ParallelMode.WEIGHT + else: + return ParallelMode.TENSOR + + +def get_head_parallel_mode() -> ParallelMode: + return ParallelMode.TENSOR + + +def get_parallel_strategies_split_mode(linear_name: str) -> str: + tp_mode = gpc.config.parallel.tensor.mode + + if linear_name in ("head", "output"): + return "head" + elif linear_name in ("wqkv", "wq", "wk", "wv", "wkv", "w1", "w3", "w13"): + return "column" + elif linear_name in ("wo", "out_proj", "w2") and tp_mode == "isp": + return "column" + elif linear_name in ("wo", "out_proj", "w2"): + return "row" + else: + return "unknown" + + +def partition_uniform(num_items: int, pipeline_parallel_size: int, num_chunks: int): + assert ( + num_items % num_chunks == 0 + ), "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" + + parts = [[] for _ in range(pipeline_parallel_size)] + partition_items = num_items // num_chunks + for idx in range(num_chunks): + base_idx = idx * partition_items + chunk_size = partition_items // pipeline_parallel_size + left = pipeline_parallel_size - partition_items % pipeline_parallel_size + if chunk_size == 0: + raise ValueError("Some nodes in Pipeline have no requests") + + for p in range(pipeline_parallel_size): + st = base_idx + base_idx += chunk_size + (p >= left) + parts[p].append((st, base_idx)) + + indexes = [] + for _parts in parts: + for s, e in _parts: + indexes.extend(list(range(s, e))) + assert len(indexes) == len(set(indexes)), indexes # should have no duplicates + assert set(indexes) == set(list(range(num_items))), (indexes, num_items) # should have the same indexes as expected + return parts + + +def pipeline_parallel_sharding_wrapper( + num_layers: int, num_chunks: int, model_builder: Callable, device: torch.device, **kwargs +): + """ + build generic model 1d + + Args: + num_layers (int): The number of layer. + num_chunks (int): The number of partitions in pipeline parallel. + device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default. + + """ + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) + parts = all_parts[pipeline_rank] + + if gpc.is_rank_for_log(): + logger.info("The layer sharding is %r.", all_parts) + + models = [] + + for start, end in parts: + kwargs["num_layers"] = end - start + kwargs["first"] = start == 0 + # If there is no content in the final layer, assign the last layer. + kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 + kwargs["device"] = device + kwargs["start_layer_idx"] = start + + chunk = model_builder(**kwargs).to(device) + setattr(chunk, "first_layer", start) + setattr(chunk, "last_layer", end) + + models.append(chunk) + + torch.distributed.barrier() + + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + + return model diff --git a/internlm/core/communication/__init__.py b/internlm/core/scheduler/comm/__init__.py similarity index 100% rename from internlm/core/communication/__init__.py rename to internlm/core/scheduler/comm/__init__.py diff --git a/internlm/core/communication/p2p.py b/internlm/core/scheduler/comm/p2p.py similarity index 100% rename from internlm/core/communication/p2p.py rename to internlm/core/scheduler/comm/p2p.py diff --git a/internlm/core/scheduler/comm/utils.py b/internlm/core/scheduler/comm/utils.py new file mode 100644 index 00000000..d9e6f7e8 --- /dev/null +++ b/internlm/core/scheduler/comm/utils.py @@ -0,0 +1,125 @@ +# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication + +from typing import List, Tuple, Union + +import torch +import torch.distributed as dist + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.utils.common import get_current_device + +TensorShape = Union[torch.Size, List[int], Tuple[int]] + + +def send_meta_helper(obj, next_rank, tensor_kwargs): + send_shape = torch.tensor(obj.size(), **tensor_kwargs) + send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs) + dist.send(send_ndims, next_rank) + dist.send(send_shape, next_rank) + + +def send_obj_meta(obj, next_rank=None): + """Sends obj meta information before sending a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be sent before communications. This function + synchronizes with :func:`recv_obj_meta`. + + Args: + obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent. + need_meta (bool, optional): If False, meta information won't be sent. + next_rank (int): The rank of the next member in pipeline parallel group. + + Returns: + bool: False + """ + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + + tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + if isinstance(obj, torch.Tensor): + send_obj_nums = torch.tensor(1, **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + send_meta_helper(obj, next_rank, tensor_kwargs) + else: + send_obj_nums = torch.tensor(len(obj), **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + for tensor_to_send in obj: + send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) + + +def recv_meta_helper(prev_rank, tensor_kwargs): + recv_ndims = torch.empty((), **tensor_kwargs) + dist.recv(recv_ndims, prev_rank) + recv_shape = torch.empty(recv_ndims, **tensor_kwargs) + dist.recv(recv_shape, prev_rank) + return recv_shape + + +def recv_obj_meta(prev_rank=None) -> torch.Size: + """Receives obj meta information before receiving a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be received before communications. This function + synchronizes with :func:`send_obj_meta`. + + Args: + obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. + prev_rank (int): The rank of the source of the obj. + + Returns: + Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. + """ + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + + tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + recv_obj_nums = torch.empty((), **tensor_kwargs) + dist.recv(recv_obj_nums, prev_rank) + if recv_obj_nums.item() == 1: + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape = torch.Size(recv_shape) + else: + obj_shape = [] + for _ in range(recv_obj_nums.item()): + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape.append(torch.Size(recv_shape)) + + return obj_shape + + +def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: + """Break a tensor into equal 1D chunks. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be split before communication. + new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. + + Returns: + :class:`torch.Tensor`: The split tensor + """ + partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR) + start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR) + end_index = start_index + partition_size + if new_buffer: + data = torch.empty(partition_size, dtype=tensor.dtype, device=get_current_device(), requires_grad=False) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: + """Opposite of above function, gather values from model parallel ranks. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. + Returns: + :class:`torch.Tensor`: The gathered tensor. + """ + world_size = gpc.get_world_size(ParallelMode.TENSOR) + numel = torch.numel(tensor) + numel_gathered = world_size * numel + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=get_current_device(), requires_grad=False) + chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] + dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR)) + return gathered diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 3aacf77a..339a404e 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -85,10 +85,7 @@ def _load_accum_batch(self, data: Any, label: Any): self._grad_accum_offset += self._bsz_stride if self.data_process_func: - _data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"]) - _label = self.data_process_func(_label, _data["cu_seqlens"], padding_v=-100) - _data.pop("cu_seqlens") - _data.pop("indexes") + _data, _label = self.data_process_func(_data, _label) return _data, _label diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 66d1cca2..97caa9f0 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -9,11 +9,11 @@ import torch import torch.distributed as dist -import internlm.core.communication as comm from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.core.naive_amp import NaiveAMPModel +from internlm.core.scheduler import comm from internlm.utils.common import ( SchedulerHook, check_data_is_packed, @@ -220,16 +220,9 @@ def load_micro_batch(self): micro_batch_data, micro_batch_label = self._load_micro_batch( data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, bsz_stride=self.bsz_stride ) - if self.data_process_func: - micro_batch_data["input_ids"] = self.data_process_func( - micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"] - ) - micro_batch_label = self.data_process_func( - micro_batch_label, micro_batch_data["cu_seqlens"], padding_v=-100 - ) - micro_batch_data.pop("cu_seqlens") - micro_batch_data.pop("indexes") + if self.data_process_func: + micro_batch_data, micro_batch_label = self.data_process_func(micro_batch_data, micro_batch_label) micro_batch_data["label"] = micro_batch_label self.microbatch_offset += self.bsz_stride diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 4461c001..91585a70 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -5,7 +5,9 @@ import torch +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm.utils import _split def get_dataset_type_ids_map(path): @@ -24,34 +26,51 @@ def get_dataset_type_id(dataset_type_ids_map, path): return match_idxes[0] -def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False, padding_v: int = 0): - """ - input_ids: if input_ids is not type_ids, the shape is (1, packed_length) - else the shape is (micro_num, packed_length) - is_type_ids: whether the input_ids is type_ids - - Return: - output: if input_ids is not type ids, the shape is (micro_bsz, max_length) - else the shape is (micro_num, micro_bsz, max_length) - """ - bsz = input_ids.shape[0] +def _unpack_data(data, cu_seqlens, padding_v: int = 0): + bsz = data.shape[0] num_seq = gpc.config.data["micro_bsz"] seq_len_ = gpc.config.data.seq_len - dtype_ = input_ids.dtype + dtype_ = data.dtype - outputs = torch.empty(bsz, num_seq, seq_len_, device=input_ids.device, dtype=dtype_).fill_(padding_v) + outputs = torch.empty(bsz, num_seq, seq_len_, device=data.device, dtype=dtype_).fill_(padding_v) for i in range(bsz): - output = torch.empty(num_seq, seq_len_, device=input_ids.device, dtype=dtype_).fill_(padding_v) + output = torch.empty(num_seq, seq_len_, device=data.device, dtype=dtype_).fill_(padding_v) cu_seqlens_slice = cu_seqlens[i] for j in range(num_seq): length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j] - output[j, 0:length] = input_ids[i, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] + output[j, 0:length] = data[i, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] outputs[i] = output - # if the input_ids is not type_ids, we need squeeze the first dimension if it is 1. - if bsz == 1 and not is_type_ids: - outputs = outputs.squeeze(0) - return outputs + + +def unpack_type_ids(type_ids, cu_seqlens): + return _unpack_data(type_ids, cu_seqlens) + + +def unpack_data(data, label): + data["input_ids"] = _unpack_data(data["input_ids"], data["cu_seqlens"], padding_v=0).squeeze(0) + label = _unpack_data(label, data["cu_seqlens"], padding_v=-100).squeeze(0) + + data.pop("cu_seqlens") + data.pop("indexes") + + return data, label + + +def packed_data_normalizer(data, label): + # Should we normalize packed data in this form of this data processor + # or let the dataset handle it? Currently inclined towards the latter. + assert data["input_ids"].shape[0] == 1, "data should be packed with batch size 1" + + data["indexes"] = data["indexes"][0] + data["cu_seqlens"] = data["cu_seqlens"][0].squeeze(0) + data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item() + + # Move to parallel package for standardization + if gpc.config.parallel.sequence_parallel and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=0) + + return data, label diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index be42897e..b90a25e9 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -22,7 +22,7 @@ ) from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape from internlm.core.trainer import Trainer -from internlm.data.utils import unpack_data +from internlm.data.utils import packed_data_normalizer, unpack_data from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler from internlm.utils.common import SchedulerHook, get_current_device @@ -79,10 +79,9 @@ def initialize_trainer( # initialize scheduler for trainer scheduler = None - if gpc.config.data.use_packed_dataset: - data_fn = None - else: - data_fn = unpack_data + + data_fn = packed_data_normalizer if gpc.config.data.use_packed_dataset else unpack_data + if gpc.is_using_parallel_mode(ParallelMode.PIPELINE): gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num tensor_shape = get_tensor_shape() diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 0ea000e9..df835681 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -13,10 +13,6 @@ from internlm.core.context import Config from internlm.core.context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode -from internlm.model.moe.megablock.utils import ( - check_megablock_installed, - check_stk_installed, -) from internlm.utils.common import get_master_node from internlm.utils.gputest import warmup_process_group from internlm.utils.logger import get_logger @@ -89,7 +85,7 @@ def args_sanity_check(): gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False)) if "tensor" not in gpc.config.parallel: - gpc.config.parallel._add_item("tensor", 1) + gpc.config.parallel._add_item("tensor", dict(size=1, mode="mtp")) if "weight" not in gpc.config.parallel: gpc.config.parallel._add_item("weight", dict(size=1, overlap=False, memory_pool=False)) @@ -339,16 +335,21 @@ def args_sanity_check(): model._add_item("moe_use_residual", False) if "moe_type" not in model: model._add_item("moe_type", "GShard") - # check dependency - if gpc.config.model.moe_type == "MegaBlock": - check_megablock_installed() - if gpc.config.model.moe_type == "MegaBlock-D": - check_megablock_installed() - check_stk_installed() if "mlp_layer_fusion" not in model: model._add_item("mlp_layer_fusion", False) + # qk_interleaved config + if "qk_interleaved" not in gpc.config.model: + if "adapt_hf" in gpc.config.model: + model._add_item("qk_interleaved", not gpc.config.model.adapt_hf) + else: + model._add_item("qk_interleaved", False) + elif "adapt_hf" in gpc.config.model: + assert gpc.config.model.adapt_hf == ( + not gpc.config.model.qk_interleaved + ), "adapt_hf and qk_interleaved must be opposite" + # process the parallel config if "sequence_parallel" not in gpc.config.parallel: gpc.config.parallel._add_item("sequence_parallel", False) diff --git a/internlm/model/__init__.py b/internlm/model/__init__.py index 26ac3e7c..e69de29b 100644 --- a/internlm/model/__init__.py +++ b/internlm/model/__init__.py @@ -1,33 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from .metrics import AccPerplex -from .modeling_internlm import build_model_with_cfg -from .modeling_internlm2 import build_model_with_cfg as build_model_with_cfg2 -from .modeling_llama import build_model_with_cfg as build_model_with_llama_cfg -from .modeling_llava import build_model_with_cfg as build_model_with_llava_cfg -from .modeling_moe import build_model_with_moe_cfg -from .modules.embedding import Embedding1D, RotaryEmbedding -from .modules.mlp import FeedForward -from .modules.multi_head_attention import MHA, DistributedAttention -from .moe.moe import MoE -from .ops.linear import RewardModelLinear, ScaleColumnParallelLinear -from .utils import gather_forward_split_backward - -__all__ = [ - "Embedding1D", - "FeedForward", - "MoE", - "RotaryEmbedding", - "RewardModelLinear", - "ScaleColumnParallelLinear", - "AccPerplex", - "MHA", - "DistributedAttention", - "gather_forward_split_backward", - "build_model_with_cfg", - "build_model_with_cfg2", - "build_model_with_moe_cfg", - "build_model_with_llama_cfg", - "build_model_with_llava_cfg", -] diff --git a/internlm/model/builder.py b/internlm/model/builder.py new file mode 100644 index 00000000..2b10406b --- /dev/null +++ b/internlm/model/builder.py @@ -0,0 +1,36 @@ +from typing import List, Union + +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper +from internlm.model.registry import model_initializer +from internlm.utils.common import get_current_device + + +def create_model(model_type, *args, **kwargs) -> Union[nn.Module, List[nn.Module]]: + num_layers = kwargs.pop("num_layers") + num_chunks = kwargs.pop("num_chunks", 1) + + # TODO: fix use_flash_attn parameter config + kwargs.pop("use_flash_attn", False) + kwargs.pop("apply_post_layer_norm") + kwargs.pop("embed_split_hidden", True) + + kwargs["checkpoint"] = float(kwargs.get("checkpoint", False)) + kwargs["device"] = get_current_device() + + model_buidler = model_initializer.get_module(module_name=model_type) + + if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE): + kwargs["first"] = kwargs["last"] = True + kwargs["start_layer_idx"] = 0 + kwargs["num_layers"] = num_layers + model = model_buidler(*args, **kwargs).to(kwargs["device"]) + setattr(model, "first_layer", 0) + setattr(model, "last_layer", num_layers) + else: + model = pipeline_parallel_sharding_wrapper(num_layers, num_chunks, model_buidler, *args, **kwargs) + + return model diff --git a/internlm/model/llava/__init__.py b/internlm/model/llava/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/internlm/model/llava_modules/clip_builder.py b/internlm/model/llava/clip_builder.py similarity index 100% rename from internlm/model/llava_modules/clip_builder.py rename to internlm/model/llava/clip_builder.py diff --git a/internlm/model/llava_modules/clip_encoder.py b/internlm/model/llava/clip_encoder.py similarity index 100% rename from internlm/model/llava_modules/clip_encoder.py rename to internlm/model/llava/clip_encoder.py diff --git a/internlm/model/llava_modules/projector_builder.py b/internlm/model/llava/projector_builder.py similarity index 100% rename from internlm/model/llava_modules/projector_builder.py rename to internlm/model/llava/projector_builder.py diff --git a/internlm/model/losses/ce_loss.py b/internlm/model/losses/ce_loss.py index 3fe4858b..69e09d2f 100644 --- a/internlm/model/losses/ce_loss.py +++ b/internlm/model/losses/ce_loss.py @@ -3,9 +3,11 @@ from torch import nn -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.ops.fusion_ops_import_helper import internlm_init_CrossEntropyLoss +from internlm.model.ops.cross_entropy import new_cross_entropy +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) class FlashGPTLMLoss(nn.Module): @@ -24,12 +26,11 @@ def __init__(self, parallel_output=True, label_smoothing=0): label_smoothing = 0 self.label_smoothing = label_smoothing - self.loss_fn = internlm_init_CrossEntropyLoss( - parallel_output=parallel_output, + self.loss_fn = new_cross_entropy( reduction="mean", - inplace_backward=True, - process_group=gpc.get_group(ParallelMode.TENSOR), label_smoothing=self.label_smoothing, + parallel_output=parallel_output, + inplace_backward=True, ) def forward(self, *args): diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 6db8044a..54cc41ba 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -3,17 +3,21 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.ops.fusion_ops_import_helper import ( - internlm_init_CrossEntropyLoss, - try_import_scatter_sum, -) +from internlm.model.ops.cross_entropy import new_cross_entropy from internlm.utils.common import SchedulerHook, get_current_device +from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +try: + from torch_scatter import scatter as cuda_scatter + + cuda_scatter_impl = True +except (ModuleNotFoundError, ImportError): + cuda_scatter_impl = False + +logger = get_logger(__file__) internlm_accelerator = get_accelerator() -scatter_sum = try_import_scatter_sum() def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): @@ -51,6 +55,24 @@ def vanilla_scatter( return out.scatter_add_(dim, index, src) +# move to ops when there are more than one files use it. +def _get_scatter_sum_impl(): + if cuda_scatter_impl and internlm_accelerator.get_accelerator_backend() in ( + AcceleratorType.GPU, + AcceleratorType.DIPU, + ): + if gpc.is_rank_for_log(): + logger.warning("Use cuda_scatter. Please note this!") + return cuda_scatter + else: + if gpc.is_rank_for_log(): + logger.warning("Use vanilla_scatter rather than cuda_scatter. Please note this!") + return vanilla_scatter + + +scatter_sum_impl = _get_scatter_sum_impl() + + class AccPerplex: """ AccPerplex module for calculating model's accuracy and perplexity metrics. @@ -88,7 +110,7 @@ def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str self.ds_tokens = torch.zeros(self.total_type_count, dtype=torch.long, device=device) self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types) - self.scatter_sum = scatter_sum if scatter_sum else vanilla_scatter + self.scatter_sum = scatter_sum_impl def set_current_type_ids(self, type_ids: torch.Tensor): self.batch_shift = 0 @@ -257,13 +279,12 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None: self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device) self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device) - self.loss_fn = internlm_init_CrossEntropyLoss( - parallel_output=gpc.config.model.parallel_output, + self.loss_fn = new_cross_entropy( reduction="none", + parallel_output=gpc.config.model.parallel_output, inplace_backward=True, - process_group=gpc.get_group(ParallelMode.TENSOR), ) - self.scatter_sum = scatter_sum if scatter_sum else vanilla_scatter + self.scatter_sum = scatter_sum_impl def update(self, logits, labels, type_ids=None): with torch.no_grad(): diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index ef5f7e9f..28c5a69c 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -12,27 +12,21 @@ from internlm.core.naive_amp import set_output_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.mlp import get_mlp_cls -from internlm.model.modules.multi_head_attention import MHA -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm -from internlm.model.ops.linear import RewardModelLinear, ScaleColumnParallelLinear +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import MHA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm from internlm.model.utils import ( - gather_forward_split_backward, - split_forward_gather_backward, + internlm1_mha_pre_load_convert, + internlm1_mha_save_convert, ) from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs, get_current_device from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER - -MODEL_TYPE = "INTERNLM" logger = get_logger(__file__) -RMSNorm = try_import_RMSNorm() -class PackedFlashBaseLayer1D(nn.Module): +class InternLM1Decoder(nn.Module): """ 1D Packed Flash Base Layer. @@ -42,15 +36,22 @@ class PackedFlashBaseLayer1D(nn.Module): mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0 by default. drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. dtype (torch.dtype): Type of data. torch.float by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. layer_idx (int): The index of current layer. 0 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - use_flash_attn (bool): Whether use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( @@ -69,11 +70,10 @@ def __init__( residual_in_fp32: bool = False, device: Optional[torch.device] = None, norm_type: str = "rmsnorm", + qk_interleaved: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, - tp_mode: str = "mtp", rope_base: int = 10000, mlp_layer_fusion: bool = False, multiple_of: int = 256, @@ -83,18 +83,14 @@ def __init__( # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx - self.use_flash_attn = use_flash_attn head_dim = hidden_size // num_attention_heads - self.tp_mode = tp_mode - parallel_mode = ParallelMode.WEIGHT if self.tp_mode == "isp" else ParallelMode.TENSOR self.mixer = MHA( embed_dim=hidden_size, num_heads=num_attention_heads, - process_group=gpc.get_group(parallel_mode), - sequence_process_group=gpc.get_group(ParallelMode.TENSOR), dropout=attn_drop_rate, + bias=True, max_position_embeddings=max_position_embeddings, softmax_scale=1 / math.sqrt(head_dim), causal=True, @@ -102,37 +98,35 @@ def __init__( use_dynamic_ntk_rope=use_dynamic_ntk_rope, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, - use_flash_attn=use_flash_attn, rope_base=rope_base, device=device, dtype=dtype, - tp_mode=self.tp_mode, + qk_interleaved=qk_interleaved, + enable_qkv_fusion=True, ) - self.dropout1 = nn.Dropout(drop_rate) - if norm_type == "rmsnorm": - self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon) - self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - - if use_swiglu or not use_flash_attn: - mlp_cls = get_mlp_cls(self.tp_mode) - self.mlp = mlp_cls( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - process_group=gpc.get_group(parallel_mode), - bias=False, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - sequence_parallel=gpc.config.parallel.sequence_parallel, - multiple_of=multiple_of, - ) + # Compatible with the name of internlm1 Wqkv linear layer + self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert, internlm1_mha_save_convert) + self.dropout1 = nn.Dropout(drop_rate) self.dropout2 = nn.Dropout(drop_rate) + + self.norm1 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.norm2 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + + self.mlp = new_feed_forward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + bias=False, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + # TODO: to support more activation functions + activation_type="swiglu" if use_swiglu else "swiglu", + ) + self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm @@ -144,7 +138,7 @@ def reset_parameters(self): for name, param in self.mixer.named_parameters(): if param.ndim == 1: param.data.zero_() - elif "Wqkv" in name: + elif "wqkv" in name: normal_(std=0.006)(param.data) elif self.use_scaled_init: scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) @@ -166,15 +160,13 @@ def reset_parameters(self): else: normal_(std=0.006 if "fc1" in name else 0.0015)(param.data) - def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): + def forward(self, hidden_states, **kwargs): if self.checkpoint and self.training: - return activation_checkpoint( - self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen - ) + return activation_checkpoint(self._forward, False, hidden_states, **kwargs) else: - return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen) + return self._forward(hidden_states, **kwargs) - def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): + def _forward(self, hidden_states=None, **kwargs): r"""Pass the input through the encoder layer. Args: @@ -183,12 +175,6 @@ def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_ cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } def _dropout_and_norm_attn(_hidden_states): _dropped = self.dropout1(_hidden_states) @@ -204,7 +190,7 @@ def _dropout_and_norm_attn(_hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mixer(hidden_states, **mixer_kwargs) + hidden_states = self.mixer(hidden_states, **kwargs) def _dropout_and_norm_ffn(_residual, _hidden_states): _dropped = self.dropout2(_hidden_states) @@ -225,7 +211,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states + residual -class PackedFlashInternLm1D(nn.Module): +class InternLM1(nn.Module): """ 1D Packed Flash InternLm. @@ -237,23 +223,27 @@ class PackedFlashInternLm1D(nn.Module): mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. dtype (torch.dtype): The type of data. torch.float by default. checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number of layers. 0.0 by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. device (Optional[Union[str, torch.device]]): The device will be used. None by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( @@ -271,7 +261,6 @@ def __init__( layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, - embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, @@ -279,11 +268,11 @@ def __init__( device: Optional[torch.device] = None, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", + qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, rope_base: int = 10000, mlp_layer_fusion: bool = False, multiple_of: int = 256, @@ -291,25 +280,17 @@ def __init__( super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) - self.tp_mode = "mtp" - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output - if is_reward: - head_cls = RewardModelLinear - else: - head_cls = ScaleColumnParallelLinear if first: - self.embedding = Embedding1D( - num_embeddings=vocab_size, embedding_dim=hidden_size, embed_split_hidden=embed_split_hidden - ) + self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) for _, param in self.embedding.named_parameters(): normal_(std=0.0052)(param) - self.embed_grad_scale = embed_grad_scale self.blocks = nn.ModuleList( [ - PackedFlashBaseLayer1D( + InternLM1Decoder( hidden_size=hidden_size, num_attention_heads=num_attention_heads, mlp_ratio=mlp_ratio, @@ -327,36 +308,32 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - tp_mode=self.tp_mode, rope_base=rope_base, + qk_interleaved=qk_interleaved, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, ) for lid in range(num_layers) ] ) + if last: - if norm_type == "rmsnorm": - self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.head = head_cls( + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.head = new_linear( + name="head", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, + is_reward=is_reward, weight_scale=embed_grad_scale, ) set_output_attr_to_module(self.head) for _, param in self.head.named_parameters(): normal_(std=0.0052)(param) - self.parallel_output = parallel_output - - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward(self, hidden_states=None, input_ids=None, **kwargs): # attention_mask: compute attention on the places where the value is 1 if hasattr(self, "embedding") and input_ids is not None: hidden_states = self.embedding(input_ids) @@ -365,172 +342,12 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) - if isinstance(cu_seqlens, list): - assert len(cu_seqlens) == 1 - cu_seqlens = cu_seqlens[0].to(hidden_states.device) - - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.squeeze(0) - - if indexes is not None: - assert len(indexes) == 1 - # The indexes are used to indicate the actual position IDs of each token in the packed input. - indexes = indexes[0] - # if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp": - indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None - for _, block in enumerate(self.blocks): - hidden_states = block( - hidden_states, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=inference_params, - max_seqlen=max_seqlen, - ) + hidden_states = block(hidden_states, **kwargs) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "head"): - hidden_states = self.head(hidden_states, gather_dim=1, tp_mode=self.tp_mode) + hidden_states = self.head(hidden_states) - if not self.parallel_output and gpc.is_pipeline_last_stage(): - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) return hidden_states - - -def _build_generic_model_1d(num_layers, num_chunks, **kwargs): - """ - build generic model 1d - - Args: - num_layers (int): The number of layer. - num_chunks (int): The number of partitions in pipeline parallel. - device (Optional[Union[str, torch.device]]): The device will be used. internlm_accelerator.device() by default. - - """ - device = get_current_device() - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - logger.info(f"The layer sharding is {all_parts}.") - - models = [] - start_idx, end_idx = 0, 0 - for start, end in parts: - start_idx, end_idx = start, end - kwargs["num_layers"] = end - start - kwargs["first"] = start == 0 - # If there is no content in the final layer, assign the last layer. - kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 - kwargs["device"] = device - kwargs["start_layer_idx"] = start - chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device) - - models.append(chunk) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - setattr(model, "first_layer", start_idx) - setattr(model, "last_layer", end_idx) - return model - - -@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) -def build_model_with_cfg( - num_chunks=1, - checkpoint=0.0, - dtype=torch.float, - embed_split_hidden=False, - num_layers=48, - hidden_size=2048, - vocab_size=50304, - embed_grad_scale=1, - parallel_output=True, - num_attention_heads=32, - max_position_embeddings=2048, - mlp_ratio=4.0, - residual_in_fp32=False, - use_dynamic_ntk_rope=False, - norm_type="rmsnorm", - drop_rate=0, - attn_drop_rate=0, - apply_post_layer_norm=False, # pylint: disable=W0613 - layer_norm_epsilon=1e-5, - is_reward=False, - dropout_selective_checkpoint=True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - use_flash_attn: bool = True, - rope_base: int = 10000, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, -): - """ - Build model with config. - - Args: - num_chunks (int): The number of partitions in pipeline parallel. 1 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. - dtype (torch.dtype): The type of data. torch.float by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - False by default. - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - num_attention_heads (int): The number of attention head. 32 by default. - mlp_ratio (int): The ratio of MLP layers. 4.0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily - because this parameter requires inconsistent data types to be passed between pipelines, - which requires significant modifications to internlm. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - drop_rate (float): The dropout rate of input hidden state. 0 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - is_reward (bool): Whether to use reward model. False by default. - dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. - use_scaled_init (bool): Whether to use scaled init. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - - """ - - cfg = dict( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden, - vocab_size=vocab_size, - embed_grad_scale=embed_grad_scale, - parallel_output=parallel_output, - mlp_ratio=mlp_ratio, - residual_in_fp32=residual_in_fp32, - max_position_embeddings=max_position_embeddings, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - norm_type=norm_type, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - layer_norm_epsilon=layer_norm_epsilon, - is_reward=is_reward, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - rope_base=rope_base, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - ) - - return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 08065ddc..fa65db79 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -3,11 +3,8 @@ from typing import Optional import torch -import torch.nn.functional as F -from einops import rearrange from torch import nn -from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.initialize_tensor import ( @@ -16,467 +13,25 @@ scaled_init_method_uniform, uniform_, ) -from internlm.model.modules.embedding import ( - DynamicNTKScalingRotaryEmbedding, - Embedding1D, - RotaryEmbedding, -) -from internlm.model.modules.mlp import get_mlp_cls -from internlm.model.modules.multi_head_attention import ( - _update_kv_cache, - get_gqa_attn_cls, -) -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm -from internlm.model.ops.linear import ( - RewardModelLinear, - ScaleColumnParallelLinearWithNormHead, - get_linear_cls, -) -from internlm.model.utils import ( - gather_forward_split_backward, - pack_output_after_attn, - split_forward_gather_backward, - unpack_qkv_before_attn, -) +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import GQA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs, get_current_device from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER - -MODEL_TYPE = "INTERNLM2_PUBLIC" logger = get_logger(__file__) -RMSNorm = try_import_RMSNorm() -internlm_accelerator = get_accelerator() -class MHA(nn.Module): +class InternLM2Decoder(nn.Module): """ - Multi-head self-attention and cross-attention. - - Args: - embed_dim (int): The dimention of hidden state. - num_heads (int): The number of attention heads. - num_kv_heads (int): The number of attention heads for key and value. - process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. - sequence_process_group (torch.distributed.ProcessGroup): The process group for attention calculation. - bias (bool): Whether the bias is needed for linears. Will be used when initializing QKV matrix and - output projection. False by default. - dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. - softmax_scale (float): The temperature to use for the softmax attention. - causal (boolean): Whether to apply causal attention mask. False by default. - layer_idx (int): The index of current layer. None by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. - rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements - XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. - use_flash_attn (bool): Whether to use flash attention or not.If False, vanilla attention module will be used. - False by default. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - rot_embed_HF_impl (Optional[bool]): Whether to use the rotary embedding implementation from HuggingFace. - True by default. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - num_kv_heads: int, - process_group: Optional[torch.distributed.ProcessGroup], - sequence_process_group: Optional[torch.distributed.ProcessGroup], - max_position_embeddings: int = 2048, - bias: bool = False, - dropout: float = 0.0, - softmax_scale: float = None, - causal: bool = False, - layer_idx: int = None, - use_dynamic_ntk_rope: bool = False, - use_flash_attn: bool = True, - rope_base: int = 10000, - rotary_emb_dim: int = 0, - rotary_emb_scale_base: int = 0, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - rot_embed_HF_impl: Optional[bool] = True, - tp_mode: str = "mtp", - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" - - self.head_dim = self.embed_dim // num_heads - self.num_kv_heads = num_kv_heads - self.kv_dim = self.head_dim * num_kv_heads - self.causal = causal - self.layer_idx = layer_idx - self.rotary_emb_dim = rotary_emb_dim - self.use_flash_attn = use_flash_attn - self.dtype = dtype - - self.q_per_kv = num_heads // num_kv_heads - - self.rot_embed_HF_impl = rot_embed_HF_impl - sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) - - self.max_position_embeddings = max_position_embeddings - self.use_dynamic_ntk_rope = use_dynamic_ntk_rope - self.tp_mode = tp_mode - - if self.rotary_emb_dim > 0: - if self.use_dynamic_ntk_rope: - self.rotary_emb = DynamicNTKScalingRotaryEmbedding( - self.rotary_emb_dim, - base=rope_base, - scale_base=rotary_emb_scale_base, - device=device, - max_position_embeddings=max_position_embeddings, - scaling_factor=1.0, # Currently do not support dynamic scaling. - ) - else: - self.rotary_emb = RotaryEmbedding( - self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device - ) - - Wqkv_cls = get_linear_cls(self.tp_mode, "column") - self.wqkv = Wqkv_cls( - embed_dim, - embed_dim + 2 * self.kv_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - - self.inner_attn, self.inner_cross_attn = get_gqa_attn_cls( - use_flash_attn, self.tp_mode, causal, softmax_scale, dropout, sequence_process_group - ) - self.inner_cross_attn_causal = causal - self.inner_cross_attn_softmax_scale = softmax_scale - self.inner_cross_attn_dropout = dropout - - wo_cls = get_linear_cls(self.tp_mode, "row") - self.wo = wo_cls( - embed_dim, - embed_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - - def forward(self, x, seqlen=None, inference_params=None, **kwargs): - if kwargs.get("indexes", None) is not None: - return self._packed_forward(x=x, inference_params=inference_params, **kwargs) - else: - return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) - - def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - bsz, _, _ = x.shape - qkv = self.wqkv(x) - - if seqlen is None: - qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim) - else: - qkv = rearrange(qkv, "(b s) (h gs d) -> b s h gs d", s=seqlen, gs=self.q_per_kv + 2, d=self.head_dim) - - q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) - - q = rearrange(q, "b s h gs d -> b s (h gs) d") - - if not self.rot_embed_HF_impl: - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - - if inference_params is None: - if self.rotary_emb_dim > 0: - q = self.rotary_emb._single_eval_forward(q) - k = self.rotary_emb._single_eval_forward(k) - kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - if self.dtype is torch.float32 and self.use_flash_attn: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_cross_attn(q=q, kv=kv).to(self.dtype) - else: - context = self.inner_cross_attn(q=q, kv=kv) - - else: - assert self.rotary_emb_dim > 0 - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - empties = inference_params.attention_mask[..., -1].sum(dim=-1) - moved_q = q.clone() - moved_k = k.clone() - if inference_params.sequence_len_offset == 0: - for i in range(len(empties)): - if empties[i] != 0: - moved_q[i][: -empties[i]] = q[i][empties[i] :] - moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_q = self.rotary_emb._single_eval_forward( - moved_q, seqlen_offset=inference_params.sequence_len_offset - ) - moved_k = self.rotary_emb._single_eval_forward( - moved_k, seqlen_offset=inference_params.sequence_len_offset - ) - for i in range(len(empties)): - if empties[i] != 0: - q[i][empties[i] :] = moved_q[i][: -empties[i]] - k[i][empties[i] :] = moved_k[i][: -empties[i]] - else: - q[i] = moved_q[i] - k[i] = moved_k[i] - else: - q = self.rotary_emb._single_forward( - q, - inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - - empties, - ) - k = self.rotary_emb._single_forward( - k, - inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - - empties, - ) - else: - raise NotImplementedError( - "You should make sure you are aware that you are changing the method of generating." - "According to your generation function instead of inference/seq_generator_module.py, " - "You may implement here for normal running." - ) - - kv = torch.stack([k, v], dim=2) - - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - if hasattr(inference_params, "window_size") and inference_params.window_size is not None: - if inference_params.window_size <= inference_params.sequence_len_offset: - assert kv.size(1) == 1, "update kv lenth more than 1" - inference_params.key_value_memory_dict[self.layer_idx][ - :, inference_params.keep_first : inference_params.window_size - 1, ... - ] = inference_params.key_value_memory_dict[self.layer_idx][ - :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... - ].clone() - inference_params.real_sequence_len_offset = inference_params.sequence_len_offset - inference_params.sequence_len_offset = inference_params.window_size - 1 - - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - inference_params.sequence_len_offset = inference_params.real_sequence_len_offset - else: - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - else: - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - # When using FP16, there is a high probability of NAN in the KV. - # Since NAN cannot be removed by multiplying with and 0, it needs - # to be removed manually here. - kv = torch.where(torch.isnan(kv), 0, kv) - - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - from flash_attn import flash_attn_varlen_kvpacked_func - - if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) - attn_mask = inference_params.attention_mask[:, None, ...] - attn_mask = torch.logical_or( - torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask - ) - attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) - cu_seqlens = torch.concat( - [ - torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), - attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), - ], - dim=0, - ) - cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) - max_seqlen_q = attn_mask4flsh.shape[-1] - max_seqlen_k = attn_mask4flsh.shape[-1] - total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) - total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( - -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] - ) - if self.dtype is torch.float32: - if total_q.dtype not in [torch.float16, torch.bfloat16]: - total_q = total_q.to(torch.bfloat16) - if total_kv.dtype not in [torch.float16, torch.bfloat16]: - total_kv = total_kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - output = flash_attn_varlen_kvpacked_func( - q=total_q, - kv=total_kv, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=0.0, - causal=True, - ).to(self.dtype) - else: - output = flash_attn_varlen_kvpacked_func( - q=total_q, - kv=total_kv, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=0.0, - causal=True, - ) - - context = torch.zeros_like(q) - context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) - - else: - attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) - if hasattr(inference_params, "window_size") and inference_params.window_size is not None: - if inference_params.window_size <= inference_params.sequence_len_offset: - attn_mask = torch.concat( - [ - attn_mask[..., : inference_params.keep_first], - attn_mask[..., -(inference_params.window_size - inference_params.keep_first) :], - ], - dim=-1, - ) - - k, v = torch.chunk(kv, 2, dim=2) - k = k.squeeze(2) - v = v.squeeze(2) - sp = k.shape - expansion = q.size(2) // k.size(2) - scores = torch.einsum( - "blhd,bnhd->bhln", - q, - k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), - ) / math.sqrt(q.size(-1)) - scores = scores.masked_fill(attn_mask, -65000.0) - scores = F.softmax(scores, dim=-1) # bsz x h x L x L - context = torch.einsum( - "bhmn,bnhd->bmhd", - scores, - v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), - ) - else: - if self.dtype is torch.float32 and self.use_flash_attn: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_cross_attn(q=q, kv=kv, causal=True).to(self.dtype) - else: - context = self.inner_cross_attn(q=q, kv=kv, causal=True) - - if seqlen is None: - context = rearrange(context, "b s h d -> b s (h d)") - else: - context = rearrange(context, "b s h d -> (b s) (h d)") - - out = self.wo(context) - return out - - def _packed_forward(self, x, inference_params=None, **kwargs): - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - assert self.use_flash_attn is True - - qkv = self.wqkv(x) - - qkv = rearrange(qkv, "b t (h gs d) -> b t h gs d", gs=self.q_per_kv + 2, d=self.head_dim) - - q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) - - q = rearrange(q, "b t h gs d -> b t (h gs) d") - - # qkv shift - # the rotary embedding in flash attention module in performed by separating the front and back parts, while - # most of others are done by odd-even methods. - if not self.rot_embed_HF_impl: - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - - indexes = kwargs.pop("indexes") - - q = self.rotary_emb._single_forward(q, indexes=indexes) - k = self.rotary_emb._single_forward(k, indexes=indexes) - - if inference_params is None: - kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - # for packed data, batch dimension with a size of 1 should be directly squeezed off. - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - q = q.squeeze(0) - kv = kv.squeeze(0) - # since torch_npu only supports fa with no packed data currently, qkv should be unpacked - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - q = unpack_qkv_before_attn(q, kwargs["cu_seqlens"]) - kv = unpack_qkv_before_attn(kv, kwargs["cu_seqlens"]) - - if self.dtype is torch.float32: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_attn( - q=q, - kv=kv, - cu_seqlens_q=kwargs["cu_seqlens"], - cu_seqlens_k=kwargs["cu_seqlens"], - max_seqlen_q=kwargs["max_seqlen"], - max_seqlen_k=kwargs["max_seqlen"], - dropout_p=self.inner_cross_attn_dropout, - softmax_scale=self.inner_cross_attn_softmax_scale, - causal=self.inner_cross_attn_causal, - ).to(self.dtype) - else: - context = self.inner_attn( - q=q, - kv=kv, - cu_seqlens_q=kwargs["cu_seqlens"], - cu_seqlens_k=kwargs["cu_seqlens"], - max_seqlen_q=kwargs["max_seqlen"], - max_seqlen_k=kwargs["max_seqlen"], - dropout_p=self.inner_cross_attn_dropout, - softmax_scale=self.inner_cross_attn_softmax_scale, - causal=self.inner_cross_attn_causal, - ) - else: - raise RuntimeError("Not support this right now") - - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - context = rearrange(context, "s h d -> s (h d)") # recover the shape - context = context.unsqueeze(0) # restore bsz dimension - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - context = rearrange(context, "b s h d -> b s (h d)") # recover the shape - context = pack_output_after_attn(context, kwargs["cu_seqlens"]) - - out = self.wo(context) - return out - - -class PackedFlashLlamaLayer1D(nn.Module): - """ - InternLM2 layer. + InternLM2 Decoder layer. Args: hidden_size (int): The hidden size of model. 768 by default. num_attention_heads (int): The number of attention heads. 12 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 12. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0 by default. drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. @@ -488,10 +43,16 @@ class PackedFlashLlamaLayer1D(nn.Module): use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. + apply_post_layer_norm (bool): Whether to apply layer normalization after the attention and mlp. + Defaults to False. + fused_dropout_add_ln (bool): Whether to fuse dropout, residual addition, and layer normalization. + Defaults to True. + no_bias (bool): Whether to exclude bias in attention and feed-forward networks. Defaults to False. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - use_flash_attn (bool): Whether use flash-attn. True by default. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu @@ -499,6 +60,8 @@ class PackedFlashLlamaLayer1D(nn.Module): ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, init_type (str): Initialization type. Use uniform or normal. "normal" by default, rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( @@ -521,12 +84,10 @@ def __init__( fused_dropout_add_ln: bool = True, no_bias: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = True, + qk_interleaved: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, - tp_mode: str = "mtp", attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, ffn_uplayer_init_std: float = 0.02, @@ -541,7 +102,6 @@ def __init__( # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx - self.use_flash_attn = use_flash_attn self.prenorm = not apply_post_layer_norm assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" self.fused_dropout_add_ln = fused_dropout_add_ln @@ -552,16 +112,12 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.use_dynamic_ntk_rope = use_dynamic_ntk_rope - self.tp_mode = tp_mode - parallel_mode = ParallelMode.WEIGHT if self.tp_mode == "isp" else ParallelMode.TENSOR head_dim = hidden_size // num_attention_heads - self.attention = MHA( + self.attention = GQA( embed_dim=hidden_size, num_heads=num_attention_heads, num_kv_heads=num_kv_attention_heads, - process_group=gpc.get_group(parallel_mode), - sequence_process_group=gpc.get_group(ParallelMode.TENSOR), dropout=attn_drop_rate, max_position_embeddings=max_position_embeddings, softmax_scale=1 / math.sqrt(head_dim), @@ -570,39 +126,32 @@ def __init__( use_dynamic_ntk_rope=use_dynamic_ntk_rope, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, - use_flash_attn=use_flash_attn, device=device, dtype=dtype, - rot_embed_HF_impl=adapt_hf, + qk_interleaved=qk_interleaved, bias=not no_bias, rope_base=rope_base, - tp_mode=self.tp_mode, + enable_qkv_fusion=True, ) self.dropout1 = nn.Dropout(drop_rate) - if norm_type == "rmsnorm": - self.attention_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.dropout2 = nn.Dropout(drop_rate) + self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) - self.feed_forward = get_mlp_cls(self.tp_mode)( + self.feed_forward = new_feed_forward( hidden_size, int(hidden_size * mlp_ratio), out_features=hidden_size, - process_group=gpc.get_group(parallel_mode), bias=False, device=device, dtype=dtype, mlp_layer_fusion=mlp_layer_fusion, - sequence_parallel=sequence_parallel, multiple_of=multiple_of, + # TODO: to support more activation functions + activation_type="swiglu" if use_swiglu else "swiglu", ) - assert use_swiglu is True, "InternLM2 only support swiglu." - self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm @@ -646,19 +195,13 @@ def reset_parameters(self): param.data ) - def forward( - self, hidden_states, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None - ): + def forward(self, hidden_states, residual=None, **kwargs): if self.checkpoint and self.training: - return activation_checkpoint( - self._forward, False, hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen - ) + return activation_checkpoint(self._forward, False, hidden_states, residual, **kwargs) else: - return self._forward(hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen) + return self._forward(hidden_states, residual, **kwargs) - def _forward( - self, hidden_states=None, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None - ): + def _forward(self, hidden_states=None, residual=None, **kwargs): r"""Pass the input through the encoder layer. Args: @@ -683,13 +226,8 @@ def _dropout_and_norm_attn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } - hidden_states = self.attention(hidden_states, **mixer_kwargs) + + hidden_states = self.attention(hidden_states, **kwargs) if not isinstance(self.feed_forward, nn.Identity): if not self.fused_dropout_add_ln: @@ -715,13 +253,8 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states + residual else: assert residual is None - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } - mixer_out = self.attention(hidden_states, **mixer_kwargs) + + mixer_out = self.attention(hidden_states, **kwargs) if self.return_residual: # mixer out is actually a pair here mixer_out, hidden_states = mixer_out hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( @@ -737,28 +270,26 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class PackedFlashLlama1D(nn.Module): +class InternLM2(nn.Module): """ - 1D Packed Flash InternLM2. + InternLM2 Model. Args: - num_layers (int): The number of layer. 12 by default. - hidden_size (int): The size of hidden state. 768 by default. - num_attention_heads (int): The number of attention head. 12 by default. + num_layers (int): The number of layer. 48 by default. + hidden_size (int): The size of hidden state. 2048 by default. + num_attention_heads (int): The number of attention head. 32 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 32. vocab_size (int): The size of vocabulary. 50304 by default. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. max_position_embeddings (int): The maximum position embeddings. 2048 by default. dtype (torch.dtype): The type of data. torch.float by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number - of layers. 1.0 by default. + checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 0.0 by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. @@ -766,7 +297,10 @@ class PackedFlashLlama1D(nn.Module): device (Optional[Union[str, torch.device]]): The device will be used. None by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. embedding_init_std (float): std used to init embedding weight. 0.02 by default, attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, @@ -777,28 +311,26 @@ class PackedFlashLlama1D(nn.Module): init_type (str): Initialization type. Use uniform or normal. "normal" by default, rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. norm_head (bool): Whether to use norm head. False by default. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, + num_layers: int = 48, + hidden_size: int = 2048, + num_attention_heads: int = 32, + num_kv_attention_heads: int = 32, vocab_size: int = 50304, - mlp_ratio: int = 4, + mlp_ratio: float = 4.0, attn_drop_rate: float = 0.0, drop_rate: float = 0.0, max_position_embeddings: int = 2048, dtype: torch.dtype = torch.float, - checkpoint: bool = False, - checkpoint_fraction: float = 1.0, + checkpoint: float = 0.0, layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, - embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, @@ -808,12 +340,11 @@ def __init__( no_bias=False, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = True, + qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, embedding_init_std: float = 0.02, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, @@ -823,44 +354,27 @@ def __init__( init_type: str = "normal", rope_base: int = 10000, norm_head: bool = False, - tp_mode: str = "mtp", mlp_layer_fusion: bool = False, multiple_of: int = 256, ): super().__init__() - self.use_flash_attn = use_flash_attn - - if checkpoint_fraction <= 0: - checkpoint = False - if not checkpoint: - checkpoint_fraction = 0 - checkpoint_layer_num = num_layers * checkpoint_fraction - - self.tp_mode = tp_mode - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") - - if is_reward: - head_cls = RewardModelLinear - else: - head_cls = ScaleColumnParallelLinearWithNormHead + checkpoint_layer_num = int(num_layers * checkpoint) + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output if first: - self.tok_embeddings = Embedding1D( - num_embeddings=vocab_size, embedding_dim=hidden_size, embed_split_hidden=embed_split_hidden - ) + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.tok_embeddings.named_parameters(): if init_type == "normal": normal_(std=embedding_init_std)(param) else: uniform_(std=embedding_init_std)(param) - self.embed_grad_scale = embed_grad_scale - self.layers = nn.ModuleList( [ - PackedFlashLlamaLayer1D( + InternLM2Decoder( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_kv_attention_heads=num_kv_attention_heads, @@ -882,14 +396,12 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - adapt_hf=adapt_hf, + qk_interleaved=qk_interleaved, attn_wqkv_init_std=attn_wqkv_init_std, attn_other_init_std=attn_other_init_std, ffn_uplayer_init_std=ffn_uplayer_init_std, ffn_other_init_std=ffn_other_init_std, init_type=init_type, - tp_mode=self.tp_mode, rope_base=rope_base, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, @@ -900,23 +412,16 @@ def __init__( if last: if not apply_post_layer_norm: - if norm_type == "rmsnorm": - self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - if norm_head and not issubclass(head_cls, ScaleColumnParallelLinearWithNormHead): - raise TypeError( - "Parameter ``norm_head`` should only be True when head_cls is " - f"``ScaleColumnParallelLinearWithNormHead``, instead of {head_cls}." - ) - self.output = head_cls( # pylint: disable=E1123 + self.output = new_linear( + name="output", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, + is_reward=is_reward, weight_scale=embed_grad_scale, norm_head=norm_head, ) @@ -926,9 +431,7 @@ def __init__( else: uniform_(std=out_head_init_std)(param) - self.parallel_output = parallel_output - - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward(self, hidden_states=None, input_ids=None, **kwargs): # attention_mask: compute attention on the places where the value is 1 if hasattr(self, "tok_embeddings") and input_ids is not None: hidden_states = self.tok_embeddings(input_ids) @@ -936,210 +439,13 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) - if isinstance(cu_seqlens, list): - assert len(cu_seqlens) == 1 - cu_seqlens = cu_seqlens[0].to(hidden_states.device) - - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.squeeze(0) - - if indexes is not None: - assert len(indexes) == 1 - # The indexes are used to indicate the actual position IDs of each token in the packed input. - indexes = indexes[0] - # if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp": - indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None for _, block in enumerate(self.layers): - hidden_states = block( - hidden_states, - residual=None, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=inference_params, - max_seqlen=max_seqlen, - ) + hidden_states = block(hidden_states, residual=None, **kwargs) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "output"): - hidden_states = self.output(hidden_states, gather_dim=1, tp_mode=self.tp_mode) - - if not self.parallel_output and gpc.is_pipeline_last_stage(): - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + hidden_states = self.output(hidden_states) return hidden_states - - -def _build_generic_model_1d(num_layers, num_chunks, **kwargs): - """ - build generic model 1d - - Args: - num_layers (int): The number of layer. - num_chunks (int): The number of partitions in pipeline parallel. - device (Optional[Union[str, torch.device]]): The device will be used. internlm_accelerator.device() by default. - - """ - device = get_current_device() - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - logger.info(f"The layer sharding is {all_parts}.") - - models = [] - kwargs["checkpoint_fraction"] = float(kwargs.get("checkpoint", False)) - start_idx, end_idx = 0, 0 - for start, end in parts: - start_idx, end_idx = start, end - kwargs["num_layers"] = end - start - kwargs["first"] = start == 0 - # If there is no content in the final layer, assign the last layer. - kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 - kwargs["device"] = device - kwargs["start_layer_idx"] = start - chunk = PackedFlashLlama1D(**filter_kwargs(PackedFlashLlama1D.__init__, kwargs)).to(device) - - models.append(chunk) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - setattr(model, "first_layer", start_idx) - setattr(model, "last_layer", end_idx) - return model - - -@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) -def build_model_with_cfg( - num_chunks=1, - checkpoint=False, - dtype=torch.float, - embed_split_hidden=False, - num_layers=48, - hidden_size=2048, - vocab_size=50304, - embed_grad_scale=1, - parallel_output=True, - num_attention_heads=32, - num_kv_attention_heads=None, - mlp_ratio=4.0, - residual_in_fp32=False, - norm_type="rmsnorm", - adapt_hf=True, - drop_rate=0, - attn_drop_rate=0, - apply_post_layer_norm=False, # pylint: disable=W0613 - no_bias=False, - deepnorm=False, - layer_norm_epsilon=1e-5, - is_reward=False, - dropout_selective_checkpoint=True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - use_flash_attn: bool = True, - embedding_init_std: float = 0.02, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - out_head_init_std: float = 0.02, - init_type: str = "normal", - rope_base: int = 10000, - norm_head: bool = False, - max_position_embeddings=2048, - use_dynamic_ntk_rope=False, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, -): - """ - Builde model with config - - Args: - num_chunks (int): The number of partitions in pipeline parallel. 1 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. - dtype (torch.dtype): The type of data. torch.float by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - False by default. - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - num_attention_heads (int): The number of attention head. 32 by default. - mlp_ratio (int): The ratio of MLP layers. 4.0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily - because this parameter requires inconsistent data types to be passed between pipelines, - which requires significant modifications to internlm. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - drop_rate (float): The dropout rate of input hidden state. 0 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - is_reward (bool): Whether to use reward model. False by default. - dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. - use_scaled_init (bool): Whether to use scaled init. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - embedding_init_std (float): std used to init embedding weight. 0.02 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - max_position_embeddings (int): The maximum position embeddings. 2048 by default. - use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. - """ - if deepnorm: - raise AssertionError("deepnorm will not be supported in future versions." "Use early versions if necessary.") - - cfg = dict( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden, - vocab_size=vocab_size, - embed_grad_scale=embed_grad_scale, - parallel_output=parallel_output, - mlp_ratio=mlp_ratio, - apply_post_layer_norm=apply_post_layer_norm, - no_bias=no_bias, - residual_in_fp32=residual_in_fp32, - norm_type=norm_type, - adapt_hf=adapt_hf, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - layer_norm_epsilon=layer_norm_epsilon, - is_reward=is_reward, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - embedding_init_std=embedding_init_std, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - out_head_init_std=out_head_init_std, - init_type=init_type, - rope_base=rope_base, - norm_head=norm_head, - max_position_embeddings=max_position_embeddings, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - ) - - return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index adbb9a9a..1e077d4f 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -2,11 +2,8 @@ from typing import Optional import torch -import torch.nn.functional as F -from einops import rearrange from torch import nn -from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import set_output_attr_to_module @@ -16,462 +13,25 @@ scaled_init_method_uniform, uniform_, ) -from internlm.model.modules.embedding import Embedding1D, RotaryEmbedding -from internlm.model.modules.mlp import get_mlp_cls -from internlm.model.modules.multi_head_attention import ( - _update_kv_cache, - get_gqa_attn_cls, -) -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm -from internlm.model.ops.linear import ( - RewardModelLinear, - ScaleColumnParallelLinear, - get_linear_cls, -) -from internlm.model.utils import ( - gather_forward_split_backward, - pack_output_after_attn, - split_forward_gather_backward, - unpack_qkv_before_attn, -) +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import GQA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs, get_current_device from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER - -MODEL_TYPE = "LLAMA2" logger = get_logger(__file__) -RMSNorm = try_import_RMSNorm() -internlm_accelerator = get_accelerator() -class MHA(nn.Module): +class Llama2Decoder(nn.Module): """ - Multi-head self-attention and cross-attention. - - Args: - embed_dim (int): The dimention of hidden state. - num_heads (int): The number of attention heads. - process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. - sequence_process_group (torch.distributed.ProcessGroup): The process group for attention calculation. - bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and - output projection. True by default. - dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. - softmax_scale (float): The temperature to use for the softmax attention. - causal (boolean): Whether to apply causal attention mask. False by default. - layer_idx (int): The index of current layer. None by default. - rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. - rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements - XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. - use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used. - False by default. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - use_flash_attn (bool): Whether to use flash-attn. True by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. - - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - num_kv_heads: int, - process_group: Optional[torch.distributed.ProcessGroup], - sequence_process_group: Optional[torch.distributed.ProcessGroup], - bias: bool = True, - dropout: float = 0.0, - softmax_scale: float = None, - causal: bool = False, - layer_idx: int = None, - rope_base: int = 10000, - rotary_emb_dim: int = 0, - rotary_emb_scale_base: int = 0, - use_flash_attn: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - rot_embed_HF_impl: Optional[bool] = False, - tp_mode: str = "mtp", - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" - - self.head_dim = self.embed_dim // num_heads - self.num_kv_heads = num_kv_heads - self.kv_dim = self.head_dim * num_kv_heads - self.causal = causal - self.layer_idx = layer_idx - self.rotary_emb_dim = rotary_emb_dim - self.use_flash_attn = use_flash_attn - self.dtype = dtype - self.tp_mode = tp_mode - - self.rot_embed_HF_impl = rot_embed_HF_impl - sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) - - if self.rotary_emb_dim > 0: - self.rotary_emb = RotaryEmbedding( - self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device - ) - - Wqkv_cls = get_linear_cls(self.tp_mode, "column") - # notice here should change bias=True - self.wq = Wqkv_cls( - embed_dim, - embed_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - self.wk = Wqkv_cls( - embed_dim, - self.kv_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - self.wv = Wqkv_cls( - embed_dim, - self.kv_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - - self.inner_attn, self.inner_cross_attn = get_gqa_attn_cls( - use_flash_attn, self.tp_mode, causal, softmax_scale, dropout, sequence_process_group - ) - self.inner_cross_attn_causal = causal - self.inner_cross_attn_softmax_scale = softmax_scale - self.inner_cross_attn_dropout = dropout - - # output projection always have the bias (for now) - out_proj_cls = get_linear_cls(self.tp_mode, "row") - self.wo = out_proj_cls( - embed_dim, - embed_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - - def forward(self, x, seqlen=None, inference_params=None, **kwargs): - if kwargs.get("indexes", None) is not None: - return self._packed_forward(x=x, inference_params=inference_params, **kwargs) - else: - return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) - - def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - bsz, _, _ = x.shape - q, k, v = self.wq(x), self.wk(x), self.wv(x) - if seqlen is None: - q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) - k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) - v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) - else: - q = rearrange(q, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) - k = rearrange(k, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) - v = rearrange(v, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) - - if not self.rot_embed_HF_impl: - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - if inference_params is None: - if self.rotary_emb_dim > 0: - q = self.rotary_emb._single_eval_forward(q) - k = self.rotary_emb._single_eval_forward(k) - kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - if self.dtype is torch.float32 and self.use_flash_attn: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_cross_attn(q=q, kv=kv).to(self.dtype) - else: - context = self.inner_cross_attn(q=q, kv=kv) - - else: - assert self.rotary_emb_dim > 0 - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - empties = inference_params.attention_mask[..., -1].sum(dim=-1) - moved_q = q.clone() - moved_k = k.clone() - if inference_params.sequence_len_offset == 0: - for i in range(len(empties)): - if empties[i] != 0: - moved_q[i][: -empties[i]] = q[i][empties[i] :] - moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_q = self.rotary_emb._single_eval_forward( - moved_q, seqlen_offset=inference_params.sequence_len_offset - ) - moved_k = self.rotary_emb._single_eval_forward( - moved_k, seqlen_offset=inference_params.sequence_len_offset - ) - for i in range(len(empties)): - if empties[i] != 0: - q[i][empties[i] :] = moved_q[i][: -empties[i]] - k[i][empties[i] :] = moved_k[i][: -empties[i]] - else: - q[i] = moved_q[i] - k[i] = moved_k[i] - else: - q = self.rotary_emb._single_forward( - q, - inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - - empties, - ) - k = self.rotary_emb._single_forward( - k, - inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - - empties, - ) - else: - raise NotImplementedError( - "You should make sure you are aware that you are changing the method of generating." - "According to your generation function instead of inference/seq_generator_module.py, " - "You may implement here for normal running." - ) - - kv = torch.stack([k, v], dim=2) - - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - if hasattr(inference_params, "window_size") and inference_params.window_size is not None: - if inference_params.window_size <= inference_params.sequence_len_offset: - assert kv.size(1) == 1, "update kv lenth more than 1" - inference_params.key_value_memory_dict[self.layer_idx][ - :, inference_params.keep_first : inference_params.window_size - 1, ... - ] = inference_params.key_value_memory_dict[self.layer_idx][ - :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... - ].clone() - inference_params.real_sequence_len_offset = inference_params.sequence_len_offset - inference_params.sequence_len_offset = inference_params.window_size - 1 - - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - inference_params.sequence_len_offset = inference_params.real_sequence_len_offset - else: - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - else: - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - # When using FP16, there is a high probability of NAN in the KV. - # Since NAN cannot be removed by multiplying with and 0, it needs - # to be removed manually here. - kv = torch.where(torch.isnan(kv), 0, kv) - - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - from flash_attn.flash_attn_interface import FlashAttnVarlenKVPackedFunc - - if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) - attn_mask = inference_params.attention_mask[:, None, ...] - attn_mask = torch.logical_or( - torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask - ) - attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) - cu_seqlens = torch.concat( - [ - torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), - attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), - ], - dim=0, - ) - cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) - max_seqlen_q = attn_mask4flsh.shape[-1] - max_seqlen_k = attn_mask4flsh.shape[-1] - total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) - total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( - -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] - ) - - if self.dtype is torch.float32: - if total_q.dtype not in [torch.float16, torch.bfloat16]: - total_q = total_q.to(torch.bfloat16) - if total_kv.dtype not in [torch.float16, torch.bfloat16]: - total_kv = total_kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - output = FlashAttnVarlenKVPackedFunc.apply( - total_q, - total_kv, - cu_seqlens, - cu_seqlens, - max_seqlen_q, - max_seqlen_k, - 0.0, - None, - True, - False, - ).to(self.dtype) - else: - output = FlashAttnVarlenKVPackedFunc.apply( - total_q, - total_kv, - cu_seqlens, - cu_seqlens, - max_seqlen_q, - max_seqlen_k, - 0.0, - None, - True, - False, - ) - - context = torch.zeros_like(q) - context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) - - else: - attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) - if hasattr(inference_params, "window_size") and inference_params.window_size is not None: - if inference_params.window_size <= inference_params.sequence_len_offset: - attn_mask = torch.concat( - [ - attn_mask[..., : inference_params.keep_first], - attn_mask[..., -(inference_params.window_size - inference_params.keep_first) :], - ], - dim=-1, - ) - - k, v = torch.chunk(kv, 2, dim=2) - k = k.squeeze(2) - v = v.squeeze(2) - sp = k.shape - expansion = q.size(2) // k.size(2) - scores = torch.einsum( - "blhd,bnhd->bhln", - q, - k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), - ) / math.sqrt(q.size(-1)) - scores = scores.masked_fill(attn_mask, -65000.0) - scores = F.softmax(scores, dim=-1) # bsz x h x L x L - context = torch.einsum( - "bhmn,bnhd->bmhd", - scores, - v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), - ) - else: - if self.dtype is torch.float32 and self.use_flash_attn: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_cross_attn(q=q, kv=kv, causal=True).to(self.dtype) - else: - context = self.inner_cross_attn(q=q, kv=kv, causal=True) - if seqlen is None: - context = rearrange(context, "b s h d -> b s (h d)") - else: - context = rearrange(context, "b s h d -> (b s) (h d)") - out = self.wo(context) - return out - - def _packed_forward(self, x, inference_params=None, **kwargs): - """ - we delete seqlen=None for lint check, cause this arg is not used. - - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - assert self.use_flash_attn is True - q, k, v = self.wq(x), self.wk(x), self.wv(x) - q = rearrange(q, "b t (h d) -> b t h d", d=self.head_dim) - k = rearrange(k, "b t (h d) -> b t h d", d=self.head_dim) - v = rearrange(v, "b t (h d) -> b t h d", d=self.head_dim) - - # qkv shift - # the rotary embedding in flash attention module in performed by separating the front and back parts, while - # most of others are done by odd-even methods. - if not self.rot_embed_HF_impl: - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - - indexes = kwargs.pop("indexes") - - q = self.rotary_emb._single_forward(q, indexes=indexes) - k = self.rotary_emb._single_forward(k, indexes=indexes) - - if inference_params is None: - kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - # for packed data, batch dimension with a size of 1 should be directly squeezed off. - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - q = q.squeeze(0) - kv = kv.squeeze(0) - # since torch_npu only supports fa with no packed data currently, qkv should be unpacked - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - q = unpack_qkv_before_attn(q, kwargs["cu_seqlens"]) - kv = unpack_qkv_before_attn(kv, kwargs["cu_seqlens"]) - - if self.dtype is torch.float32: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_attn( - q=q, - kv=kv, - cu_seqlens_q=kwargs["cu_seqlens"], - cu_seqlens_k=kwargs["cu_seqlens"], - max_seqlen_q=kwargs["max_seqlen"], - max_seqlen_k=kwargs["max_seqlen"], - dropout_p=self.inner_cross_attn_dropout, - softmax_scale=self.inner_cross_attn_softmax_scale, - causal=self.inner_cross_attn_causal, - ).to(self.dtype) - else: - context = self.inner_attn( - q=q, - kv=kv, - cu_seqlens_q=kwargs["cu_seqlens"], - cu_seqlens_k=kwargs["cu_seqlens"], - max_seqlen_q=kwargs["max_seqlen"], - max_seqlen_k=kwargs["max_seqlen"], - dropout_p=self.inner_cross_attn_dropout, - softmax_scale=self.inner_cross_attn_softmax_scale, - causal=self.inner_cross_attn_causal, - ) - else: - raise RuntimeError("Not support this right now") - - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - context = rearrange(context, "s h d -> s (h d)") # recover the shape - context = context.unsqueeze(0) # restore bsz dimension - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - context = rearrange(context, "b s h d -> b s (h d)") # recover the shape - context = pack_output_after_attn(context, kwargs["cu_seqlens"]) - - out = self.wo(context) - return out - - -class PackedFlashLlamaLayer1D(nn.Module): - """ - 1D Packed Flash Llama Layer. + Llama2 Decoder Layer. Args: hidden_size (int): The hidden size of model. 768 by default. num_attention_heads (int): The number of attention heads. 12 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 12. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0 by default. drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. @@ -481,8 +41,16 @@ class PackedFlashLlamaLayer1D(nn.Module): layer_idx (int): The index of current layer. 0 by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. + apply_post_layer_norm (bool): Whether to apply layer normalization after the attention and mlp. + Defaults to False. + fused_dropout_add_ln (bool): Whether to fuse dropout, residual addition, and layer normalization. + Defaults to True. + no_bias (bool): Whether to exclude bias in attention and feed-forward networks. Defaults to False. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - use_flash_attn (bool): Whether use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu @@ -490,8 +58,8 @@ class PackedFlashLlamaLayer1D(nn.Module): ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, init_type (str): Initialization type. Use uniform or normal. "normal" by default, rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( @@ -512,18 +80,16 @@ def __init__( fused_dropout_add_ln: bool = True, no_bias: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = False, + qk_interleaved: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, ffn_uplayer_init_std: float = 0.02, ffn_other_init_std: float = 0.02, init_type: str = "normal", rope_base: int = 10000, - tp_mode: str = "mtp", mlp_layer_fusion: bool = False, multiple_of: int = 256, ): @@ -532,7 +98,6 @@ def __init__( # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx - self.use_flash_attn = use_flash_attn self.prenorm = not apply_post_layer_norm assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" self.fused_dropout_add_ln = fused_dropout_add_ln @@ -542,52 +107,43 @@ def __init__( self.ffn_other_init_std = ffn_other_init_std head_dim = hidden_size // num_attention_heads - self.tp_mode = tp_mode - parallel_mode = ParallelMode.WEIGHT if self.tp_mode == "isp" else ParallelMode.TENSOR - self.attention = MHA( + self.attention = GQA( embed_dim=hidden_size, num_heads=num_attention_heads, num_kv_heads=num_kv_attention_heads, - process_group=gpc.get_group(parallel_mode), - sequence_process_group=gpc.get_group(ParallelMode.TENSOR), dropout=attn_drop_rate, softmax_scale=1 / math.sqrt(head_dim), causal=True, layer_idx=layer_idx, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, - use_flash_attn=use_flash_attn, device=device, dtype=dtype, - rot_embed_HF_impl=adapt_hf, + qk_interleaved=qk_interleaved, bias=not no_bias, rope_base=rope_base, - tp_mode=self.tp_mode, + enable_qkv_fusion=False, ) self.dropout1 = nn.Dropout(drop_rate) - if norm_type == "rmsnorm": - self.attention_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.dropout2 = nn.Dropout(drop_rate) + self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.feed_forward = get_mlp_cls(self.tp_mode)( + self.feed_forward = new_feed_forward( hidden_size, int(hidden_size * mlp_ratio), out_features=hidden_size, - process_group=gpc.get_group(parallel_mode), bias=False, device=device, dtype=dtype, mlp_layer_fusion=mlp_layer_fusion, - sequence_parallel=gpc.config.parallel.get("sequence_parallel", False), multiple_of=multiple_of, + # TODO: to support more activation functions + activation_type="swiglu" if use_swiglu else "swiglu", ) - self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm @@ -631,19 +187,13 @@ def reset_parameters(self): param.data ) - def forward( - self, hidden_states, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None - ): + def forward(self, hidden_states, residual=None, **kwargs): if self.checkpoint and self.training: - return activation_checkpoint( - self._forward, False, hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen - ) + return activation_checkpoint(self._forward, False, hidden_states, residual, **kwargs) else: - return self._forward(hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen) + return self._forward(hidden_states, residual, **kwargs) - def _forward( - self, hidden_states=None, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None - ): + def _forward(self, hidden_states=None, residual=None, **kwargs): r"""Pass the input through the encoder layer. Args: @@ -668,13 +218,8 @@ def _dropout_and_norm_attn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } - hidden_states = self.attention(hidden_states, **mixer_kwargs) + + hidden_states = self.attention(hidden_states, **kwargs) if not isinstance(self.feed_forward, nn.Identity): if not self.fused_dropout_add_ln: @@ -700,13 +245,8 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states + residual else: assert residual is None - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } - mixer_out = self.attention(hidden_states, **mixer_kwargs) + + mixer_out = self.attention(hidden_states, **kwargs) if self.return_residual: # mixer out is actually a pair here mixer_out, hidden_states = mixer_out hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( @@ -722,34 +262,38 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class PackedFlashLlama1D(nn.Module): +class Llama2(nn.Module): """ - 1D Packed Flash Llama. + Llama2 Model. Args: num_layers (int): The number of layer. 12 by default. hidden_size (int): The size of hidden state. 768 by default. num_attention_heads (int): The number of attention head. 12 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 12. vocab_size (int): The size of vocabulary. 50304 by default. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. dtype (torch.dtype): The type of data. torch.float by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number - of layers. 1.0 by default. + checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 0.0 by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. device (Optional[Union[str, torch.device]]): The device will be used. None by default. + apply_post_layer_norm (bool): Whether to apply layer normalization after the attention and mlp. + Defaults to False. + no_bias (bool): Whether to exclude bias in attention and feed-forward networks. Defaults to False. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. embedding_init_std (float): std used to init embedding weight. 0.02 by default, attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, @@ -759,25 +303,25 @@ class PackedFlashLlama1D(nn.Module): out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, init_type (str): Initialization type. Use uniform or normal. "normal" by default, rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, + num_layers: int = 48, + hidden_size: int = 2048, + num_attention_heads: int = 32, + num_kv_attention_heads: int = 32, vocab_size: int = 50304, - mlp_ratio: int = 4, + mlp_ratio: float = 4.0, attn_drop_rate: float = 0.0, drop_rate: float = 0.0, dtype: torch.dtype = torch.float, - checkpoint: bool = False, - checkpoint_fraction: float = 1.0, + checkpoint: float = 0.0, layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, - embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, @@ -786,12 +330,11 @@ def __init__( no_bias=False, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = False, + qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, embedding_init_std: float = 0.02, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, @@ -805,35 +348,22 @@ def __init__( ): super().__init__() - self.use_flash_attn = use_flash_attn - if checkpoint_fraction <= 0: - checkpoint = False - if not checkpoint: - checkpoint_fraction = 0 - checkpoint_layer_num = num_layers * checkpoint_fraction - self.tp_mode = "mtp" - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") - - if is_reward: - head_cls = RewardModelLinear - else: - head_cls = ScaleColumnParallelLinear + checkpoint_layer_num = int(num_layers * checkpoint) + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output if first: - self.tok_embeddings = Embedding1D( - num_embeddings=vocab_size, embedding_dim=hidden_size, embed_split_hidden=embed_split_hidden - ) + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.tok_embeddings.named_parameters(): if init_type == "normal": normal_(std=embedding_init_std)(param) else: uniform_(std=embedding_init_std)(param) - self.embed_grad_scale = embed_grad_scale self.layers = nn.ModuleList( [ - PackedFlashLlamaLayer1D( + Llama2Decoder( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_kv_attention_heads=num_kv_attention_heads, @@ -853,15 +383,13 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - adapt_hf=adapt_hf, + qk_interleaved=qk_interleaved, attn_wqkv_init_std=attn_wqkv_init_std, attn_other_init_std=attn_other_init_std, ffn_uplayer_init_std=ffn_uplayer_init_std, ffn_other_init_std=ffn_other_init_std, init_type=init_type, rope_base=rope_base, - tp_mode=self.tp_mode, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, ) @@ -871,18 +399,16 @@ def __init__( if last: if not apply_post_layer_norm: - if norm_type == "rmsnorm": - self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.output = head_cls( + self.output = new_linear( + name="output", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, + is_reward=is_reward, weight_scale=embed_grad_scale, ) set_output_attr_to_module(self.output) @@ -892,9 +418,7 @@ def __init__( else: uniform_(std=out_head_init_std)(param) - self.parallel_output = parallel_output - - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward(self, hidden_states=None, input_ids=None, **kwargs): # attention_mask: compute attention on the places where the value is 1 if hasattr(self, "tok_embeddings") and input_ids is not None: hidden_states = self.tok_embeddings(input_ids) @@ -902,203 +426,14 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) - if isinstance(cu_seqlens, list): - assert len(cu_seqlens) == 1 - cu_seqlens = cu_seqlens[0].to(hidden_states.device) - - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.squeeze(0) - - if indexes is not None: - assert len(indexes) == 1 - # The indexes are used to indicate the actual position IDs of each token in the packed input. - indexes = indexes[0] - # if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp": - indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None for _, block in enumerate(self.layers): - hidden_states = block( - hidden_states, - residual=None, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=inference_params, - max_seqlen=max_seqlen, - ) + hidden_states = block(hidden_states, residual=None, **kwargs) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "output"): - hidden_states = self.output(hidden_states, gather_dim=1, tp_mode=self.tp_mode) - - if not self.parallel_output and gpc.is_pipeline_last_stage(): - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + hidden_states = self.output(hidden_states) return hidden_states - - -def _build_generic_model_1d(num_layers, num_chunks, **kwargs): - """ - build generic model 1d - - Args: - num_layers (int): The number of layer. - num_chunks (int): The number of partitions in pipeline parallel. - device (Optional[Union[str, torch.device]]): The device will be used. internlm_accelerator.device() by default. - - """ - device = get_current_device() - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - logger.info(f"The layer sharding is {all_parts}.") - - models = [] - kwargs["checkpoint_fraction"] = float(kwargs.get("checkpoint", False)) - start_idx, end_idx = 0, 0 - for start, end in parts: - start_idx, end_idx = start, end - kwargs["num_layers"] = end - start - kwargs["first"] = start == 0 - # If there is no content in the final layer, assign the last layer. - kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 - kwargs["device"] = device - kwargs["start_layer_idx"] = start - chunk = PackedFlashLlama1D(**filter_kwargs(PackedFlashLlama1D.__init__, kwargs)).to(device) - - models.append(chunk) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - setattr(model, "first_layer", start_idx) - setattr(model, "last_layer", end_idx) - return model - - -@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) -def build_model_with_cfg( - num_chunks=1, - checkpoint=False, - dtype=torch.float, - embed_split_hidden=False, - num_layers=48, - hidden_size=2048, - vocab_size=50304, - embed_grad_scale=1, - parallel_output=True, - num_attention_heads=32, - num_kv_attention_heads=None, - mlp_ratio=4.0, - residual_in_fp32=False, - norm_type="rmsnorm", - adapt_hf=False, - drop_rate=0, - attn_drop_rate=0, - apply_post_layer_norm=False, # pylint: disable=W0613 - no_bias=False, - deepnorm=False, - layer_norm_epsilon=1e-5, - is_reward=False, - dropout_selective_checkpoint=True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - use_flash_attn: bool = True, - embedding_init_std: float = 0.02, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - out_head_init_std: float = 0.02, - init_type: str = "normal", - rope_base: int = 10000, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, -): - """ - Builde model with config - - Args: - num_chunks (int): The number of partitions in pipeline parallel. 1 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. - dtype (torch.dtype): The type of data. torch.float by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - False by default. - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - num_attention_heads (int): The number of attention head. 32 by default. - mlp_ratio (int): The ratio of MLP layers. 4.0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily - because this parameter requires inconsistent data types to be passed between pipelines, - which requires significant modifications to internlm. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - drop_rate (float): The dropout rate of input hidden state. 0 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - is_reward (bool): Whether to use reward model. False by default. - dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. - use_scaled_init (bool): Whether to use scaled init. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - embedding_init_std (float): std used to init embedding weight. 0.02 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - """ - if deepnorm: - raise AssertionError("deepnorm will not be supported in future versions." "Use early versions if necessary.") - - cfg = dict( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden, - vocab_size=vocab_size, - embed_grad_scale=embed_grad_scale, - parallel_output=parallel_output, - mlp_ratio=mlp_ratio, - apply_post_layer_norm=apply_post_layer_norm, - no_bias=no_bias, - residual_in_fp32=residual_in_fp32, - norm_type=norm_type, - adapt_hf=adapt_hf, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - layer_norm_epsilon=layer_norm_epsilon, - is_reward=is_reward, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - embedding_init_std=embedding_init_std, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - out_head_init_std=out_head_init_std, - init_type=init_type, - rope_base=rope_base, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - ) - - return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modeling_llava.py b/internlm/model/modeling_llava.py index 37246e14..57b68859 100644 --- a/internlm/model/modeling_llava.py +++ b/internlm/model/modeling_llava.py @@ -7,53 +7,42 @@ from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import set_output_attr_to_module from internlm.initialize.initialize_tensor import normal_, uniform_ -from internlm.model.modeling_llama import PackedFlashLlamaLayer1D +from internlm.model.llava.clip_builder import build_vision_tower +from internlm.model.llava.projector_builder import build_vision_projector +from internlm.model.modeling_llama import Llama2Decoder from internlm.model.modules.embedding import Embedding1D -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm -from internlm.model.ops.linear import RewardModelLinear, ScaleColumnParallelLinear -from internlm.model.utils import ( - gather_forward_split_backward, - split_forward_gather_backward, -) -from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs +from internlm.model.modules.linear import new_linear +from internlm.model.modules.norm import new_layer_norm from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER - -MODEL_TYPE = "LLAVA" logger = get_logger(__file__) -RMSNorm = try_import_RMSNorm() -class PackedFlashLlava1D(nn.Module): +class Llava(nn.Module): """ 1D Packed Flash Llava. Args: - num_layers (int): The number of layer. 12 by default. - hidden_size (int): The size of hidden state. 768 by default. - num_attention_heads (int): The number of attention head. 12 by default. + num_layers (int): The number of layer. 48 by default. + hidden_size (int): The size of hidden state. 2048 by default. + num_attention_heads (int): The number of attention head. 32 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 32. vocab_size (int): The size of vocabulary. 50304 by default. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. dtype (torch.dtype): The type of data. torch.float by default. checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number - of layers. 1.0 by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. device (Optional[Union[str, torch.device]]): The device will be used. None by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. embedding_init_std (float): std used to init embedding weight. 0.02 by default, attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, @@ -70,21 +59,19 @@ class PackedFlashLlava1D(nn.Module): def __init__( self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, + num_layers: int = 48, + hidden_size: int = 2048, + num_attention_heads: int = 32, + num_kv_attention_heads: int = 32, vocab_size: int = 50304, mlp_ratio: int = 4, attn_drop_rate: float = 0.0, drop_rate: float = 0.0, dtype: torch.dtype = torch.float, checkpoint: bool = False, - checkpoint_fraction: float = 1.0, layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, - embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, @@ -93,12 +80,11 @@ def __init__( no_bias=False, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = False, + qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, embedding_init_std: float = 0.02, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, @@ -115,39 +101,25 @@ def __init__( ): super().__init__() - self.use_flash_attn = use_flash_attn - if checkpoint_fraction <= 0: - checkpoint = False - if not checkpoint: - checkpoint_fraction = 0 - checkpoint_layer_num = num_layers * checkpoint_fraction - self.tp_mode = "mtp" + checkpoint_layer_num = num_layers * checkpoint + self.dtype = dtype self.image_token_id = image_token_id - - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") - - if is_reward: - head_cls = RewardModelLinear - else: - head_cls = ScaleColumnParallelLinear + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output if first: - self.tok_embeddings = Embedding1D( - num_embeddings=vocab_size, embedding_dim=hidden_size, embed_split_hidden=embed_split_hidden - ) + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) for _, param in self.tok_embeddings.named_parameters(): if init_type == "normal": normal_(std=embedding_init_std)(param) else: uniform_(std=embedding_init_std)(param) - self.embed_grad_scale = embed_grad_scale self.layers = nn.ModuleList( [ - PackedFlashLlamaLayer1D( + Llama2Decoder( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_kv_attention_heads=num_kv_attention_heads, @@ -167,15 +139,13 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - adapt_hf=adapt_hf, + qk_interleaved=qk_interleaved, attn_wqkv_init_std=attn_wqkv_init_std, attn_other_init_std=attn_other_init_std, ffn_uplayer_init_std=ffn_uplayer_init_std, ffn_other_init_std=ffn_other_init_std, init_type=init_type, rope_base=rope_base, - tp_mode=self.tp_mode, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, ) @@ -185,18 +155,16 @@ def __init__( if last: if not apply_post_layer_norm: - if norm_type == "rmsnorm": - self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.output = head_cls( + self.output = new_linear( + name="output", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, + is_reward=is_reward, weight_scale=embed_grad_scale, ) set_output_attr_to_module(self.output) @@ -206,57 +174,42 @@ def __init__( else: uniform_(std=out_head_init_std)(param) - self.parallel_output = parallel_output - assert vit_cfg is not None if first: - from internlm.model.llava_modules.clip_builder import build_vision_tower - + assert vit_cfg is not None self.vit = build_vision_tower(vit_cfg) self.vit.requires_grad_(False) - assert vision_proj_cfg is not None - if first: - from internlm.model.llava_modules.projector_builder import ( - build_vision_projector, - ) - + assert vision_proj_cfg is not None self.vision_proj = build_vision_projector(vision_proj_cfg) # self.vision_proj.requires_grad_(False) - def forward( # pylint: disable=W0102 - self, - hidden_states=None, - images=[], - cu_seqlens=None, - input_ids=None, - indexes=None, - inference_params=None, - ): + def forward(self, hidden_states=None, images=None, input_ids=None, **kwargs): xs = [] pure_text = False - input_ids = input_ids.clone() - assert hasattr(self, "vit") - assert hasattr(self, "vision_proj") - if len(images) == 1 and len(images[0]) == 0: # make sure grad in Qformer for update - images = [torch.rand(1, 3, self.vit.image_size, self.vit.image_size).cuda().to(self.dtype)] - pure_text = True - - for image in images: - assert len(image) > 0 - if len(image) == 0: - x = [] - else: - assert not isinstance(image, list), image - x = image.to(torch.cuda.current_device()).to(self.dtype) - x = self.vit(x) - x = self.vision_proj(x) - xs.append(x) + images = [] if images is None else images + + if hasattr(self, "vit") and hasattr(self, "vision_proj") and hasattr(self, "tok_embeddings"): + # vit + if len(images) == 1 and len(images[0]) == 0: # make sure grad in Qformer for update + images = [torch.rand(1, 3, self.vit.image_size, self.vit.image_size).cuda().to(self.dtype)] + pure_text = True + + for image in images: + assert len(image) > 0 + if len(image) == 0: + x = [] + else: + assert not isinstance(image, list), image + x = image.to(torch.cuda.current_device()).to(self.dtype) + x = self.vit(x) + x = self.vision_proj(x) + xs.append(x) - # attention_mask: compute attention on the places where the value is 1 - if hasattr(self, "tok_embeddings") and input_ids is not None: + # tok embeddings org_ids = input_ids.clone() input_ids[input_ids == self.image_token_id] = 0 hidden_states = self.tok_embeddings(input_ids).clone() + if pure_text and len(xs) > 0: hidden_states = hidden_states + 0 * xs[0].sum() else: @@ -269,208 +222,14 @@ def forward( # pylint: disable=W0102 hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) - if isinstance(cu_seqlens, list): - assert len(cu_seqlens) == 1 - cu_seqlens = cu_seqlens[0].to(hidden_states.device) - - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.squeeze(0) - - if indexes is not None: - assert len(indexes) == 1 - # The indexes are used to indicate the actual position IDs of each token in the packed input. - indexes = indexes[0] - # if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp": - indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None for _, block in enumerate(self.layers): - hidden_states = block( - hidden_states, - residual=None, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=inference_params, - max_seqlen=max_seqlen, - ) + hidden_states = block(hidden_states, residual=None, **kwargs) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "output"): - hidden_states = self.output(hidden_states, gather_dim=1, tp_mode=self.tp_mode) - - if not self.parallel_output: - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + hidden_states = self.output(hidden_states) return hidden_states - - -def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): - """ - build generic model 1d - - Args: - num_layers (int): The number of layer. - num_chunks (int): The number of partitions in pipeline parallel. - device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default. - - """ - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - logger.info(f"The layer sharding is {all_parts}.") - - models = [] - kwargs["checkpoint_fraction"] = float(kwargs.get("checkpoint", False)) - start_idx, end_idx = 0, 0 - for start, end in parts: - start_idx, end_idx = start, end - kwargs["num_layers"] = end - start - kwargs["first"] = start == 0 - # If there is no content in the final layer, assign the last layer. - kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 - kwargs["device"] = device - kwargs["start_layer_idx"] = start - chunk = PackedFlashLlava1D(**filter_kwargs(PackedFlashLlava1D.__init__, kwargs)).to(device) - - models.append(chunk) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - setattr(model, "first_layer", start_idx) - setattr(model, "last_layer", end_idx) - return model - - -@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) -def build_model_with_cfg( - num_chunks=1, - checkpoint=False, - dtype=torch.float, - embed_split_hidden=False, - num_layers=48, - hidden_size=2048, - vocab_size=50304, - embed_grad_scale=1, - parallel_output=True, - num_attention_heads=32, - num_kv_attention_heads=None, - mlp_ratio=4.0, - residual_in_fp32=False, - norm_type="rmsnorm", - adapt_hf=False, - drop_rate=0, - attn_drop_rate=0, - apply_post_layer_norm=False, # pylint: disable=W0613 - no_bias=False, - deepnorm=False, - layer_norm_epsilon=1e-5, - is_reward=False, - dropout_selective_checkpoint=True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - use_flash_attn: bool = True, - embedding_init_std: float = 0.02, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - out_head_init_std: float = 0.02, - init_type: str = "normal", - rope_base: int = 10000, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - image_token_id: int = 200000, - vit_cfg=None, - vision_proj_cfg=None, -): - """ - Builde model with config - - Args: - num_chunks (int): The number of partitions in pipeline parallel. 1 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. - dtype (torch.dtype): The type of data. torch.float by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - False by default. - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - num_attention_heads (int): The number of attention head. 32 by default. - mlp_ratio (int): The ratio of MLP layers. 4.0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily - because this parameter requires inconsistent data types to be passed between pipelines, - which requires significant modifications to internlm. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - drop_rate (float): The dropout rate of input hidden state. 0 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - is_reward (bool): Whether to use reward model. False by default. - dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. - use_scaled_init (bool): Whether to use scaled init. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - embedding_init_std (float): std used to init embedding weight. 0.02 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - """ - if deepnorm: - raise AssertionError("deepnorm will not be supported in future versions." "Use early versions if necessary.") - - cfg = dict( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden, - vocab_size=vocab_size, - embed_grad_scale=embed_grad_scale, - parallel_output=parallel_output, - mlp_ratio=mlp_ratio, - apply_post_layer_norm=apply_post_layer_norm, - no_bias=no_bias, - residual_in_fp32=residual_in_fp32, - norm_type=norm_type, - adapt_hf=adapt_hf, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - layer_norm_epsilon=layer_norm_epsilon, - is_reward=is_reward, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - embedding_init_std=embedding_init_std, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - out_head_init_std=out_head_init_std, - init_type=init_type, - rope_base=rope_base, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - image_token_id=image_token_id, - vit_cfg=vit_cfg, - vision_proj_cfg=vision_proj_cfg, - ) - - return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index b1e3ed8b..d15d378f 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -12,30 +12,24 @@ from internlm.core.naive_amp import set_fp32_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.mlp import get_mlp_cls -from internlm.model.modules.multi_head_attention import MHA -from internlm.model.moe import MoE -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm -from internlm.model.ops.linear import RewardModelLinear, ScaleColumnParallelLinear +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import MHA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm +from internlm.model.moe.moe import MoE from internlm.model.utils import ( - gather_forward_split_backward, - split_forward_gather_backward, + internlm1_mha_pre_load_convert, + internlm1_mha_save_convert, ) from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs, get_current_device from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER - -MODEL_TYPE = "INTERNLM_MoE" logger = get_logger(__file__) -RMSNorm = try_import_RMSNorm() -class PackedFlashBaseLayer1D(nn.Module): +class Internlm1MoEDecoder(nn.Module): """ - 1D Packed Flash Base Layer. + InternLM1 MoE Decoder Layer. Args: hidden_size (int): The hidden size of model. 768 by default. @@ -43,18 +37,22 @@ class PackedFlashBaseLayer1D(nn.Module): mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0 by default. drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. dtype (torch.dtype): Type of data. torch.float by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. layer_idx (int): The index of current layer. 0 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - use_flash_attn (bool): Whether use flash-attn. True by default. - num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE - (https://arxiv.org/abs/2201.05596) layer. - moe_type (str): determine which moe impl will be used, default is GShardMoE + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( @@ -73,12 +71,11 @@ def __init__( residual_in_fp32: bool = False, device: Optional[torch.device] = None, norm_type: str = "rmsnorm", + qk_interleaved: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, num_experts: int = 1, - tp_mode: str = "mtp", mlp_layer_fusion: bool = False, multiple_of: int = 256, ): @@ -87,16 +84,12 @@ def __init__( # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx - self.use_flash_attn = use_flash_attn head_dim = hidden_size // num_attention_heads - self.tp_mode = tp_mode - parallel_mode = ParallelMode.WEIGHT if self.tp_mode == "isp" else ParallelMode.TENSOR + self.mixer = MHA( embed_dim=hidden_size, num_heads=num_attention_heads, - process_group=gpc.get_group(parallel_mode), - sequence_process_group=gpc.get_group(ParallelMode.TENSOR), dropout=attn_drop_rate, max_position_embeddings=max_position_embeddings, softmax_scale=1 / math.sqrt(head_dim), @@ -105,54 +98,51 @@ def __init__( use_dynamic_ntk_rope=use_dynamic_ntk_rope, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, - use_flash_attn=use_flash_attn, device=device, dtype=dtype, - tp_mode=self.tp_mode, + qk_interleaved=qk_interleaved, ) + # Compatible with the name of internlm1 Wqkv linear layer + self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert, internlm1_mha_save_convert) + self.dropout1 = nn.Dropout(drop_rate) - if norm_type == "rmsnorm": - self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon) - self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.dropout2 = nn.Dropout(drop_rate) + + self.norm1 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.norm2 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) self.num_experts = num_experts ep_size = gpc.get_world_size(ParallelMode.EXPERT) if num_experts <= 1: # dense, not MoE - if use_swiglu: - mlp_cls = get_mlp_cls(self.tp_mode) - self.mlp = mlp_cls( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - process_group=gpc.get_group(parallel_mode), - bias=False, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - sequence_parallel=gpc.config.parallel.sequence_parallel, - multiple_of=multiple_of, - ) + self.mlp = new_feed_forward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + bias=False, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + # TODO: to support more activation functions + activation_type="swiglu" if use_swiglu else "swiglu", + ) else: # replace mlp by MoE module. The expert in MoE is a FeedForward module. - mlp_cls = get_mlp_cls(self.tp_mode) + # mlp_cls = get_mlp_cls(self.tp_mode) self.mlp = MoE( hidden_size, int(hidden_size * mlp_ratio), out_features=hidden_size, num_experts=num_experts, - ep_cls=mlp_cls, ep_group=gpc.get_group(ParallelMode.EXPERT), ep_size=ep_size, device=device, dtype=dtype, ) + # TODO: remove from model package. set_fp32_attr_to_module(self.mlp.moe_layer.gate) - self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm @@ -164,7 +154,7 @@ def reset_parameters(self): for name, param in self.mixer.named_parameters(): if param.ndim == 1: param.data.zero_() - elif "Wqkv" in name: + elif "wqkv" in name: normal_(std=0.006)(param.data) elif self.use_scaled_init: scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) @@ -186,15 +176,14 @@ def reset_parameters(self): else: normal_(std=0.006 if "fc1" in name else 0.0015)(param.data) - def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): + def forward(self, hidden_states, **kwargs): if self.checkpoint and self.training: - return activation_checkpoint( - self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen - ) # TODO: check whether this will be affected by moe + # TODO: check whether this will be affected by moe + return activation_checkpoint(self._forward, False, hidden_states, **kwargs) else: - return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen) + return self._forward(hidden_states, **kwargs) - def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): + def _forward(self, hidden_states=None, **kwargs): r"""Pass the input through the encoder layer. Args: @@ -203,12 +192,6 @@ def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_ cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } def _dropout_and_norm_attn(_hidden_states): _dropped = self.dropout1(_hidden_states) @@ -224,7 +207,7 @@ def _dropout_and_norm_attn(_hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mixer(hidden_states, **mixer_kwargs) + hidden_states = self.mixer(hidden_states, **kwargs) def _dropout_and_norm_ffn(_residual, _hidden_states): _dropped = self.dropout2(_hidden_states) @@ -241,18 +224,18 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): residual = residual.to(torch.float32) # MLP. - moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) if self.num_experts <= 1: # dense mlp output hidden_states = self.mlp(hidden_states) + moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) else: # MoE output hidden_states, moe_loss, _ = self.mlp(hidden_states) return hidden_states + residual, moe_loss -class PackedFlashInternLm1D(nn.Module): +class Internlm1MoE(nn.Module): """ - 1D Packed Flash InternLm. + InternLM1 MoE. Args: num_layers (int): The number of layer. 12 by default. @@ -262,34 +245,39 @@ class PackedFlashInternLm1D(nn.Module): mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. dtype (torch.dtype): The type of data. torch.float by default. checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number of layers. 0.0 by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. device (Optional[Union[str, torch.device]]): The device will be used. None by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. moe_type (str): determine which moe impl will be used, default is GShardMoE + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, + num_layers: int = 48, + hidden_size: int = 2048, + num_attention_heads: int = 32, vocab_size: int = 50304, - mlp_ratio: int = 4.0, + mlp_ratio: float = 4.0, attn_drop_rate: float = 0.0, drop_rate: float = 0.0, max_position_embeddings: int = 2048, @@ -298,7 +286,6 @@ def __init__( layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, - embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, @@ -306,37 +293,30 @@ def __init__( device: Optional[torch.device] = None, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", + qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, num_experts: bool = 1, + moe_use_residual: bool = False, # pylint: disable=W0613 + moe_type: str = None, # pylint: disable=W0613 mlp_layer_fusion: bool = False, multiple_of: int = 256, ): super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) - self.tp_mode = "mtp" - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") - - if is_reward: - head_cls = RewardModelLinear - else: - head_cls = ScaleColumnParallelLinear if first: - self.embedding = Embedding1D( - num_embeddings=vocab_size, embedding_dim=hidden_size, embed_split_hidden=embed_split_hidden - ) + self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.embedding.named_parameters(): normal_(std=0.0052)(param) self.embed_grad_scale = embed_grad_scale self.blocks = nn.ModuleList( [ - PackedFlashBaseLayer1D( + Internlm1MoEDecoder( hidden_size=hidden_size, num_attention_heads=num_attention_heads, mlp_ratio=mlp_ratio, @@ -354,9 +334,8 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, + qk_interleaved=qk_interleaved, num_experts=num_experts, - tp_mode=self.tp_mode, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, ) @@ -364,17 +343,15 @@ def __init__( ] ) if last: - if norm_type == "rmsnorm": - self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.head = head_cls( + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.head = new_linear( + name="head", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, + is_reward=is_reward, weight_scale=embed_grad_scale, ) for _, param in self.head.named_parameters(): @@ -382,7 +359,7 @@ def __init__( self.parallel_output = parallel_output - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward(self, hidden_states=None, input_ids=None, **kwargs): # attention_mask: compute attention on the places where the value is 1 # old condition may fail when use shared embedding if gpc.is_pipeline_first_stage() and input_ids is not None: @@ -391,176 +368,15 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) - if isinstance(cu_seqlens, list): - assert len(cu_seqlens) == 1 - cu_seqlens = cu_seqlens[0].to(hidden_states.device) - - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.squeeze(0) - - if indexes is not None: - assert len(indexes) == 1 - # The indexes are used to indicate the actual position IDs of each token in the packed input. - indexes = indexes[0] - # if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp": - indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None moe_losses = [] for _, block in enumerate(self.blocks): - hidden_states, mos_loss = block( - hidden_states, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=inference_params, - max_seqlen=max_seqlen, - ) + hidden_states, mos_loss = block(hidden_states, **kwargs) moe_losses.append(mos_loss) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "head"): - hidden_states = self.head(hidden_states, gather_dim=1, tp_mode=self.tp_mode) + hidden_states = self.head(hidden_states) - if not self.parallel_output and gpc.is_pipeline_last_stage(): - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) return hidden_states, moe_losses - - -def _build_generic_model_1d(num_layers, num_chunks, **kwargs): - """ - build generic model 1d - - Args: - num_layers (int): The number of layer. - num_chunks (int): The number of partitions in pipeline parallel. - device (Optional[Union[str, torch.device]]): The device will be used. internlm_accelerator.device() by default. - - """ - device = get_current_device() - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - logger.info(f"The layer sharding is {all_parts}.") - - models = [] - - for start, end in parts: - kwargs["num_layers"] = end - start - kwargs["first"] = start == 0 - # If there is no content in the final layer, assign the last layer. - kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 - kwargs["device"] = device - kwargs["start_layer_idx"] = start - chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device) - - models.append(chunk) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - - return model - - -@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) -def build_model_with_moe_cfg( - num_chunks=1, - checkpoint=0.0, - dtype=torch.float, - embed_split_hidden=False, - num_layers=48, - hidden_size=2048, - vocab_size=50304, - embed_grad_scale=1, - parallel_output=True, - num_attention_heads=32, - max_position_embeddings=2048, - mlp_ratio=4.0, - residual_in_fp32=False, - use_dynamic_ntk_rope=False, - norm_type="rmsnorm", - drop_rate=0, - attn_drop_rate=0, - apply_post_layer_norm=False, # pylint: disable=W0613 - layer_norm_epsilon=1e-5, - is_reward=False, - dropout_selective_checkpoint=True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - use_flash_attn: bool = True, - num_experts: int = 1, - moe_use_residual: bool = False, # pylint: disable=W0613 - moe_type: str = None, # pylint: disable=W0613 - mlp_layer_fusion: bool = False, - multiple_of: int = 256, -): - """ - Build model with config. - - Args: - num_chunks (int): The number of partitions in pipeline parallel. 1 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. - dtype (torch.dtype): The type of data. torch.float by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - False by default. - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - num_attention_heads (int): The number of attention head. 32 by default. - mlp_ratio (int): The ratio of MLP layers. 4.0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily - because this parameter requires inconsistent data types to be passed between pipelines, - which requires significant modifications to internlm. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - drop_rate (float): The dropout rate of input hidden state. 0 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - is_reward (bool): Whether to use reward model. False by default. - dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. - use_scaled_init (bool): Whether to use scaled init. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE - (https://arxiv.org/abs/2201.05596) layer. - moe_type (str): determine which moe impl will be used, default is GShardMoE - """ - - cfg = dict( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden, - vocab_size=vocab_size, - embed_grad_scale=embed_grad_scale, - parallel_output=parallel_output, - mlp_ratio=mlp_ratio, - residual_in_fp32=residual_in_fp32, - max_position_embeddings=max_position_embeddings, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - norm_type=norm_type, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - layer_norm_epsilon=layer_norm_epsilon, - is_reward=is_reward, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - num_experts=num_experts, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - ) - - return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index 2dfea2e8..fa922daa 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -1,23 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Tuple +from typing import Optional, Union import torch import torch.nn.functional as F from einops import rearrange from torch import Tensor, nn -from internlm.accelerator import get_accelerator -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.ops.fusion_ops_import_helper import try_import_fused_rotary - -from ..utils import gather_forward_split_backward, split_forward_gather_backward - -internlm_accelerator = get_accelerator() - -apply_rotary_emb, apply_rotary_emb_qkv_, apply_rotary_func = None, None, None +from internlm.model.ops.rotary_emb import apply_rotary_emb class Embedding1D(nn.Module): @@ -31,8 +23,6 @@ class Embedding1D(nn.Module): therefore, the embedding vector at :attr:`padding_idx` is not updated during training, i.e. it remains as a fixed "pad". None by default. dtype (Optional[torch.dtype]): Data type None by default. - embed_split_hidden (Optional[Bool]): Whether to split the embed_dim in tensor parallel style. - """ def __init__( @@ -42,220 +32,21 @@ def __init__( *args, padding_idx: int = None, dtype: torch.dtype = None, - embed_split_hidden: bool = True, **kwargs, ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim - self.embed_split_hidden = embed_split_hidden - if self.embed_split_hidden: - self.embed_split_hidden = gpc.tensor_parallel_size > 1 - - split_nums = 1 if not self.embed_split_hidden else gpc.tensor_parallel_size - embed_dim_per_partition = embedding_dim // split_nums - self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs + embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype)) def forward(self, input_: Tensor) -> Tensor: - output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - if self.embed_split_hidden: - output = gather_forward_split_backward(output, ParallelMode.TENSOR, dim=-1) - - if gpc.config.parallel.sequence_parallel: - output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1) - - return output - - -def _torch_apply_rotary_func( - x1: torch.Tensor, - x2: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - out1: torch.Tensor, - out2: torch.Tensor, - conj: bool = False, -): - assert x1.device == x2.device == cos.device == sin.device, "All inputs must be on the same device" - assert x1.dtype == x2.dtype == cos.dtype == sin.dtype, "All inputs must have the same dtype" - assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes" - assert cos.size() == sin.size(), "Input cos and sin must have the same sizes" - - x1, x2, cos, sin = x1.float(), x2.float(), cos.float(), sin.float() - - if conj: - out1.copy_(x1 * cos + x2 * sin) - out2.copy_(-x1 * sin + x2 * cos) - else: - out1.copy_(x1 * cos - x2 * sin) - out2.copy_(x1 * sin + x2 * cos) - - return out1, out2 - - -class ApplyRotaryEmb(torch.autograd.Function): - """ - ApplyRotaryEmb - """ - - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - _, seqlen, _, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) - out = torch.empty_like(x) - out_ro = out[..., :rotary_dim] - o1, o2 = out_ro.chunk(2, dim=-1) if not interleaved else (out_ro[..., ::2], out_ro[..., 1::2]) - - apply_rotary_func( - x1, - x2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - o1, - o2, - False, - ) - - if rotary_dim < headdim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - return out - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, headdim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - do_ro = do[..., :rotary_dim] - do1, do2 = do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) - dx = torch.empty_like(do) - dx_ro = dx[..., :rotary_dim] - dx1, dx2 = dx_ro.chunk(2, dim=-1) if not ctx.interleaved else (dx_ro[..., ::2], dx_ro[..., 1::2]) - - apply_rotary_func( - do1, - do2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - dx1, - dx2, - True, - ) - if rotary_dim < headdim: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None, None - - -class ApplyRotaryEmbQKV_(torch.autograd.Function): - """ - ApplyRotaryEmbQKV_ - """ - - @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): - """ - qkv: (total, 3, nheads, headdim) / (batch_size, seqlen, 3, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of q and k. - """ - # len(qkv.shape) == 4 means the format of qkv is (total, 3, nheads, headdim) which is packed, - # otherwise the format of qkv is (batch_size, seqlen, 3, nheads, headdim) which is unpacked. - # We handle both packed qkv and unpacked qkv scenario in this class. - three = qkv.shape[1] if len(qkv.shape) == 4 else qkv.shape[2] - assert three == 3 - seqlen = None if len(qkv.shape) == 4 else qkv.shape[1] - rotary_seqlen, rotary_dim = cos.shape - if len(qkv.shape) != 4: - assert seqlen <= rotary_seqlen - headdim = qkv.shape[-1] - rotary_dim *= 2 - assert rotary_dim <= headdim - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - q_ro = qkv[:, 0, :, :rotary_dim] if len(qkv.shape) == 4 else qkv[:, :, 0, :, :rotary_dim] - q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2]) - re_cos = rearrange(cos, "s d -> s 1 d") if len(qkv.shape) == 4 else rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin, "s d -> s 1 d") if len(qkv.shape) == 4 else rearrange(sin[:seqlen], "s d -> s 1 d") - - apply_rotary_func(q1, q2, re_cos, re_sin, q1, q2, False) - - k_ro = qkv[:, 1, :, :rotary_dim] if len(qkv.shape) == 4 else qkv[:, :, 1, :, :rotary_dim] - k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) - re_cos_k = ( - rearrange(cos_k, "s d -> s 1 d") if len(qkv.shape) == 4 else rearrange(cos_k[:seqlen], "s d -> s 1 d") - ) - re_sin_k = ( - rearrange(sin_k, "s d -> s 1 d") if len(qkv.shape) == 4 else rearrange(sin_k[:seqlen], "s d -> s 1 d") - ) - - apply_rotary_func(k1, k2, re_cos_k, re_sin_k, k1, k2, False) - - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.interleaved = interleaved - return qkv - - @staticmethod - def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - seqlen = None if len(dqkv.shape) == 4 else dqkv.shape[1] - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, 0, :, :rotary_dim] if len(dqkv.shape) == 4 else dqkv[:, :, 0, :, :rotary_dim] - dq1, dq2 = dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2]) - re_cos = rearrange(cos, "s d -> s 1 d") if len(dqkv.shape) == 4 else rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin, "s d -> s 1 d") if len(dqkv.shape) == 4 else rearrange(sin[:seqlen], "s d -> s 1 d") - - apply_rotary_func(dq1, dq2, re_cos, re_sin, dq1, dq2, True) - - dk_ro = dqkv[:, 1, :, :rotary_dim] if len(dqkv.shape) == 4 else dqkv[:, :, 1, :, :rotary_dim] - dk1, dk2 = dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2]) - re_cos_k = ( - rearrange(cos_k, "s d -> s 1 d") if len(dqkv.shape) == 4 else rearrange(cos_k[:seqlen], "s d -> s 1 d") - ) - re_sin_k = ( - rearrange(sin_k, "s d -> s 1 d") if len(dqkv.shape) == 4 else rearrange(sin_k[:seqlen], "s d -> s 1 d") - ) - - apply_rotary_func(dk1, dk2, re_cos_k, re_sin_k, dk1, dk2, True) - - return dqkv, None, None, None, None, None - - -apply_rotary_emb, apply_rotary_emb_qkv_, apply_rotary_func = try_import_fused_rotary() -if apply_rotary_emb is None: - apply_rotary_emb = ApplyRotaryEmb.apply -if apply_rotary_emb_qkv_ is None: - apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply -if apply_rotary_func is None: - apply_rotary_func = _torch_apply_rotary_func + return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) class RotaryEmbedding(torch.nn.Module): @@ -296,12 +87,19 @@ def __init__(self, dim: int, base=10000, scale_base=0, device=None): self._cos_k_cached = None self._sin_k_cached = None - def _update_cos_sin_cache(self, x, indexes): - """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" - if not isinstance(indexes, int): - seqlen = indexes.max().item() + 1 + def _update_cos_sin_cache( + self, x: torch.Tensor, indexes: Union[int, torch.Tensor] = 0, max_seqlen: Optional[int] = None + ): + """x: (batch, seqlen, nheads, headdim)""" + if max_seqlen is not None: + seqlen = max_seqlen + elif isinstance(indexes, int): + seqlen = indexes + x.shape[1] + 1 else: - seqlen = indexes + 1 # eval_forward + # Note that this statement may cause synchronization between CPU and GPU, + # so it's best to precompute and pass in max_seqlen ahead of time + seqlen = indexes.max().item() + 1 + # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: @@ -324,54 +122,78 @@ def _update_cos_sin_cache(self, x, indexes): self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) - def forward(self, qkv: torch.Tensor, **kwargs): - if kwargs.get("indexes", None) is not None: - return self._forward(qkv, kwargs.pop("indexes")) - if kwargs.get("inference_params", None) is not None: - return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset) + def _get_slice(self, tensor: torch.Tensor, offsets: Union[int, torch.Tensor] = 0): + if isinstance(offsets, int): + return tensor[offsets:] else: - return self._eval_forward(qkv) + return tensor[offsets] - def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]: - self._update_cos_sin_cache(qkv, indexes) - if self.scale is None: - return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes]) - else: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached[indexes], - self._sin_cached[indexes], - self._cos_k_cached[indexes], - self._sin_k_cached[indexes], - ) - - def _eval_forward(self, qkv, seqlen_offset=0): + def _convert_padding( + self, x: torch.Tensor, empties: torch.Tensor, convert_type: str = "left2right", in_place: bool = False + ): + # TODO: impl in_place = True. + assert not in_place, "in_place = True is NYI." + assert convert_type in ("left2right", "right2left"), f"Unknown convert type {convert_type}" + + ret = x.clone() + + for i in range(len(empties)): + if empties[i] == 0: + continue + + if convert_type == "left2right": + ret[i][: -empties[i]] = x[i][empties[i] :] + ret[i][empties[i] :] = x[i][: -empties[i]] + else: # right2left + ret[i][empties[i] :] = x[i][: -empties[i]] + ret[i][: -empties[i]] = x[i][empties[i] :] + + return ret + + def forward( + self, + x: torch.Tensor, + offsets: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + cache_type: str = "query", + interleaved: bool = False, + in_place: bool = False, + left_padding_mask: Optional[torch.Tensor] = None, + ): """ - seqlen_offset: can be used in generation where the qkv being passed in is only the last - token in the batch. + Applies rotary position embeddings to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + offsets (Union[int, torch.Tensor], optional): The sequence offsets for the input. Defaults to 0. + max_seqlen (Optional[int], optional): The maximum sequence length for caching. Defaults to None. + cache_type (str, optional): Specifies whether the cache is for 'query' or 'key'. Defaults to "query". + interleaved (bool, optional): Whether the input tensor is interleaved. Defaults to False. + in_place (bool, optional): Whether the operation should be done in-place. Defaults to False. + left_padding_mask (Optional[torch.Tensor], optional): A mask for left padding. Defaults to None. + + Returns: + torch.Tensor: The tensor with applied rotary position embeddings. """ - self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1]) - if self.scale is None: - return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) - else: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - self._cos_k_cached[seqlen_offset:], - self._sin_k_cached[seqlen_offset:], - ) - - def _single_forward(self, x, indexes=0): - assert self.scale is None - self._update_cos_sin_cache(x, indexes) - ret = apply_rotary_emb(x, self._cos_cached[indexes], self._sin_cached[indexes]) - return ret + assert cache_type in ("query", "key"), f"Unknown cache type {cache_type}" + assert isinstance(offsets, (int, torch.Tensor)), f"Invalid offsets type {type(offsets)}" - def _single_eval_forward(self, x, seqlen_offset=0): - assert self.scale is None - self._update_cos_sin_cache(x, seqlen_offset + x.shape[1]) - return apply_rotary_emb(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) + if left_padding_mask is not None: + empties = left_padding_mask[..., -1].sum(dim=-1) + x = self._convert_padding(x, empties, convert_type="left2right", in_place=in_place) + + self._update_cos_sin_cache(x, offsets, max_seqlen) + + cos_cached = self._cos_k_cached if cache_type == "key" and self.scale is not None else self._cos_cached + sin_cached = self._sin_k_cached if cache_type == "key" and self.scale is not None else self._sin_cached + ret = apply_rotary_emb( + x, self._get_slice(cos_cached, offsets), self._get_slice(sin_cached, offsets), interleaved, in_place + ) + + if left_padding_mask is not None: + ret = self._convert_padding(ret, empties, convert_type="right2left", in_place=in_place) + + return ret class LinearRotaryEmbedding(RotaryEmbedding): @@ -390,11 +212,11 @@ def __init__( self.scaling_factor = scaling_factor def _update_cos_sin_cache(self, x, indexes): - """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" + """x: (batch, seqlen, nheads, headdim)""" if not isinstance(indexes, int): seqlen = indexes.max().item() + 1 else: - seqlen = indexes + 1 + seqlen = indexes + x.shape[1] + 1 t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype) t = t / self.scaling_factor @@ -457,11 +279,11 @@ def _update(self, seqlen, x): self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) def _update_cos_sin_cache(self, x, indexes): - """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" + """x: (batch, seqlen, nheads, headdim)""" if not isinstance(indexes, int): seqlen = indexes.max().item() + 1 else: - seqlen = indexes + 1 # eval_forward + seqlen = indexes + x.shape[1] + 1 # eval_forward if seqlen <= self.max_position_embeddings: # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) @@ -474,3 +296,22 @@ def _update_cos_sin_cache(self, x, indexes): self._update(seqlen, x) else: self._update(seqlen, x) + + +def new_rotary_embedding( + dim: int, + base=10000, + scale_base=0, + device=None, + max_position_embeddings=2048, + scaling_factor=1.0, + rotary_type: str = "native", +) -> RotaryEmbedding: + assert rotary_type in ("native", "linear_scale", "dynamic_ntk"), f"Unknown rotary type {rotary_type}" + + if rotary_type == "linear_scale": + return LinearRotaryEmbedding(dim, base, scale_base, device, max_position_embeddings, scaling_factor) + elif rotary_type == "dynamic_ntk": + return DynamicNTKScalingRotaryEmbedding(dim, base, scale_base, device, max_position_embeddings, scaling_factor) + else: # native + return RotaryEmbedding(dim, base, scale_base, device) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py new file mode 100644 index 00000000..44353970 --- /dev/null +++ b/internlm/model/modules/linear.py @@ -0,0 +1,605 @@ +""" +Linear Modules +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn + +from internlm.accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.parallel.shard import ( + get_head_parallel_mode, + get_parallel_strategies_split_mode, + get_tensor_split_parallel_mode, +) +from internlm.model.ops.linear import linear_backward_op, linear_forward_op +from internlm.utils.logger import get_logger + +if TYPE_CHECKING: + from internlm.core.parallel.comm.isp import WPCommunicator + from internlm.core.parallel.comm.tensor import TPCommunicator + +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + +custom_bwd = internlm_accelerator.return_custom_bwd() +custom_fwd = internlm_accelerator.return_custom_fwd() + + +# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py +class SPFusedDenseFunc(torch.autograd.Function): + "FusedDenseFunc for tensor parallel in flash-attn implementation." + + @staticmethod + @custom_fwd + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + communicator: TPCommunicator, + return_residual=False, + ): + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.communicator = communicator + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + + # parallel strategy-specific communication callback 1-1. + # see more details in the communicator for different parallel strategies. + # we want to kick off the all_gather early, before weight dtype conversion. + total_x, handle_x = communicator.input_hook(x, async_op=True) + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + weight = weight.contiguous() + + # wait for x has been gathered. + handle_x.wait() + + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight.shape) > 65535 * 32: + raise RuntimeError("fused_dense only supports matrix dims <= 2M") + + output = linear_forward_op(total_x, weight, bias) + + # parallel strategy-specific communication callback 2. + # see more details in the communicator for different parallel strategies. + output, _ = communicator.output_hook(output, async_op=False) + + saved_x = None if ctx.compute_weight_gradient is False else total_x if communicator.save_total_input() else x + ctx.save_for_backward(saved_x, weight) + + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + communicator: TPCommunicator = ctx.communicator + + # parallel strategy-specific communication callback 3. + # see more details in the communicator for different parallel strategies. + grad_output, _ = communicator.grad_output_hook(grad_output, async_op=False) + grad_output = grad_output.contiguous() + + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + + x, weight = ctx.saved_tensors + + # parallel strategy-specific communication callback 1-2. + # see more details in the communicator for different parallel strategies. + if ctx.needs_input_grad[1]: + x, handle_x = communicator.input_hook(x, async_op=True, is_forward=False) + + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = linear_forward_op(grad_output, weight.t()) + else: + grad_input = torch.addmm( + grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_output, + weight, + ) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + # parallel strategy-specific communication callback 4. + # see more details in the communicator for different parallel strategies. + grad_input, handle_grad_input = communicator.grad_input_hook(grad_input, async_op=True) + else: + grad_input = None + + # computes gradinets for weight and bias if necessary + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + + # wait for x has been gathered + handle_x.wait() + + x = x.reshape(batch_dim, x.shape[-1]) + grad_weight, grad_bias = linear_backward_op(x, grad_output, ctx.needs_input_grad[2]) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + + # wait for grad_input has been gathered + handle_grad_input.wait() + + return grad_input, grad_weight, grad_bias, None, None, None, None, None + + +# Q: Should we unify WPFusedDenseFunc and SPFusedDenseFunc, as well as the related communicator interface? +# A: Currently, WPFusedDenseFunc and SPFusedDenseFunc have significant differences in their computation logic +# and communication interfaces, so they should not be unified. +class WPFusedDenseFunc(torch.autograd.Function): + "FusedDenseFunc for weigth parallel, which is optimized based on flash implementation." + + @staticmethod + @custom_fwd + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + module: nn.Module, + communicator: WPCommunicator, + return_residual=False, + ): + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.module = module + ctx.communicator = communicator + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + + total_weight = communicator.weight_hook(weight, module=module) + total_bias = bias if bias is None else communicator.weight_hook(bias, module=module, is_bias=True) + + if torch.is_autocast_enabled(): + total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype()) + if total_bias: + total_bias.to(dtype=torch.get_autocast_gpu_dtype()) + + total_weight = total_weight.contiguous() + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *total_weight.shape) > 65535 * 32: + raise RuntimeError("fused_dense only supports matrix dims <= 2M") + + output = linear_forward_op(x, total_weight, total_bias) + + # release memory + del total_weight + del total_bias + + saved_x = None if ctx.compute_weight_gradient is False else x + ctx.save_for_backward(saved_x, weight) + + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + module: nn.Module = ctx.module + communicator: WPCommunicator = ctx.communicator + x, weight = ctx.saved_tensors + + grad_output = grad_output.contiguous() + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + + total_weight = communicator.weight_hook(weight, module=module) + + # compute weight grad + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + grad_weight, grad_bias = linear_backward_op( + x.reshape(batch_dim, x.shape[-1]), + grad_output, + ctx.needs_input_grad[2], + ) + + grad_weight, grad_weight_sync = communicator.grad_hook( + grad_weight, async_op=True, module=module, is_bias=False + ) + if grad_bias is not None: + grad_bias, grad_bias_sync = communicator.grad_hook( + grad_bias, async_op=True, module=module, is_bias=True + ) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = linear_forward_op(grad_output, total_weight.t()) + else: + grad_input = torch.addmm( + grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_output, + total_weight, + ) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + else: + grad_input = None + + del total_weight + + if ctx.needs_input_grad[1]: + grad_weight_sync.wait() + if grad_bias is not None: + grad_bias_sync.wait() + + return grad_input, grad_weight, grad_bias, None, None, None, None + + +def fused_dense_func( + x: torch.Tensor, + weight: torch.Tensor, + communicator: Union[TPCommunicator, WPCommunicator], + module: Optional[nn.Module] = None, + bias: Optional[torch.Tensor] = None, + return_residual: bool = False, +): + if communicator.communication_mode() == "wp": + return WPFusedDenseFunc.apply( + x, + weight, + bias, + module, + communicator, + return_residual, + ) + else: # mtp, msp, and fsp + return SPFusedDenseFunc.apply( + x, + weight, + bias, + communicator, + return_residual, + ) + + +class ParallelLinearWithCommExt(nn.Linear): + """ + Parallel linear with commuication extention. + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + split_mode (str): The split mode. It can be "none", "column", or "row". + """ + + # class level communicator variable. + _communicator = None + + @classmethod + def register_cls_communicator(cls, communicator): + cls._communicator = communicator + + def register_communicator(self, communicator): + """ + override the class default communicator for a parallel linear instance + """ + self._communicator = communicator + + def __init__( + self, + in_features: int, + out_features: int, + parallel_mode: ParallelMode, + bias: bool = True, + multiple_of: int = 1, + device: torch.device = None, + dtype: torch.dtype = None, + split_mode: str = "none", + ) -> None: + assert split_mode in ("none", "column", "row"), f"unknown split_mode {split_mode}" + + world_size = gpc.get_world_size(parallel_mode) + rank = gpc.get_local_rank(parallel_mode) + + if split_mode != "none": + split_features = out_features if split_mode == "column" else in_features + multiple = split_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(rank < mod) + + if split_mode == "column": + super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype) + elif split_mode == "row": + super().__init__(local_multiple * multiple_of, out_features, bias=bias, device=device, dtype=dtype) + else: + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + + def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622 + _class_name = self.__class__.__name__ + assert self._communicator is not None, f"{_class_name} should register with a communicator first." + + return fused_dense_func( + input, + self.weight, + communicator=self._communicator, + module=self, + bias=self.bias, + ) + + +class ColumnParallelLinear(ParallelLinearWithCommExt): + """ + ColumnParallelLinear + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + multiple_of: int = 1, + device: torch.device = None, + dtype: torch.dtype = None, + ) -> None: + if out_features % multiple_of: + raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") + + parallel_mode = get_tensor_split_parallel_mode() + super().__init__( + in_features, out_features, parallel_mode, bias=bias, device=device, dtype=dtype, split_mode="column" + ) + + +class RowParallelLinear(ParallelLinearWithCommExt): + """ + RowParallelLinear + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + multiple_of: int = 1, + device: torch.device = None, + dtype: torch.dtype = None, + ) -> None: + if in_features % multiple_of: + raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") + + parallel_mode = get_tensor_split_parallel_mode() + rank = gpc.get_local_rank(parallel_mode) + super().__init__( + in_features, + out_features, + parallel_mode, + bias=bias and rank == 0, + device=device, + dtype=dtype, + split_mode="row", + ) + + +class ScaleColumnParallelLinear(ParallelLinearWithCommExt): + """ + ScaleColumnParallelLinear. + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + weight_scale (int): For training stability. 1 by default. + norm_head (bool): Normalize the output embedding in order to let the calculation of logits not affected by + the norm of embedding. The implementation is referred to baichuan2, + see https://huggingface.co/baichuan-inc/Baichuan2-7B-Base for more information. False by default. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + weight_scale: int = 1, + norm_head: bool = False, + ) -> None: + if norm_head: + logger.info("Notice that norm head is enabled to normalize head weight.") + + parallel_mode = get_tensor_split_parallel_mode(is_head=True) + super().__init__( + in_features, out_features, parallel_mode, bias=bias, device=device, dtype=dtype, split_mode="column" + ) + + self.weight_scale = weight_scale + self.norm_head = norm_head + self.first_eval_flag = True + self.tmp_weight = None + + def forward(self, input): # pylint: disable=W0622 + _class_name = self.__class__.__name__ + assert self._communicator is not None, f"{_class_name} should register with a communicator first." + + if self.weight_scale == 1: + weight = self.weight + else: + weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() + + if self.norm_head: + if self.training: + if not self.first_eval_flag: + self.first_eval_flag = True + self.tmp_weight = None + # We normalized the output Embedding so that the dot product + # is not affected by the norm of embedding. Ref: https://arxiv.org/pdf/2309.10305.pdf + weight = nn.functional.normalize(weight) + else: + if self.first_eval_flag: + # cache l2 norm of head to accelerate infer. + self.first_eval_flag = False + self.tmp_weight = nn.functional.normalize(weight) + + weight = self.tmp_weight + + return fused_dense_func( + input, + self.weight, + communicator=self._communicator, + module=self, + bias=self.bias, + ) + + +class RewardModelLinear(ScaleColumnParallelLinear): + """ + RewardModelLinear. + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + weight_scale (int): For training stability. 1 by default. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + weight_scale: int = 1, + ) -> None: + super().__init__(in_features, out_features, bias, device, dtype, weight_scale) + + # broadcast parameters for reward model head layer. + parallel_mode = get_head_parallel_mode() + dist.broadcast(self.weight, gpc.get_ranks_in_group(parallel_mode)[0]) + if bias: + dist.broadcast(self.bias, gpc.get_ranks_in_group(parallel_mode)[0]) + + +def new_linear( + name: str, + in_features: int, + out_features: int, + bias: bool = True, + multiple_of=1, + device=None, + dtype=None, + is_reward: bool = False, + weight_scale: int = 1, + norm_head: bool = False, + **kwargs, +) -> nn.Linear: + + name = str.lower(name) + manual_select_class: Optional[str] = kwargs.get("manual_select_class", None) + + if manual_select_class is not None: + assert manual_select_class in ( + "head", + "column", + "row", + ), f"unknown manual selection {manual_select_class} for creating a linear." + + # use caller manual selection if it is provided. + split_mode = manual_select_class if manual_select_class is not None else get_parallel_strategies_split_mode(name) + + if split_mode == "head": + if is_reward: + return RewardModelLinear( + in_features, + out_features, + bias, + device, + dtype, + weight_scale, + ) + else: + return ScaleColumnParallelLinear( + in_features, + out_features, + bias, + device, + dtype, + weight_scale=weight_scale, + norm_head=norm_head, + ) + elif split_mode == "column": + return ColumnParallelLinear( + in_features, + out_features, + bias, + multiple_of, + device, + dtype, + ) + elif split_mode == "row": + return RowParallelLinear( + in_features, + out_features, + bias, + multiple_of, + device, + dtype, + ) + else: + err_msg = ( + f"Parallel strategies for linear is unsupported, which is named as {name}.\n" + + "Consider use manual_select_class parameter to select a linear class manually." + ) + + raise ValueError(err_msg) diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py new file mode 100644 index 00000000..e0669726 --- /dev/null +++ b/internlm/model/modules/mha.py @@ -0,0 +1,596 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import math +from typing import Callable, Dict, Optional + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F + +from internlm.model.modules.embedding import new_rotary_embedding +from internlm.model.modules.linear import new_linear +from internlm.model.modules.utils import update_kv_cache +from internlm.model.ops.attention import CrossAttention, SelfAttention +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +def _convert_cu_seqlens_for_qksplited(kwargs: Dict): + cu_seqlens = kwargs.pop("cu_seqlens", None) + max_seqlen = kwargs.pop("max_seqlen", None) + + if cu_seqlens is not None: + kwargs["cu_seqlens_q"] = cu_seqlens + kwargs["cu_seqlens_k"] = cu_seqlens + kwargs["max_seqlen_q"] = max_seqlen + kwargs["max_seqlen_k"] = max_seqlen + + return kwargs + + +class MHA(nn.Module): + """ + Multi-head self-attention and cross-attention. + + Args: + embed_dim (int): The dimention of hidden state. + num_heads (int): The number of attention heads. + max_position_embeddings (int): max position embeddings, 2048 by default. + bias (bool): Whether the bias is needed for linears. True by default. + dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. + softmax_scale (float): The temperature to use for the softmax attention. + causal (boolean): Whether to apply causal attention mask. False by default. + layer_idx (int): The index of current layer. None by default. + use_dynamic_ntk_rope (bool): whether use dynamic ntk rope, false by default. + rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. + rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements + XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default. + enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + max_position_embeddings: int = 2048, + bias: bool = True, + dropout: float = 0.0, + softmax_scale: float = None, + causal: bool = False, + layer_idx: int = None, + use_dynamic_ntk_rope: bool = False, + rotary_emb_dim: int = 0, + rotary_emb_scale_base: int = 0, + rope_base: int = 10000, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + qk_interleaved: Optional[bool] = True, + enable_qkv_fusion: bool = True, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.causal = causal + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = self.embed_dim // num_heads + self.enable_qkv_fusion = enable_qkv_fusion + + self.use_dynamic_ntk_rope = use_dynamic_ntk_rope + self.rotary_emb_dim = rotary_emb_dim + self.max_position_embeddings = max_position_embeddings + self.interleaved = qk_interleaved + + factory_kwargs = {"device": device, "dtype": dtype} + + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + + if self.rotary_emb_dim > 0: + self.rotary_emb = new_rotary_embedding( + self.rotary_emb_dim, + base=rope_base, + scale_base=rotary_emb_scale_base, + device=device, + max_position_embeddings=max_position_embeddings, + scaling_factor=1.0, + rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native", + ) + + if self.enable_qkv_fusion: + # bias=True is according to https://spaces.ac.cn/archives/9577 + self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, **factory_kwargs) + else: + self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs) + self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) + self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) + + self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + + # output projection always have the bias (for now) + self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=True, **factory_kwargs) + + def register_checkpoint_compatibility_hooks( + self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None + ): + # Here we explicitly expose the checkpoint compatibility interface of the module, + # hoping that model developers will make good use of it when adapting. + # Is this interface already meeting all reasonable requirements? + self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True) + self._register_state_dict_hook(pre_save_hook) + + def forward(self, x, inference_params=None, **kwargs): + if inference_params is None: + return self._training(x=x, **kwargs) + else: + return self._inference(x=x, inference_params=inference_params, **kwargs) + + def _training(self, x, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) + """ + # wqkv + if self.enable_qkv_fusion: + qkv = self.wqkv(x) + qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) + + q = qkv[:, :, 0].squeeze(2) + k = qkv[:, :, 1].squeeze(2) + v = qkv[:, :, 2].squeeze(2) + else: + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + + # rotary embedding + indexes = kwargs.pop("indexes", 0) + max_seqlen = kwargs.get("max_seqlen", None) + q = self.rotary_emb( + q, offsets=indexes, cache_type="query", interleaved=self.interleaved, max_seqlen=max_seqlen, in_place=True + ) + k = self.rotary_emb( + k, offsets=indexes, cache_type="key", interleaved=self.interleaved, max_seqlen=max_seqlen, in_place=True + ) + + # self attention + kwargs = _convert_cu_seqlens_for_qksplited(kwargs) + context = self.inner_attn(q, k, v, **kwargs) + + # wo + return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) + + def _convert_unpacked_qkv_to_packed( + self, q: torch.Tensor, kv: torch.Tensor, batch_size: int, attention_mask: torch.Tensor + ): + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attention_mask.device), + attention_mask.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ).cumsum(dim=0, dtype=torch.int32) + + cu_seqlens_q = cu_seqlens + cu_seqlens_k = cu_seqlens + + max_seqlen_q = attention_mask.shape[-1] + max_seqlen_k = attention_mask.shape[-1] + + q_packed = q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) + kv_packed = kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)).view( + -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] + ) + + return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k + + def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 + assert inference_params is not None, "inference_params is required for inference" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + attention_mask = inference_params.get("attention_mask", None) + sequence_len_offset = inference_params.get("sequence_len_offset", 0) + batch_size = x.shape[0] + + # wqkv, output: q, kv + if self.enable_qkv_fusion: + qkv = self.wqkv(x) + qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) + + q = qkv[:, :, 0].squeeze(2) + kv = qkv[:, :, 1:] + else: + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + kv = torch.stack([k, v], dim=2) + + # rotary embedding, output: q, kv + # q shape: [bsz, nheads, head_dim] + # kv shape: [bsz, seqlen, 2, nheads, head_dim] + if self.use_dynamic_ntk_rope: + # update kv cache fisrt when enable dynamic ntk rope. + kv = update_kv_cache(kv, inference_params, self.layer_idx) + + if sequence_len_offset != 0: + if sequence_len_offset > self.max_position_embeddings: + logger.warning( + "Notice your prompt's length is longer than model's max_position_embeddings: " + f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." + ) + + if self.rotary_emb_dim > 0: + q = self.rotary_emb( + q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved + ) + k = kv[:, :, 0].squeueze(2) + self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True + ) # in-place is important + else: + if self.rotary_emb_dim > 0: + q = self.rotary_emb(q, offsets=0, cache_type="query", interleaved=self.interleaved) + k = kv[:, :, 0].squeueze(2) + self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True + ) # in-place is important + else: + assert self.rotary_emb_dim > 0, "You should use rotary_emb." + + k, v = kv[:, :, 0].squeueze(2), kv[:, :, 1].squeueze(2) + + if attention_mask is None: + q = self.rotary_emb(q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved) + k = self.rotary_emb(k, offsets=sequence_len_offset, cache_type="key", interleaved=self.interleaved) + else: + if sequence_len_offset == 0: + q = self.rotary_emb( + q, offsets=0, cache_type="query", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + k = self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + else: + if sequence_len_offset > self.max_position_embeddings: + logger.warning( + "Notice your prompt's length is longer than model's max_position_embeddings: " + f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." + ) + + empties = attention_mask[..., -1].sum(dim=-1) + indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties + indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties + q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved) + k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved) + + kv = torch.stack([k, v], dim=2) + # update kv cache after rotary embedding when disable dynamic ntk rope. + kv = update_kv_cache(kv, inference_params, self.layer_idx) + + # self-attention + if attention_mask is None: + context = self.inner_cross_attn(q, kv) + else: + if sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = attention_mask[:, None, ...] + attn_mask = torch.logical_or(torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(batch_size, -1) + + output = self.inner_attn(*self._convert_unpacked_qkv_to_packed(q, kv, batch_size, attn_mask4flsh)) + output = output.to(x.dtype) + + context = torch.zeros_like(q).masked_scatter_(attn_mask4flsh.view(batch_size, -1, 1, 1), output) + else: + attn_mask = attention_mask[:, -1, :].view(batch_size, 1, 1, -1) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + + # wo + return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) + + +class GQA(nn.Module): + """ + Multi-head self-attention and cross-attention. + + Args: + embed_dim (int): The dimention of hidden state. + num_heads (int): The number of attention heads. + num_kv_heads (int): The number of attention heads for key and value. + max_position_embeddings (int): max position embeddings, 2048 by default. + bias (bool): Whether the bias is needed for linears. Will be used when initializing QKV matrix and + output projection. False by default. + dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. + softmax_scale (float): The temperature to use for the softmax attention. + causal (boolean): Whether to apply causal attention mask. False by default. + layer_idx (int): The index of current layer. None by default. + use_dynamic_ntk_rope (bool): whether use dynamic ntk rope, false by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. + rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements + XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default. + enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_position_embeddings: int = 2048, + bias: bool = False, + dropout: float = 0.0, + softmax_scale: float = None, + causal: bool = False, + layer_idx: int = None, + use_dynamic_ntk_rope: bool = False, + rope_base: int = 10000, + rotary_emb_dim: int = 0, + rotary_emb_scale_base: int = 0, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + qk_interleaved: Optional[bool] = True, + enable_qkv_fusion: bool = True, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.causal = causal + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_per_kv = num_heads // num_kv_heads + self.head_dim = self.embed_dim // num_heads + self.kv_dim = self.head_dim * num_kv_heads + self.enable_qkv_fusion = enable_qkv_fusion + + self.use_dynamic_ntk_rope = use_dynamic_ntk_rope + self.rotary_emb_dim = rotary_emb_dim + self.max_position_embeddings = max_position_embeddings + self.interleaved = qk_interleaved + + factory_kwargs = {"device": device, "dtype": dtype} + + assert self.use_dynamic_ntk_rope is False, "Not support dynamic ntk rope yet." + assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" + + if self.rotary_emb_dim > 0: + self.rotary_emb = new_rotary_embedding( + self.rotary_emb_dim, + base=rope_base, + scale_base=rotary_emb_scale_base, + device=device, + max_position_embeddings=max_position_embeddings, + scaling_factor=1.0, + rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native", + ) + + if enable_qkv_fusion: + self.wqkv = new_linear("wqkv", embed_dim, embed_dim + 2 * self.kv_dim, bias, **factory_kwargs) + else: + self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs) + self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) + self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) + + self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + + self.wo = new_linear("wo", embed_dim, embed_dim, bias, **factory_kwargs) + + def register_checkpoint_compatibility_hooks( + self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None + ): + # Here we explicitly expose the checkpoint compatibility interface of the module, + # hoping that model developers will make good use of it when adapting. + # Is this interface already meeting all reasonable requirements? + self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True) + self._register_state_dict_hook(pre_save_hook) + + def forward(self, x, inference_params=None, **kwargs): + if inference_params is None: + return self._training(x=x, **kwargs) + else: + return self._inference(x=x, inference_params=inference_params, **kwargs) + + def _training(self, x, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) + """ + # wqkv + if self.enable_qkv_fusion: + qkv = self.wqkv(x) + qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim) + q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) + q = rearrange(q, "b s h gs d -> b s (h gs) d") + else: + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + + kwargs = _convert_cu_seqlens_for_qksplited(kwargs) + + # rotary embedding + if self.rotary_emb_dim > 0: + indexes = kwargs.pop("indexes", 0) + max_seqlen_q = kwargs.get("max_seqlen_q", None) + max_seqlen_k = kwargs.get("max_seqlen_k", None) + + q = self.rotary_emb( + q, offsets=indexes, max_seqlen=max_seqlen_q, cache_type="query", interleaved=self.interleaved + ) + k = self.rotary_emb( + k, offsets=indexes, max_seqlen=max_seqlen_k, cache_type="key", interleaved=self.interleaved + ) + + kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) + + # self attention + context = self.inner_attn(q, kv, **kwargs) + + # wo + return self.wo(rearrange(context, "b s h d -> b s (h d)")) + + def _convert_unpacked_qkv_to_packed( + self, q: torch.Tensor, kv: torch.Tensor, batch_size: int, attention_mask: torch.Tensor + ): + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attention_mask.device), + attention_mask.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ).cumsum(dim=0, dtype=torch.int32) + + cu_seqlens_q = cu_seqlens + cu_seqlens_k = cu_seqlens + + max_seqlen_q = attention_mask.shape[-1] + max_seqlen_k = attention_mask.shape[-1] + + q_packed = q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) + kv_packed = kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)).view( + -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] + ) + + return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k + + def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 + assert inference_params is not None, "inference_params is required for inference" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + attention_mask = inference_params.get("attention_mask", None) + sequence_len_offset = inference_params.get("sequence_len_offset", 0) + window_size = inference_params.get("window_size", None) + + batch_size = x.shape[0] + + # wqkv, output: q, k, v + if self.enable_qkv_fusion: + qkv = self.wqkv(x) + qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim) + q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :].unsqueeze(-2), qkv[..., -1, :].unsqueeze(-2)) + q = rearrange(q, "b s h gs d -> b s (h gs) d") + else: + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + + # rotary embedding, output: q, kv + assert self.rotary_emb_dim > 0 + if attention_mask is None: + raise NotImplementedError( + "You should make sure you are aware that you are changing the method of generating." + "According to your generation function instead of inference/seq_generator_module.py, " + "You may implement here for normal running." + ) + else: + if inference_params.sequence_len_offset == 0: + q = self.rotary_emb( + q, offsets=0, cache_type="query", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + k = self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + else: + empties = attention_mask[..., -1].sum(dim=-1) + indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties + indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties + q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved) + k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved) + + kv = torch.stack([k, v], dim=2) + + if window_size is None or window_size > sequence_len_offset: + kv = update_kv_cache(kv, inference_params, self.layer_idx) + else: # window_size <= sequence_len_offset + assert kv.size(1) == 1, "update kv length more than 1" + + inference_params.key_value_memory_dict[self.layer_idx][ + :, inference_params.keep_first : inference_params.window_size - 1, ... + ] = inference_params.key_value_memory_dict[self.layer_idx][ + :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... + ].clone() + inference_params.real_sequence_len_offset = inference_params.sequence_len_offset + inference_params.sequence_len_offset = inference_params.window_size - 1 + + kv = update_kv_cache(kv, inference_params, self.layer_idx) + + inference_params.sequence_len_offset = inference_params.real_sequence_len_offset + + # When using FP16, there is a high probability of NAN in the KV. + # Since NAN cannot be removed by multiplying with and 0, it needs + # to be removed manually here. + kv = torch.where(torch.isnan(kv), 0, kv) + + # attention + if attention_mask is None: + context = self.inner_cross_attn(q, kv) + else: + if sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = attention_mask[:, None, ...] + attn_mask = torch.logical_or(torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(batch_size, -1) + + output = self.inner_attn(*self._convert_unpacked_qkv_to_packed(q, kv, batch_size, attn_mask4flsh)) + output = output.to(x.dtype) + + context = torch.zeros_like(q).masked_scatter_(attn_mask4flsh.view(batch_size, -1, 1, 1), output) + + else: + attn_mask = attention_mask[:, -1, :].view(batch_size, 1, 1, -1) + if window_size is not None and window_size <= sequence_len_offset: + attn_mask = torch.concat( + [ + attn_mask[..., : inference_params.keep_first], + attn_mask[..., -(window_size - inference_params.keep_first) :], + ], + dim=-1, + ) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + expansion = q.size(2) // k.size(2) + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + + # wo + return self.wo(rearrange(context, "b s h d -> b s (h d)")) diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index fddc4194..897e1363 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -1,142 +1,60 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Callable, Dict, Optional +from typing import Dict, Optional import torch from torch import nn -from internlm.model.ops.linear import ( - ColumnParallelLinearTorch, - ISPLinear, - MegatronColumnParallelLinearTorch, - MegatronRowParallelLinearTorch, - RowParallelLinearTorch, -) -from internlm.model.utils import Silu +from internlm.model.modules.linear import new_linear +from internlm.model.modules.utils import Silu +from internlm.utils.logger import get_logger +logger = get_logger(__file__) -class BaseFeedForward(nn.Module): - """ - Base FeedForward in flash implementation. - Args: - in_features (int): size of each input sample - hidden_features (int): size of hidden state of FFN - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. - column_cls (Optional[Callable]): The column parallel class for w1 and w3. None by default. - row_cls (Optional[Callable]): The row parallel class for w2. None by default. - mlp_layer_fusion (Optional[Bool]): Some linears without bias in FFN can be fused to reduce the comm cost of SP. - """ +def split_fused_mlp_weight(w1_w3): + w1, w3 = torch.split(w1_w3, w1_w3.shape[0] // 2, dim=0) + return w1, w3 - def __init__( - self, - in_features: int, - hidden_features: int, - out_features: int = None, - process_group: Optional[torch.distributed.ProcessGroup] = None, - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - multiple_of: int = 256, - mlp_layer_fusion: Optional[bool] = False, - sequence_parallel: Optional[bool] = False, - column_cls: Optional[Callable] = None, - row_cls: Optional[Callable] = None, - ): - super().__init__() - self.mlp_layer_fusion = mlp_layer_fusion - hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) - mlp_args = { - "process_group": process_group, - "bias": bias, - "sequence_parallel": sequence_parallel, - "device": device, - "dtype": dtype, - "multiple_of": 1, # TODO: check Column/RowParallelLinearTorch. - } - if not self.mlp_layer_fusion: - # gate_proj - self.w1 = column_cls(in_features, hidden_features, **mlp_args) - # down_proj - self.w2 = row_cls(hidden_features, out_features, **mlp_args) - # up_proj - self.w3 = column_cls(in_features, hidden_features, **mlp_args) - else: - assert bias is False, "Fuesd FeedForward only support bias is False." - # fused gate/up projection - self.fused_w1_w3 = column_cls(in_features, hidden_features * 2, **mlp_args) - # down_proj - self.w2 = row_cls(hidden_features, out_features, **mlp_args) - # TODO: Internal methods could change without a deprecation warning. - self._register_load_state_dict_pre_hook(BaseFeedForward._mlp_pre_load_convert, with_module=True) - self._register_state_dict_hook(BaseFeedForward._mlp_save_convert) +def _mlp_pre_load_convert( + module: "FeedForward", state_dict, prefix: str, *args, **kwargs # pylint: disable=W0613 +) -> None: + w1_name, w3_name, fused_name = f"{prefix}w1.weight", f"{prefix}w3.weight", f"{prefix}fused_w1_w3.weight" - def forward(self, x): - if not self.mlp_layer_fusion: - w1_o = self.w1(x) - w3_o = self.w3(x) - else: - fussed_out = self.fused_w1_w3(x) - w1_o, w3_o = BaseFeedForward.split_fused_mlp_output(fussed_out) - out = self.w2(Silu(w1_o, w3_o)) - return out + if module.mlp_layer_fusion and fused_name not in state_dict: + w1, w3 = state_dict.pop(w1_name), state_dict.pop(w3_name) + state_dict[fused_name] = torch.cat([w1, w3], dim=0) - @staticmethod - def split_fused_mlp_weight(w1_w3): - w1, w3 = torch.split(w1_w3, w1_w3.shape[0] // 2, dim=0) - return w1, w3 + if not module.mlp_layer_fusion and (w1_name not in state_dict or w3_name not in state_dict): + state_dict[w1_name], state_dict[w3_name] = split_fused_mlp_weight(state_dict.pop(fused_name)) - @staticmethod - def split_fused_mlp_output(w1_w3_out): - w1_o, w3_o = torch.split(w1_w3_out, w1_w3_out.shape[-1] // 2, dim=-1) - return w1_o, w3_o - def _mlp_pre_load_convert(self, state_dict, prefix, *args, **kwargs) -> None: # pylint: disable=W0613 - w1_name = f"{prefix}w1.weight" - w3_name = f"{prefix}w3.weight" - fused_w1_w3_name = f"{prefix}fused_w1_w3.weight" +def _mlp_save_convert(module: "FeedForward", state_dict, prefix: str, *args, **kwargs) -> Dict: # pylint: disable=W0613 + w1_name, w3_name, fused_name = f"{prefix}w1.weight", f"{prefix}w3.weight", f"{prefix}fused_w1_w3.weight" - if self.mlp_layer_fusion and fused_w1_w3_name not in state_dict: - w1, w3 = state_dict.pop(w1_name), state_dict.pop(w3_name) - state_dict[fused_w1_w3_name] = torch.cat([w1, w3], dim=0) - if not self.mlp_layer_fusion and (w1_name not in state_dict or w3_name not in state_dict): - state_dict[w1_name], state_dict[w3_name] = self.split_fused_mlp_weight(state_dict.pop(fused_w1_w3_name)) + if module.mlp_layer_fusion: + state_dict[w1_name], state_dict[w3_name] = split_fused_mlp_weight(state_dict.pop(fused_name)) - def _mlp_save_convert(self, state_dict, prefix, *args, **kwargs) -> Dict: # pylint: disable=W0613 - w1_name = f"{prefix}w1.weight" - w3_name = f"{prefix}w3.weight" - fused_w1_w3_name = f"{prefix}fused_w1_w3.weight" + return state_dict - if self.mlp_layer_fusion: - state_dict[w1_name], state_dict[w3_name] = self.split_fused_mlp_weight( - w1_w3=state_dict.pop(fused_w1_w3_name) - ) - return state_dict - - -class FeedForward(BaseFeedForward): +class FeedForward(nn.Module): """ - FeedForward in flash implementation. + Base FeedForward in flash implementation. Args: in_features (int): size of each input sample hidden_features (int): size of hidden state of FFN out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False in the config. device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. + mlp_layer_fusion (Optional[Bool]): Some linears without bias in FFN can be fused to reduce the comm cost of SP. + activation_type (str): the activation function used for feed forward, "swiglu" by default. """ def __init__( @@ -144,125 +62,57 @@ def __init__( in_features: int, hidden_features: int, out_features: int = None, - process_group: Optional[torch.distributed.ProcessGroup] = None, bias: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, mlp_layer_fusion: Optional[bool] = False, - sequence_parallel: Optional[bool] = False, + activation_type: str = "swiglu", ): - super().__init__( - in_features, - hidden_features, - out_features, - process_group, - bias, - device, - dtype, - multiple_of, - mlp_layer_fusion, - sequence_parallel, - ColumnParallelLinearTorch, - RowParallelLinearTorch, - ) - + super().__init__() -class MegatronFeedForward(BaseFeedForward): - """ - FeedForward in megatron implementation. + # TODO: support gelu... + assert activation_type in ("swiglu"), f"Unsupported activation type: {activation_type}" - Args: - in_features (int): size of each input sample - hidden_features (int): size of hidden state of FFN - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. - """ + self.mlp_layer_fusion = mlp_layer_fusion - def __init__( - self, - in_features: int, - hidden_features: int, - out_features: int = None, - process_group: Optional[torch.distributed.ProcessGroup] = None, - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - multiple_of: int = 256, - mlp_layer_fusion: Optional[bool] = False, - sequence_parallel: Optional[bool] = False, - ): - super().__init__( - in_features, - hidden_features, - out_features, - process_group, - bias, - device, - dtype, - multiple_of, - mlp_layer_fusion, - sequence_parallel, - MegatronColumnParallelLinearTorch, - MegatronRowParallelLinearTorch, - ) + hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) + if self.mlp_layer_fusion: + assert bias is False, "Fuesd FeedForward only support bias is False." -class ISPFeedForward(BaseFeedForward): - """ - FeedForward in ISP. + self.fused_w1_w3 = new_linear("w13", in_features, hidden_features * 2, bias, device=device, dtype=dtype) + self.w2 = new_linear("w2", hidden_features, out_features, bias, device=device, dtype=dtype) - Args: - in_features (int): size of each input sample - hidden_features (int): size of hidden state of FFN - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. - """ + self._register_load_state_dict_pre_hook(_mlp_pre_load_convert, with_module=True) + self._register_state_dict_hook(_mlp_save_convert) + else: + self.w1 = new_linear("w1", in_features, hidden_features, bias, device=device, dtype=dtype) + self.w2 = new_linear("w2", hidden_features, out_features, bias, device=device, dtype=dtype) + self.w3 = new_linear("w3", in_features, hidden_features, bias, device=device, dtype=dtype) - def __init__( - self, - in_features: int, - hidden_features: int, - out_features: int = None, - process_group: Optional[torch.distributed.ProcessGroup] = None, - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - multiple_of: int = 256, - mlp_layer_fusion: Optional[bool] = False, - sequence_parallel: Optional[bool] = False, - ): - super().__init__( - in_features, - hidden_features, - out_features, - process_group, - bias, - device, - dtype, - multiple_of, - mlp_layer_fusion, - sequence_parallel, - ISPLinear, - ISPLinear, - ) + def forward(self, x): + if not self.mlp_layer_fusion: + w1_o = self.w1(x) + w3_o = self.w3(x) + else: + fussed_out = self.fused_w1_w3(x) + w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1) + out = self.w2(Silu(w1_o, w3_o)) + return out -def get_mlp_cls(tp_mode: str): - if tp_mode in ["mtp", "fsp"]: - mlp_cls = FeedForward - elif tp_mode == "msp": - mlp_cls = MegatronFeedForward - else: - mlp_cls = ISPFeedForward - return mlp_cls +def new_feed_forward( + in_features: int, + hidden_features: int, + out_features: int = None, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + multiple_of: int = 256, + mlp_layer_fusion: Optional[bool] = False, + activation_type: str = "swiglu", +) -> FeedForward: + return FeedForward( + in_features, hidden_features, out_features, bias, device, dtype, multiple_of, mlp_layer_fusion, activation_type + ) diff --git a/internlm/model/modules/multi_head_attention.py b/internlm/model/modules/multi_head_attention.py deleted file mode 100644 index 8067e1dd..00000000 --- a/internlm/model/modules/multi_head_attention.py +++ /dev/null @@ -1,867 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import math -import warnings -from typing import Any, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from einops import rearrange, repeat -from torch import Tensor, nn -from torch.nn import Module - -from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import global_context as gpc -from internlm.model.modules.embedding import ( - DynamicNTKScalingRotaryEmbedding, - RotaryEmbedding, -) -from internlm.model.ops.linear import get_linear_cls -from internlm.model.utils import pack_output_after_attn, unpack_qkv_before_attn -from internlm.utils.common import get_current_device - -internlm_accelerator = get_accelerator() - -try: - import torch_npu -except (ImportError, ModuleNotFoundError): - pass - - -def get_gqa_attn_cls(use_flash_attn, tp_mode, causal, softmax_scale, dropout, sequence_process_group): - if use_flash_attn: - device_backend = internlm_accelerator.get_accelerator_backend() - if device_backend == AcceleratorType.GPU: - from flash_attn import flash_attn_varlen_kvpacked_func - from flash_attn.modules.mha import FlashCrossAttention - - inner_attn, inner_cross_attn_cls = flash_attn_varlen_kvpacked_func, FlashCrossAttention - elif device_backend == AcceleratorType.NPU: - from internlm.model.modules.multi_head_attention import ( - AscendFlashSelfAttention, - ) - - inner_attn_cls, inner_cross_attn_cls = AscendFlashSelfAttention, AscendFlashSelfAttention - inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - elif device_backend == AcceleratorType.DIPU: - from deeplink_ext.internevo_ops import ( - FlashCrossAttention, - FlashSelfAttention, - ) - - inner_attn_cls, inner_cross_attn_cls = FlashSelfAttention, FlashCrossAttention - inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - else: - raise NotImplementedError(f"Unsupport device type: {device_backend} for flash attention") - else: - inner_attn_cls, inner_cross_attn_cls = SelfAttention, CrossAttention - inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - - inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - - if tp_mode == "isp": - inner_attn = DistributedAttention(inner_attn, sequence_process_group=sequence_process_group) - inner_cross_attn = DistributedAttention(inner_cross_attn, sequence_process_group=sequence_process_group) - - return inner_attn, inner_cross_attn - - -class AscendFlashSelfAttention(torch.nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__( - self, - causal: bool = True, - softmax_scale: float = None, - attention_dropout: float = 0.0, - ): - super().__init__() - assert rearrange is not None, "Please install einops first, e.g., with pip install einops" - self.causal = causal - self.softmax_scale = softmax_scale - self.shape_order = "BSND" - self.dropout_p = attention_dropout - - if self.causal: - self.sparse_mode = 0 - self.next_tockens = 0 - else: - assert False, "Ascend flash attention unsupport causal=False now!" - - def forward( - self, - qkv=None, - q=None, - k=None, - v=None, - kv=None, - cu_seqlens_q=None, # pylint: disable=W0613 - cu_seqlens_k=None, # pylint: disable=W0613 - max_seqlen_q=None, # pylint: disable=W0613 - max_seqlen_k=None, # pylint: disable=W0613 - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # pylint: disable=W0613 - alibi_slopes=None, # pylint: disable=W0613 - deterministic=False, - return_attn_probs=False, # pylint: disable=W0613 - attention_mask=None, - ): - if qkv is not None: - assert (q, k, v, kv) == (None, None, None, None) - q = qkv[:, :, 0] - k = qkv[:, :, 1] - v = qkv[:, :, 2] - else: - assert q is not None - if kv is not None: - assert (k, v) == (None, None) - k = kv[:, :, 0] - v = kv[:, :, 1] - else: - assert k is not None and v is not None - - if causal: - assert causal == self.causal - if dropout_p: - assert dropout_p == self.dropout_p - if softmax_scale: - assert softmax_scale == self.softmax_scale - - return self._forward(q, k, v, deterministic=deterministic, attention_mask=attention_mask) - - def _forward( - self, - q, - k, - v, - deterministic: bool = False, - attention_mask: Tensor = None, - actual_seq_qlen: Tensor = None, # pylint: disable=W0613 - actual_seq_kvlen: Tensor = None, # pylint: disable=W0613 - ): - """Implements the multihead softmax attention. - Arguments - --------- - q, k, v: The tensor containing the query, key, and value. (B, S, H, D) - """ - assert q.dtype in (torch.bfloat16, torch.float16) - - if len(q.shape) == 5: - q = q.squeeze(dim=2) - k = k.squeeze(dim=2) - v = v.squeeze(dim=2) - - B, S, N, D = q.shape[0], q.shape[1], q.shape[2], q.shape[3] # noqa: F841 # pylint: disable=W0612 - - if self.shape_order == "BSH": - q, k, v = [rearrange(x, "b s h d -> b s (h d)") for x in [q, k, v]] - elif self.shape_order == "SBH": - q, k, v = [rearrange(x, "b s h d -> s b (h d)") for x in [q, k, v]] - elif self.shape_order != "BSND": - raise ValueError("Invalid shape-order: {}, shape-order must be SBH or BSH or BSND".format(self.shape_order)) - - if attention_mask is None: - attention_mask = torch.triu(torch.ones(S, S, device=get_current_device()), 1).bool() - - output = torch_npu.npu_fusion_attention( - query=q, - key=k, - value=v, - head_num=N, - input_layout="BSND", - pse=None, - atten_mask=attention_mask, - scale=self.softmax_scale, - sparse_mode=self.sparse_mode, - pre_tockens=k.shape[1], # Used for sparse calculations, representing the left boundary of the slides window - next_tockens=self.next_tockens, - keep_prob=1 - self.dropout_p, - inner_precise=0 if not deterministic else 2, - )[0] - - if self.shape_order == "BSH": - output = rearrange(output, "b s (h d) -> b s h d", h=N) - elif self.shape_order == "SBH": - output = rearrange(output, "s b (h d) -> b s h d", h=N) - elif self.shape_order != "BSND": - raise ValueError("Invalid shape-order: {}, shape-order must be SBH or BSH or BSND".format(self.shape_order)) - - return output - - -# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py -class _SeqAllToAll(torch.autograd.Function): - "sequence alltoall" - - @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: - ctx.group = group - ctx.scatter_idx = scatter_idx - ctx.gather_idx = gather_idx - - if dist.get_world_size(group) <= 1: - return input_ - - seq_world_size = dist.get_world_size(group) - - input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)] - output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] - # TODO Use all_to_all_single instead - dist.all_to_all(output_list, input_list, group=group) - return torch.cat(output_list, dim=gather_idx).contiguous() - - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: - if dist.get_world_size(ctx.group) <= 1: - return (None, *grad_output, None, None) - - return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) - - -# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py -class DistributedAttention(torch.nn.Module): - """Initialization. - - Arguments: - local_attention (Module): local attention with q,k,v - sequence_process_group (ProcessGroup): sequence parallel process group - first_scatter_idx (int): scatter_idx for the first all2all comm - first_gather_idx (int): gather_idx for the first all2all comm - second_scatter_idx (int): scatter_idx for the second all2all comm - second_gather_idx (int): gather_idx for the second all2all comm - """ - - def __init__( - self, - local_attention: Module, - sequence_process_group: dist.ProcessGroup, - ) -> None: - super().__init__() - self.local_attn = local_attention - self.spg = sequence_process_group - self._scatter_gather_idx = {} - - # scatter_gather_idx contains the scatter and gather index for different data packed mode - # key is the data packed mode, which should be in ['qkv', 'kv', 'q', 'output'] - # value is the scatter and gather index in all2all - self._scatter_gather_idx["qkv"] = [2, 0] # qkv shape:[sequence, 3, head, head_dim] - self._scatter_gather_idx["kv"] = [2, 0] # kv shape: [sequence, 2, head, head_dim] - self._scatter_gather_idx["q"] = [1, 0] # q/k/v shape: [sequence, head, head_dim] - self._scatter_gather_idx["output"] = [0, 1] # output shape: [sequence, head, head_dim] - - def forward( - self, qkv: Tensor = None, kv: Tensor = None, q: Tensor = None, k: Tensor = None, v: Tensor = None, **kwargs: Any - ) -> Tensor: - if gpc.is_evaluating is True or gpc.config.data.use_packed_dataset is False: - # when conducting evaluation, the scatter and gather index should add 1. - eval_scatter_gather_idx = {key: [x + 1 for x in value] for key, value in self._scatter_gather_idx.items()} - return self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=eval_scatter_gather_idx, **kwargs) - else: - return self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=self._scatter_gather_idx, **kwargs) - - def _forward( - self, - qkv: Tensor = None, - kv: Tensor = None, - q: Tensor = None, - k: Tensor = None, - v: Tensor = None, - scatter_gather: dict = None, - **kwargs: Any, - ) -> Tensor: - """forward - - Arguments: - qkv (Tensor): packed qkv input to the layer - kv (Tensor): packed kv input to the layer - q (Tensor): q input to the layer - k (Tensor): k input to the layer - v (Tensor): v input to the layer - args: other args - - Returns: - * output (Tensor): context output - """ - - if qkv is not None: - qkv = _SeqAllToAll.apply(self.spg, qkv, scatter_gather["qkv"][0], scatter_gather["qkv"][1]) - context_layer = self.local_attn(qkv=qkv, **kwargs) - elif kv is not None: - q = _SeqAllToAll.apply(self.spg, q, scatter_gather["q"][0], scatter_gather["q"][1]) - kv = _SeqAllToAll.apply(self.spg, kv, scatter_gather["kv"][0], scatter_gather["kv"][1]) - context_layer = self.local_attn(q=q, kv=kv, **kwargs) - else: - q = _SeqAllToAll.apply(self.spg, q, scatter_gather["q"][0], scatter_gather["q"][1]) - k = _SeqAllToAll.apply(self.spg, k, scatter_gather["q"][0], scatter_gather["q"][1]) - v = _SeqAllToAll.apply(self.spg, v, scatter_gather["q"][0], scatter_gather["q"][1]) - context_layer = self.local_attn(q=q, k=k, v=v, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, scatter_gather["output"][0], scatter_gather["output"][1]) - - # out e.g., [s/p::h] - return output - - -class SelfAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - def forward(self, qkv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, S) - """ - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - causal = self.causal if causal is None else causal - q, k, v = qkv.unbind(dim=2) - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - if key_padding_mask is not None: - padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) - padding_mask.masked_fill_(key_padding_mask, 0.0) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - if causal: - # "triu_tril_cuda_template" not implemented for 'BFloat16' - # So we have to construct the mask in float - causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + causal_mask.to(dtype=scores.dtype) - attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = self.drop(attention) - output = torch.einsum("bhts,bshd->bthd", attention_drop, v) - return output - - -class CrossAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - def forward(self, q, kv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - q: The tensor containing the query. (B, Sq, H, D) - kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, Sk) - """ - batch_size, seqlen_q = q.shape[0], q.shape[1] - causal = self.causal if causal is None else causal - seqlen_k = kv.shape[1] - assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] - if kv.shape[3] != q.shape[2]: # MQA/GQA - kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) - k, v = kv.unbind(dim=2) - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - if key_padding_mask is not None: - padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device) - padding_mask.masked_fill_(key_padding_mask, 0.0) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - if causal: - # causal mask needs to take into account the difference between seqlen_q and seqlen_k - row_idx = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") - col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) - sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") - causal_mask = col_idx > row_idx + sk - seqlen_q - scores = scores.masked_fill(causal_mask, -10000.0) - attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = self.drop(attention) - output = torch.einsum("bhts,bshd->bthd", attention_drop, v) - return output - - -def _update_kv_cache(kv, inference_params, layer_idx): - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" - # Pre-allocate memory for key-values for inference. - num_heads, head_dim = kv.shape[-2:] - if layer_idx not in inference_params.key_value_memory_dict: - kv_cache = torch.empty( - inference_params.max_batch_size, - inference_params.max_sequence_len, - 2, - num_heads, - head_dim, - dtype=kv.dtype, - device=kv.device, - ) - inference_params.key_value_memory_dict[layer_idx] = kv_cache - else: - if not inference_params.fused_ft_kernel: - kv_cache = inference_params.key_value_memory_dict[layer_idx] - else: - # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) - # where packsize = 4 if fp32, 8 if fp16 or bf16. - # v_cache has shape (b, h, s, headdim) - k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] - kv_cache = None - # Adjust key and value for inference - batch_start = inference_params.batch_size_offset - batch_end = batch_start + kv.shape[0] - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + kv.shape[1] - assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) - assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) - # Copy key and values. - if not inference_params.fused_ft_kernel: - assert kv_cache is not None - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = kv_cache[batch_start:batch_end, :sequence_end, ...] - return kv - else: - assert inference_params.sequence_len_offset == 0 - # FT kernel requires different layouts for the k_cache and v_cache. - assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] - packsize = 4 if kv.dtype == torch.float32 else 8 - if kv_cache is not None: - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - k_cache = rearrange( - kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize - ).contiguous() - v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous() - inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) - else: - k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( - kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize - ) - v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d") - return kv - - -class MHA(nn.Module): - """ - Multi-head self-attention and cross-attention. - - Args: - embed_dim (int): The dimention of hidden state. - num_heads (int): The number of attention heads. - process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. - max_position_embeddings (int): max position embeddings, 2048 by default. - dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. - softmax_scale (float): The temperature to use for the softmax attention. - causal (boolean): Whether to apply causal attention mask. False by default. - layer_idx (int): The index of current layer. None by default. - use_dynamic_ntk_rope (bool): whether use dynamic ntk rope, false by default. - rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. - rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements - XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - process_group: Optional[torch.distributed.ProcessGroup], - sequence_process_group: Optional[torch.distributed.ProcessGroup], - max_position_embeddings: int = 2048, - dropout: float = 0.0, - softmax_scale: float = None, - causal: bool = False, - layer_idx: int = None, - use_dynamic_ntk_rope: bool = False, - rotary_emb_dim: int = 0, - rotary_emb_scale_base: int = 0, - use_flash_attn: bool = True, - rope_base: int = 10000, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - tp_mode: str = "mtp", - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.embed_dim = embed_dim - self.causal = causal - self.layer_idx = layer_idx - self.max_position_embeddings = max_position_embeddings - self.use_dynamic_ntk_rope = use_dynamic_ntk_rope - self.rotary_emb_dim = rotary_emb_dim - self.use_flash_attn = use_flash_attn - self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" - self.head_dim = self.embed_dim // num_heads - self.tp_mode = tp_mode - - if self.rotary_emb_dim > 0: - if self.use_dynamic_ntk_rope: - self.rotary_emb = DynamicNTKScalingRotaryEmbedding( - self.rotary_emb_dim, - base=rope_base, - scale_base=rotary_emb_scale_base, - device=device, - max_position_embeddings=max_position_embeddings, - scaling_factor=1.0, # Currently do not support dynamic scaling. - ) - else: - self.rotary_emb = RotaryEmbedding( - self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device - ) - - # notice here should change bias=True - Wqkv_cls = get_linear_cls(self.tp_mode, "column") - self.Wqkv = Wqkv_cls( - embed_dim, - 3 * embed_dim, - process_group, - bias=True, - sequence_parallel=gpc.config.parallel.sequence_parallel, - **factory_kwargs, - ) # according to https://spaces.ac.cn/archives/9577 - - if gpc.config.model.use_flash_attn: - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - from flash_attn.modules.mha import ( - FlashCrossAttention, - FlashSelfAttention, - ) - elif internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: - FlashCrossAttention, FlashSelfAttention = AscendFlashSelfAttention, AscendFlashSelfAttention - elif internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU: - from deeplink_ext.internevo_ops import ( - FlashCrossAttention, - FlashSelfAttention, - ) - - inner_attn_cls = FlashSelfAttention - inner_cross_attn_cls = FlashCrossAttention - else: - inner_attn_cls = SelfAttention - inner_cross_attn_cls = CrossAttention - - self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - self.inner_cross_attn = inner_cross_attn_cls( - causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout - ) - if self.tp_mode == "isp": - self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=sequence_process_group) - self.inner_cross_attn = DistributedAttention( - self.inner_cross_attn, sequence_process_group=sequence_process_group - ) - - # output projection always have the bias (for now) - out_proj_cls = get_linear_cls(self.tp_mode, "row") - self.out_proj = out_proj_cls( - embed_dim, - embed_dim, - process_group, - bias=True, - sequence_parallel=gpc.config.parallel.sequence_parallel, - **factory_kwargs, - ) - - def forward(self, x, seqlen=None, inference_params=None, **kwargs): - if kwargs.get("indexes", None) is not None: - return self._packed_forward(x=x, inference_params=inference_params, **kwargs) - else: - return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) - - def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - bsz, _, _ = x.shape - qkv = self.Wqkv(x) - if seqlen is None: - qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) - else: - qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim) - - if inference_params is None: - kwargs["inference_params"] = inference_params - qkv = self.rotary_emb(qkv, **kwargs) - if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - if qkv.dtype not in [torch.float16, torch.bfloat16]: - qkv = qkv.to(torch.bfloat16) - context = self.inner_attn(qkv=qkv).to(x.dtype) - else: - context = self.inner_attn(qkv=qkv) - - else: - if self.use_dynamic_ntk_rope: - q = qkv[:, :, 0] - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) - if inference_params.sequence_len_offset != 0: - # q shape: [bsz, 1, nheads, head_dim] - # kv shape: [bsz, seqlen, 2, nheads, head_dim] - bsz, seq_len, _, nheads, head_dim = kv.shape - q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2) - qkv = torch.cat([q, kv], dim=2) - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv) - q = qkv[:, [-1], 0] - kv = qkv[:, :, 1:] - else: - if inference_params.sequence_len_offset > self.max_position_embeddings: - warnings.warn( - "Notice your prompt's length is longer than model's max_position_embeddings: " - f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." - ) - if self.rotary_emb_dim > 0: - kwargs["inference_params"] = inference_params - qkv = self.rotary_emb(qkv, **kwargs) - q = qkv[:, :, 0] - kv = qkv[:, :, 1:] - else: - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2)) - kv = torch.stack([k, v], dim=2) - assert self.rotary_emb_dim > 0, "You should use rotary_emb." - - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - empties = inference_params.attention_mask[..., -1].sum(dim=-1) - if inference_params.sequence_len_offset == 0: - moved_q = q.clone() - moved_k = k.clone() - for i in range(len(empties)): - if empties[i] != 0: - moved_q[i][: -empties[i]] = q[i][empties[i] :] - moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0) - moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) - for i in range(len(empties)): - if empties[i] != 0: - q[i][empties[i] :] = moved_q[i][: -empties[i]] - k[i][empties[i] :] = moved_k[i][: -empties[i]] - else: - q[i] = moved_q[i] - k[i] = moved_k[i] - elif not self.use_dynamic_ntk_rope: - if inference_params.sequence_len_offset > self.max_position_embeddings: - warnings.warn( - "Notice your prompt's length is longer than model's max_position_embeddings: " - f"{self.max_position_embeddings}, may cause deviations in dynamic ntk calculations." - ) - q = self.rotary_emb._single_forward( - q, - inference_params.sequence_len_offset - * torch.ones(q.size(0), dtype=torch.int, device=q.device) - - empties, - ) - k = self.rotary_emb._single_forward( - k, - inference_params.sequence_len_offset - * torch.ones(k.size(0), dtype=torch.int, device=k.device) - - empties, - ) - else: - q = self.rotary_emb._single_forward( - q, - inference_params.sequence_len_offset - * torch.ones(q.size(0), dtype=torch.int, device=q.device) - - empties, - ) - moved_k = k.clone() - for i in range(len(empties)): - if empties[i] != 0: - moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) - for i in range(len(empties)): - if empties[i] != 0: - k[i][empties[i] :] = moved_k[i][: -empties[i]] - else: - k[i] = moved_k[i] - else: - q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset) - k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset) - - kv = torch.stack([k, v], dim=2) - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) - attn_mask = inference_params.attention_mask[:, None, ...] - attn_mask = torch.logical_or( - torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask - ) - attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) - cu_seqlens = torch.concat( - [ - torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), - attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), - ], - dim=0, - ) - cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) - max_seqlen_q = attn_mask4flsh.shape[-1] - max_seqlen_k = attn_mask4flsh.shape[-1] - total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) - total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( - -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] - ) - - if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - if total_q.dtype not in [torch.float16, torch.bfloat16]: - total_q = total_q.to(torch.bfloat16) - if total_kv.dtype not in [torch.float16, torch.bfloat16]: - total_kv = total_kv.to(torch.bfloat16) - - try: - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_func, - ) - except ImportError: - try: - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func, - ) - except ImportError: - try: - from flash_attn.flash_attn_interface import ( - flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func, - ) - except ImportError: - raise ImportError("Please check your flash_attn version >= 1.0.5.") - - output = flash_attn_unpadded_func( - total_q, - total_kv, - cu_seqlens, - cu_seqlens, - max_seqlen_q, - max_seqlen_k, - 0.0, - None, - True, - False, - ).to(x.dtype) - else: - attn_scores = torch.matmul(total_q, total_kv.transpose(-2, -1)) / (cu_seqlens**0.5) - attn_weights = F.softmax(attn_scores, dim=-1) - output = torch.matmul(attn_weights, total_kv) - - context = torch.zeros_like(q) - context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) - - else: - attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) - - k, v = torch.chunk(kv, 2, dim=2) - k = k.squeeze(2) - v = v.squeeze(2) - sp = k.shape - scores = torch.einsum( - "blhd,bnhd->bhln", - q, - k.reshape(sp[0], sp[1], q.size(2), sp[3]), - ) / math.sqrt(q.size(-1)) - scores = scores.masked_fill(attn_mask, -65000.0) - scores = F.softmax(scores, dim=-1) # bsz x h x L x L - context = torch.einsum( - "bhmn,bnhd->bmhd", - scores, - v.reshape(sp[0], sp[1], q.size(2), sp[3]), - ) - else: - context = self.inner_cross_attn(q, kv, causal=True) - - if seqlen is None: - context = rearrange(context, "b s h d -> b s (h d)") - else: - context = rearrange(context, "b s h d -> (b s) (h d)") - - out = self.out_proj(context) - return out - - def _packed_forward(self, x, inference_params=None, **kwargs): - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - qkv = self.Wqkv(x) # bsz x total x hsz - qkv = rearrange( - qkv, "b t (three h d) -> b t three h d", three=3, d=self.head_dim - ) # bsz x total x 3 x n_head x d - qkv = self.rotary_emb(qkv, **kwargs) - - kwargs.pop("indexes") - cu_seqlens = kwargs["cu_seqlens"] - - # for packed data, batch dimension with a size of 1 should be directly squeezed off. - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - qkv = qkv.squeeze(0) - # since torch_npu only supports fa with no packed data currently, qkv should be unpacked - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - qkv = unpack_qkv_before_attn(qkv, cu_seqlens) - kwargs.pop("cu_seqlens") - kwargs.pop("max_seqlen") - - if inference_params is None: - if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - if qkv.dtype not in [torch.float16, torch.bfloat16]: - qkv = qkv.to(torch.bfloat16) - context = self.inner_attn(qkv=qkv, **kwargs).to(x.dtype) - else: - context = self.inner_attn(qkv=qkv, **kwargs) - - else: - raise RuntimeError("Not support this right now") - - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - context = rearrange(context, "s h d -> s (h d)") # recover the shape - context = context.unsqueeze(0) # restore bsz dimension - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - context = rearrange(context, "b s h d -> b s (h d)") # recover the shape - context = pack_output_after_attn(context, cu_seqlens) - - out = self.out_proj(context) - - return out diff --git a/internlm/model/modules/norm.py b/internlm/model/modules/norm.py new file mode 100644 index 00000000..b94cdd43 --- /dev/null +++ b/internlm/model/modules/norm.py @@ -0,0 +1,19 @@ +""" +layer norm modules +""" + +from typing import List, Union + +import torch +from torch import nn + +from internlm.model.ops.norm import RMSNorm + +Shape = Union[int, List[int], torch.Size] + + +def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5): + if norm_type == "rmsnorm": + return RMSNorm(normalized_shape, eps) + else: # default: layernorm + return nn.LayerNorm(normalized_shape, eps) diff --git a/internlm/model/modules/utils.py b/internlm/model/modules/utils.py new file mode 100644 index 00000000..dd86cb1c --- /dev/null +++ b/internlm/model/modules/utils.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn.functional as F +from einops import rearrange + +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +def is_moe_param(param: torch.Tensor) -> bool: + if hasattr(param, "is_expert") and param.is_expert: + return True + return False + + +def Silu(w1_o, w2_o): + return F.silu(w1_o) * w2_o + + +Silu = torch.jit.script(Silu) + + +def update_kv_cache(kv, inference_params, layer_idx): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + # Pre-allocate memory for key-values for inference. + num_heads, head_dim = kv.shape[-2:] + if layer_idx not in inference_params.key_value_memory_dict: + kv_cache = torch.empty( + inference_params.max_batch_size, + inference_params.max_sequence_len, + 2, + num_heads, + head_dim, + dtype=kv.dtype, + device=kv.device, + ) + inference_params.key_value_memory_dict[layer_idx] = kv_cache + else: + if not inference_params.fused_ft_kernel: + kv_cache = inference_params.key_value_memory_dict[layer_idx] + else: + # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) + # where packsize = 4 if fp32, 8 if fp16 or bf16. + # v_cache has shape (b, h, s, headdim) + k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] + kv_cache = None + # Adjust key and value for inference + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + kv.shape[1] + assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) + assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) + # Copy key and values. + if not inference_params.fused_ft_kernel: + assert kv_cache is not None + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + return kv + else: + assert inference_params.sequence_len_offset == 0 + # FT kernel requires different layouts for the k_cache and v_cache. + assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if kv.dtype == torch.float32 else 8 + if kv_cache is not None: + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + k_cache = rearrange( + kv_cache[:, :, 0], + "b s h (d packsize) -> b h d s packsize", + packsize=packsize, + ).contiguous() + v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous() + inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) + else: + k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( + kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize + ) + v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d") + return kv diff --git a/internlm/model/moe/__init__.py b/internlm/model/moe/__init__.py index 9ebcea66..e69de29b 100644 --- a/internlm/model/moe/__init__.py +++ b/internlm/model/moe/__init__.py @@ -1,28 +0,0 @@ -from .gshard_layer import GShardMOELayer -from .moe import MoE - -__all__ = ["MoE", "GShardMOELayer"] - - -try: - from megablocks import ops # noqa # pylint: disable=W0611 -except ModuleNotFoundError: - pass -else: - from internlm.model.moe.megablock.megablock_moe import ( # noqa # pylint: disable=W0611 - MegaBlockMoE, - ) - - __all__ += "MegaBlockMoE" - -try: - import stk # noqa # pylint: disable=W0611 - from megablocks import ops # noqa # pylint: disable=W0611 -except ModuleNotFoundError: - pass -else: - from internlm.model.moe.megablock.megablock_dmoe import ( # noqa # pylint: disable=W0611 - MegaBlockdMoE, - ) - - __all__ += "MegaBlockdMoE" diff --git a/internlm/model/moe/gshard_layer.py b/internlm/model/moe/gshard_layer.py index ee03d781..e84abc88 100644 --- a/internlm/model/moe/gshard_layer.py +++ b/internlm/model/moe/gshard_layer.py @@ -15,9 +15,9 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.modules.mlp import new_feed_forward from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.registry import MODEL_INITIALIZER from .base_layer import BaseMoELayer from .utils import all_to_all @@ -436,7 +436,6 @@ def forward( return gate_output -@MODEL_INITIALIZER.register_module(module_name="GShard") class GShardMOELayer(BaseMoELayer): """MOELayer module which implements MixtureOfExperts as described in Gshard_. :: @@ -461,7 +460,6 @@ def __init__( hidden_features: int, out_features: int, num_experts: int, - ep_cls: Optional[Callable], ep_group: Optional[torch.distributed.ProcessGroup], ep_size: int, top_k: int = 1, @@ -496,11 +494,10 @@ def __init__( ), torch.nn.ModuleList( [ - ep_cls( + new_feed_forward( in_features, hidden_features, out_features, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, diff --git a/internlm/model/moe/megablock/megablock_dmoe.py b/internlm/model/moe/megablock/megablock_dmoe.py index 88e7e806..f2cd6766 100644 --- a/internlm/model/moe/megablock/megablock_dmoe.py +++ b/internlm/model/moe/megablock/megablock_dmoe.py @@ -1,9 +1,7 @@ from typing import Optional, Tuple import numpy as np -import stk import torch -from megablocks import ops from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc @@ -11,10 +9,15 @@ from internlm.model.moe.megablock.megablock_moe import MegaBlockMoE from internlm.model.moe.megablock.mlp import MegaBlockGroupedFeedForward from internlm.model.moe.megablock.utils import promote_scalar -from internlm.utils.registry import MODEL_INITIALIZER + +try: + import stk + from megablocks import ops +except (ModuleNotFoundError, ImportError): + stk = None + megablocks = None -@MODEL_INITIALIZER.register_module(module_name="MegaBlock-D") class MegaBlockdMoE(MegaBlockMoE): """ Built on the paper and library Megablocks as described in @@ -111,7 +114,7 @@ def sparse_transpose( offsets_t = torch.cat([zero, nnz_per_column]) return column_indices_t, offsets_t, block_offsets_t - def topology(self, x: torch.Tensor, padded_bins: torch.Tensor) -> stk.Matrix: + def topology(self, x: torch.Tensor, padded_bins: torch.Tensor): padded_tokens, _ = x.size() assert padded_tokens % self.blocking == 0 assert self.ffn_dim_per_row % self.blocking == 0 diff --git a/internlm/model/moe/megablock/megablock_moe.py b/internlm/model/moe/megablock/megablock_moe.py index 202a5088..312e6c71 100644 --- a/internlm/model/moe/megablock/megablock_moe.py +++ b/internlm/model/moe/megablock/megablock_moe.py @@ -3,17 +3,19 @@ import numpy as np import torch import torch.nn.functional as F -from megablocks import ops from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.moe.base_layer import BaseMoELayer from internlm.model.moe.megablock.mlp import MegaBlockFeedForward from internlm.model.moe.utils import all_to_all -from internlm.utils.registry import MODEL_INITIALIZER + +try: + from megablocks import ops +except (ModuleNotFoundError, ImportError): + ops = None -@MODEL_INITIALIZER.register_module(module_name="MegaBlock") class MegaBlockMoE(BaseMoELayer): """ Built on the paper and library Megablocks as described in diff --git a/internlm/model/moe/megablock/mlp.py b/internlm/model/moe/megablock/mlp.py index 3ac8913b..9e911e8e 100644 --- a/internlm/model/moe/megablock/mlp.py +++ b/internlm/model/moe/megablock/mlp.py @@ -3,13 +3,13 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.modules.utils import Silu from internlm.model.moe.megablock.utils import ( act_fn, dsd_nn, sdd_nt, tensor_parallel_bmm, ) -from internlm.model.utils import Silu class MegaBlockFeedForward(nn.Module): diff --git a/internlm/model/moe/megablock/utils.py b/internlm/model/moe/megablock/utils.py index 2c890e01..857dd8b7 100644 --- a/internlm/model/moe/megablock/utils.py +++ b/internlm/model/moe/megablock/utils.py @@ -1,9 +1,7 @@ -import sys - import torch from internlm.accelerator import get_accelerator -from internlm.model.utils import Silu +from internlm.model.modules.utils import Silu try: import stk @@ -366,26 +364,3 @@ def act_fn(x1, x2, topo): ) return y - - -# check dependency -def check_megablock_installed(): - try: - from megablocks import ops # noqa # pylint: disable=W0611 - except ModuleNotFoundError: - print( - "MegaBlocks not found, please see " - "https://github.com/stanford-futuredata/megablocks/. " - "Note that MegaBlocks depends on mosaicml-turbo, which only " - "supports python 3.10.", - flush=True, - ) - sys.exit() - - -def check_stk_installed(): - try: - import stk # noqa # pylint: disable=W0611 - except ModuleNotFoundError: - print("STK not found: please see https://github.com/stanford-futuredata/stk", flush=True) - sys.exit() diff --git a/internlm/model/moe/moe.py b/internlm/model/moe/moe.py index 392cca89..304d8d0a 100644 --- a/internlm/model/moe/moe.py +++ b/internlm/model/moe/moe.py @@ -1,16 +1,29 @@ -from typing import Callable, Optional +from typing import Optional import torch -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.moe.gshard_layer import GShardMOELayer +from internlm.model.moe.megablock.megablock_dmoe import MegaBlockdMoE +from internlm.model.moe.megablock.megablock_moe import MegaBlockMoE from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER # global llm logger logger = get_logger(__file__) +def new_moe_layer(moe_type: str, **kwargs): + if moe_type == "GShard": + return GShardMOELayer(**kwargs) + elif moe_type == "MegaBlock": + return MegaBlockMoE(**kwargs) + elif moe_type == "MegaBlock-D": + return MegaBlockdMoE(**kwargs) + else: + raise ValueError(f"Unsupported model type: {moe_type}") + + class MoE(torch.nn.Module): """Initialize an MoE layer. @@ -38,7 +51,6 @@ def __init__( in_features: int, hidden_features: int, out_features: int, - ep_cls: Optional[Callable], ep_group: Optional[torch.distributed.ProcessGroup], num_experts: int = 1, ep_size=1, @@ -52,27 +64,26 @@ def __init__( if not hasattr(gpc.config, "moe"): gpc.config.moe = dict() - self.moe_layer = MODEL_INITIALIZER.get_module(module_name=gpc.config.model.moe_type)( + self.moe_layer = new_moe_layer( + moe_type=gpc.config.model.moe_type, in_features=in_features, hidden_features=hidden_features, out_features=out_features, num_experts=num_experts, - ep_cls=ep_cls, ep_group=ep_group, ep_size=ep_size, device=device, dtype=dtype, - **(gpc.config.moe) + **(gpc.config.moe), ) # residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence self.use_residual = use_residual if self.use_residual: - self.residual_mlp = ep_cls( + self.residual_mlp = new_feed_forward( in_features=in_features, hidden_features=hidden_features, out_features=out_features, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py new file mode 100644 index 00000000..9205652a --- /dev/null +++ b/internlm/model/ops/attention.py @@ -0,0 +1,847 @@ +""" +A simple operator selector, used for compatibility with different platforms such as CUDA and Ascend, +as well as whether to enable flash-attn operator optimization, may be replaced by a more comprehensive +operator compatibility layer in the future. + +This file implements support for the attention operators. +""" + +import math +from typing import Callable, Tuple + +import torch +from einops import rearrange, repeat +from torch import nn + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm.isp import auto_wrap_distributed_attention +from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn +from internlm.utils.common import get_current_device +from internlm.utils.utils import ( + CuSeqlenType, + QKVPackType, + check_attention_argument, + params_dispatch_with_condition, +) + +try: + from torch_npu import npu_fusion_attention as _origin_npu_fixedlen_qkvsplited_func + + is_torch_npu = True +except (ModuleNotFoundError, ImportError): + is_torch_npu = False + +try: + # TODO: add support of deeplink + from deeplink_ext.internevo_ops import FlashCrossAttention, FlashSelfAttention + + del FlashCrossAttention, FlashSelfAttention + + deeplink_flash_attn_impl = True +except (ModuleNotFoundError, ImportError): + deeplink_flash_attn_impl = False + +try: + from flash_attn.flash_attn_interface import ( + flash_attn_func as _flash_fixedlen_qkvsplited_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_kvpacked_func as _flash_fixedlen_kvpacked_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_qkvpacked_func as _flash_fixedlen_qkvpacked_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func as _flash_varlen_qkvsplited_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_kvpacked_func as _flash_varlen_kvpacked_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_qkvpacked_func as _flash_varlen_qkvpacked_func, + ) + + gpu_flash_attn_impl = True +except (ModuleNotFoundError, ImportError): + gpu_flash_attn_impl = False + +internlm_accelerator = get_accelerator() +device_backend = internlm_accelerator.get_accelerator_backend() + + +def _nyi_attn(func_name, *args, **kwargs): # pylint: disable=W0613 + assert False, f"{func_name} is not yet implemented" + + +# gpu flash attention operators + + +def _flash_float32_compatibility_wrapper(input_idxs: Tuple, flash_func: Callable, *args, **kwargs): + if gpc.config.model.dtype is torch.float32: + inputs = (args[idx] for idx in input_idxs) + input_dtype = inputs[0].dtype + other_args = [args[idx] for idx in range(len(inputs), len(args))] + + with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): + for idx in input_idxs: + if inputs[idx].dtype is torch.float32: + inputs[idx] = inputs[idx].to(torch.bfloat16) + return flash_func(*inputs, *other_args, **kwargs).to(input_dtype) + + return flash_func(*args, **kwargs) + + +def _flash_varlen_qkvpacked_attn( + qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False +): + # compatible data format: [1, packelen, 3, n_head, headim] + qkv = qkv.squeeze(dim=0) + + # input_idxs: 0: qkv + output = _flash_float32_compatibility_wrapper( + (0), _flash_varlen_qkvpacked_func, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal + ) + + return output.unsqueeze(dim=0) + + +def _flash_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout_p=0.0, softmax_scale=None, causal=False): + # input_idxs: 0: qkv + return _flash_float32_compatibility_wrapper( + (0), _flash_fixedlen_qkvpacked_func, qkv, dropout_p, softmax_scale, causal + ) + + +def _flash_varlen_kvpacked_attn( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, +): + # compatible data format: [1, packelen, 3, n_head, headim] + q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) + + # input_idxs: 0: q, 1: kv + output = _flash_float32_compatibility_wrapper( + (0, 1), + _flash_varlen_kvpacked_func, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + ) + + return output.unsqueeze(dim=0) + + +def _flash_fixedlen_kvpacked_attn(q: torch.Tensor, kv: torch.Tensor, dropout_p=0.0, softmax_scale=None, causal=False): + # input_idxs: 0: q, 1: kv + return _flash_float32_compatibility_wrapper( + (0, 1), _flash_fixedlen_kvpacked_func, q, kv, dropout_p, softmax_scale, causal + ) + + +def _flash_varlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, +): + # compatible data format: [1, packelen, 3, n_head, headim] + q, k, v = q.squeeze(dim=0), k.squeeze(dim=0), v.squeeze(dim=0) + + # input_idxs: 0: q, 1: k, 2: v + output = _flash_float32_compatibility_wrapper( + (0, 1, 2), + _flash_varlen_qkvsplited_func, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + ) + + return output.unsqueeze(dim=0) + + +def _flash_fixedlen_qkvsplited_attn(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): + # input_idxs: 0: q, 1: k, 2: v + return _flash_float32_compatibility_wrapper( + (0, 1, 2), _flash_fixedlen_qkvsplited_func, q, k, v, dropout_p, softmax_scale, causal + ) + + +# npu flash attention operators +# TODO: should we add _flash_float32_compatibility_wrapper support for npu. + + +def _npu_varlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, # pylint: disable=W0613 + max_seqlen_k, # pylint: disable=W0613 + dropout_p=0.0, + softmax_scale=None, + causal=False, +): + # TODO: support npu native varlen flash attention + packed_length = q.size(dim=1) + + q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) + k = unpack_qkv_before_attn(k, cu_seqlens=cu_seqlens_k) + v = unpack_qkv_before_attn(v, cu_seqlens=cu_seqlens_k) + + output = _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal) + + return pack_output_after_attn(output, cu_seqlens_q, packed_length) + + +def _npu_fixedlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float, + softmax_scale=None, + causal=False, +): + assert causal is True + assert q.dtype in (torch.bfloat16, torch.float16) + + if len(q.shape) == 5: # [batch, seqlen, 1, n_head, headdim] + q, k, v = q.squeeze(dim=2), k.squeeze(dim=2), v.squeeze(dim=2) + + _, seqlen, n_head, _ = q.shape + attention_mask = torch.triu(torch.ones(seqlen, seqlen, device=get_current_device()), 1).bool() + + return _origin_npu_fixedlen_qkvsplited_func( + query=q, + key=k, + value=v, + head_num=n_head, + input_layout="BSND", # If necessary, expose the interface + pse=None, + atten_mask=attention_mask, + scale=softmax_scale, + sparse_mode=0, # If necessary, expose the interface + pre_tockens=seqlen, # Used for sparse calculations, representing the left boundary of the slides window + next_tockens=0, # If necessary, expose the interface + keep_prob=1 - dropout_p, + inner_precise=0, # If necessary, expose the interface + ) + + +def _npu_varlen_qkvpacked_attn( + qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613 +): + # TODO: support npu native varlen flash attention + packed_length = qkv.size(dim=1) + + qkv = unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens) + + output = _npu_fixedlen_qkvpacked_attn(qkv, dropout_p, softmax_scale, causal) + + return pack_output_after_attn(output, cu_seqlens, packed_length) + + +def _npu_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False): + q, k, v = qkv.unbind(dim=2) + return _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal) + + +def _npu_varlen_kvpacked_attn( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, # pylint: disable=W0613 + max_seqlen_k, # pylint: disable=W0613 + dropout_p=0.0, + softmax_scale=None, + causal=False, +): + # TODO: support npu native varlen flash attention + packed_length = q.size(dim=1) + + q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) + kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k) + + output = _npu_fixedlen_kvpacked_attn(q, kv, dropout_p, softmax_scale, causal) + + return pack_output_after_attn(output, cu_seqlens_q, packed_length) + + +def _npu_fixedlen_kvpacked_attn(q: torch.Tensor, kv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False): + k, v = kv.unbind(dim=2) + k, v = k.squeeze(dim=2), v.squeeze(dim=2) + return _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal) + + +# deeplink flash attention operators + + +def _deeplink_varlen_qkvpacked_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_varlen_qkvpacked_attn", *args, **kwargs) + + +def _deeplink_fixedlne_qkvpacked_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_fixedlne_qkvpacked_attn", *args, **kwargs) + + +def _deeplink_varlen_kvpacked_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_varlen_kvpacked_attn", *args, **kwargs) + + +def _deeplink_fixedlen_kvpacked_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_fixedlen_kvpacked_attn", *args, **kwargs) + + +def _deeplink_varlen_qkvsplited_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_varlen_qkvsplited_attn", *args, **kwargs) + + +def _deeplink_fixedlen_qkvsplited_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_fixedlen_qkvsplited_attn", *args, **kwargs) + + +# torch attention operators + + +def _torch_varlen_qkvpacked_attn(*args, **kwargs): + _nyi_attn("_torch_varlen_qkvpacked_attn", *args, **kwargs) + + +# adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py +def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None): + batch_size, seqlen = qkv.shape[0], qkv.shape[1] + q, k, v = qkv.unbind(dim=2) + + softmax_scale = softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + + if key_padding_mask is not None: + padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = dropout(attention) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + + return output + + +def _torch_varlen_kvpacked_attn(*args, **kwargs): + _nyi_attn("_torch_varlen_kvpacked_attn", *args, **kwargs) + + +# adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py +def _torch_fixedlen_kvpacked_attn( + q: torch.Tensor, kv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None +): + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = kv.shape[1] + + assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] + if kv.shape[3] != q.shape[2]: # MQA/GQA + kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) + k, v = kv.unbind(dim=2) + softmax_scale = softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + + if causal: + # causal mask needs to take into account the difference between seqlen_q and seqlen_k + row_idx = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + causal_mask = col_idx > row_idx + sk - seqlen_q + scores = scores.masked_fill(causal_mask, -10000.0) + + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = dropout(attention) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + + return output + + +def _torch_varlen_qkvsplited_attn(*args, **kwargs): + _nyi_attn("_torch_varlen_qkvsplited_attn", *args, **kwargs) + + +def _torch_fixedlen_qkvsplited_attn( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None +): + kv = torch.stack([k, v], dim=2) + return _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask) + + +@auto_wrap_distributed_attention +class SelfAttention(nn.Module): + """Implements scaled dot-product attention with optional softmax scaling. + + This class implements the scaled dot-product attention mechanism, which can be optionally scaled + by a softmax scaling factor. It supports configurations for causal attention and applies dropout + to the attention scores. + + Arguments: + causal (bool): If True, applies causal attention to mask future tokens. Defaults to False. + softmax_scale (Optional[float]): Scaling factor for attention scores before applying softmax. + Defaults to 1/sqrt(d_keys) where d_keys is the dimension of the keys, computed at runtime. + attention_dropout (float): Dropout rate for attention scores. Defaults to 0.0. + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = nn.Dropout(attention_dropout) + + if device_backend == AcceleratorType.NPU: + assert self.causal, "Ascend flash attention does not spport causal=False yet!" + + @params_dispatch_with_condition(condition=check_attention_argument) + def forward(self): + """Placeholder for multihead softmax attention implementation. + + This method serves as a placeholder and should not be reached during execution. It is expected + to be overridden by specific implementations for different attention mechanisms. + + Raises: + AssertionError: Always raised to indicate the method should not be called directly. + """ + assert False, "Never arrive here" + + @forward.register(conditions=(str(QKVPackType.QKVPACKED), str(CuSeqlenType.WithOut))) + def _qkv_without_cu_seqlens(self, qkv, softmax_scale=None, causal=None, key_padding_mask=None): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_fixedlen_qkvpacked_attn(qkv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_fixedlen_qkvpacked_attn(qkv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_fixedlne_qkvpacked_attn(qkv, self.dropout.p, softmax_scale, causal) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_fixedlen_qkvpacked_attn(qkv, self.dropout, softmax_scale, causal, key_padding_mask) + + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.WithOut))) + def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_padding_mask=None): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_fixedlen_kvpacked_attn(q, kv, self.dropout, softmax_scale, causal, key_padding_mask) + + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.WithOut))) + def _q_k_v_without_cu_seqlens(self, q, k, v, softmax_scale=None, causal=None, key_padding_mask=None): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_fixedlen_qkvsplited_attn(q, k, v, self.dropout, softmax_scale, causal, key_padding_mask) + + @forward.register(conditions=(str(QKVPackType.QKVPACKED), str(CuSeqlenType.With))) + def _qkv_with_cu_seqlens( + self, + qkv, + cu_seqlens, + max_seqlen, + softmax_scale=None, + causal=None, + key_padding_mask=None, + ): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_varlen_qkvpacked_attn(qkv, cu_seqlens, max_seqlen, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_varlen_qkvpacked_attn(qkv, cu_seqlens, max_seqlen, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_varlen_qkvpacked_attn( + qkv, cu_seqlens, max_seqlen, self.dropout.p, softmax_scale, causal + ) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_varlen_qkvpacked_attn( + qkv, cu_seqlens, max_seqlen, self.dropout, softmax_scale, causal, key_padding_mask + ) + + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.With))) + def _q_kv_with_cu_seqlens( + self, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=None, + key_padding_mask=None, + ): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_varlen_kvpacked_attn( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout, + softmax_scale, + causal, + key_padding_mask, + ) + + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.With))) + def _q_k_v_with_cu_seqlens( + self, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=None, + key_padding_mask=None, + ): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout, + softmax_scale, + causal, + key_padding_mask, + ) + + +@auto_wrap_distributed_attention +class CrossAttention(nn.Module): + """Implements scaled dot product attention with softmax. + + This class provides the functionality for cross attention mechanism using scaled dot product attention + with optional softmax scaling and dropout for attention weights. + + Arguments: + causal (bool): If True, applies causality to prevent tokens from attending to future tokens. Default is False. + softmax_scale (float, optional): The scaling factor to apply to the dot products before softmax. If None, + it defaults to 1/sqrt(d_keys) where d_keys is the dimension of the keys, computed at runtime. + attention_dropout (float): The dropout rate to apply to the attention. + + Raises: + AssertionError: If `device_backend` is NPU and `causal` is False, since Ascend flash attention does not + support non-causal attention yet. + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = nn.Dropout(attention_dropout) + + if device_backend == AcceleratorType.NPU: + assert self.causal, "Ascend flash attention does not support causal=False yet!" + + @params_dispatch_with_condition(condition=check_attention_argument) + def forward(self): + """Placeholder for cross attention implementation. + + This method is a placeholder and should not be reached in execution as it is expected to be + overridden by specific implementations for different attention parameters. + + Raises: + AssertionError: Always raised to indicate the method should not be called directly. + """ + assert False, "Never arrive here" + + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.WithOut))) + def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_padding_mask=None): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_fixedlen_kvpacked_attn(q, kv, self.dropout, softmax_scale, causal, key_padding_mask) + + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.WithOut))) + def _q_k_v_without_cu_seqlens(self, q, k, v, softmax_scale=None, causal=None, key_padding_mask=None): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_fixedlen_qkvsplited_attn(q, k, v, self.dropout, softmax_scale, causal, key_padding_mask) + + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.With))) + def _q_kv_with_cu_seqlens( + self, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=None, + key_padding_mask=None, + ): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_varlen_kvpacked_attn( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout, + softmax_scale, + causal, + key_padding_mask, + ) + + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.With))) + def _q_k_v_with_cu_seqlens( + self, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=None, + key_padding_mask=None, + ): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout, + softmax_scale, + causal, + key_padding_mask, + ) diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py new file mode 100644 index 00000000..f3fdccf9 --- /dev/null +++ b/internlm/model/ops/cross_entropy.py @@ -0,0 +1,60 @@ +""" +A simple operator selector, used for compatibility with different platforms such as CUDA and Ascend, +as well as whether to enable flash-attn operator optimization, may be replaced by a more comprehensive +operator compatibility layer in the future. + +This file implements support for the cross entropy operators. +""" + +from torch import nn + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +try: + from flash_attn.losses.cross_entropy import ( + CrossEntropyLoss as FlashCrossEntropyLoss, + ) + + flash_cross_entropy_impl = True +except (ModuleNotFoundError, ImportError): + flash_cross_entropy_impl = False + +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + + +# TODO: ops是否需要实现更加统一的形式 +def new_cross_entropy( + ignore_index: int = -100, + reduction: str = "mean", + label_smoothing: float = 0, + parallel_output: bool = False, + **kwargs, +): + if parallel_output: + assert ( + gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl + ), "Only flash cross entropy support parallel_output" + assert ( + internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU + ), "flash cross entropy only support gpu backend" + + return FlashCrossEntropyLoss( + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + process_group=gpc.get_group(ParallelMode.TENSOR), + ) + else: + if gpc.is_rank_for_log(): + logger.warning( + "Use nn.CrossEntropyLoss rather than flashattn CrossEntropyLoss." + "parallel_output must be set false. Please note this!" + ) + kwargs.pop("inplace_backward", None) + return nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction=reduction, label_smoothing=label_smoothing, **kwargs + ) diff --git a/internlm/model/ops/fusion_ops_import_helper.py b/internlm/model/ops/fusion_ops_import_helper.py deleted file mode 100644 index f75ff889..00000000 --- a/internlm/model/ops/fusion_ops_import_helper.py +++ /dev/null @@ -1,211 +0,0 @@ -from typing import Callable, Tuple, Union - -import torch -from torch import nn - -from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - -internlm_accelerator = get_accelerator() - - -# RMSNorm -def try_import_RMSNorm(): - """ - Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm - - """ - try: - device_backend = internlm_accelerator.get_accelerator_backend() - if device_backend == AcceleratorType.DIPU: - from deeplink_ext.internevo_ops import MixedFusedRMSNorm as RMSNorm - - if gpc.is_rank_for_log(): - logger.warning("Use Deeplink MixedFusedRMSNorm, Please note this!") - - return RMSNorm - else: - from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm - - if gpc.is_rank_for_log(): - logger.warning("Use apex MixedFusedRMSNorm, Please note this!") - - return RMSNorm - except (ModuleNotFoundError, ImportError): - if gpc.is_rank_for_log(): - logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") - from internlm.model.ops.norm import RMSNormTorch as RMSNorm - - return RMSNorm - - -# RotaryEmb -def try_import_fused_rotary() -> Tuple[Union[None, Callable], Union[None, Callable], Union[None, Callable]]: - """try_import_fused_rotary - - Returns: - Tuple[Union[None, Callable], Union[None, Callable], Union[None, Callable]]: - Returns if there is a mixing operator available, otherwise returns None. - """ - try: - device_backend = internlm_accelerator.get_accelerator_backend() - if device_backend is AcceleratorType.GPU: - import rotary_emb - - if gpc.is_rank_for_log(): - logger.warning("Use flash_attn rotary_emb, Please note this!") - - return None, None, rotary_emb.apply_rotary - elif device_backend is AcceleratorType.DIPU: - from deeplink_ext.internevo_ops import ( - ApplyRotaryEmb as DeeplinkApplyRotaryEmb, - ) - from deeplink_ext.internevo_ops import ( - ApplyRotaryEmbQKV_ as DeeplinkApplyRotaryEmbQKV_, - ) - - if gpc.is_rank_for_log(): - logger.warning("Use Deeplink ApplyRotaryEmb, Please note this!") - - return DeeplinkApplyRotaryEmb.apply, DeeplinkApplyRotaryEmbQKV_.apply, None - - except (ModuleNotFoundError, ImportError): - pass - - if gpc.is_rank_for_log(): - logger.warning( - "The torch implementation for apply_rotary is slower" "than flash atten rotary_emb. Please note this!" - ) - return None, None, None - - -# CrossEntropyLoss -def internlm_init_CrossEntropyLoss( - parallel_output: bool, reduction="none", label_smoothing=0, inplace_backward=True, process_group=None, **kwargs -): - """ - Try import FlashCrossEntropyLoss from flash_attn, if failed, return our CrossEntropyLoss - - """ - if parallel_output: - try: - if internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU: - from flash_attn.losses.cross_entropy import ( - CrossEntropyLoss as FlashCrossEntropyLoss, - ) - - if process_group is None: - gpc.get_group(ParallelMode.TENSOR) - - if gpc.is_rank_for_log(): - logger.warning("Use flash_attn FlashCrossEntropyLoss, Please note this!") - - return FlashCrossEntropyLoss( - reduction=reduction, - inplace_backward=inplace_backward, - process_group=process_group, - label_smoothing=label_smoothing, - **kwargs, - ) - except (ModuleNotFoundError, ImportError): - pass - - if gpc.is_rank_for_log(): - logger.warning( - "Use nn.CrossEntropyLoss rather than CrossEntropyLoss." - "parallel_output must be set false. Please note this!" - ) - - if "process_group" in kwargs: - kwargs.pop("process_group") - if "inplace_backward" in kwargs: - kwargs.pop("inplace_backward") - - return nn.CrossEntropyLoss(reduction=reduction, label_smoothing=label_smoothing, **kwargs) - - -# Adamw -def try_import_FusedAdamW(): - """ - Try import FusedAdamW from torch_npu/torch - - """ - adam_extra_kwargs = {} - backend = internlm_accelerator.get_accelerator_backend() - try: - if backend is AcceleratorType.GPU: - if torch.__version__ >= "2.1.0": - adam_extra_kwargs["fused"] = True - - if gpc.is_rank_for_log(): - logger.warning( - "Use fused AdamaW to avoid nan grad norm when " - "model size is larger and use_fp32_norm=True, Please note this!" - ) - return adam_extra_kwargs, torch.optim.AdamW - elif backend is AcceleratorType.NPU: - - if gpc.is_rank_for_log(): - logger.warning( - "Use normal AdamaW, NPU fused_adamw currently has" - "accuracy issues and is not supported yet. Please note this!" - ) - # return adam_extra_kwargs, torch_npu.optim.NpuFusedAdamW - except (ModuleNotFoundError, ImportError): - pass - - if gpc.is_rank_for_log(): - logger.warning("Use torch.optim.AdamW rather than FusedAdamW. Please note this!") - return adam_extra_kwargs, torch.optim.AdamW - - -# scatter_sum -def try_import_scatter_sum(): - """ - Try import scatter_sum from cuda, if failed, return None - - """ - try: - if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU]: - from torch_scatter import scatter as cuda_scatter - - if gpc.is_rank_for_log(): - logger.warning("Use cuda_scatter. Please note this!") - - return cuda_scatter - - except (ModuleNotFoundError, ImportError): - pass - - if gpc.is_rank_for_log(): - logger.warning("Use vanilla_scatter rather than cuda_scatter. Please note this!") - - return None - - -# FlashAttn -def try_import_linear_bias_wgrad(): - """ - Try import linear_bias_wgrad from flash_attn, if failed, return None - - """ - try: - if internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU: - import fused_dense_lib as fused_dense_cuda - - if gpc.is_rank_for_log(): - logger.warning("Use flash_attn linear_bias_wgrad. Please note this!") - - return fused_dense_cuda.linear_bias_wgrad - - except (ModuleNotFoundError, ImportError): - pass - - if gpc.is_rank_for_log(): - logger.warning("Use linear_bias_wgrad_torch. Please note this!") - - return None diff --git a/internlm/model/ops/linear.py b/internlm/model/ops/linear.py index 6afd1e61..eeffddc0 100644 --- a/internlm/model/ops/linear.py +++ b/internlm/model/ops/linear.py @@ -1,396 +1,63 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- +""" +A simple operator selector, used for compatibility with different platforms such as CUDA and Ascend, +as well as whether to enable flash-attn operator optimization, may be replaced by a more comprehensive +operator compatibility layer in the future. -from typing import Optional +This file implements support for the linear layer operators. +""" + +from typing import Optional, Tuple import torch -from torch import nn -from torch.distributed import ProcessGroup +from torch.nn.functional import linear as _torch_linear_forward_op -from internlm.core.context import ParallelMode +from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import global_context as gpc -from internlm.model.utils import ( - all_reduce, - fused_dense_func, - isp_fused_dense_func, - megatron_fused_dense_func, - reduce_scatter, -) -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - - -class BaseScaleColumnParallelLinear(nn.Linear): - """ - Base class for ScaleColumnParallelLinear. - - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. - """ - - def __init__( - self, - in_features: int, - out_features: int, - process_group: Optional[torch.distributed.ProcessGroup], - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - weight_scale: int = 1, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - if out_features % world_size != 0: - raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})") - super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype) - self.process_group = process_group - self.weight_scale = weight_scale - - -class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): - """ - ScaleColumnParallelLinear in flash implementation. - """ - - def forward(self, input, gather_dim=1, tp_mode: str = "mtp"): # pylint: disable=W0622 - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - if self.weight_scale != 1: - weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() - else: - weight = self.weight - - _fused_func = fused_dense_func if tp_mode in ["mtp", "fsp", "isp"] else megatron_fused_dense_func - return _fused_func( - input, - weight, - self.bias, - process_group=self.process_group, - sequence_parallel=gpc.config.parallel.sequence_parallel, - gather_dim=gather_dim, - ) - - -class ScaleColumnParallelLinearWithNormHead(BaseScaleColumnParallelLinear): - """ - ScaleColumnParallelLinear for InternLM2. - - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. - norm_head (bool): Normalize the output embedding in order to let the calculation of logits not affected by - the norm of embedding. The implementation is referred to baichuan2, - see https://huggingface.co/baichuan-inc/Baichuan2-7B-Base for more information. False by default. - """ - - def __init__( - self, - in_features: int, - out_features: int, - process_group: Optional[torch.distributed.ProcessGroup], - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - weight_scale: int = 1, - norm_head: bool = False, - ) -> None: - super().__init__( - in_features, out_features, process_group, bias=bias, device=device, dtype=dtype, weight_scale=weight_scale - ) - - self.norm_head = norm_head - if self.norm_head: - logger.info("Notice that norm head is enabled to normalize head weight.") - self.first_eval_flag = True - self.tmp_weight = None - - def forward(self, input, gather_dim=1, tp_mode: str = "mtp"): # pylint: disable=W0622 - if self.weight_scale != 1: - weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() - else: - weight = self.weight - if self.norm_head: - if self.training: - if not self.first_eval_flag: - self.first_eval_flag = True - self.tmp_weight = None - # We normalized the output Embedding so that the dot product - # is not affected by the norm of embedding. Ref: https://arxiv.org/pdf/2309.10305.pdf - weight = nn.functional.normalize(weight) - else: - if self.first_eval_flag: - # cache l2 norm of head to accelerate infer. - self.first_eval_flag = False - self.tmp_weight = nn.functional.normalize(weight) - - weight = self.tmp_weight - - _fused_func = fused_dense_func if tp_mode in ["mtp", "fsp", "isp"] else megatron_fused_dense_func - return _fused_func( - input, - weight, - self.bias, - process_group=self.process_group, - sequence_parallel=gpc.config.parallel.sequence_parallel, - gather_dim=gather_dim, - ) - - -class RewardModelLinear(BaseScaleColumnParallelLinear): - """ - RewardModelLinear. - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. - """ - - def __init__( - self, - in_features: int, - out_features: int, - process_group: Optional[torch.distributed.ProcessGroup], - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - weight_scale: int = 1, - ) -> None: - super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale) - torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) - if bias: - torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) - - def forward(self, input): # pylint: disable=W0622 - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - if self.weight_scale != 1: - weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() - else: - weight = self.weight - return fused_dense_func( - input, - weight, - self.bias, - process_group=self.process_group, - sequence_parallel=gpc.config.parallel.sequence_parallel, - ) - - -class ColumnParallelLinearTorch(nn.Linear): - """ - ColumnParallelLinearTorch. - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. - """ - - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - if out_features % multiple_of: - raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") - multiple = out_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - - def forward(self, x, gather_dim=1): - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - return fused_dense_func( - x, - self.weight, - self.bias, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - gather_dim=gather_dim, - ) - - -class MegatronColumnParallelLinearTorch(ColumnParallelLinearTorch): - """ - MegatronColumnParallelLinearTorch - """ - def forward(self, x, gather_dim=1): - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - return megatron_fused_dense_func( - x, - self.weight, - self.bias, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - gather_dim=gather_dim, - ) +try: + from fused_dense_lib import linear_bias_wgrad as _flash_linear_backward_op + flash_attn_impl = True +except (ModuleNotFoundError, ImportError): + flash_attn_impl = False -class RowParallelLinearTorch(nn.Linear): - """ - RowParallelLinearTorch. - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. - """ +internlm_accelerator = get_accelerator() - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - rank = torch.distributed.get_rank(process_group) - if in_features % multiple_of: - raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") - multiple = in_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - # Only rank 0 will have bias - super().__init__( - local_multiple * multiple_of, - out_features, - bias=bias and rank == 0, - device=device, - dtype=dtype, - ) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - def forward(self, x, reduce_dim=1): - """ - We're doing Tensor Parallel with sequence parallelism: we do the matmul and then - a reduce_scatter of the result. - """ - out = fused_dense_func(x, self.weight, self.bias) - if self.sequence_parallel: - return reduce_scatter(out, self.process_group, reduce_dim) - else: - return all_reduce(out, self.process_group) +def _select_ops_binding(dtype: torch.dtype, is_cuda: bool = True) -> None: + dtype_eligible = dtype in (torch.float16, torch.bfloat16) or ( + dtype == torch.float32 and torch.is_autocast_enabled() + ) + use_flash_attn = gpc.config.model.get("use_flash_attn", False) + is_gpu_backend = internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU + flash_attn_eligible = flash_attn_impl and dtype_eligible and is_cuda + if use_flash_attn and is_gpu_backend and flash_attn_eligible: + return _torch_linear_forward_op, _flash_linear_backward_op + else: + return _torch_linear_forward_op, _linear_bias_wgrad_torch -class MegatronRowParallelLinearTorch(RowParallelLinearTorch): - """ - MegatronRowParallelLinearTorch. - """ - def forward(self, x, reduce_dim=1): - """ - We're doing Tensor Parallel with sequence parallelism: we do the matmul and then - a reduce_scatter of the result. - """ - out = megatron_fused_dense_func(x, self.weight, self.bias) - if self.sequence_parallel: - return reduce_scatter(out, self.process_group, reduce_dim) - else: - return all_reduce(out, self.process_group) +def _linear_bias_wgrad_torch(_input: torch.Tensor, grad_output: torch.Tensor, has_d_bias: bool): + assert _input.dtype == grad_output.dtype + grad_weight = torch.matmul(grad_output.t(), _input) + grad_bias = grad_output.sum(dim=0) if has_d_bias else None -class ISPLinear(ColumnParallelLinearTorch): - """ - Linear class for isp tensor parallel mode. - """ + return grad_weight, grad_bias - # class level communicator variable. - __communicator = None - @staticmethod - def register_communicator(communicator): - ISPLinear.__communicator = communicator +def linear_forward_op(_input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + _is_cuda = internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU + _forward_op, _ = _select_ops_binding(_input.dtype, _is_cuda) - def forward(self, x): - assert self.__communicator is not None, "ISPLinear should be register with a communicator first." + return _forward_op(_input, weight, bias) - return isp_fused_dense_func( - x, - self.weight, - module=self, - communicator=self.__communicator, - bias=self.bias, - ) +def linear_backward_op( + _input: torch.Tensor, weight: torch.Tensor, has_d_bias: bool +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + _is_cuda = internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU + _, _backward_op = _select_ops_binding(_input.dtype, _is_cuda) -def get_linear_cls(tp_mode: str, parallel_mode: str): - if parallel_mode == "column": - if tp_mode in ["mtp", "fsp"]: - cls = ColumnParallelLinearTorch - elif tp_mode == "msp": - cls = MegatronColumnParallelLinearTorch - else: - cls = ISPLinear - elif parallel_mode == "row": - if tp_mode in ["mtp", "fsp"]: - cls = RowParallelLinearTorch - elif tp_mode == "msp": - cls = MegatronRowParallelLinearTorch - else: - cls = ISPLinear - return cls + return _backward_op(_input, weight, has_d_bias) diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index 6598e178..8ade10ca 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -6,8 +6,29 @@ from torch.nn import init from torch.nn.parameter import Parameter +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.utils.logger import get_logger -def manual_rms_norm(my_input, normalized_shape, weight, eps): +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + +try: + from apex.normalization.fused_layer_norm import mixed_dtype_fused_rms_norm_affine + + apex_rmsnorm_impl = True +except (ModuleNotFoundError, ImportError): + logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") + apex_rmsnorm_impl = False + +try: + from deeplink_ext.internevo_ops import MixedFusedRMSNorm + + deeplink_rmsnorm_impl = True +except (ModuleNotFoundError, ImportError): + deeplink_rmsnorm_impl = False + + +def manual_rms_norm(my_input, weight, normalized_shape, eps): # layer norm should always be calculated in float32 dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True) @@ -23,8 +44,8 @@ def manual_rms_norm(my_input, normalized_shape, weight, eps): return weight * my_input -class RMSNormTorch(torch.nn.Module): - """A custom PyTorch module for RMS normalization.""" +class _RMSNorm(torch.nn.Module): + """A generic module for RMS normalization.""" def __init__(self, normalized_shape, eps=1e-5): super().__init__() @@ -37,10 +58,23 @@ def __init__(self, normalized_shape, eps=1e-5): self.reset_parameters() def forward(self, _input: torch.Tensor): - return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps) + if apex_rmsnorm_impl: + _norm_func = mixed_dtype_fused_rms_norm_affine + else: + _norm_func = manual_rms_norm + + return _norm_func(_input, self.weight, self.normalized_shape, self.eps) def reset_parameters(self): init.ones_(self.weight) def extra_repr(self): - return "{normalized_shape}, eps={eps}, ".format(**self.__dict__) + return f"{self.normalized_shape}, eps={self.eps}, " + + +# TODO: Support deeplink in a more unified manner +RMSNorm = ( + MixedFusedRMSNorm + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU and deeplink_rmsnorm_impl + else _RMSNorm +) diff --git a/internlm/model/ops/rotary_emb.py b/internlm/model/ops/rotary_emb.py new file mode 100644 index 00000000..58f142c7 --- /dev/null +++ b/internlm/model/ops/rotary_emb.py @@ -0,0 +1,158 @@ +""" +A simple operator selector, used for compatibility with different platforms such as CUDA and Ascend, +as well as whether to enable flash-attn operator optimization, may be replaced by a more comprehensive +operator compatibility layer in the future. + +This file implements support for the roatry embedding operators. +""" + +import torch +from einops import rearrange + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.core.context import global_context as gpc + +try: + from rotary_emb import apply_rotary as _flash_apply_rotary_func + + flash_rotary_impl = True +except (ModuleNotFoundError, ImportError): + flash_rotary_impl = False + +try: + from deeplink_ext.internlm_ops import ApplyRotaryEmb as DeeplinkApplyRotaryEmb + + deeplink_rotary_impl = True +except (ModuleNotFoundError, ImportError): + deeplink_rotary_impl = False + +internlm_accelerator = get_accelerator() + + +def _torch_apply_rotary_func( + x1: torch.Tensor, + x2: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, + conj: bool = False, +): + # TODO: improve perfermance. + assert x1.device == x2.device == cos.device == sin.device, "All inputs must be on the same device" + assert x1.dtype == x2.dtype == cos.dtype == sin.dtype, "All inputs must have the same dtype" + assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes" + assert cos.size() == sin.size(), "Input cos and sin must have the same sizes" + + x1, x2, cos, sin = x1.float(), x2.float(), cos.float(), sin.float() + + if conj: + out1.copy_(x1 * cos + x2 * sin) + out2.copy_(-x1 * sin + x2 * cos) + else: + out1.copy_(x1 * cos - x2 * sin) + out2.copy_(x1 * sin + x2 * cos) + + return out1, out2 + + +def _select_apply_rotary_func( + x1: torch.Tensor, + x2: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, + conj: bool = False, +): + if gpc.config.model.get("use_flash_attn", False) and flash_rotary_impl: + _flash_apply_rotary_func(x1, x2, cos, sin, out1, out2, conj) + else: + _torch_apply_rotary_func(x1, x2, cos, sin, out1, out2, conj) + + +# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 +class ApplyRotaryEmb(torch.autograd.Function): + """ + ApplyRotaryEmb + """ + + @staticmethod + def forward( + ctx, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False, in_place: bool = False + ): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + *_, seqlen, _, head_dim = x.shape + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + + assert rotary_dim <= head_dim + assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) + + x_ro = x[..., :rotary_dim] + x1, x2 = (x_ro[..., ::2], x_ro[..., 1::2]) if interleaved else x_ro.chunk(2, dim=-1) + + if in_place: + out, o1, o2 = x, x1, x2 + else: + out = torch.empty_like(x) + out_ro = out[..., :rotary_dim] + o1, o2 = (out_ro[..., ::2], out_ro[..., 1::2]) if interleaved else out_ro.chunk(2, dim=-1) + + _select_apply_rotary_func( + x1, x2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), o1, o2, False + ) + + if rotary_dim < head_dim and not in_place: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved + ctx.in_place = in_place + + return out + + @staticmethod + def backward(ctx, do): + cos, sin = ctx.saved_tensors + *_, seqlen, _, head_dim = do.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + + do_ro = do[..., :rotary_dim] + do1, do2 = (do_ro[..., ::2], do_ro[..., 1::2]) if ctx.interleaved else do_ro.chunk(2, dim=-1) + + if ctx.in_place: + dx, dx1, dx2 = do, do1, do2 + else: + dx = torch.empty_like(do) + dx_ro = dx[..., :rotary_dim] + dx1, dx2 = (dx_ro[..., ::2], dx_ro[..., 1::2]) if ctx.interleaved else dx_ro.chunk(2, dim=-1) + + _select_apply_rotary_func( + do1, do2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), dx1, dx2, True + ) + + if rotary_dim < head_dim and not ctx.in_place: + dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) + + return dx, None, None, None, None + + +def apply_rotary_emb( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False, in_place: bool = False +): + # TODO: Support deeplink in a more unified manner + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU: + # TODO: to support in_place argument + return DeeplinkApplyRotaryEmb.apply(x, cos, sin, interleaved) + else: + return ApplyRotaryEmb.apply(x, cos, sin, interleaved, in_place) diff --git a/internlm/model/ops/utils.py b/internlm/model/ops/utils.py new file mode 100644 index 00000000..04d068cd --- /dev/null +++ b/internlm/model/ops/utils.py @@ -0,0 +1,48 @@ +""" +Some hepler functions for ops package. +""" + +import torch +from torch.nn.utils.rnn import pad_sequence + + +def unpack_qkv_before_attn(cur_input: torch.Tensor, cu_seqlens: torch.Tensor, padding_v: int = 0): + """ + qkv: the shape is (1, packed_length, three, head_num, head_dim) + kv: the shape is (1, packed_length, two, head_num, head_dim) + q/k/v: the shape is (1, packed_length, head_num, head_dim) + + Return: + output: the shape is (micro_bsz, seq_len, three, head_num, head_dim) for qkv + (micro_bsz, seq_len, two, head_num, head_dim) for kv + (micro_bsz, seq_len, head_num, head_dim) for q/k/v + """ + assert cur_input.shape[0] == 1 + cur_input = cur_input.squeeze(0) + + sequences = [] + for i in range(len(cu_seqlens) - 1): + sequences.append(cur_input[cu_seqlens[i] : cu_seqlens[i + 1]]) + + padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_v) + + return padded_sequences + + +def pack_output_after_attn(cur_input: torch.Tensor, cu_seqlens: torch.Tensor, packed_length: int, padding_v: int = 0): + """ + cur_input: the shape is (micro_bsz, seq_len, head_num, head_dim) + + Return: + output: the shape is (1, packed_length, head_num, head_dim) + """ + output_shape = list(cur_input.shape) + output_shape[0] = 1 + output_shape[1] = packed_length + + output = torch.full(output_shape, fill_value=padding_v, device=cur_input.device, dtype=cur_input.dtype) + for i in range(len(cu_seqlens) - 1): + length = cu_seqlens[i + 1] - cu_seqlens[i] + output[0, cu_seqlens[i] : cu_seqlens[i + 1]] = cur_input[i, 0:length] + + return output diff --git a/internlm/utils/registry.py b/internlm/model/registry.py similarity index 71% rename from internlm/utils/registry.py rename to internlm/model/registry.py index 3ac14452..e91a2255 100644 --- a/internlm/utils/registry.py +++ b/internlm/model/registry.py @@ -1,6 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from typing import Callable + +from internlm.model.modeling_internlm import InternLM1 +from internlm.model.modeling_internlm2 import InternLM2 +from internlm.model.modeling_llama import Llama2 +from internlm.model.modeling_llava import Llava +from internlm.model.modeling_moe import Internlm1MoE + class Registry: """This is a registry class used to register classes and modules so that a universal @@ -12,13 +20,13 @@ class Registry: def __init__(self, name: str): self._name = name - self._registry = dict() + self._registry = {} @property def name(self): return self._name - def register_module(self, module_name: str): + def register_module(self, module_name: str, func: Callable): """Registers a module represented in `module_class`. Args: @@ -31,11 +39,7 @@ def register_module(self, module_name: str): assert module_name not in self._registry, f"{module_name} already registered in {self.name}" - def decorator_wrapper(original_func): - self._registry[module_name] = original_func - return original_func - - return decorator_wrapper + self._registry[module_name] = func def get_module(self, module_name: str): """Retrieves a module with name `module_name` and returns the module if it has @@ -68,4 +72,12 @@ def has(self, module_name: str): return found_flag -MODEL_INITIALIZER = Registry("model_initializer") +model_initializer = Registry("model_initializer") + + +def register_model_initializer() -> None: + model_initializer.register_module("INTERNLM", InternLM1) + model_initializer.register_module("INTERNLM2_PUBLIC", InternLM2) + model_initializer.register_module("LLAMA2", Llama2) + model_initializer.register_module("INTERNLM_MoE", Internlm1MoE) + model_initializer.register_module("LLAVA", Llava) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 7fef0f93..f63b2cde 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,715 +1,22 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- +from typing import Dict -from typing import Callable, Optional +from internlm.model.modules.mha import MHA -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.utils.rnn import pad_sequence -from internlm.accelerator import get_accelerator -from internlm.core.context import global_context as gpc -from internlm.model.ops.fusion_ops_import_helper import try_import_linear_bias_wgrad -from internlm.utils.logger import get_logger +def internlm1_mha_pre_load_convert( + model: MHA, state_dict: Dict, prefix: str, *args, **kwargs # pylint: disable=W0613 +) -> None: + if f"{prefix}wqkv.weight" not in state_dict and f"{prefix}Wqkv.weight" in state_dict: + state_dict[f"{prefix}wqkv.weight"] = state_dict.pop(f"{prefix}Wqkv.weight") -internlm_accelerator = get_accelerator() + if f"{prefix}wqkv.bias" not in state_dict and f"{prefix}Wqkv.bias" in state_dict: + state_dict[f"{prefix}wqkv.bias"] = state_dict.pop(f"{prefix}Wqkv.bias") -custom_bwd = internlm_accelerator.return_custom_bwd() -custom_fwd = internlm_accelerator.return_custom_fwd() -logger = get_logger(__file__) +def internlm1_mha_save_convert( + model: MHA, state_dict: Dict, prefix: str, *args, **kwargs # pylint: disable=W0613 +) -> None: + state_dict[f"{prefix}Wqkv.weight"] = state_dict.pop(f"{prefix}wqkv.weight") - -def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): - assert my_input.dtype == grad_output.dtype - grad_weight = torch.matmul(grad_output.t(), my_input) - grad_bias = grad_output.sum(dim=0) if has_d_bias else None - return grad_weight, grad_bias - - -linear_bias_wgrad = try_import_linear_bias_wgrad() -is_using_cuda_linear_bias_wgrad = True -if linear_bias_wgrad is None: - linear_bias_wgrad = linear_bias_wgrad_torch - is_using_cuda_linear_bias_wgrad = False - - -# Raw operation, does not support autograd, but does support async -def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): - input_ = input_.contiguous() - handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) - return input_, handle - - -class ReduceScatterFunc(torch.autograd.Function): - """Reduce scatter the input from the sequence parallel region and concatenate.""" - - @staticmethod - def forward(ctx, input_: Tensor, process_group: ProcessGroup, reduce_dim: int = 0) -> Tensor: - ctx.process_group = process_group - ctx.reduce_dim = reduce_dim - output, _ = reduce_scatter_raw(input_, process_group, reduce_dim=reduce_dim) - return output - - @staticmethod - def backward(ctx, grad_output: Tensor): - gather_dim = ctx.reduce_dim - grad_input, _ = all_gather_raw(grad_output, ctx.process_group, gather_dim=gather_dim) - return grad_input, None, None - - -# Supports autograd, but does not support async -reduce_scatter = ReduceScatterFunc.apply - - -class AllReduceFunc(torch.autograd.Function): - """Gather the input from sequence parallel region and concatenate.""" - - @staticmethod - def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: - ctx.process_group = process_group - output, _ = all_reduce_raw(input_, process_group) - return output - - @staticmethod - def backward(ctx, grad_output: Tensor): - _ = ctx # avoid lint warning W0613 - return grad_output, None - - -# Supports autograd, but does not support async -all_reduce = AllReduceFunc.apply - - -def _split(input_, parallel_mode, dim=-1): - # skip if only one rank involved - world_size = gpc.get_world_size(parallel_mode) - if world_size == 1: - return input_ - - # Split along last dimension. - dim_size = input_.size(dim) - assert dim_size % world_size == 0, ( - f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " - f"cannot split tensor evenly" - ) - - tensor_list = torch.split(input_, dim_size // world_size, dim=dim) - rank = gpc.get_local_rank(parallel_mode) - output = tensor_list[rank].contiguous() - output = output.detach().clone() - - return output - - -def _gather(input_, parallel_mode, dim=-1): - # skip if only one rank involved - world_size = gpc.get_world_size(parallel_mode) - if world_size == 1: - return input_ - - # all gather - rank = gpc.get_local_rank(parallel_mode) - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) - dist.all_gather(tensor_list, input_, group=group) - - # concat - output = torch.cat(tensor_list, dim=dim).contiguous() - - return output - - -class _GatherForwardSplitBackward(torch.autograd.Function): - """Gather the input from model parallel region and concatenate. - - Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension - """ - - @staticmethod - def symbolic(input_): - return _gather(input_, parallel_mode=None) - - @staticmethod - def forward(ctx, input_, parallel_mode, dim): - ctx.mode = parallel_mode - ctx.dim = dim - return _gather(input_, parallel_mode, dim) - - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.mode, ctx.dim), None, None - - -def gather_forward_split_backward(input_, parallel_mode, dim): - return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) - - -class _SplitForwardGatherBackward(torch.autograd.Function): - """ - Split the input and keep only the corresponding chuck to the rank. - - Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension - """ - - @staticmethod - def symbolic(input_): - return _split(input_, parallel_mode=None) - - @staticmethod - def forward(ctx, input_, parallel_mode, dim): - ctx.mode = parallel_mode - ctx.dim = dim - return _split(input_, parallel_mode, dim) - - @staticmethod - def backward(ctx, grad_output): - return _gather(grad_output, ctx.mode, ctx.dim), None, None - - -def split_forward_gather_backward(input_, parallel_mode, dim): - return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) - - -def all_gather_raw( - input_: Tensor, - process_group: ProcessGroup, - async_op: bool = False, - gather_dim: int = 0, - memory_pool_allocator: Callable = None, -): - world_size = dist.get_world_size(process_group) - if world_size <= 1: - return input_, None - - if memory_pool_allocator is not None: - output = memory_pool_allocator() - else: - shape = list(input_.shape) - shape[gather_dim] = shape[gather_dim] * world_size - output = torch.empty(shape, dtype=input_.dtype, device=input_.device) - - handle = dist.all_gather_into_tensor(output, input_.contiguous(), group=process_group, async_op=async_op) - return output, handle - - -def reduce_scatter_raw( - input_: Tensor, - process_group: ProcessGroup, - op=dist.ReduceOp.SUM, - async_op: bool = False, - reduce_dim: int = 0, - memory_pool_allocator: Callable = None, -): - world_size = dist.get_world_size(process_group) - assert input_.shape[reduce_dim] % world_size == 0 - - if world_size <= 1: - return input_, None - - shape_list = list(input_.shape) - shape_list[reduce_dim] = shape_list[reduce_dim] // world_size - - if memory_pool_allocator is not None: - output = memory_pool_allocator(tuple(shape_list)) - else: - output = torch.empty( - shape_list, - dtype=input_.dtype, - device=input_.device, - ).contiguous() - - handle = dist.reduce_scatter_tensor(output, input_.contiguous(), op=op, group=process_group, async_op=async_op) - return output, handle - - -# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py -class FusedDenseFunc(torch.autograd.Function): - "FusedDenseFunc for tensor parallel in flash-attn implementation." - - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight, - bias, - return_residual=False, - process_group=None, - sequence_parallel=True, - gather_dim=0, - dtype_eligible: bool = True, - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather_raw of x before doing the matmul. - """ - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - ctx.gather_dim = gather_dim - ctx.dtype_eligible = dtype_eligible - - if ctx.dtype_eligible is False: - global linear_bias_wgrad, is_using_cuda_linear_bias_wgrad - linear_bias_wgrad = linear_bias_wgrad_torch - is_using_cuda_linear_bias_wgrad = False - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim) - else: - total_x = x - - if torch.is_autocast_enabled(): - weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) - bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - weight = weight.contiguous() - if process_group is not None and sequence_parallel and handle_x is not None: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - output = F.linear(total_x, weight, bias) # pylint: disable=E1102 - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - gather_dim = ctx.gather_dim - if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim) - else: - total_x = x - else: - (weight,) = ctx.saved_tensors - total_x = None - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, weight.t()) # pylint: disable=E1102 - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, - weight, - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - if sequence_parallel: - grad_input, handle_grad_input = reduce_scatter_raw( - grad_input, process_group, async_op=True, reduce_dim=1 - ) - else: - grad_input, handle_grad_input = all_reduce_raw(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - if process_group is not None and sequence_parallel and handle_x is not None: - handle_x.wait() - grad_weight, grad_bias = linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), - grad_output, - ctx.needs_input_grad[2], - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - if process_group is not None and ctx.needs_input_grad[0] and handle_grad_input is not None: - handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None, None, None, None - - -class MegatronFusedDenseFunc(torch.autograd.Function): - """ - FusedDenseFunc for tensor parallel in megatron implementation. - The diffenrence between the implementation of flash-attn and megatron is that the total_x could be - saved for backward in megatron, so that the all-gather in backward is ommited. - """ - - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight, - bias, - return_residual=False, - process_group=None, - sequence_parallel=True, - gather_dim=0, - dtype_eligible: bool = True, - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather_raw of x before doing the matmul. - """ - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - ctx.dtype_eligible = dtype_eligible - - if ctx.dtype_eligible is False: - global linear_bias_wgrad, is_using_cuda_linear_bias_wgrad - linear_bias_wgrad = linear_bias_wgrad_torch - is_using_cuda_linear_bias_wgrad = False - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim) - else: - total_x = x - - if torch.is_autocast_enabled(): - weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) - bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - weight = weight.contiguous() - if process_group is not None and sequence_parallel and handle_x is not None: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - output = F.linear(total_x, weight, bias) # pylint: disable=E1102 - if ctx.compute_weight_gradient: - ctx.save_for_backward(total_x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - if ctx.compute_weight_gradient: - total_x, weight = ctx.saved_tensors - else: - (weight,) = ctx.saved_tensors - total_x = None - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, weight.t()) # pylint: disable=E1102 - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, - weight, - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - if sequence_parallel: - grad_input, handle_grad_input = reduce_scatter_raw( - grad_input, process_group, async_op=True, reduce_dim=1 - ) - else: - grad_input, handle_grad_input = all_reduce_raw(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - grad_weight, grad_bias = linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), - grad_output, - ctx.needs_input_grad[2], - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - if process_group is not None and ctx.needs_input_grad[0] and handle_grad_input is not None: - handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None, None, None, None - - -class ISPFusedDenseFunc(torch.autograd.Function): - "FusedDenseFunc for ISP, which is optimized based on flash implementation." - - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight, - bias, - module, - communicator, - return_residual=False, - dtype_eligible: bool = True, - ): - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.module = module - ctx.communicator = communicator - ctx.dtype_eligible = dtype_eligible - - if ctx.dtype_eligible is False: - global linear_bias_wgrad, is_using_cuda_linear_bias_wgrad - linear_bias_wgrad = linear_bias_wgrad_torch - is_using_cuda_linear_bias_wgrad = False - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - - total_weight = communicator.all_gather(weight, module) - total_bias = bias if bias is None else communicator.all_gather(bias, module, is_bias=True) - - if torch.is_autocast_enabled(): - total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype()) - if total_bias: - total_bias.to(dtype=torch.get_autocast_gpu_dtype()) - - total_weight = total_weight.contiguous() - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *total_weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - - output = F.linear(x, total_weight, total_bias) # pylint: disable=E1102 - - # release memory - del total_weight - del total_bias - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - module = ctx.module - communicator = ctx.communicator - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - - if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors - else: - x, weight = (None, *ctx.saved_tensors) - - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - - total_weight = communicator.all_gather(weight, module) - - # compute weight grad - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - grad_weight, grad_bias = linear_bias_wgrad( - x.reshape(batch_dim, x.shape[-1]), - grad_output, - ctx.needs_input_grad[2], - ) - - grad_weight, grad_weight_sync = communicator.reduce_scatter(grad_weight, module, op=dist.ReduceOp.AVG) - if grad_bias is not None: - grad_bias, grad_bias_sync = communicator.reduce_scatter( - grad_bias, module, op=dist.ReduceOp.AVG, is_bias=True - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, total_weight.t()) # pylint: disable=E1102 - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, - total_weight, - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - else: - grad_input = None - - del total_weight - - if ctx.needs_input_grad[1]: - if grad_weight_sync: - grad_weight_sync.wait() - if grad_bias is not None and grad_bias_sync is not None: - grad_bias_sync.wait() - - return grad_input, grad_weight, grad_bias, None, None, None, None - - -def fused_dense_func( - x: Tensor, - weight: Tensor, - bias: Optional[Tensor] = None, - return_residual: bool = False, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, - gather_dim: int = 0, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - return FusedDenseFunc.apply( - x, - weight, - bias, - return_residual, - process_group, - sequence_parallel, - gather_dim, - dtype_eligible, - ) - - -def megatron_fused_dense_func( - x: Tensor, - weight: Tensor, - bias: Optional[Tensor] = None, - return_residual: bool = False, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, - gather_dim: int = 0, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - return MegatronFusedDenseFunc.apply( - x, - weight, - bias, - return_residual, - process_group, - sequence_parallel, - gather_dim, - dtype_eligible, - ) - - -def isp_fused_dense_func( - x: Tensor, - weight: Tensor, - module, - communicator, - bias: Optional[Tensor] = None, - return_residual: bool = False, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - return ISPFusedDenseFunc.apply( - x, - weight, - bias, - module, - communicator, - return_residual, - dtype_eligible, - ) - - -def is_moe_param(param: torch.Tensor) -> bool: - if hasattr(param, "is_expert") and param.is_expert: - return True - return False - - -def Silu(w1_o, w2_o): - return F.silu(w1_o) * w2_o - - -Silu = torch.jit.script(Silu) - - -def unpack_qkv_before_attn(cur_input=None, cu_seqlens=None, padding_v: int = 0): - """ - qkv: the shape is (1, packed_length, three, head_num, head_dim) - kv: the shape is (1, packed_length, two, head_num, head_dim) - q/k/v: the shape is (1, packed_length, head_num, head_dim) - - Return: - output: the shape is (micro_bsz, seq_len, three, head_num, head_dim) for qkv - (micro_bsz, seq_len, two, head_num, head_dim) for kv - (micro_bsz, seq_len, head_num, head_dim) for q/k/v - """ - if cu_seqlens is None or cur_input is None: - raise ValueError("cu_seqlens and cur_input must be provided.") - - assert cur_input.shape[0] == 1 - cur_input = cur_input.squeeze(0) - - sequences = [] - for i in range(len(cu_seqlens) - 1): - sequences.append(cur_input[cu_seqlens[i] : cu_seqlens[i + 1]]) - - padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_v) - - return padded_sequences - - -def pack_output_after_attn(cur_input=None, cu_seqlens=None, padding_v: int = 0): - """ - cur_input: the shape is (micro_bsz, seq_len, hidden_size) - - Return: - output: the shape is (1, packed_length, hidden_size) - """ - if cu_seqlens is None or cur_input is None: - raise ValueError("cu_seqlens and cur_input must be provided.") - - packed_len_ = gpc.config.data.micro_bsz * gpc.config.data.seq_len - output_shape = list(cur_input.shape) - output_shape[0] = 1 - output_shape[1] = packed_len_ - - output = torch.full(output_shape, fill_value=padding_v, device=cur_input.device, dtype=cur_input.dtype) - for i in range(len(cu_seqlens) - 1): - length = cu_seqlens[i + 1] - cu_seqlens[i] - output[0, cu_seqlens[i] : cu_seqlens[i + 1]] = cur_input[i, 0:length] - - return output + if f"{prefix}wqkv.bias" in state_dict: + state_dict[f"{prefix}Wqkv.bias"] = state_dict.pop(f"{prefix}wqkv.bias") diff --git a/internlm/solver/optimizer/compatible_adamw.py b/internlm/solver/optimizer/compatible_adamw.py new file mode 100644 index 00000000..bca8c274 --- /dev/null +++ b/internlm/solver/optimizer/compatible_adamw.py @@ -0,0 +1,52 @@ +from typing import Tuple + +import torch + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + +try: + from torch_npu.optim import NpuFusedAdamW + + del NpuFusedAdamW + + npu_adamw_impl = True +except (ModuleNotFoundError, ImportError): + npu_adamw_impl = False + + +# TODO: 给上次一个统一的接口,这些接口都能被下层的各种实现支持,哪些参数应该保留,那些参数应该省略? +def new_compatible_adamw(params, lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8): + """ + return a compatibel adamw instance. + """ + adam_extra_kwargs = {} + backend = internlm_accelerator.get_accelerator_backend() + + if backend is AcceleratorType.GPU and torch.__version__ >= "2.1.0": + if gpc.is_rank_for_log(): + logger.warning( + "Use fused AdamaW to avoid nan grad norm when " + "model size is larger and use_fp32_norm=True, Please note this!" + ) + adam_extra_kwargs["fused"] = True + elif backend is AcceleratorType.NPU: + if gpc.is_rank_for_log(): + logger.warning( + "Use normal AdamaW, NPU fused_adamw currently has" + "accuracy issues and is not supported yet. Please note this!" + ) + # TODO: support npu version adamw + elif backend is AcceleratorType.DIPU: + if gpc.is_rank_for_log(): + logger.warning("Use torch.optim.AdamW rather than deeplink adamw. Please note this!") + # TODO: support deeplink version adamw + else: + if gpc.is_rank_for_log(): + logger.warning("Use torch.optim.AdamW rather than FusedAdamW. Please note this!") + + return torch.optim.AdamW(params, lr=lr, betas=betas, eps=eps, **adam_extra_kwargs) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index c4d36be2..1d26bd87 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,7 +11,6 @@ from torch.optim import Optimizer from internlm.accelerator import get_accelerator -from internlm.core.communication.utils import ParamAsyncBcastHandler from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import ( @@ -21,6 +20,7 @@ IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, ) +from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, diff --git a/internlm/solver/pipeline_utils.py b/internlm/solver/pipeline_utils.py deleted file mode 100644 index c57765e4..00000000 --- a/internlm/solver/pipeline_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - - -def partition_uniform(num_items, pipeline_parallel_size, num_chunks): - assert ( - num_items % num_chunks == 0 - ), "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" - - parts = [[] for _ in range(pipeline_parallel_size)] - partition_items = num_items // num_chunks - for idx in range(num_chunks): - base_idx = idx * partition_items - chunk_size = partition_items // pipeline_parallel_size - left = pipeline_parallel_size - partition_items % pipeline_parallel_size - if chunk_size == 0: - raise ValueError("Some nodes in Pipeline have no requests") - - for p in range(pipeline_parallel_size): - st = base_idx - base_idx += chunk_size + (p >= left) - parts[p].append((st, base_idx)) - - indexes = [] - for _parts in parts: - for s, e in _parts: - indexes.extend(list(range(s, e))) - assert len(indexes) == len(set(indexes)), indexes # should have no duplicates - assert set(indexes) == set(list(range(num_items))), (indexes, num_items) # should have the same indexes as expected - return parts diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py index 9fc111ff..bf020924 100644 --- a/internlm/train/__init__.py +++ b/internlm/train/__init__.py @@ -1,9 +1,9 @@ from .pipeline import ( get_scheduler_hooks, - initialize_isp_communicator, initialize_llm_profile, initialize_model, initialize_optimizer, + initialize_parallel_communicator, load_new_batch, record_current_batch_training_metrics, set_fp32_attr_for_model, @@ -14,7 +14,7 @@ __all__ = [ "initialize_llm_profile", "initialize_model", - "initialize_isp_communicator", + "initialize_parallel_communicator", "initialize_optimizer", "load_new_batch", "record_current_batch_training_metrics", diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 3152af0d..70a30baf 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -4,7 +4,7 @@ import functools import math import time -from typing import Callable, Iterable, List, Optional, Union +from typing import Callable, Iterable, List, Optional, Tuple, TypeVar, Union import torch from torch import nn @@ -17,12 +17,6 @@ from torch.utils.data import DataLoader from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.communication.isp import ( - ISPCommModelConfig, - ISPCommunicator, - ISPCommunicatorSchedulerHook, -) -from internlm.core.communication.utils import ParamAsyncBcastHandler from internlm.core.context import ( IS_REPLICA_ZERO_PARALLEL, IS_TENSOR_DATA_PARALLEL, @@ -33,34 +27,53 @@ ) from internlm.core.context import global_context as gpc from internlm.core.context.random import set_mode -from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module +from internlm.core.naive_amp import ( + NaiveAMPModel, + set_fp32_attr_to_module, + unwrap_naive_amp, +) +from internlm.core.parallel.comm.isp import ( + ISPCommModelConfig, + ISPCommunicator, + ISPCommunicatorSchedulerHook, +) +from internlm.core.parallel.comm.tensor import ( + EmbbedingSequenceParallelCommunicator, + EmbbedingTensorParallelCommunicator, + HeadSequenceParallelCommunicator, + HeadTensorParallelCommunicator, + LinearRole, + MoESequenceParallelCommunicator, + SequenceParallelCommunicator, + TensorParallelCommunicator, +) +from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler from internlm.core.trainer import TrainState -from internlm.data.utils import unpack_data +from internlm.data.utils import unpack_type_ids +from internlm.model.builder import create_model from internlm.model.metrics import SchedulerMetricHook from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import ( + ColumnParallelLinear, + ParallelLinearWithCommExt, + RewardModelLinear, + RowParallelLinear, + ScaleColumnParallelLinear, +) +from internlm.model.modules.mha import GQA, MHA from internlm.model.modules.mlp import FeedForward -from internlm.model.modules.multi_head_attention import MHA +from internlm.model.modules.utils import is_moe_param from internlm.model.moe.megablock.mlp import ( MegaBlockFeedForward, MegaBlockGroupedFeedForward, ) from internlm.model.moe.moe import MoE -from internlm.model.ops.fusion_ops_import_helper import ( - try_import_FusedAdamW, - try_import_RMSNorm, -) -from internlm.model.ops.linear import ( - BaseScaleColumnParallelLinear, - ColumnParallelLinearTorch, - ISPLinear, - RewardModelLinear, - RowParallelLinearTorch, - ScaleColumnParallelLinear, -) -from internlm.model.utils import is_moe_param +from internlm.model.ops.norm import RMSNorm +from internlm.model.registry import register_model_initializer from internlm.monitor import set_env_var from internlm.monitor.monitor import monitor_manager as mm from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer +from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler from internlm.solver.schedulers.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.train.utils import create_param_groups @@ -77,7 +90,6 @@ sync_model_param, sync_model_replica_param_group, ) -from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.timeout import llm_timeout try: @@ -85,7 +97,6 @@ except (ImportError, ModuleNotFoundError): pass -RMSNorm = try_import_RMSNorm() logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -112,9 +123,8 @@ def _check_module(name, module): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) # embedding and head - embedding_head_cls = (Embedding1D, BaseScaleColumnParallelLinear) - if isinstance(module, embedding_head_cls): + if isinstance(module, (Embedding1D, ScaleColumnParallelLinear)): for param in module.parameters(): if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp(): setattr(param, IS_TENSOR_DATA_PARALLEL, True) @@ -124,7 +134,7 @@ def _check_module(name, module): # for linear module if isinstance( module, - (ColumnParallelLinearTorch, RowParallelLinearTorch, MegaBlockFeedForward, MegaBlockGroupedFeedForward), + (ParallelLinearWithCommExt, MegaBlockFeedForward, MegaBlockGroupedFeedForward), ): for param in module.parameters(): if gpc.is_initialized(ParallelMode.EXPERT_DATA) and is_moe_param(param): @@ -138,18 +148,9 @@ def _check_module(name, module): # for vit and vit project if "vision_tower" in name.lower() or "vision_proj" in name.lower(): for param in module.parameters(): - if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp(): - setattr(param, IS_TENSOR_DATA_PARALLEL, True) - elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): - setattr(param, IS_TENSOR_ZERO_PARALLEL, True) - - if not isinstance(model, nn.ModuleList): - model = [model] - - for _chunk in model: - if isinstance(_chunk, NaiveAMPModel): - _chunk = _chunk.model + setattr(param, IS_REPLICA_ZERO_PARALLEL, True) + for _chunk in unwrap_naive_amp(model): # set param parallel attribute for name, module in _chunk.named_modules(): _check_module(name, module) @@ -175,7 +176,11 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f """ if pre_process_func: pre_process_output = pre_process_func() - model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model)) + + register_model_initializer() + + model = create_model(model_type=gpc.config.model_type, **(gpc.config.model)) + if post_process_func: post_process_func(pre_process_output) @@ -221,20 +226,21 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f # if fsdp enabled, wrap the model model = wrap_FSDP_model(model) + # TODO: add a checker to ensure model only use ours linear, expect fsdp. + return model def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): - if gpc.config.parallel.zero1.fsdp and gpc.config.model.use_flash_attn: - from flash_attn.modules.embedding import ParallelGPT2Embeddings + if gpc.config.parallel.zero1.fsdp: # set wrap_policy for fsdp wrap transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={ Embedding1D, - ParallelGPT2Embeddings, MHA, + GQA, RMSNorm, FeedForward, RewardModelLinear, @@ -258,7 +264,19 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): return model -def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]): +_T = TypeVar("_T") + + +def _submodule_filter(model: Union[nn.Module, nn.ModuleList], target_cls: Union[_T, Tuple[_T]]) -> Iterable[_T]: + for _chunk in unwrap_naive_amp(model): + for _module in _chunk.modules(): + if not isinstance(_module, target_cls): + continue + + yield _module + + +def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): """ Initialize communicator for isp tensor parallel mode. @@ -269,6 +287,8 @@ def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]): An isp communicator for managing comp/comm overlap and memory pool. """ isp_communicator = None + _retain_out_sharded = gpc.config.model.get("parallel_output", True) + if is_using_isp(): isp_communicator = ISPCommunicator( model, @@ -281,8 +301,73 @@ def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.weight.memory_pool, gpc.get_group(ParallelMode.WEIGHT), ) - # register communicator for isp linear. - ISPLinear.register_communicator(isp_communicator) + # register communicator for isp column parallel linear. + ColumnParallelLinear.register_cls_communicator(isp_communicator) + # row parallel linear will not be used. + RowParallelLinear.register_cls_communicator(None) + _head_communicator = HeadSequenceParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) + _embedding_communicator = EmbbedingSequenceParallelCommunicator(ParallelMode.TENSOR) + + # register communictor for mtp/msp/fsp linear. + + # tensor parallel + if gpc.config.parallel.tensor.mode == "mtp": + ColumnParallelLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN) + ) + RowParallelLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) + ) + _head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) + _embedding_communicator = EmbbedingTensorParallelCommunicator(ParallelMode.TENSOR) + # sequence parallel + if gpc.config.parallel.tensor.mode in ("msp", "fsp"): + save_total_input_as_activation = gpc.config.parallel.tensor.mode == "msp" + + ColumnParallelLinear.register_cls_communicator( + SequenceParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.COLUMN, + save_total_input_as_activation=save_total_input_as_activation, + ) + ) + RowParallelLinear.register_cls_communicator( + SequenceParallelCommunicator( + gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.ROW, + save_total_input_as_activation=save_total_input_as_activation, + ) + ) + + _head_communicator = HeadSequenceParallelCommunicator( + ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation + ) + _embedding_communicator = EmbbedingSequenceParallelCommunicator(ParallelMode.TENSOR) + + # MoE sequence parallel + if gpc.config.model.get("num_experts", 1) > 1: + _column_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN + ) + _row_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW + ) + for moe in _submodule_filter(model, MoE): + # 1. the linear in MoE degrades the parallel communication pattern from sp to tp + for column_linear in _submodule_filter(moe, ColumnParallelLinear): + column_linear.register_communicator(_column_communicator) + for row_linear in _submodule_filter(moe, RowParallelLinear): + row_linear.register_communicator(_row_communicator) + # 2. register MoESequenceParallelCommunicator for MoE layer + MoESequenceParallelCommunicator(ParallelMode.TENSOR).register_module_hook(moe) + + # register communitorc for embedding layer. + for embedding in _submodule_filter(model, Embedding1D): + _embedding_communicator.register_module_hook(embedding) + + # register communictor for head layer. + ScaleColumnParallelLinear.register_cls_communicator(_head_communicator) + RewardModelLinear.register_cls_communicator(_head_communicator) return isp_communicator @@ -305,15 +390,11 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato params = create_param_groups(model, adam_cfg.weight_decay) - # TODO(caikun): add DIPU backend adamw - adam_extra_kwargs, internlm_adamw = try_import_FusedAdamW() - - naive_optimizer = internlm_adamw( + naive_optimizer = new_compatible_adamw( params=params, lr=adam_cfg.lr, betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), eps=adam_cfg.adam_eps, - **adam_extra_kwargs, ) if ( @@ -411,7 +492,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai if batch[0].get("type_ids", None) is not None: # if use_packed_dataset is False, we need to unpack type_ids if not gpc.config.data.use_packed_dataset: - batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True) + batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"]) return batch, train_iter diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 26cf7acf..47890d78 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -4,7 +4,7 @@ from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc -from internlm.model.utils import is_moe_param +from internlm.model.modules.utils import is_moe_param from internlm.utils.parallel import is_tensor_data_parallel_parameter, is_using_isp diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 956f8e16..df4583d4 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -45,45 +45,15 @@ def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Te return norm -def _move_tensor(element): - if not torch.is_tensor(element): - # we expecte the data type if a list of dictionaries - for idx, item in enumerate(element): - if isinstance(item, dict): - for key, value in item.items(): - assert value.device.type == "cpu" - item[key] = value.to(get_current_device()).detach() - elif isinstance(item, list): - for index, value in enumerate(item): - assert value.device.type == "cpu" - item[index] = value.to(get_current_device()).detach() - elif torch.is_tensor(item): - if item.device.type == "cpu": - element[idx] = item.to(get_current_device()).detach() - else: - assert False, f"{type(item)}, {item}" - else: - assert torch.is_tensor(element), f"element should be of type tensor, but got {type(element)}" - if element.device.type == "cpu": - element = element.to(get_current_device()).detach() - return element - - def move_to_device(data): if isinstance(data, torch.Tensor): - data = data.to(get_current_device()) + if data.device.type == "cpu": + data = data.to(get_current_device()).detach() elif isinstance(data, (list, tuple)): - data_to_return = [] - for element in data: - if isinstance(element, dict): - data_to_return.append({k: _move_tensor(v) for k, v in element.items()}) - else: - data_to_return.append(_move_tensor(element)) - data = data_to_return + data = [move_to_device(x) for x in data] elif isinstance(data, dict): - data = {k: _move_tensor(v) for k, v in data.items()} - else: - raise TypeError(f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + data = {k: move_to_device(v) for k, v in data.items()} + return data @@ -247,11 +217,14 @@ def get_megatron_flops( def enable_pytorch_expandable_segments(): if torch.__version__ >= "2.1.0" and AcceleratorType.GPU == internlm_accelerator.get_accelerator_backend(): - _alloc_setting = "expandable_segments:True" - assert ( - os.getenv("PYTORCH_CUDA_ALLOC_CONF", None) is None - ), "PYTORCH_CUDA_ALLOC_CONF should not be set when using expandable_segments" - internlm_accelerator.memory._set_allocator_settings(_alloc_setting) + _expandable_segments_conf = "expandable_segments:True" + _alloc_conf = os.getenv("PYTORCH_CUDA_ALLOC_CONF", None) + if _alloc_conf is None: + _alloc_conf = _expandable_segments_conf + elif "max_split_size_mb" not in _alloc_conf: + _alloc_conf = _alloc_conf + "," + _expandable_segments_conf + + internlm_accelerator.memory._set_allocator_settings(_alloc_conf) else: logger.warning("To support the 'expandable_segments' configuration, please upgrade torch to version 2.1.0.") diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 4de457cd..1b92974d 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -12,9 +12,6 @@ ParallelMode, ) from internlm.core.context import global_context as gpc -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm - -RMSNorm = try_import_RMSNorm() def is_using_sequence_parallel(): @@ -74,7 +71,6 @@ def sync_model_param(model): Args: model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. """ - sync_moe_param = gpc.is_using_parallel_mode(ParallelMode.EXPERT_DATA) sync_parallel_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA for param in model.parameters(): diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py index 9a30eb26..e8e76e70 100644 --- a/internlm/utils/utils.py +++ b/internlm/utils/utils.py @@ -1,5 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import types from contextlib import contextmanager +from enum import Enum, IntEnum +from functools import update_wrapper +from typing import Callable, Tuple + +import torch @contextmanager @@ -16,3 +22,106 @@ def read_base(): .. _tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta # pylint: disable=line-too-long """ # noqa: E501 yield + + +class QKVPackType(IntEnum): + QKVPACKED = 2 + KVPACKED = 3 + QKVSPLITED = 4 + + def __str__(self) -> str: + return str(self.value) + + +class CuSeqlenType(Enum): + With = True + WithOut = False + + def __str__(self) -> str: + return str(self.value) + + +def check_attention_argument(*args, **kwargs) -> str: + # self, qkv, ... + # self, q, kv, .... + # self, q, k, v, ... + # self, qkv, cu_seqlens, max_seqlen, ... + # self, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, ... + # self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, ... + def __qkv_checker(num_args: int): + if num_args < 2: + return "qkv" in kwargs + else: + # qkv: [batch, seqlen, 3, n_head, headdim] + return len(args[1].shape) == 5 + + def __kv_checker(num_args: int): + if num_args < 3: + return "kv" in kwargs + else: + # kv: [batch, seqlen, 3, n_head, headdim] + return len(args[2].shape) == 5 + + def __cu_seqlens_checker(num_args: int, check_idx: int): + if num_args < (check_idx + 1): + if check_idx == 2: + return "cu_seqlens" in kwargs and kwargs["cu_seqlens"] is not None + else: + return "cu_seqlens_q" in kwargs and kwargs["cu_seqlens_q"] is not None + else: + return isinstance(num_args[check_idx], torch.Tensor) + + if __qkv_checker(len(args)): + # qkv packed, and we should check cu_seqlens with index 2 + qkv_pack_type = int(QKVPackType.QKVPACKED) + elif __kv_checker(len(args)): + # kv packed, and we should check cu_seqlens with index 3 + qkv_pack_type = int(QKVPackType.KVPACKED) + else: + # qkv splited, and we should check cu_seqlens with index 4 + qkv_pack_type = int(QKVPackType.QKVSPLITED) + + with_cu_seqlens = __cu_seqlens_checker(len(args), qkv_pack_type) + + return str(qkv_pack_type), str(with_cu_seqlens) + + +def params_dispatch_with_condition(condition: Callable, func: Callable = None): + + if func is None: + # create a params dispatch wrapper + return lambda f: params_dispatch_with_condition(condition, f) + + registry = {} + funcname = getattr(func, "__name__", "params_dispatch_with_condition function") + + def dispatch(_type: str) -> Callable: + return registry[_type] + + def register(conditions: Tuple[str], func: Callable = None) -> None: + if func is None: + # create a register wrapper + return lambda f: register(conditions, f) + + _type = "-".join(conditions) + + assert _type not in registry, f"Repeatedly register dispatch functions for pattern {_type}" + + registry[_type] = func + + return func + + def wrapper(*args, **kwargs): + if not args: + raise TypeError(f"{funcname} requires at least " "1 positional argument") + + _type = "-".join(condition(*args, **kwargs)) + + return dispatch(_type)(*args, **kwargs) + + registry[""] = func + wrapper.register = register + wrapper.dispatch = dispatch + wrapper.registry = types.MappingProxyType(registry) + update_wrapper(wrapper, func) + return wrapper diff --git a/tests/common_fixture.py b/tests/common_fixture.py index ed1d914e..6f3def53 100644 --- a/tests/common_fixture.py +++ b/tests/common_fixture.py @@ -9,7 +9,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config -from internlm.data.utils import unpack_data +from internlm.data.utils import unpack_type_ids from internlm.initialize.launch import args_sanity_check internlm_accelerator = get_accelerator() @@ -149,6 +149,6 @@ def load_new_batch(train_dl, train_iter): if batch[0].get("type_ids", None) is not None: # if use_flash_attn is False, we need to unpack type_ids if not gpc.config.model.use_flash_attn: - batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True) + batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"]) return batch, train_iter diff --git a/tests/test_core/utils.py b/tests/test_core/utils.py index 5c767914..5ccaccaf 100644 --- a/tests/test_core/utils.py +++ b/tests/test_core/utils.py @@ -11,13 +11,13 @@ from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler +from internlm.core.parallel.shard import partition_uniform from internlm.core.scheduler import ( InterleavedPipelineScheduler, NonPipelineScheduler, PipelineScheduler, ) from internlm.model.metrics import SchedulerMetricHook -from internlm.solver.pipeline_utils import partition_uniform from internlm.train import initialize_optimizer from internlm.utils.common import get_current_device @@ -41,7 +41,7 @@ def forward( ): # pylint: disable=W0613 if self.model_type != "torch" and self.part[0] != 0: input_ids = hidden_states - + # Simulate Embedding. if self.embedding: if len(input_ids.shape) == 2: diff --git a/tests/test_model/test_feed_forward.py b/tests/test_model/test_feed_forward.py index e4aab9ec..311f30d7 100644 --- a/tests/test_model/test_feed_forward.py +++ b/tests/test_model/test_feed_forward.py @@ -1,7 +1,7 @@ import pytest import torch -from internlm.model.modules.mlp import BaseFeedForward +from internlm.model.modules.mlp import new_feed_forward, split_fused_mlp_weight from internlm.utils.common import get_current_device SEQ_LEN = 64 @@ -9,20 +9,6 @@ MLP_RATIO = 8 / 3 -class InternLMLinear(torch.nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - *args, # pylint: disable=W0613 - bias: bool = True, - device=None, - dtype=None, - **kwargs, # pylint: disable=W0613 - ) -> None: - super().__init__(in_features, out_features, bias, device, dtype) - - mlp_args = { "in_features": HIDDEN_SIZE, "hidden_features": int(HIDDEN_SIZE * MLP_RATIO), @@ -30,8 +16,6 @@ def __init__( "bias": False, "device": get_current_device(), "dtype": torch.bfloat16, - "column_cls": InternLMLinear, - "row_cls": InternLMLinear, } @@ -43,13 +27,13 @@ def check_param(a1, a2, b1, b2): def init_mlp(): - mlp_no_fused = BaseFeedForward(**mlp_args) - mlp_fused = BaseFeedForward(mlp_layer_fusion=True, **mlp_args) + mlp_no_fused = new_feed_forward(**mlp_args) + mlp_fused = new_feed_forward(mlp_layer_fusion=True, **mlp_args) for _, param in mlp_fused.named_parameters(): torch.nn.init.normal_(param.data, std=0.02) - w1, w3 = BaseFeedForward.split_fused_mlp_weight(mlp_fused.fused_w1_w3.weight) + w1, w3 = split_fused_mlp_weight(mlp_fused.fused_w1_w3.weight) mlp_no_fused.w1.weight.data = w1.data mlp_no_fused.w3.weight.data = w3.data mlp_no_fused.w2.weight.data = mlp_fused.w2.weight.data @@ -99,7 +83,7 @@ def test_mlp_layer_fusion_loss(): l2.backward() assert torch.allclose(mlp_no_fused.w2.weight.grad, mlp_fused.w2.weight.grad, rtol=1e-4, atol=1e-5) - w1_g, w3_g = BaseFeedForward.split_fused_mlp_weight(mlp_fused.fused_w1_w3.weight.grad) + w1_g, w3_g = split_fused_mlp_weight(mlp_fused.fused_w1_w3.weight.grad) assert torch.allclose(mlp_no_fused.w1.weight.grad, w1_g, rtol=1e-4, atol=1e-5) assert torch.allclose(mlp_no_fused.w3.weight.grad, w3_g, rtol=1e-4, atol=1e-5) diff --git a/tests/test_model/test_fused_precision/test_fused_precision.py b/tests/test_model/test_fused_precision/test_fused_precision.py index 54959ecb..d0b79aae 100644 --- a/tests/test_model/test_fused_precision/test_fused_precision.py +++ b/tests/test_model/test_fused_precision/test_fused_precision.py @@ -6,7 +6,8 @@ from torch import nn from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module -from internlm.model.modeling_internlm import PackedFlashBaseLayer1D +from internlm.model.modeling_internlm import InternLM1Decoder +from internlm.train.pipeline import initialize_parallel_communicator from internlm.train.utils import create_param_groups from internlm.utils.common import get_current_device from tests.common_fixture import find_free_port @@ -33,7 +34,7 @@ def check_fused_precision(args): # fix seed seed_all(1024) # define model - model = PackedFlashBaseLayer1D( + model = InternLM1Decoder( hidden_size=16, # 768 num_attention_heads=2, # 12 mlp_ratio=2, @@ -58,6 +59,7 @@ def check_fused_precision(args): dtype=torch.half, sync_buffer=False, ) + _ = initialize_parallel_communicator(model) model.model.norm1.register_forward_pre_hook(partial(_pre_forward_hook_for_check)) model.model.norm1.register_forward_hook(partial(_post_forward_hook_for_check)) diff --git a/tests/test_model/test_model_internlm.py b/tests/test_model/test_model_internlm.py index 4ed8f535..c33f188c 100644 --- a/tests/test_model/test_model_internlm.py +++ b/tests/test_model/test_model_internlm.py @@ -11,9 +11,19 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import global_context as gpc -from internlm.model.modeling_internlm import PackedFlashBaseLayer1D -from internlm.model.ops.linear import RewardModelLinear, ScaleColumnParallelLinear -from internlm.model.utils import gather_forward_split_backward +from internlm.core.parallel.comm.tensor import ( + HeadTensorParallelCommunicator, + LinearRole, + TensorParallelCommunicator, +) +from internlm.core.parallel.comm.utils import gather_forward_split_backward +from internlm.model.modeling_internlm import InternLM1Decoder +from internlm.model.modules.linear import ( + ColumnParallelLinear, + RowParallelLinear, + ScaleColumnParallelLinear, + new_linear, +) from internlm.utils.common import get_current_device from tests.common_fixture import find_free_port @@ -101,10 +111,18 @@ def check_block(args): # fix seed seed_all(1024) + ColumnParallelLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN) + ) + + RowParallelLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) + ) + # define block blocks = nn.ModuleList( [ - PackedFlashBaseLayer1D( + InternLM1Decoder( hidden_size=4, # 768 num_attention_heads=2, # 12 mlp_ratio=2, @@ -215,9 +233,12 @@ def check_head(args): # fix seed seed_all(1024) + _retain_out_sharded = gpc.config.model.get("parallel_output", True) + _head_comminucator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) + ScaleColumnParallelLinear.register_cls_communicator(_head_comminucator) + # load standard if is_reward: - head_cls = RewardModelLinear standard_result = torch.tensor([[3.5938], [1.0703], [3.6250], [3.6250]], dtype=torch.bfloat16).to(device) standard_grad = torch.tensor( [ @@ -229,7 +250,6 @@ def check_head(args): dtype=torch.bfloat16, ).to(device) else: - head_cls = ScaleColumnParallelLinear standard_result = torch.tensor( [ [3.5938, -2.2188, 2.0312, 3.5625], @@ -250,13 +270,14 @@ def check_head(args): ).to(device) # define head - head = head_cls( + head = new_linear( + name="head", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=torch.bfloat16, + is_reward=is_reward, weight_scale=embed_grad_scale, ) diff --git a/tests/test_model/test_norm.py b/tests/test_model/test_norm.py index 0f5a3a4c..83861b36 100644 --- a/tests/test_model/test_norm.py +++ b/tests/test_model/test_norm.py @@ -3,13 +3,11 @@ import pytest import torch -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm +from internlm.model.modules.norm import new_layer_norm from internlm.utils.common import get_current_device from tests.common_fixture import find_free_port from tests.test_model.test_model_internlm import build_environment, seed_all -RMSNorm = try_import_RMSNorm() - def check_norm(args): # init @@ -24,7 +22,7 @@ def check_norm(args): seed_all(1024) # define norm - norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + norm = new_layer_norm(norm_type="rmsnorm", normalized_shape=hidden_size, eps=layer_norm_epsilon) norm = norm.to(device) # create input diff --git a/tests/test_model/test_npu_ops.py b/tests/test_model/test_npu_ops.py index 7d31bc6d..31a8ba61 100644 --- a/tests/test_model/test_npu_ops.py +++ b/tests/test_model/test_npu_ops.py @@ -16,9 +16,6 @@ CrossAttention, SelfAttention, ) -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm - -RMSNorm = try_import_RMSNorm() HEAD_NUM = 32 HIDDEN_SZIE = 4096 @@ -88,6 +85,7 @@ def do_cmp_attn( softmax_scale=softmax_scale, attention_dropout=attention_dropout, ).to(dtype) + # TODO: 修复它. npu_flash_attn = AscendFlashSelfAttention( causal=True, softmax_scale=softmax_scale, diff --git a/tests/test_solver/test_optimizer.py b/tests/test_solver/test_optimizer.py index 0738ddb3..2c7a93c2 100644 --- a/tests/test_solver/test_optimizer.py +++ b/tests/test_solver/test_optimizer.py @@ -11,8 +11,8 @@ import internlm from internlm.accelerator import get_accelerator -from internlm.core.communication.utils import ParamAsyncBcastHandler from internlm.core.context.parallel_context import Config, ParallelMode +from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index a1f18201..00025a7b 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -16,7 +16,11 @@ from internlm.initialize.launch import args_sanity_check from internlm.model.losses import FlashGPTLMLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook -from internlm.train import initialize_model, initialize_optimizer +from internlm.train import ( + initialize_model, + initialize_optimizer, + initialize_parallel_communicator, +) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -165,6 +169,7 @@ def train_check_output(args): # initialize model model = initialize_model() + _ = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=False, label_smoothing=gpc.config.loss.label_smoothing) @@ -228,6 +233,7 @@ def train_check_output(args): logger.info("Outputs are totally equal") else: logger.warning("Outputs are not totally equal") + print(f"tensor1: {tensor1}, tensor2: {tensor2}", flush=True) max_diff, index_max_diff = (tensor1 - tensor2).abs().max(dim=0) max_diff = max_diff.item() index_max_diff = index_max_diff.item() diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py index a09191f9..0cd22145 100644 --- a/tests/test_training/test_load_ckpt_loss.py +++ b/tests/test_training/test_load_ckpt_loss.py @@ -46,6 +46,7 @@ from internlm.train import ( # noqa: E402 #pylint: disable=wrong-import-position initialize_model, initialize_optimizer, + initialize_parallel_communicator, load_new_batch, ) from internlm.utils.common import ( # noqa: E402 #pylint: disable=wrong-import-position @@ -67,7 +68,7 @@ zero1=dict(size=-1, fsdp=False), pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, - tensor=1, + tensor=dict(size=1, mode="mtp"), ), data=dict( seq_len=2048, @@ -218,6 +219,7 @@ def train_model(args): # initialize model model = initialize_model() + _ = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index cd97da86..b757b4c0 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -16,9 +16,9 @@ from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, - initialize_isp_communicator, initialize_model, initialize_optimizer, + initialize_parallel_communicator, load_new_batch, ) from internlm.utils.common import BatchSkipper, get_current_device, launch_time @@ -134,7 +134,7 @@ def train( model = initialize_model() # initialize isp communicator - isp_communicator = initialize_isp_communicator(model) + isp_communicator = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) diff --git a/tests/test_training/test_no_fa_train_temp.py b/tests/test_training/test_no_fa_train_temp.py index 419d08c1..d430c16a 100644 --- a/tests/test_training/test_no_fa_train_temp.py +++ b/tests/test_training/test_no_fa_train_temp.py @@ -11,9 +11,9 @@ from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, - initialize_isp_communicator, initialize_model, initialize_optimizer, + initialize_parallel_communicator, ) from internlm.utils.logger import get_logger from tests.common_fixture import ( @@ -54,7 +54,7 @@ def train_check(args): model = initialize_model() # initialize isp communicator - isp_communicator = initialize_isp_communicator(model) + isp_communicator = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) diff --git a/tests/test_training/test_norm_weight.py b/tests/test_training/test_norm_weight.py index 5c45677e..848cf740 100644 --- a/tests/test_training/test_norm_weight.py +++ b/tests/test_training/test_norm_weight.py @@ -14,9 +14,9 @@ from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, - initialize_isp_communicator, initialize_model, initialize_optimizer, + initialize_parallel_communicator, ) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -74,7 +74,7 @@ def train_check_norm_weight(args): model = initialize_model() # initialize isp communicator - isp_communicator = initialize_isp_communicator(model) + isp_communicator = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index 48f05f47..f6e52382 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -22,7 +22,11 @@ from internlm.initialize.launch import args_sanity_check from internlm.model.losses import FlashGPTLMLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook -from internlm.train import initialize_model, initialize_optimizer +from internlm.train import ( + initialize_model, + initialize_optimizer, + initialize_parallel_communicator, +) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -266,6 +270,7 @@ def exam_loss(args): # initialize model model = initialize_model() + _ = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index 381a8b7c..6a3549b0 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -38,6 +38,7 @@ initialize_llm_profile, initialize_model, initialize_optimizer, + initialize_parallel_communicator, record_current_batch_training_metrics, ) from internlm.utils.common import ( # noqa: E402 @@ -62,7 +63,12 @@ def check_model_weights(model, ckpt_path, total_equal=False): model1_dict = torch.load(ckpt_path, map_location="cuda") model2_dict = model.state_dict() - for key in model2_dict.keys(): + copy_of_ordered_dict = model2_dict.copy() + + for key in copy_of_ordered_dict.keys(): + if "wqkv" in key: + model2_dict[key.replace("wqkv", "Wqkv")] = model2_dict.pop(key) + key = key.replace("wqkv", "Wqkv") if key not in model1_dict: assert False, f"Error: The key {key} for current model dose not exist in standard ckpt!" @@ -109,6 +115,7 @@ def main(args): # initialize model model = initialize_model() + _ = initialize_parallel_communicator(model) with open(args.config, "r") as f: config_lines = f.readlines() diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index 499692e9..b8b56ec6 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -7,6 +7,9 @@ from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config +from internlm.core.naive_amp import NaiveAMPModel +from internlm.model.builder import create_model +from internlm.model.registry import register_model_initializer from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer from internlm.train.utils import create_param_groups from internlm.utils.storage_manager import SingletonMeta @@ -87,13 +90,8 @@ def init_naive_model(): - # let MODEL_INITIALIZER to work - import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import - import internlm.model.modeling_moe # noqa # pylint: disable=unused-import - from internlm.core.naive_amp import NaiveAMPModel - from internlm.utils.registry import MODEL_INITIALIZER - - model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(init_config.model)) + register_model_initializer() + model = create_model(model_type=gpc.config.model_type, **(init_config.model)) model = NaiveAMPModel( model=model, output_to_fp32=False, diff --git a/tools/load_internlm_model.py b/tools/load_internlm_model.py index 98e6ad53..3de52c22 100644 --- a/tools/load_internlm_model.py +++ b/tools/load_internlm_model.py @@ -10,8 +10,8 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.initialize.launch import launch_from_torch +from internlm.model.registry import model_initializer from internlm.train import initialize_model -from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.storage_manager import get_fns, init_storage_manager, llm_load from tools.interface import GenerationConfig @@ -172,7 +172,8 @@ def initialize_internlm_model( model_config["dtype"] = param_dtype model_config["parallel_output"] = False - match_fn_signature(MODEL_INITIALIZER.get_module(model_type), model_config) + # FIXME: fix it. + match_fn_signature(model_initializer.get_module(model_type), model_config) if gpc.is_rank_for_log(): logger.info(f"model_config: {model_config}.") launch_from_torch( diff --git a/train.py b/train.py index 985ac5de..f94506d4 100644 --- a/train.py +++ b/train.py @@ -26,10 +26,10 @@ from internlm.monitor.monitor import monitor_manager as mm from internlm.train import ( get_scheduler_hooks, - initialize_isp_communicator, initialize_llm_profile, initialize_model, initialize_optimizer, + initialize_parallel_communicator, load_new_batch, record_current_batch_training_metrics, ) @@ -87,7 +87,7 @@ def main(args): model = initialize_model() # initialize isp communicator - isp_communicator = initialize_isp_communicator(model) + isp_communicator = initialize_parallel_communicator(model) with open(args.config, "r") as f: config_lines = f.readlines()