Skip to content

Commit

Permalink
align to flash attention v2.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun committed Jan 30, 2024
1 parent 8da532c commit 4c786e9
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions internlm/model/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
import torch.nn.functional as F
from einops import rearrange
from einops import rearrange, repeat
from torch import nn

from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
Expand Down Expand Up @@ -86,15 +86,17 @@ def forward(self, q, kv, causal=None, key_padding_mask=None):
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
causal: if passed, will override self.causal
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
batch_size, seqlen_q = q.shape[0], q.shape[1]
causal = self.causal if causal is None else causal
seqlen_k = kv.shape[1]
assert kv.shape[0] == batch_size and kv.shape[3] == q.shape[2] and kv.shape[4] == q.shape[3]
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
if kv.shape[3] != q.shape[2]: # MQA/GQA
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
k, v = kv.unbind(dim=2)
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
Expand All @@ -104,11 +106,12 @@ def forward(self, q, kv, causal=None, key_padding_mask=None):
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
if causal:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores = scores + causal_mask.to(dtype=scores.dtype)
# causal mask needs to take into account the difference between seqlen_q and seqlen_k
row_idx = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1")
col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long)
sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
causal_mask = col_idx > row_idx + sk - seqlen_q
scores = scores.masked_fill(causal_mask, -10000.0)
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
attention_drop = self.drop(attention)
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
Expand Down

0 comments on commit 4c786e9

Please sign in to comment.