Skip to content

Commit

Permalink
[PyTorch] Fix get_swa_mask() for padding masks (NVIDIA#1281)
Browse files Browse the repository at this point in the history
* WIP: fix get_swa_mask for padding

Signed-off-by: Charlene Yang <[email protected]>

* fix mask type setting

Signed-off-by: Charlene Yang <[email protected]>

* fix the order of checking valid swa and changing mask type

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix lint

Signed-off-by: Charlene Yang <[email protected]>

* revamp to get full mask

Signed-off-by: Charlene Yang <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cyanguwa and pre-commit-ci[bot] authored Dec 18, 2024
1 parent 83dac8c commit f033498
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 98 deletions.
28 changes: 16 additions & 12 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,18 +531,22 @@ def test_dpa_bias_shapes(dtype, model_configs, model):

model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_0": ModelConfig(
4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_6_1": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
}


Expand Down
227 changes: 141 additions & 86 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,61 +1024,157 @@ def swap_key_value_dict(self, batch_indices):


@torch.no_grad()
def get_swa_mask(
window_size: Tuple[int, int],
def get_full_mask(
max_seqlen_q: int,
max_seqlen_kv: int,
attn_mask_type: str = "no_mask",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None,
window_size: Tuple[int, int] = None,
attention_type: str = "self",
bottom_right_alignment: bool = True,
) -> torch.Tensor:
"""
Convert sliding window `window_size` to an equivalent "`arbitrary`" mask.
For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner,
and for other mask types, the bottom right corner.
Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`,
`attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends
on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.::
attn_mask_type output shape diagonal alignment
--------------------------------------------------------------------------------------------
no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left
causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right
padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment
padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left
padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right
arbitrary same as attention_mask follow bottom_right_alignment
.. note::
For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right
diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix,
i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4,
max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = (
[[False, False, True, True], [False, False, False, False]],
[[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4]
shape and is,::
[[[False, False, False, True],
[False, False, False, True],
[ True, True, True, True],
[ True, True, True, True]],
[[False, True, True, True],
[False, True, True, True],
[False, True, True, True],
[False, True, True, True]]]
Parameters
----------
window_size: Tuple[int, int]
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
max_seqlen_q: int
Maximum sequence length for queries.
max_seqlen_kv: int
Maximum sequence length for keys and values.
attn_mask_type: str, default = `no_mask`
Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`",
"`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"}
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]],
attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
default = `None`
Boolean tensor(s) used to mask out attention softmax input.
Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention
for the requirements of `attention_mask` for different `attn_mask_type`s.
window_size: Tuple[int, int], default = `None`
Sliding window size for local attention, where query at position i attends to keys
in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically. Both `causal` and `causal_bottom_right` masks
map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on
`attn_mask_type`.
attention_type: str, default = "self"
Attention type, {"self", "cross"}
bottom_right_alignment: bool, default = `True`
Whether to align the diagonal of the sliding window attention to the bottom right (`True`)
or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly
specifies "causal" or "causal_bottom_right".
Returns
----------
attn_mask_type: str
For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type`
attention_mask: torch.Tensor
Combined `attention_mask` (input) and sliding window attention mask.
The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None;
else, the same shape as input `attention_mask`.
The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size`
actual_seqlens_q: torch.Tensor
For padding masks, the actual sequence lengths for queries, in shape [batch_size].
For other masks, `None`.
actual_seqlens_kv: Optional[torch.Tensor], default = `None`
For padding masks, the actual sequence lengths for keys and values, in shape [batch_size].
For other masks, `None`.
"""
mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda")
if attn_mask_type in ["causal"]:
left = window_size[0] if window_size[0] != -1 else max_seqlen_q
right = window_size[1] if window_size[1] != -1 else max_seqlen_q
mask_upper = torch.triu(mask, diagonal=-left)
mask_lower = torch.tril(mask_upper, diagonal=right)
else:
left = window_size[0] if window_size[0] != -1 else max_seqlen_kv
right = window_size[1] if window_size[1] != -1 else max_seqlen_kv
mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left)
mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right)
attn_mask_type = "arbitrary"
mask = mask_lower.logical_not()
# perform basic checks
change_type = window_size is not None and (
window_size[0] != -1 or window_size[1] not in [-1, 0]
)
if window_size is None:
window_size = (-1, -1)
if "causal" in attn_mask_type:
window_size = (window_size[0], 0)
window_size = (
max_seqlen_kv if window_size[0] == -1 else window_size[0],
max_seqlen_q if window_size[1] == -1 else window_size[1],
)

