Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from NVIDIA:main #44

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 43 files
+1 −1 CMakeLists.txt
+10 −0 docs/operations/Attention.md
+3 −2 include/cudnn_backend_base.h
+1 −0 include/cudnn_frontend.h
+24 −2 include/cudnn_frontend/graph_helpers.h
+28 −0 include/cudnn_frontend/graph_interface.h
+32 −1 include/cudnn_frontend/graph_properties.h
+6 −0 include/cudnn_frontend/node/paged_cache_load.h
+3 −0 include/cudnn_frontend/node/resample.h
+372 −481 include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
+4 −1 include/cudnn_frontend/node/sdpa_fp8.h
+5 −1 include/cudnn_frontend/node/sdpa_fp8_bwd.h
+7 −3 include/cudnn_frontend/plans.h
+387 −0 include/cudnn_frontend/utils/attn_score_modifiers.h
+3 −3 include/cudnn_frontend_EngineFallbackList.h
+3 −3 include/cudnn_frontend_ExecutionPlan.h
+3 −4 include/cudnn_frontend_Operation.h
+1 −1 include/cudnn_frontend_OperationGraph.h
+3 −4 include/cudnn_frontend_get_plan.h
+2 −0 include/cudnn_frontend_shim.h
+1 −1 include/cudnn_frontend_utils.h
+1 −1 include/cudnn_frontend_version.h
+2 −2 pyproject.toml
+1 −1 python/cudnn/__init__.py
+16 −0 python/pygraph/pygraph.cpp
+3 −0 python/pygraph/pygraph.h
+2 −2 python/pygraph/sdpa.cpp
+3 −0 samples/cpp/CMakeLists.txt
+205 −0 samples/cpp/convolution/conv_dynamic_shape_benchmark.cpp
+2 −1 samples/cpp/convolution/fp8_fprop.cpp
+4 −0 samples/cpp/convolution/fprop.cpp
+5 −1 samples/cpp/convolution/wgrads.cpp
+144 −0 samples/cpp/norm/layernorm.cpp
+207 −0 samples/cpp/sdpa/fp16_bwd_with_flexible_graphs.cpp
+198 −0 samples/cpp/sdpa/fp16_fwd_with_flexible_graphs.cpp
+1 −1 samples/cpp/utils/helpers.h
+5 −3 samples/legacy_samples/fp16_emu.cpp
+1 −1 samples/legacy_samples/helpers.cpp
+5 −0 samples/legacy_samples/test_list.cpp
+3 −1 samples/python/50_scaled_dot_product_attention.ipynb
+5 −3 samples/python/52_scaled_dot_product_attention_with_paged_caches.ipynb
+7 −0 test/python/test_conv_bias.py
+112 −60 test/python/test_mhas.py
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py
Expand All @@ -22,3 +21,4 @@ pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py
pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
18 changes: 15 additions & 3 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ def make_mask(
max_seqlen_kv = inv_mask.shape[-1]
inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type)
inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape)
# In inv_swa_mask and inv_mask 0 is masked out
inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask)
inv_mask = combine_masks(inv_mask, inv_swa_mask)

mask = jnp.logical_not(inv_mask)
return mask
Expand Down Expand Up @@ -315,6 +314,13 @@ def _get_max_segments_per_sequence(self):
return self.num_segments_per_seq + 1

def _check_configs(self):
# TODO(rewang): Fix THD + PADDING_CAUSAL + SWA reference
if (
self.qkv_layout.is_thd()
and self.attn_mask_type == AttnMaskType.PADDING_CAUSAL_MASK
and self.window_size is not None
):
pytest.skip("THD + PADDING_CAUSAL + SWA reference is not implemented.")
# TODO(rewang): probably adds this in is_fused_attn_available
if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding():
pytest.skip("THD format requires padding masks.")
Expand Down Expand Up @@ -504,7 +510,13 @@ def generate_random_segment_ids(
if self.qkv_layout.is_thd():
self.mask_for_customcall = None # THD format doesn't support mask
else:
self.mask_for_customcall = self.mask
self.mask_for_customcall = make_mask(
self.segment_ids_q,
self.segment_ids_kv,
self.segment_pos_q,
self.segment_pos_kv,
self.attn_mask_type,
)

self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
self.scaling_factor = 1.0 / sqrt(self.head_dim)
Expand Down
186 changes: 124 additions & 62 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,18 @@ def test_dot_product_attention(
tols = dict(atol=1.5e-2, rtol=1.5e-2)
config = model_configs[model]
is_mla = config.head_dim_qk != config.head_dim_v
is_mqa_gqa = config.num_heads != config.num_gqa_groups
if qkv_layout is None:
if config.attn_type == "self":
qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd"
qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd"
else:
qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd"
qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd"
if "3" in qkv_layout and config.attn_type == "cross":
pytest.skip("No need to test this layout for cross attention")

# Test backend availability
window_size = (-1, -1)
if swa:
window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, window_size)
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
available_backends, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
Expand Down Expand Up @@ -334,16 +333,16 @@ def test_dot_product_attention(
is_training,
)

if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
Expand Down Expand Up @@ -399,30 +398,41 @@ def test_dpa_mla(dtype, model_configs, model):

model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_7_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
"mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"mask_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_5_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"),
"mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"),
"mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"),
"mask_10_0": ModelConfig(
2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_8_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
"mask_10_1": ModelConfig(
2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"),
"mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"),
"mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
"mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"),
}


Expand Down Expand Up @@ -531,20 +541,28 @@ def test_dpa_bias_shapes(dtype, model_configs, model):

model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"swa_1_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_5_0": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_1": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_0": ModelConfig(
4, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
"swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"),
"swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"),
"swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"swa_6_1": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_6_2": ModelConfig(
2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"swa_6_3": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
}
Expand Down Expand Up @@ -623,18 +641,57 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"),
"layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
"layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"),
"layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
"layout_2_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_2_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_2_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"
),
"layout_3_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_3_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_3_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4)
),
"layout_4_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
),
"layout_4_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
),
"layout_4_2": ModelConfig(
2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0)
),
"layout_5_0": ModelConfig(
2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0)
),
"layout_5_1": ModelConfig(
2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0)
),
"layout_5_2": ModelConfig(
2,
24,
24,
128,
2048,
4096,
0.0,
"padding_causal_bottom_right",
"no_bias",
window_size=(4, 0),
),
}


Expand All @@ -651,11 +708,13 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
config = model_configs[model]
if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
pytest.skip("qkv_layout not applicable for MQA/GQA")
logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True")
pad_between_seqs = True
test_dot_product_attention(
dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
)
if get_cudnn_version() >= (9, 3, 0):
logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False")
# cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run
pad_between_seqs = False
test_dot_product_attention(
Expand Down Expand Up @@ -695,9 +754,12 @@ def _run_dot_product_attention(
)
seqlens_kv = seqlens_q
if config.attn_type == "cross":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
if config.max_seqlen_q > 1:
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")
if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0):
pytest.skip("THD format is not supported for cuDNN 9.6+!")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
Expand Down
Loading
Loading