diff --git a/tests/tests_nn/test_transformers.py b/tests/tests_nn/test_transformers.py index 22902711..145dafbe 100644 --- a/tests/tests_nn/test_transformers.py +++ b/tests/tests_nn/test_transformers.py @@ -15,91 +15,91 @@ def create_input(shape): return data -# @pytest.mark.parametrize( -# "shape", -# [ -# [3, 2, 32, 32], -# [3, 2, 16, 16], -# ], -# ) -# @pytest.mark.parametrize( -# "embedding_dim", -# [20], -# ) -# @pytest.mark.parametrize( -# "patch_size", -# [140], -# ) -# @pytest.mark.parametrize( -# "encoder_depths, encoder_num_heads, bottleneck_depth, bottleneck_num_heads", -# [ -# [(2, 2, 2), (1, 2, 4), 1, 8], -# [(2, 2, 2, 2), (1, 2, 4, 8), 2, 8], -# ], -# ) -# @pytest.mark.parametrize( -# "patch_norm", -# [True, False], -# ) -# @pytest.mark.parametrize( -# "win_size", -# [8], -# ) -# @pytest.mark.parametrize( -# "mlp_ratio", -# [2], -# ) -# @pytest.mark.parametrize( -# "qkv_bias", -# [True, False], -# ) -# @pytest.mark.parametrize( -# "qk_scale", -# [None, 0.5], -# ) -# @pytest.mark.parametrize( -# "token_projection", -# [AttentionTokenProjectionType.LINEAR, AttentionTokenProjectionType.CONV], -# ) -# @pytest.mark.parametrize( -# "token_mlp", -# [LeWinTransformerMLPTokenType.FFN, LeWinTransformerMLPTokenType.MLP, LeWinTransformerMLPTokenType.LEFF], -# ) -# def test_uformer( -# shape, -# patch_size, -# embedding_dim, -# encoder_depths, -# encoder_num_heads, -# bottleneck_depth, -# bottleneck_num_heads, -# win_size, -# mlp_ratio, -# patch_norm, -# qkv_bias, -# qk_scale, -# token_projection, -# token_mlp, -# ): -# model = UFormerModel( -# patch_size=patch_size, -# in_channels=2, -# embedding_dim=embedding_dim, -# encoder_depths=encoder_depths, -# encoder_num_heads=encoder_num_heads, -# bottleneck_depth=bottleneck_depth, -# bottleneck_num_heads=bottleneck_num_heads, -# win_size=win_size, -# mlp_ratio=mlp_ratio, -# qkv_bias=qkv_bias, -# qk_scale=qk_scale, -# patch_norm=patch_norm, -# token_projection=token_projection, -# token_mlp=token_mlp, -# ) -# data = create_input(shape).cpu() -# out = model(data) -# assert list(out.shape) == shape +@pytest.mark.parametrize( + "shape", + [ + [3, 2, 32, 32], + [3, 2, 16, 16], + ], +) +@pytest.mark.parametrize( + "embedding_dim", + [20], +) +@pytest.mark.parametrize( + "patch_size", + [140], +) +@pytest.mark.parametrize( + "encoder_depths, encoder_num_heads, bottleneck_depth, bottleneck_num_heads", + [ + [(2, 2, 2), (1, 2, 4), 1, 8], + [(2, 2, 2, 2), (1, 2, 4, 8), 2, 8], + ], +) +@pytest.mark.parametrize( + "patch_norm", + [True, False], +) +@pytest.mark.parametrize( + "win_size", + [8], +) +@pytest.mark.parametrize( + "mlp_ratio", + [2], +) +@pytest.mark.parametrize( + "qkv_bias", + [True, False], +) +@pytest.mark.parametrize( + "qk_scale", + [None, 0.5], +) +@pytest.mark.parametrize( + "token_projection", + [AttentionTokenProjectionType.LINEAR, AttentionTokenProjectionType.CONV], +) +@pytest.mark.parametrize( + "token_mlp", + [LeWinTransformerMLPTokenType.FFN, LeWinTransformerMLPTokenType.MLP, LeWinTransformerMLPTokenType.LEFF], +) +def test_uformer( + shape, + patch_size, + embedding_dim, + encoder_depths, + encoder_num_heads, + bottleneck_depth, + bottleneck_num_heads, + win_size, + mlp_ratio, + patch_norm, + qkv_bias, + qk_scale, + token_projection, + token_mlp, +): + model = UFormerModel( + patch_size=patch_size, + in_channels=2, + embedding_dim=embedding_dim, + encoder_depths=encoder_depths, + encoder_num_heads=encoder_num_heads, + bottleneck_depth=bottleneck_depth, + bottleneck_num_heads=bottleneck_num_heads, + win_size=win_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + patch_norm=patch_norm, + token_projection=token_projection, + token_mlp=token_mlp, + ) + data = create_input(shape).cpu() + out = model(data) + assert list(out.shape) == shape @pytest.mark.parametrize(