# apply padding mask
actual_seqlens_q = None
actual_seqlens_kv = None
if "padding" in attn_mask_type:
if attention_type == "self":
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
m = attention_mask.logical_not()
actual_seqlens_q = m[:, 0, :, 0].sum(dim=1)
actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1)

# apply SWA mask
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv)
swa_left = None
swa_right = None
if attn_mask_type == "causal_bottom_right" or (
attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment
):
swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0]
swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1]
elif attn_mask_type in ["causal", "padding_causal"] or (
attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment
):
swa_left = mask - window_size[0]
swa_right = mask + window_size[1]
elif attn_mask_type == "padding_causal_bottom_right" or (
attn_mask_type == "padding" and bottom_right_alignment
):
batch_size = attention_mask.shape[0]
swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q - window_size[0]
).view(batch_size, 1, 1, 1)
swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + (
actual_seqlens_kv - actual_seqlens_q + window_size[1]
).view(batch_size, 1, 1, 1)
swa_mask = torch.logical_not(
torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0)
)
if attention_mask is not None:
mask = torch.logical_and(attention_mask, mask)
return attn_mask_type, mask
attention_mask = torch.logical_or(swa_mask, attention_mask)
else:
attention_mask = swa_mask

# change mask type
if change_type:
attn_mask_type = "arbitrary"

return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv


@torch.no_grad()
Expand Down Expand Up @@ -4733,6 +4829,7 @@ def forward(
cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
attn_mask_type: str = "causal",
attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None,
window_size: Optional[Tuple[int, int]] = None,
core_attention_bias_type: str = "no_bias",
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
Expand All @@ -4752,53 +4849,15 @@ def forward(
query_layer.shape[0],
key_layer.shape[0],
)
if "padding" in attn_mask_type:
if self.attention_type == "self":
assert attention_mask.shape == (
batch_size,
1,
1,
max_seqlen_q,
), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!"
attention_mask = torch.logical_or(
attention_mask.squeeze(1).unsqueeze(3), attention_mask
)
else:
assert (
len(attention_mask) == 2
and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q)
and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv)
), (
"attention_mask should be a tuple of two tensors with shapes "
"[b, 1, 1, sq] and [b, 1, 1, skv]!"
)
attention_mask = torch.logical_or(
attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1]
)
mask = attention_mask.squeeze(1).logical_not()
actual_seqlens_q = mask[:, :, 0].sum(dim=1)
actual_seqlens_kv = mask[:, 0, :].sum(dim=1)
mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view(
1, 1, max_seqlen_q, 1
) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(
1, 1, 1, max_seqlen_kv
)
if attn_mask_type == "padding_causal":
attention_mask = torch.logical_or(
torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0),
attention_mask,
)
if attn_mask_type == "padding_causal_bottom_right":
attention_mask = torch.logical_or(
torch.where(
mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv)
+ (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1)
< 0,
1,
0,
),
attention_mask,
)

attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = get_full_mask(
max_seqlen_q,
max_seqlen_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
attention_type=self.attention_type,
)

batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]
apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16
Expand Down Expand Up @@ -8274,12 +8333,6 @@ def forward(
)

if use_unfused_attention:
if window_size is not None and (
window_size[0] != -1 or window_size[1] not in [-1, 0]
):
attn_mask_type, attention_mask = get_swa_mask(
window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask
)
if checkpoint_core_attention:
return self._checkpointed_attention_forward(
self.unfused_attention,
Expand All @@ -8291,6 +8344,7 @@ def forward(
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
Expand All @@ -8304,6 +8358,7 @@ def forward(
cu_seqlens_kv=cu_seqlens_kv,
attn_mask_type=attn_mask_type,
attention_mask=attention_mask,
window_size=window_size,
core_attention_bias_type=core_attention_bias_type,
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
Expand Down

0 comments on commit f033498

Please sign in to comment.