Skip to content

Commit

Permalink
feat(op): support varlen npu flash attention (#209)
Browse files Browse the repository at this point in the history
  • Loading branch information
SolenoidWGT authored Jul 18, 2024
1 parent 7cd091c commit 57b7cd5
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 160 deletions.
7 changes: 6 additions & 1 deletion internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,7 +900,12 @@ def auto_wrap_distributed_attention(cls: nn.Module) -> Callable[[bool, Any, floa
def _attetion_constructor(
local_attn_cls: type, causal=False, softmax_scale=None, attention_dropout=0.0
) -> nn.Module:
if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp":
try:
tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp")
except AttributeError:
tp_mode = "mtp"

if tp_mode != "isp":
return local_attn_cls(causal, softmax_scale, attention_dropout)
else:
return DistributedAttention(
Expand Down
224 changes: 184 additions & 40 deletions internlm/model/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _nyi_attn(func_name, *args, **kwargs): # pylint: disable=W0613

def _flash_float32_compatibility_wrapper(input_idxs: Tuple, flash_func: Callable, *args, **kwargs):
if gpc.config.model.dtype is torch.float32:
inputs = (args[idx] for idx in input_idxs)
inputs = [args[idx] for idx in input_idxs]
input_dtype = inputs[0].dtype
other_args = [args[idx] for idx in range(len(inputs), len(args))]

Expand Down Expand Up @@ -194,10 +194,35 @@ def _flash_fixedlen_qkvsplited_attn(q, k, v, dropout_p=0.0, softmax_scale=None,


# npu flash attention operators
# TODO: should we add _flash_float32_compatibility_wrapper support for npu.
def _npu_varlen_qkvsplited_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q, # pylint: disable=W0613
max_seqlen_k, # pylint: disable=W0613
dropout_p=0.0,
softmax_scale=None,
causal=False,
):
return _flash_float32_compatibility_wrapper(
(0, 1, 2),
_npu_varlen_qkvsplited_func,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
)


def _npu_varlen_qkvsplited_attn(
def _npu_varlen_qkvsplited_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
Expand All @@ -208,17 +233,32 @@ def _npu_varlen_qkvsplited_attn(
dropout_p=0.0,
softmax_scale=None,
causal=False,
use_fixlen=False,
):
# TODO: support npu native varlen flash attention
"""Support Huawei Ascend's torch_npu flash attention.
Tested version:
torch: 2.1.0+cpu
torch_npu: 2.1.0.post3+git7c4136d
cann: 8.0.RC1.alpha003
"""
packed_length = q.size(dim=1)
softmax_scale = softmax_scale or 1.0 / math.sqrt(q.shape[-1])

q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q)
k = unpack_qkv_before_attn(k, cu_seqlens=cu_seqlens_k)
v = unpack_qkv_before_attn(v, cu_seqlens=cu_seqlens_k)
if use_fixlen:

output = _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal)
q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q)
k = unpack_qkv_before_attn(k, cu_seqlens=cu_seqlens_k)
v = unpack_qkv_before_attn(v, cu_seqlens=cu_seqlens_k)

return pack_output_after_attn(output, cu_seqlens_q, packed_length)
output = _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal)

output = pack_output_after_attn(output, cu_seqlens_q, packed_length)
else:
output = _npu_fused_varlen_qkvsplited_attn(
q, k, v, dropout_p, softmax_scale, causal, max_seqlen_q, max_seqlen_k, cu_seqlens_q, cu_seqlens_k
)

return output


def _npu_fixedlen_qkvsplited_attn(
Expand All @@ -236,6 +276,7 @@ def _npu_fixedlen_qkvsplited_attn(
q, k, v = q.squeeze(dim=2), k.squeeze(dim=2), v.squeeze(dim=2)

_, seqlen, n_head, _ = q.shape
sparse_mode = 0
attention_mask = torch.triu(torch.ones(seqlen, seqlen, device=get_current_device()), 1).bool()

return _origin_npu_fixedlen_qkvsplited_func(
Expand All @@ -247,25 +288,71 @@ def _npu_fixedlen_qkvsplited_attn(
pse=None,
atten_mask=attention_mask,
scale=softmax_scale,
sparse_mode=0, # If necessary, expose the interface
sparse_mode=sparse_mode, # If necessary, expose the interface
pre_tockens=seqlen, # Used for sparse calculations, representing the left boundary of the slides window
next_tockens=0, # If necessary, expose the interface
keep_prob=1 - dropout_p,
inner_precise=0, # If necessary, expose the interface
)
)[0]


def _npu_varlen_qkvpacked_attn(
qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613
def _npu_fused_varlen_qkvsplited_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float,
softmax_scale=None,
causal=False,
max_seqlen_q: int = None,
max_seqlen_k: int = None,
cu_seqlens_q=None,
cu_seqlens_kv=None,
deterministic=False,
):
# TODO: support npu native varlen flash attention
packed_length = qkv.size(dim=1)
assert causal is True
assert q.dtype in (torch.bfloat16, torch.float16)

qkv = unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens)
if len(q.shape) == 4: # [1, packedseqlen, n_head, headdim]
q, k, v = q.squeeze(dim=0), k.squeeze(dim=0), v.squeeze(dim=0)

output = _npu_fixedlen_qkvpacked_attn(qkv, dropout_p, softmax_scale, causal)
S, N = max(max_seqlen_q, max_seqlen_k), q.shape[1]
device = get_current_device()
sparse_mode = 0

return pack_output_after_attn(output, cu_seqlens, packed_length)
if max_seqlen_k > 2048 and max_seqlen_q > 2048:
sparse_mode = 2
max_seqlen_k = 2048
max_seqlen_q = 2048

attention_mask = torch.triu(torch.ones(max_seqlen_q, max_seqlen_k, device=device), 1).bool()
cu_seqlens_q = cu_seqlens_q[1:].tolist()
cu_seqlens_kv = cu_seqlens_kv[1:].tolist()

return _origin_npu_fixedlen_qkvsplited_func(
query=q,
key=k,
value=v,
head_num=N,
input_layout="TND",
pse=None,
atten_mask=attention_mask,
scale=softmax_scale,
sparse_mode=sparse_mode,
pre_tockens=S, # Used for sparse calculations, representing the left boundary of the slides window
next_tockens=0,
keep_prob=1 - dropout_p,
inner_precise=0 if not deterministic else 2,
actual_seq_kvlen=cu_seqlens_kv,
actual_seq_qlen=cu_seqlens_q,
)[0].unsqueeze(dim=0)


def _npu_varlen_qkvpacked_attn(
qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613
):
# TODO: support npu native varlen flash attention
q, k, v = qkv.unbind(dim=2)
return _npu_varlen_qkvsplited_attn(q, k, v, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal)


def _npu_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False):
Expand All @@ -285,14 +372,20 @@ def _npu_varlen_kvpacked_attn(
causal=False,
):
# TODO: support npu native varlen flash attention
packed_length = q.size(dim=1)

q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q)
kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k)

