From 9b67f33a71073241fd1798b435639d217c31dd35 Mon Sep 17 00:00:00 2001 From: geruijun Date: Tue, 16 Jan 2024 18:18:56 +0800 Subject: [PATCH] remove dependency of flash_attn when use_flash_attn is set to false --- internlm/model/embedding.py | 183 +++++++++++++++++-- internlm/model/linear.py | 102 ++++++++++- internlm/model/loss.py | 7 +- internlm/model/metrics.py | 15 +- internlm/model/modeling_internlm.py | 10 +- internlm/model/modeling_moe.py | 10 +- internlm/model/multi_head_attention.py | 238 +++++++++++++++++++++---- internlm/model/utils.py | 124 ++++++++++++- internlm/train/training_internlm.py | 7 +- internlm/utils/gputest.py | 10 +- internlm/utils/model_checkpoint.py | 5 +- 11 files changed, 625 insertions(+), 86 deletions(-) diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index d1770538b..27ff04e93 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -3,12 +3,9 @@ from typing import Tuple -import rotary_emb import torch import torch.nn.functional as F from einops import rearrange -from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb -from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_ from torch import Tensor, nn from internlm.core.context import ParallelMode @@ -63,6 +60,22 @@ def forward(self, input_: Tensor) -> Tensor: return output +def apply_rotary_torch(x1, x2, cos, sin, conj): + assert x1.device == x2.device == cos.device == sin.device, "All inputs must be on the same device" + assert x1.dtype == x2.dtype == cos.dtype == sin.dtype, "All inputs must have the same dtype" + assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes" + assert cos.size() == sin.size(), "Input cos and sin must have the same sizes" + + if conj: + out1 = x1 * cos + x2 * sin + out2 = -x1 * sin + x2 * cos + else: + out1 = x1 * cos - x2 * sin + out2 = x1 * sin + x2 * cos + + return out1, out2 + + class ApplyRotaryEmbQKV_(torch.autograd.Function): """ ApplyRotaryEmbQKV_ @@ -86,11 +99,23 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None): sin_k = sin if sin_k is None else sin_k assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1) - rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False) + if gpc.config.model.use_flash_attn: + import rotary_emb + + rotary_emb.apply_rotary( + q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False + ) + else: + q1, q2 = apply_rotary_torch(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), False) k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1) - rotary_emb.apply_rotary( - k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False - ) + if gpc.config.model.use_flash_attn: + rotary_emb.apply_rotary( + k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False + ) + else: + k1, k2 = apply_rotary_torch( + k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), False + ) ctx.save_for_backward(cos, sin, cos_k, sin_k) return qkv @@ -100,19 +125,130 @@ def backward(ctx, dqkv): rotary_dim = cos.shape[-1] rotary_dim *= 2 dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1) - rotary_emb.apply_rotary( - dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True - ) + if gpc.config.model.use_flash_attn: + import rotary_emb + + rotary_emb.apply_rotary( + dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True + ) + else: + dq1, dq2 = apply_rotary_torch( + dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), True + ) dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1) - rotary_emb.apply_rotary( - dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True - ) + if gpc.config.model.use_flash_attn: + rotary_emb.apply_rotary( + dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True + ) + else: + dk1, dk2 = apply_rotary_torch( + dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), True + ) return dqkv, None, None, None, None +class TorchApplyRotaryEmb(torch.autograd.Function): + """ + TorchApplyRotaryEmb + """ + + @staticmethod + def forward(ctx, x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + _, seqlen, _, headdim = x.shape + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) + x_ro = x[..., :rotary_dim] + x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) + x1, x2 = apply_rotary_torch( + x1, x2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), False + ) + ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved + return x + + @staticmethod + def backward(ctx, do): + cos, sin = ctx.saved_tensors + _, seqlen, _, _ = do.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + do_ro = do[..., :rotary_dim] + do1, do2 = do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) + do1, do2 = apply_rotary_torch( + do1, do2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), True + ) + return do, None, None, None, None + + +class TorchApplyRotaryEmbQKV_(torch.autograd.Function): + """ + TorchApplyRotaryEmbQKV_ + """ + + @staticmethod + def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): + """ + qkv: (batch_size, seqlen, 3, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + """ + _, seqlen, three, _, headdim = qkv.shape + assert three == 3 + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) + q_ro = qkv[:, :, 0, :, :rotary_dim] + q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2]) + q1, q2 = apply_rotary_torch( + q1, q2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), False + ) + k_ro = qkv[:, :, 1, :, :rotary_dim] + k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) + k1, k2 = apply_rotary_torch( + k1, k2, rearrange(cos_k[:seqlen], "s d -> s 1 d"), rearrange(sin_k[:seqlen], "s d -> s 1 d"), False + ) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.interleaved = interleaved + return qkv + + @staticmethod + def backward(ctx, dqkv): + cos, sin, cos_k, sin_k = ctx.saved_tensors + _, seqlen, _, _, _ = dqkv.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + dq_ro = dqkv[:, :, 0, :, :rotary_dim] + dq1, dq2 = dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2]) + dq1, dq2 = apply_rotary_torch( + dq1, dq2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), True + ) + dk_ro = dqkv[:, :, 1, :, :rotary_dim] + dk1, dk2 = dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2]) + dk1, dk2 = apply_rotary_torch( + dk1, dk2, rearrange(cos_k[:seqlen], "s d -> s 1 d"), rearrange(sin_k[:seqlen], "s d -> s 1 d"), True + ) + return dqkv, None, None, None, None, None + + apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply -legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply -legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply class RotaryEmbedding(torch.nn.Module): @@ -202,12 +338,27 @@ def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Te self._sin_k_cached[indexes], ) + def _get_legacy_apply_rotary_functions(self): + if gpc.config.model.use_flash_attn: + from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb + from flash_attn.layers.rotary import ( + ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_, + ) + + legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply + legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply + else: + legacy_apply_rotary_embed_qkv = TorchApplyRotaryEmbQKV_.apply + legacy_apply_rotary_embed = TorchApplyRotaryEmb.apply + return legacy_apply_rotary_embed_qkv, legacy_apply_rotary_embed + def _eval_forward(self, qkv, seqlen_offset=0): """ seqlen_offset: can be used in generation where the qkv being passed in is only the last token in the batch. """ self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1]) + legacy_apply_rotary_embed_qkv, _ = self._get_legacy_apply_rotary_functions() if self.scale is None: return legacy_apply_rotary_embed_qkv( qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:] @@ -225,12 +376,14 @@ def _single_forward(self, x, indexes=0): assert self.scale is None self._update_cos_sin_cache(x, indexes) x = x[None, ...] + _, legacy_apply_rotary_embed = self._get_legacy_apply_rotary_functions() ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0) return ret def _single_eval_forward(self, x, seqlen_offset=0): assert self.scale is None self._update_cos_sin_cache(x, seqlen_offset + x.shape[1]) + _, legacy_apply_rotary_embed = self._get_legacy_apply_rotary_functions() return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) diff --git a/internlm/model/linear.py b/internlm/model/linear.py index d18308a89..ae2af069f 100644 --- a/internlm/model/linear.py +++ b/internlm/model/linear.py @@ -4,13 +4,17 @@ from typing import Optional import torch -from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear -from flash_attn.utils.distributed import all_reduce, reduce_scatter from torch import nn +from torch.distributed import ProcessGroup from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.utils import Silu, fused_dense_func_torch +from internlm.model.utils import ( + Silu, + all_reduce, + fused_dense_func_torch, + reduce_scatter, +) class ScaleColumnParallelLinear(nn.Linear): @@ -114,7 +118,47 @@ def forward(self, input): # pylint: disable=W0622 ) -class ColumnParallelLinearTorch(ColumnParallelLinear): +class ColumnParallelLinearTorch(nn.Linear): + """ + ColumnParallelLinearTorch. + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul. + If not, then the input is already gathered. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + weight_scale (int): For training stability. 1 by default. + """ + + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + multiple_of=1, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + if out_features % multiple_of: + raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") + multiple = out_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) + super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + def forward(self, x): # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: # we do an all_gather of x before doing the matmul. @@ -125,7 +169,55 @@ def forward(self, x): ) -class RowParallelLinearTorch(RowParallelLinear): +class RowParallelLinearTorch(nn.Linear): + """ + RowParallelLinearTorch. + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul. + If not, then the input is already gathered. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + weight_scale (int): For training stability. 1 by default. + """ + + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + multiple_of=1, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + rank = torch.distributed.get_rank(process_group) + if in_features % multiple_of: + raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") + multiple = in_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) + # Only rank 0 will have bias + super().__init__( + local_multiple * multiple_of, + out_features, + bias=bias and rank == 0, + device=device, + dtype=dtype, + ) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + def forward(self, x): """ We're doing Tensor Parallel with sequence parallelism: we do the matmul and then diff --git a/internlm/model/loss.py b/internlm/model/loss.py index ac92b4b97..4c405d1e9 100644 --- a/internlm/model/loss.py +++ b/internlm/model/loss.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss from torch import nn from internlm.core.context import ParallelMode @@ -24,7 +23,11 @@ def __init__(self, parallel_output=True, label_smoothing=0): label_smoothing = 0 self.label_smoothing = label_smoothing - if parallel_output: + if gpc.config.model.use_flash_attn and parallel_output: + from flash_attn.losses.cross_entropy import ( + CrossEntropyLoss as FlashCrossEntropyLoss, + ) + self.loss_fn = FlashCrossEntropyLoss( reduction="mean", inplace_backward=True, diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 55e0219a6..14c9902fe 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -1,7 +1,7 @@ from typing import List import torch -from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss +from torch import nn from torch_scatter import scatter from internlm.core.context import ParallelMode @@ -208,9 +208,16 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None: self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device) self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device) - self.loss_fn = FlashCrossEntropyLoss( - reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR) - ) + if gpc.config.model.use_flash_attn: + from flash_attn.losses.cross_entropy import ( + CrossEntropyLoss as FlashCrossEntropyLoss, + ) + + self.loss_fn = FlashCrossEntropyLoss( + reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR) + ) + else: + self.loss_fn = nn.CrossEntropyLoss(reduction="none") def update(self, logits, labels, type_ids=None): with torch.no_grad(): diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index a47a5cdd1..ac7b3c6b6 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -6,8 +6,6 @@ from typing import Optional import torch -from flash_attn.modules.embedding import ParallelGPT2Embeddings -from flash_attn.modules.mlp import ParallelFusedMLP from torch import nn from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode @@ -114,7 +112,7 @@ def __init__( self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - if use_swiglu: + if use_swiglu or not use_flash_attn: self.mlp = FeedForward( hidden_size, int(hidden_size * mlp_ratio), @@ -125,6 +123,8 @@ def __init__( dtype=dtype, ) else: + from flash_attn.modules.mlp import ParallelFusedMLP + self.mlp = ParallelFusedMLP( hidden_size, int(hidden_size * mlp_ratio), @@ -311,9 +311,11 @@ def __init__( else: head_cls = ScaleColumnParallelLinear if first: - if embed_split_hidden: + if embed_split_hidden or not use_flash_attn: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) else: + from flash_attn.modules.embedding import ParallelGPT2Embeddings + self.embedding = ParallelGPT2Embeddings( embed_dim=hidden_size, vocab_size=vocab_size, diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index df6c7a846..63daea68a 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -5,8 +5,6 @@ from typing import Optional import torch -from flash_attn.modules.embedding import ParallelGPT2Embeddings -from flash_attn.modules.mlp import ParallelFusedMLP from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode @@ -141,7 +139,7 @@ def __init__( self.moe_use_residual = moe_use_residual ep_size = gpc.get_world_size(ParallelMode.EXPERT) if num_experts <= 1: # dense, not MoE - if use_swiglu: + if use_swiglu or not use_flash_attn: self.mlp = FeedForward( hidden_size, int(hidden_size * mlp_ratio), @@ -152,6 +150,8 @@ def __init__( dtype=dtype, ) else: + from flash_attn.modules.mlp import ParallelFusedMLP + self.mlp = ParallelFusedMLP( hidden_size, int(hidden_size * mlp_ratio), @@ -375,9 +375,11 @@ def __init__( else: head_cls = ScaleColumnParallelLinear if first: - if embed_split_hidden: + if embed_split_hidden or not use_flash_attn: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) else: + from flash_attn.modules.embedding import ParallelGPT2Embeddings + self.embedding = ParallelGPT2Embeddings( embed_dim=hidden_size, vocab_size=vocab_size, diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index e28db6ac6..92179f205 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -8,29 +8,6 @@ import torch import torch.nn.functional as F from einops import rearrange - -try: - from flash_attn.flash_attn_interface import flash_attn_unpadded_func -except ImportError: - try: - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func, - ) - except ImportError: - try: - from flash_attn.flash_attn_interface import ( - flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func, - ) - except ImportError: - raise ImportError("Please check your flash_attn version >= 1.0.5.") - -from flash_attn.modules.mha import ( - CrossAttention, - FlashCrossAttention, - FlashSelfAttention, - SelfAttention, - _update_kv_cache, -) from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode @@ -39,6 +16,162 @@ from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch +class SelfAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + + def forward(self, qkv, causal=None, key_padding_mask=None): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) + causal: if passed, will override self.causal + key_padding_mask: boolean mask to apply to the attention weights. True means to keep, + False means to mask out. (B, S) + """ + batch_size, seqlen = qkv.shape[0], qkv.shape[1] + causal = self.causal if causal is None else causal + q, k, v = qkv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = self.drop(attention) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + return output + + +class CrossAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + + def forward(self, q, kv, causal=None, key_padding_mask=None): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, Sq, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H, D) + causal: if passed, will override self.causal + key_padding_mask: boolean mask to apply to the attention weights. True means to keep, + False means to mask out. (B, Sk) + """ + batch_size, seqlen_q = q.shape[0], q.shape[1] + causal = self.causal if causal is None else causal + seqlen_k = kv.shape[1] + assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] + k, v = kv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = self.drop(attention) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + return output + + +def _update_kv_cache(kv, inference_params, layer_idx): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + # Pre-allocate memory for key-values for inference. + num_heads, head_dim = kv.shape[-2:] + if layer_idx not in inference_params.key_value_memory_dict: + kv_cache = torch.empty( + inference_params.max_batch_size, + inference_params.max_sequence_len, + 2, + num_heads, + head_dim, + dtype=kv.dtype, + device=kv.device, + ) + inference_params.key_value_memory_dict[layer_idx] = kv_cache + else: + if not inference_params.fused_ft_kernel: + kv_cache = inference_params.key_value_memory_dict[layer_idx] + else: + # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) + # where packsize = 4 if fp32, 8 if fp16 or bf16. + # v_cache has shape (b, h, s, headdim) + k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] + kv_cache = None + # Adjust key and value for inference + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + kv.shape[1] + assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) + assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) + # Copy key and values. + if not inference_params.fused_ft_kernel: + assert kv_cache is not None + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + return kv + else: + assert inference_params.sequence_len_offset == 0 + # FT kernel requires different layouts for the k_cache and v_cache. + assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if kv.dtype == torch.float32 else 8 + if kv_cache is not None: + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + k_cache = rearrange( + kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize + ).contiguous() + v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous() + inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) + else: + k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( + kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize + ) + v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d") + return kv + + class MHA(nn.Module): """ Multi-head self-attention and cross-attention. @@ -47,23 +180,19 @@ class MHA(nn.Module): embed_dim (int): The dimention of hidden state. num_heads (int): The number of attention heads. process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. - bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and - output projection. True by default. + max_position_embeddings (int): max position embeddings, 2048 by default. dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. softmax_scale (float): The temperature to use for the softmax attention. causal (boolean): Whether to apply causal attention mask. False by default. layer_idx (int): The index of current layer. None by default. + use_dynamic_ntk_rope (bool): whether use dynamic ntk rope, false by default. rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. - use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used. - False by default. - sequence_parallel (boolean): If True, we're doing Tensor Parallel with sequence parallelism. An all_gather_raw - of x will be done before doing the matmul. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. use_flash_attn (bool): Whether to use flash-attn. True by default. rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. """ @@ -123,8 +252,14 @@ def __init__( **factory_kwargs, ) # according to https://spaces.ac.cn/archives/9577 - inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention - inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + if gpc.config.model.use_flash_attn: + from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention + + inner_attn_cls = FlashSelfAttention + inner_cross_attn_cls = FlashCrossAttention + else: + inner_attn_cls = SelfAttention + inner_cross_attn_cls = CrossAttention self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) self.inner_cross_attn = inner_cross_attn_cls( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout @@ -300,9 +435,40 @@ def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: if total_kv.dtype not in [torch.float16, torch.bfloat16]: total_kv = total_kv.to(torch.bfloat16) - output = flash_attn_unpadded_func( - total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False - ).to(x.dtype) + if gpc.config.model.use_flash_attn: + try: + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_func, + ) + except ImportError: + try: + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func, + ) + except ImportError: + try: + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func, + ) + except ImportError: + raise ImportError("Please check your flash_attn version >= 1.0.5.") + + output = flash_attn_unpadded_func( + total_q, + total_kv, + cu_seqlens, + cu_seqlens, + max_seqlen_q, + max_seqlen_k, + 0.0, + None, + True, + False, + ).to(x.dtype) + else: + attn_scores = torch.matmul(total_q, total_kv.transpose(-2, -1)) / (cu_seqlens**0.5) + attn_weights = F.softmax(attn_scores, dim=-1) + output = torch.matmul(attn_weights, total_kv) context = torch.zeros_like(q) context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 46fba5920..409d83f0a 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -5,14 +5,8 @@ import torch import torch.nn.functional as F -from flash_attn.ops.fused_dense import FusedDenseFunc -from flash_attn.utils.distributed import ( - all_gather_raw, - all_reduce_raw, - reduce_scatter_raw, -) from torch import Tensor -from torch.cuda.amp import custom_bwd +from torch.cuda.amp import custom_bwd, custom_fwd from torch.distributed import ProcessGroup from internlm.core.context import global_context as gpc @@ -21,6 +15,72 @@ logger = get_logger(__file__) +# Raw operation, does not support autograd, but does support async +def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device) + handle = torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) + return output, handle + + +# Raw operation, does not support autograd, but does support async +def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + assert input_.shape[0] % world_size == 0 + output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device) + handle = torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) + return output, handle + + +# Raw operation, does not support autograd, but does support async +def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + input_ = input_.contiguous() + handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) + return input_, handle + + +class ReduceScatterFunc(torch.autograd.Function): + """Reduce scatter the input from the sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = reduce_scatter_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = all_gather_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +reduce_scatter = ReduceScatterFunc.apply + + +class AllReduceFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_reduce_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + _ = ctx # avoid lint warning W0613 + return grad_output, None + + +# Supports autograd, but does not support async +all_reduce = AllReduceFunc.apply + + def _split(input_, parallel_mode, dim=-1): # skip if only one rank involved world_size = gpc.get_world_size(parallel_mode) @@ -96,9 +156,47 @@ def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py -class FusedDenseFuncTorch(FusedDenseFunc): +class FusedDenseFuncTorch(torch.autograd.Function): """A custom PyTorch module extending FusedDenseFunc.""" + @staticmethod + @custom_fwd + def forward(ctx, x, weight, bias, return_residual=False, process_group=None, sequence_parallel=True): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather_raw of x before doing the matmul. + """ + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + weight = weight.contiguous() + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + if min(batch_dim, n, *weight.shape) > 65535 * 32: + raise RuntimeError("fused_dense only supports matrix dims <= 2M") + output = F.linear(total_x, weight, bias) + if ctx.compute_weight_gradient: + ctx.save_for_backward(x, weight) + else: + ctx.save_for_backward(weight) + return output if not return_residual else (output, x) + @staticmethod @custom_bwd def backward(ctx, grad_output, *args): @@ -158,7 +256,15 @@ def fused_dense_func_torch( dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() ) - if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: + if ( + gpc.config.model.use_flash_attn + and x.is_cuda + and weight.is_cuda + and (bias is None or bias.is_cuda) + and dtype_eligible + ): + from flash_attn.ops.fused_dense import FusedDenseFunc + return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, sequence_parallel) else: return FusedDenseFuncTorch.apply(x, weight, bias, return_residual, process_group, sequence_parallel) diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 474bfd2a9..d838e8b80 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -10,8 +10,6 @@ import torch import torch.distributed as dist -from flash_attn.modules.embedding import ParallelGPT2Embeddings -from flash_attn.modules.mlp import ParallelFusedMLP from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.fully_sharded_data_parallel import ( @@ -122,7 +120,10 @@ def initialize_model(): def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): - if gpc.config.parallel.zero1.fsdp: + if gpc.config.parallel.zero1.fsdp and gpc.config.model.use_flash_attn: + from flash_attn.modules.embedding import ParallelGPT2Embeddings + from flash_attn.modules.mlp import ParallelFusedMLP + # set wrap_policy for fsdp wrap transformer_wrap_policy = functools.partial( transformer_auto_wrap_policy, diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 48ec0e350..4085b8790 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -7,9 +7,9 @@ import torch import torch.distributed as dist -from flash_attn.modules.mha import FlashSelfAttention, SelfAttention from torch.utils import benchmark +from internlm.model.multi_head_attention import SelfAttention from internlm.monitor import send_alert_message from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer @@ -233,7 +233,13 @@ def bench_gpu(use_flash_attn=True): batch_size, seqlen = 2, 1024 nheads = dim // headdim - inner_attn = FlashSelfAttention if use_flash_attn else SelfAttention + if use_flash_attn: + from flash_attn.modules.mha import FlashSelfAttention + + inner_attn = FlashSelfAttention + else: + inner_attn = SelfAttention + inner_attn = inner_attn(causal=True, softmax_scale=None, attention_dropout=0) qkv = torch.randn( diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index b9326deb8..464b9878a 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -157,13 +157,14 @@ def get_model_topology(model): concatenated along the dimension 'dim'. """ - from flash_attn.modules.embedding import VocabParallelEmbedding + if gpc.config.model.use_flash_attn: + from flash_attn.modules.embedding import VocabParallelEmbedding topos = {} for name, module in model.named_modules(): # If it does not meet these conditions, it is shared between various tp/dp, and it is necessary to assert # that they are consistent. - if isinstance(module, VocabParallelEmbedding): + if gpc.config.model.use_flash_attn and isinstance(module, VocabParallelEmbedding): topos[name] = {"dim": 0} return topos