From 4c786e90d47341ecf6cdfcdffc06906199f0f5cf Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Tue, 30 Jan 2024 11:20:38 +0800 Subject: [PATCH] align to flash attention v2.2.1 --- internlm/model/multi_head_attention.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 92179f20..54b0458c 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -7,7 +7,7 @@ import torch import torch.nn.functional as F -from einops import rearrange +from einops import rearrange, repeat from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode @@ -86,7 +86,7 @@ def forward(self, q, kv, causal=None, key_padding_mask=None): Arguments --------- q: The tensor containing the query. (B, Sq, H, D) - kv: The tensor containing the key and value. (B, Sk, 2, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) causal: if passed, will override self.causal key_padding_mask: boolean mask to apply to the attention weights. True means to keep, False means to mask out. (B, Sk) @@ -94,7 +94,9 @@ def forward(self, q, kv, causal=None, key_padding_mask=None): batch_size, seqlen_q = q.shape[0], q.shape[1] causal = self.causal if causal is None else causal seqlen_k = kv.shape[1] - assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3] + assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] + if kv.shape[3] != q.shape[2]: # MQA/GQA + kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) k, v = kv.unbind(dim=2) softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) @@ -104,11 +106,12 @@ def forward(self, q, kv, causal=None, key_padding_mask=None): # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") if causal: - # "triu_tril_cuda_template" not implemented for 'BFloat16' - # So we have to construct the mask in float - causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + causal_mask.to(dtype=scores.dtype) + # causal mask needs to take into account the difference between seqlen_q and seqlen_k + row_idx = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + causal_mask = col_idx > row_idx + sk - seqlen_q + scores = scores.masked_fill(causal_mask, -10000.0) attention = torch.softmax(scores, dim=-1, dtype=v.dtype) attention_drop = self.drop(attention) output = torch.einsum("bhts,bshd->bthd", attention_drop, v)