output = _npu_fixedlen_kvpacked_attn(q, kv, dropout_p, softmax_scale, causal)

return pack_output_after_attn(output, cu_seqlens_q, packed_length)
k, v = kv.unbind(dim=2)
k, v = k.squeeze(dim=2), v.squeeze(dim=2)
return _npu_varlen_qkvsplited_attn(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
)


def _npu_fixedlen_kvpacked_attn(q: torch.Tensor, kv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False):
Expand Down Expand Up @@ -335,12 +428,6 @@ def _deeplink_fixedlen_qkvsplited_attn(*args, **kwargs):


# torch attention operators


def _torch_varlen_qkvpacked_attn(*args, **kwargs):
_nyi_attn("_torch_varlen_qkvpacked_attn", *args, **kwargs)


# adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py
def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None):
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
Expand Down Expand Up @@ -369,10 +456,6 @@ def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=Non
return output


def _torch_varlen_kvpacked_attn(*args, **kwargs):
_nyi_attn("_torch_varlen_kvpacked_attn", *args, **kwargs)


# adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py
def _torch_fixedlen_kvpacked_attn(
q: torch.Tensor, kv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None
Expand Down Expand Up @@ -407,17 +490,78 @@ def _torch_fixedlen_kvpacked_attn(
return output


def _torch_varlen_qkvsplited_attn(*args, **kwargs):
_nyi_attn("_torch_varlen_qkvsplited_attn", *args, **kwargs)


def _torch_fixedlen_qkvsplited_attn(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None
):
kv = torch.stack([k, v], dim=2)
return _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask)


def _torch_varlen_qkvsplited_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q, # pylint: disable=W0613
max_seqlen_k, # pylint: disable=W0613
dropout,
softmax_scale=None,
causal=False,
key_padding_mask=None,
):
kv = torch.stack([k, v], dim=2)
packed_length = q.size(dim=1)

q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q)
kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k)

output = _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask)

return pack_output_after_attn(output, cu_seqlens_q, packed_length)


def _torch_varlen_qkvpacked_attn(
qkv: torch.Tensor,
cu_seqlens,
max_seqlen, # pylint: disable=W0613
dropout,
softmax_scale=None,
causal=False,
key_padding_mask=None,
):

packed_length = qkv.size(dim=1)
qkv = unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens)

output = _torch_fixedlen_qkvpacked_attn(qkv, dropout, softmax_scale, causal, key_padding_mask)

return pack_output_after_attn(output, cu_seqlens, packed_length)


def _torch_varlen_kvpacked_attn(
q: torch.Tensor,
kv: torch.Tensor,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q, # pylint: disable=W0613
max_seqlen_k, # pylint: disable=W0613
dropout,
softmax_scale=None,
causal=False,
key_padding_mask=None,
):

packed_length = q.size(dim=1)

q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q)
kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k)

output = _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask)

return pack_output_after_attn(output, cu_seqlens_q, packed_length)


@auto_wrap_distributed_attention
class SelfAttention(nn.Module):
"""Implements scaled dot-product attention with optional softmax scaling.
Expand Down
Loading

0 comments on commit 57b7cd5

Please sign in to comment.