diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index ca3e5507..5406dc7e 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -900,7 +900,12 @@ def auto_wrap_distributed_attention(cls: nn.Module) -> Callable[[bool, Any, floa def _attetion_constructor( local_attn_cls: type, causal=False, softmax_scale=None, attention_dropout=0.0 ) -> nn.Module: - if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp": + try: + tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") + except AttributeError: + tp_mode = "mtp" + + if tp_mode != "isp": return local_attn_cls(causal, softmax_scale, attention_dropout) else: return DistributedAttention( diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index 9205652a..9bb53806 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -79,7 +79,7 @@ def _nyi_attn(func_name, *args, **kwargs): # pylint: disable=W0613 def _flash_float32_compatibility_wrapper(input_idxs: Tuple, flash_func: Callable, *args, **kwargs): if gpc.config.model.dtype is torch.float32: - inputs = (args[idx] for idx in input_idxs) + inputs = [args[idx] for idx in input_idxs] input_dtype = inputs[0].dtype other_args = [args[idx] for idx in range(len(inputs), len(args))] @@ -194,10 +194,35 @@ def _flash_fixedlen_qkvsplited_attn(q, k, v, dropout_p=0.0, softmax_scale=None, # npu flash attention operators -# TODO: should we add _flash_float32_compatibility_wrapper support for npu. +def _npu_varlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, # pylint: disable=W0613 + max_seqlen_k, # pylint: disable=W0613 + dropout_p=0.0, + softmax_scale=None, + causal=False, +): + return _flash_float32_compatibility_wrapper( + (0, 1, 2), + _npu_varlen_qkvsplited_func, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + ) -def _npu_varlen_qkvsplited_attn( +def _npu_varlen_qkvsplited_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, @@ -208,17 +233,32 @@ def _npu_varlen_qkvsplited_attn( dropout_p=0.0, softmax_scale=None, causal=False, + use_fixlen=False, ): - # TODO: support npu native varlen flash attention + """Support Huawei Ascend's torch_npu flash attention. + Tested version: + torch: 2.1.0+cpu + torch_npu: 2.1.0.post3+git7c4136d + cann: 8.0.RC1.alpha003 + """ packed_length = q.size(dim=1) + softmax_scale = softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) - k = unpack_qkv_before_attn(k, cu_seqlens=cu_seqlens_k) - v = unpack_qkv_before_attn(v, cu_seqlens=cu_seqlens_k) + if use_fixlen: - output = _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal) + q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) + k = unpack_qkv_before_attn(k, cu_seqlens=cu_seqlens_k) + v = unpack_qkv_before_attn(v, cu_seqlens=cu_seqlens_k) - return pack_output_after_attn(output, cu_seqlens_q, packed_length) + output = _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal) + + output = pack_output_after_attn(output, cu_seqlens_q, packed_length) + else: + output = _npu_fused_varlen_qkvsplited_attn( + q, k, v, dropout_p, softmax_scale, causal, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k + ) + + return output def _npu_fixedlen_qkvsplited_attn( @@ -236,6 +276,7 @@ def _npu_fixedlen_qkvsplited_attn( q, k, v = q.squeeze(dim=2), k.squeeze(dim=2), v.squeeze(dim=2) _, seqlen, n_head, _ = q.shape + sparse_mode = 0 attention_mask = torch.triu(torch.ones(seqlen, seqlen, device=get_current_device()), 1).bool() return _origin_npu_fixedlen_qkvsplited_func( @@ -247,25 +288,71 @@ def _npu_fixedlen_qkvsplited_attn( pse=None, atten_mask=attention_mask, scale=softmax_scale, - sparse_mode=0, # If necessary, expose the interface + sparse_mode=sparse_mode, # If necessary, expose the interface pre_tockens=seqlen, # Used for sparse calculations, representing the left boundary of the slides window next_tockens=0, # If necessary, expose the interface keep_prob=1 - dropout_p, inner_precise=0, # If necessary, expose the interface - ) + )[0] -def _npu_varlen_qkvpacked_attn( - qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613 +def _npu_fused_varlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float, + softmax_scale=None, + causal=False, + max_seqlen_q: int = None, + max_seqlen_k: int = None, + cu_seqlens_q=None, + cu_seqlens_kv=None, + deterministic=False, ): - # TODO: support npu native varlen flash attention - packed_length = qkv.size(dim=1) + assert causal is True + assert q.dtype in (torch.bfloat16, torch.float16) - qkv = unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens) + if len(q.shape) == 4: # [1, packedseqlen, n_head, headdim] + q, k, v = q.squeeze(dim=0), k.squeeze(dim=0), v.squeeze(dim=0) - output = _npu_fixedlen_qkvpacked_attn(qkv, dropout_p, softmax_scale, causal) + S, N = max(max_seqlen_q, max_seqlen_k), q.shape[1] + device = get_current_device() + sparse_mode = 0 - return pack_output_after_attn(output, cu_seqlens, packed_length) + if max_seqlen_k > 2048 and max_seqlen_q > 2048: + sparse_mode = 2 + max_seqlen_k = 2048 + max_seqlen_q = 2048 + + attention_mask = torch.triu(torch.ones(max_seqlen_q, max_seqlen_k, device=device), 1).bool() + cu_seqlens_q = cu_seqlens_q[1:].tolist() + cu_seqlens_kv = cu_seqlens_kv[1:].tolist() + + return _origin_npu_fixedlen_qkvsplited_func( + query=q, + key=k, + value=v, + head_num=N, + input_layout="TND", + pse=None, + atten_mask=attention_mask, + scale=softmax_scale, + sparse_mode=sparse_mode, + pre_tockens=S, # Used for sparse calculations, representing the left boundary of the slides window + next_tockens=0, + keep_prob=1 - dropout_p, + inner_precise=0 if not deterministic else 2, + actual_seq_kvlen=cu_seqlens_kv, + actual_seq_qlen=cu_seqlens_q, + )[0].unsqueeze(dim=0) + + +def _npu_varlen_qkvpacked_attn( + qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613 +): + # TODO: support npu native varlen flash attention + q, k, v = qkv.unbind(dim=2) + return _npu_varlen_qkvsplited_attn(q, k, v, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal) def _npu_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False): @@ -285,14 +372,20 @@ def _npu_varlen_kvpacked_attn( causal=False, ): # TODO: support npu native varlen flash attention - packed_length = q.size(dim=1) - - q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) - kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k) - - output = _npu_fixedlen_kvpacked_attn(q, kv, dropout_p, softmax_scale, causal) - - return pack_output_after_attn(output, cu_seqlens_q, packed_length) + k, v = kv.unbind(dim=2) + k, v = k.squeeze(dim=2), v.squeeze(dim=2) + return _npu_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + ) def _npu_fixedlen_kvpacked_attn(q: torch.Tensor, kv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False): @@ -335,12 +428,6 @@ def _deeplink_fixedlen_qkvsplited_attn(*args, **kwargs): # torch attention operators - - -def _torch_varlen_qkvpacked_attn(*args, **kwargs): - _nyi_attn("_torch_varlen_qkvpacked_attn", *args, **kwargs) - - # adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None): batch_size, seqlen = qkv.shape[0], qkv.shape[1] @@ -369,10 +456,6 @@ def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=Non return output -def _torch_varlen_kvpacked_attn(*args, **kwargs): - _nyi_attn("_torch_varlen_kvpacked_attn", *args, **kwargs) - - # adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py def _torch_fixedlen_kvpacked_attn( q: torch.Tensor, kv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None @@ -407,10 +490,6 @@ def _torch_fixedlen_kvpacked_attn( return output -def _torch_varlen_qkvsplited_attn(*args, **kwargs): - _nyi_attn("_torch_varlen_qkvsplited_attn", *args, **kwargs) - - def _torch_fixedlen_qkvsplited_attn( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None ): @@ -418,6 +497,71 @@ def _torch_fixedlen_qkvsplited_attn( return _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask) +def _torch_varlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, # pylint: disable=W0613 + max_seqlen_k, # pylint: disable=W0613 + dropout, + softmax_scale=None, + causal=False, + key_padding_mask=None, +): + kv = torch.stack([k, v], dim=2) + packed_length = q.size(dim=1) + + q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) + kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k) + + output = _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask) + + return pack_output_after_attn(output, cu_seqlens_q, packed_length) + + +def _torch_varlen_qkvpacked_attn( + qkv: torch.Tensor, + cu_seqlens, + max_seqlen, # pylint: disable=W0613 + dropout, + softmax_scale=None, + causal=False, + key_padding_mask=None, +): + + packed_length = qkv.size(dim=1) + qkv = unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens) + + output = _torch_fixedlen_qkvpacked_attn(qkv, dropout, softmax_scale, causal, key_padding_mask) + + return pack_output_after_attn(output, cu_seqlens, packed_length) + + +def _torch_varlen_kvpacked_attn( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, # pylint: disable=W0613 + max_seqlen_k, # pylint: disable=W0613 + dropout, + softmax_scale=None, + causal=False, + key_padding_mask=None, +): + + packed_length = q.size(dim=1) + + q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) + kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k) + + output = _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask) + + return pack_output_after_attn(output, cu_seqlens_q, packed_length) + + @auto_wrap_distributed_attention class SelfAttention(nn.Module): """Implements scaled dot-product attention with optional softmax scaling. diff --git a/tests/test_model/test_npu_ops/test_flash_attention.py b/tests/test_model/test_npu_ops/test_flash_attention.py index 31a8ba61..8ab300e5 100644 --- a/tests/test_model/test_npu_ops/test_flash_attention.py +++ b/tests/test_model/test_npu_ops/test_flash_attention.py @@ -3,6 +3,7 @@ """ import math +import random import pytest import torch @@ -11,152 +12,205 @@ from torch import nn from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.model.modules.multi_head_attention import ( - AscendFlashSelfAttention, - CrossAttention, - SelfAttention, -) +from internlm.core.context import Config +from internlm.core.context import global_context as gpc +from internlm.model.ops.attention import SelfAttention +from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn +from internlm.utils.common import get_current_device, set_random_seed HEAD_NUM = 32 HIDDEN_SZIE = 4096 -SEQ_LEN = 2048 +SEQ_LEN = [2048, 4096] MICRO_BSZ = 1 HEAD_DIM = HIDDEN_SZIE // HEAD_NUM VOCAB_SIZE = 32000 - +NUM_KV_HEAD_LIST = [8, 32] MICRO_BSZ_LIST = [1, 2] DTYPE_LIST = [torch.bfloat16, torch.float16] -NUM_KV_HEAD_LIST = [8, 32] -USE_PADDING = [True, False] internlm_accelerator = get_accelerator() -def check_mean_and_std(name, out1, out2): - named1_mean = out1.to(dtype=torch.float64).mean() - named1_std = out1.to(dtype=torch.float64).std() - named2_mean = out2.to(dtype=torch.float64).mean() - named2_std = out2.to(dtype=torch.float64).std() - check_statistic_equality(name, named1_mean, named2_mean, eq=True, is_mean=True) - check_statistic_equality(name, named1_std, named2_std, eq=True, is_mean=False) - - -def check_statistic_equality(name, value1, value2, eq=False, is_mean=True, threshold=1e-9): - if (abs(value1 - value2) < threshold) ^ eq: - if eq: - print( - f"On {name}, " - f"we have {'mean' if is_mean else 'std'}s of fa_out " - f"very {'close' if not eq else 'different'}, " - f"from :{value1} " - f"and :{value2}", - flush=True, - ) - else: - print( - f"On {name}, " - f"we have {'mean' if is_mean else 'std'}s of fa_out " - f"very {'close' if not eq else 'different'}, " - f"from :{value1} " - f"and :{value2}", - flush=True, - ) - - -def do_cmp_attn( - name, - B, # pylint: disable=W0613 - S, # pylint: disable=W0613 - N, - N_KV, - q, - k, - v, - dtype, - attention_mask, # pylint: disable=W0613 - softmax_scale, - attention_dropout=0.0, - **attn_args, # pylint: disable=W0613 -): - - npu_attn_cls = CrossAttention if N != N_KV else SelfAttention - npu_attn = npu_attn_cls( - causal=True, - softmax_scale=softmax_scale, - attention_dropout=attention_dropout, - ).to(dtype) - # TODO: 修复它. - npu_flash_attn = AscendFlashSelfAttention( - causal=True, - softmax_scale=softmax_scale, - attention_dropout=attention_dropout, - ).to(dtype) - - if N == N_KV: - a = npu_attn(torch.concat([q, k, v], dim=2)) # pylint: disable=E1102 - else: - a = npu_attn(q.squeeze(dim=2), torch.concat([k, v], dim=2)) # pylint: disable=E1102 - - b = npu_flash_attn(q=q, k=k, v=v) # pylint: disable=E1102 - assert torch.isfinite(a).all().item() and torch.isfinite(b).all().item() - - if dtype == torch.bfloat16: - # torch_npu's equal not support bfloat16 by now. - assert torch.allclose(a.to(torch.float32), b.to(torch.float32), atol=5e-2, rtol=1e-4), f"{name} not pass" - else: - assert torch.allclose(a, b, atol=5e-2, rtol=1e-4), f"{name} not pass" +def init_qkv(B, S, N_KV, dtype, device): + x = torch.LongTensor([[i + 1 for i in range(S)] for _ in range(B)]).to(device) + cu_seqlens = [0] + sorted(random.sample(list(range(x.numel())), 4)) + if cu_seqlens[-1] != x.numel(): + cu_seqlens.append(x.numel()) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int64, device=device) + x = rearrange(x, "b s -> (b s)").unsqueeze(0) - -def npu_transform(B, S, N, N_KV, D, dtype, use_padding): - if use_padding: - x = torch.LongTensor([[i + 1 if i < S // 2 else 0 for i in range(S)] for _ in range(B)]).npu() # padding S-1024 - else: - x = torch.LongTensor([[i + 1 for i in range(S)] for _ in range(B)]).npu() # no-padiing - - wq = torch.zeros((N * D, N * D), dtype=dtype, device="npu") - wk = torch.zeros((N_KV * D, N * D), dtype=dtype, device="npu") - wv = torch.zeros((N_KV * D, N * D), dtype=dtype, device="npu") - wembed = torch.zeros((VOCAB_SIZE, HIDDEN_SZIE), dtype=dtype, device="npu") + KV_DIM = HEAD_DIM * N_KV + Q_PER_KV = HEAD_NUM // N_KV + wqkv = torch.rand((HIDDEN_SZIE + 2 * KV_DIM, HIDDEN_SZIE), dtype=dtype, device=device) + wembed = torch.rand((VOCAB_SIZE, HIDDEN_SZIE), dtype=dtype, device=device) # It is very important to set appropriate initialization values for parameters so # that the values fall within an appropriate precision range to prevent overflow or underflow. with torch.no_grad(): - wq = nn.init.normal_(wq.data) - wk = nn.init.normal_(wk.data) - wv = nn.init.normal_(wv.data) + wqkv.data = nn.init.normal_(wqkv.data) wembed = nn.init.normal_(wembed.data, std=0.02) embed_x = F.embedding(x, wembed).to(dtype) - q = F.linear(embed_x, wq) # pylint: disable=E1102 - k = F.linear(embed_x, wk) # pylint: disable=E1102 - v = F.linear(embed_x, wv) # pylint: disable=E1102 - - q = rearrange(q, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1) - k = rearrange(k, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1) - v = rearrange(v, "b s (one h d) -> b s one h d", b=B, s=S, d=D, one=1) - - do_cmp_attn( - f"B_{B}_S_{S}_N_{N}_N_KV_{N_KV}_D_{D}_{dtype}", - B, - S, - N, - N_KV, - q, - k, - v, - dtype, - None, - 1 / math.sqrt(HIDDEN_SZIE // HEAD_NUM), + qkv = F.linear(embed_x, wqkv) # pylint: disable=E1102 + qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=Q_PER_KV + 2, d=HEAD_DIM) + q, k, v = (qkv[..., :Q_PER_KV, :], qkv[..., -2, :], qkv[..., -1, :]) + q = rearrange(q, "b t h gs d -> b t (h gs) d") + kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) + return q, kv, cu_seqlens + + +def fixed_length_fa(q, kv, cu_seqlens, packed_len, attn_cls, use_fa=False): + q = unpack_qkv_before_attn(q, cu_seqlens) + kv = unpack_qkv_before_attn(kv, cu_seqlens) + gpc._config = Config(dict(model=dict(use_flash_attn=use_fa, dtype=torch.bfloat16))) + c = attn_cls(q=q, kv=kv) # fix length self attention in npu + c = rearrange(c, "b s h d -> b s (h d)") + return pack_output_after_attn(c, cu_seqlens, packed_length=packed_len) + + +def var_length_fa(q, kv, cu_seqlens, max_seqlen, attn_cls): + gpc._config = Config(dict(model=dict(use_flash_attn=True, dtype=torch.bfloat16))) + b = attn_cls( # pylint: disable=E1102 + q=q, + kv=kv, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, ) + return rearrange(b, "b s h d -> b s (h d)") + + +def assert_equal(a, b, atol_bf16=5e-2, rtol_bf16=1e-4, atol_fp16=5e-2, rtol_fp16=1e-4): + assert a.dtype == b.dtype + assert torch.isfinite(a).all().item() and torch.isfinite(b).all().item() + if a.dtype is torch.bfloat16: + assert torch.allclose(a, b, atol=atol_bf16, rtol=rtol_bf16), f"a: {a}, b: {b}" + elif a.dtype is torch.float16: + assert torch.allclose(a, b, atol=atol_fp16, rtol=rtol_fp16), f"a: {a}, b: {b}" + else: + assert False + + +def npu_fwd_transform(B, S, N_KV, dtype): + + set_random_seed(1024) + softmax_scale = 1 / math.sqrt(HEAD_DIM) + cross_attn = SelfAttention(causal=True, softmax_scale=softmax_scale, attention_dropout=0.0).to(dtype) + npu_flash_attn = SelfAttention(causal=True, softmax_scale=softmax_scale, attention_dropout=0.0).to(dtype) + + with torch.no_grad(): + q, kv, cu_seqlens = init_qkv(B, S, N_KV, dtype, get_current_device()) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + q, kv = q.requires_grad_(), kv.requires_grad_() + a = fixed_length_fa(q, kv, cu_seqlens, B * S, cross_attn, use_fa=False) + + q_2, kv_2 = q.detach().clone().requires_grad_(), kv.detach().clone().requires_grad_() + b = fixed_length_fa(q_2, kv_2, cu_seqlens, B * S, npu_flash_attn, use_fa=True) + + q_3, kv_3 = q.detach().clone().requires_grad_(), kv.detach().clone().requires_grad_() + c = var_length_fa(q_3, kv_3, cu_seqlens, max_seqlen, npu_flash_attn) + + # assert_equal(a, b, atol_bf16=1e-1) + assert_equal(a, c, atol_bf16=1e-1) + print("test npu_fwd_transform done!", flush=True) + + return a, b, c, q, q_2, q_3, kv, kv_2, kv_3 + + +def npu_transform(B, S, N_KV, dtype): + a, b, c, q, q_2, q_3, kv, kv_2, kv_3 = npu_fwd_transform(B, S, N_KV, dtype) # pylint: disable=W0612 + g = torch.randn_like(b) + g.uniform_(-2, 2) + + b.backward(g.clone(), retain_graph=True) + a.backward(g.clone(), retain_graph=True) + c.backward(g.clone(), retain_graph=True) + + # assert_equal(q.grad, W0612.grad, atol_bf16=1e-1) + assert_equal(q.grad, q_3.grad, atol_bf16=1e-1) + # assert_equal(kv.grad, kv_2.grad, atol_bf16=5e-1, rtol_bf16=1e-3) + assert_equal(kv.grad, kv_3.grad, atol_bf16=5e-1) + + print("test npu_transform done!", flush=True) + + +def deeplink_fwd_transform(B, S, N_KV, dtype): + from deeplink_ext.internevo_ops import FlashSelfAttention + + from internlm.model.modules.multi_head_attention import CrossAttention + + set_random_seed(1024) + softmax_scale = 1 / math.sqrt(HEAD_DIM) + cross_attn = CrossAttention(causal=True, softmax_scale=softmax_scale, attention_dropout=0.0).to(dtype) + dp_flash_attn = FlashSelfAttention(causal=True, softmax_scale=softmax_scale, attention_dropout=0.0).to(dtype) + + with torch.no_grad(): + q, kv, cu_seqlens = init_qkv(B, S, N_KV, dtype, get_current_device()) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + + q, kv = q.requires_grad_(), kv.requires_grad_() + a = fixed_length_fa(q, kv, cu_seqlens, B * S, cross_attn) + + q_2, kv_2 = q.detach().clone().requires_grad_(), kv.detach().clone().requires_grad_() + b = var_length_fa(q_2, kv_2, cu_seqlens, max_seqlen, dp_flash_attn) + + assert_equal(a, b) + print("test deeplink_fwd_transform done!", flush=True) + + return a, b, q, q_2, kv, kv_2 + + +def deeplink_transform(B, S, N_KV, dtype): + a, b, q, q_2, kv, kv_2 = deeplink_fwd_transform(B, S, N_KV, dtype) + + g = torch.randn_like(b) + g.uniform_(-2, 2) + + b.backward(g.clone(), retain_graph=True) + a.backward(g.clone(), retain_graph=True) + + assert_equal(q.grad, q_2.grad, atol_bf16=1e-1) + assert_equal(kv.grad, kv_2.grad, atol_bf16=1e-1) + + print("test deeplink_transform done!", flush=True) + + +@pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST) +@pytest.mark.parametrize("test_dtype", DTYPE_LIST) +@pytest.mark.parametrize("num_kv_head", NUM_KV_HEAD_LIST) +@pytest.mark.parametrize("seqlen", SEQ_LEN) +def test_NPU_fa_fwd(micro_bsz, test_dtype, num_kv_head, seqlen): + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: + npu_fwd_transform(micro_bsz, seqlen, num_kv_head, test_dtype) @pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST) @pytest.mark.parametrize("test_dtype", DTYPE_LIST) @pytest.mark.parametrize("num_kv_head", NUM_KV_HEAD_LIST) -@pytest.mark.parametrize("use_padding", USE_PADDING) -def test_NPU_fa(micro_bsz, test_dtype, num_kv_head, use_padding): +@pytest.mark.parametrize("seqlen", SEQ_LEN) +def test_NPU_fa_bwd(micro_bsz, test_dtype, num_kv_head, seqlen): if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: - npu_transform(micro_bsz, SEQ_LEN, HEAD_NUM, num_kv_head, HIDDEN_SZIE // HEAD_NUM, test_dtype, use_padding) + npu_transform(micro_bsz, seqlen, num_kv_head, test_dtype) + + +# @pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST) +# @pytest.mark.parametrize("test_dtype", DTYPE_LIST) +# @pytest.mark.parametrize("num_kv_head", NUM_KV_HEAD_LIST) +# def test_deeplink_fa_fwd(micro_bsz, test_dtype, num_kv_head): +# if internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU: +# deeplink_fwd_transform(micro_bsz, SEQ_LEN, num_kv_head, test_dtype) + + +# @pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST) +# @pytest.mark.parametrize("test_dtype", DTYPE_LIST) +# @pytest.mark.parametrize("num_kv_head", NUM_KV_HEAD_LIST) +# def test_deeplink_fa_bwd(micro_bsz, test_dtype, num_kv_head): +# if internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU: +# deeplink_transform(micro_bsz, SEQ_LEN, num_kv_head, test_dtype) if __name__ == "__main__":