Skip to content

Commit

Permalink
Uformer test
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jan 3, 2025
1 parent 7a27ee8 commit d4ec613
Showing 1 changed file with 85 additions and 85 deletions.
170 changes: 85 additions & 85 deletions tests/tests_nn/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,91 +15,91 @@ def create_input(shape):
return data


# @pytest.mark.parametrize(
# "shape",
# [
# [3, 2, 32, 32],
# [3, 2, 16, 16],
# ],
# )
# @pytest.mark.parametrize(
# "embedding_dim",
# [20],
# )
# @pytest.mark.parametrize(
# "patch_size",
# [140],
# )
# @pytest.mark.parametrize(
# "encoder_depths, encoder_num_heads, bottleneck_depth, bottleneck_num_heads",
# [
# [(2, 2, 2), (1, 2, 4), 1, 8],
# [(2, 2, 2, 2), (1, 2, 4, 8), 2, 8],
# ],
# )
# @pytest.mark.parametrize(
# "patch_norm",
# [True, False],
# )
# @pytest.mark.parametrize(
# "win_size",
# [8],
# )
# @pytest.mark.parametrize(
# "mlp_ratio",
# [2],
# )
# @pytest.mark.parametrize(
# "qkv_bias",
# [True, False],
# )
# @pytest.mark.parametrize(
# "qk_scale",
# [None, 0.5],
# )
# @pytest.mark.parametrize(
# "token_projection",
# [AttentionTokenProjectionType.LINEAR, AttentionTokenProjectionType.CONV],
# )
# @pytest.mark.parametrize(
# "token_mlp",
# [LeWinTransformerMLPTokenType.FFN, LeWinTransformerMLPTokenType.MLP, LeWinTransformerMLPTokenType.LEFF],
# )
# def test_uformer(
# shape,
# patch_size,
# embedding_dim,
# encoder_depths,
# encoder_num_heads,
# bottleneck_depth,
# bottleneck_num_heads,
# win_size,
# mlp_ratio,
# patch_norm,
# qkv_bias,
# qk_scale,
# token_projection,
# token_mlp,
# ):
# model = UFormerModel(
# patch_size=patch_size,
# in_channels=2,
# embedding_dim=embedding_dim,
# encoder_depths=encoder_depths,
# encoder_num_heads=encoder_num_heads,
# bottleneck_depth=bottleneck_depth,
# bottleneck_num_heads=bottleneck_num_heads,
# win_size=win_size,
# mlp_ratio=mlp_ratio,
# qkv_bias=qkv_bias,
# qk_scale=qk_scale,
# patch_norm=patch_norm,
# token_projection=token_projection,
# token_mlp=token_mlp,
# )
# data = create_input(shape).cpu()
# out = model(data)
# assert list(out.shape) == shape
@pytest.mark.parametrize(
"shape",
[
[3, 2, 32, 32],
[3, 2, 16, 16],
],
)
@pytest.mark.parametrize(
"embedding_dim",
[20],
)
@pytest.mark.parametrize(
"patch_size",
[140],
)
@pytest.mark.parametrize(
"encoder_depths, encoder_num_heads, bottleneck_depth, bottleneck_num_heads",
[
[(2, 2, 2), (1, 2, 4), 1, 8],
[(2, 2, 2, 2), (1, 2, 4, 8), 2, 8],
],
)
@pytest.mark.parametrize(
"patch_norm",
[True, False],
)
@pytest.mark.parametrize(
"win_size",
[8],
)
@pytest.mark.parametrize(
"mlp_ratio",
[2],
)
@pytest.mark.parametrize(
"qkv_bias",
[True, False],
)
@pytest.mark.parametrize(
"qk_scale",
[None, 0.5],
)
@pytest.mark.parametrize(
"token_projection",
[AttentionTokenProjectionType.LINEAR, AttentionTokenProjectionType.CONV],
)
@pytest.mark.parametrize(
"token_mlp",
[LeWinTransformerMLPTokenType.FFN, LeWinTransformerMLPTokenType.MLP, LeWinTransformerMLPTokenType.LEFF],
)
def test_uformer(
shape,
patch_size,
embedding_dim,
encoder_depths,
encoder_num_heads,
bottleneck_depth,
bottleneck_num_heads,
win_size,
mlp_ratio,
patch_norm,
qkv_bias,
qk_scale,
token_projection,
token_mlp,
):
model = UFormerModel(
patch_size=patch_size,
in_channels=2,
embedding_dim=embedding_dim,
encoder_depths=encoder_depths,
encoder_num_heads=encoder_num_heads,
bottleneck_depth=bottleneck_depth,
bottleneck_num_heads=bottleneck_num_heads,
win_size=win_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
patch_norm=patch_norm,
token_projection=token_projection,
token_mlp=token_mlp,
)
data = create_input(shape).cpu()
out = model(data)
assert list(out.shape) == shape


@pytest.mark.parametrize(
Expand Down

0 comments on commit d4ec613

Please sign in to comment.