From 4ef3cebaea3624e50922f2dac2824fc7bce5f351 Mon Sep 17 00:00:00 2001 From: Usman Aziz Date: Fri, 8 Nov 2024 21:10:42 +0000 Subject: [PATCH] Skip the train module only. --- tests/models/autoencoder_linear/test_autoencoder_linear.py | 3 ++- tests/models/beit/test_beit_image_classification.py | 3 ++- tests/models/clip/test_clip.py | 3 ++- tests/models/deit/test_deit.py | 3 ++- tests/models/hardnet/test_hardnet.py | 3 ++- tests/models/mlpmixer/test_mlpmixer.py | 3 ++- tests/models/mnist/test_mnist.py | 3 ++- tests/models/openpose/test_openpose_v2.py | 3 ++- tests/models/resnet/test_resnet.py | 3 ++- tests/models/resnet50/test_resnet50.py | 3 ++- tests/models/segformer/test_segformer.py | 3 ++- tests/models/speecht5_tts/test_speecht5_tts.py | 2 +- tests/models/unet/test_unet.py | 3 ++- tests/models/unet_brain/test_unet_brain.py | 3 ++- tests/models/unet_carvana/test_unet_carvana.py | 3 ++- 15 files changed, 29 insertions(+), 15 deletions(-) diff --git a/tests/models/autoencoder_linear/test_autoencoder_linear.py b/tests/models/autoencoder_linear/test_autoencoder_linear.py index 4473b7f1..376bd562 100644 --- a/tests/models/autoencoder_linear/test_autoencoder_linear.py +++ b/tests/models/autoencoder_linear/test_autoencoder_linear.py @@ -84,7 +84,8 @@ def _load_inputs(self): ["train", "eval"], ) def test_autoencoder_linear(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "Autoencoder (linear)" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/beit/test_beit_image_classification.py b/tests/models/beit/test_beit_image_classification.py index 617f9e5e..8da99516 100644 --- a/tests/models/beit/test_beit_image_classification.py +++ b/tests/models/beit/test_beit_image_classification.py @@ -43,7 +43,8 @@ def get_results_train(self, model, inputs, outputs): ["microsoft/beit-base-patch16-224", "microsoft/beit-large-patch16-224"], ) def test_beit_image_classification(record_property, model_name, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/clip/test_clip.py b/tests/models/clip/test_clip.py index b4ab722e..633eed37 100644 --- a/tests/models/clip/test_clip.py +++ b/tests/models/clip/test_clip.py @@ -59,7 +59,8 @@ def get_results_train(self, model, inputs, outputs): ], ) def test_clip(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "CLIP" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/deit/test_deit.py b/tests/models/deit/test_deit.py index c41e9fff..2965915e 100644 --- a/tests/models/deit/test_deit.py +++ b/tests/models/deit/test_deit.py @@ -47,7 +47,8 @@ def get_results_train(self, model, inputs, outputs): ) @pytest.mark.parametrize("model_name", ["facebook/deit-base-patch16-224"]) def test_deit(record_property, model_name, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/hardnet/test_hardnet.py b/tests/models/hardnet/test_hardnet.py index 62c909b7..43beca18 100644 --- a/tests/models/hardnet/test_hardnet.py +++ b/tests/models/hardnet/test_hardnet.py @@ -50,7 +50,8 @@ def _load_inputs(self): ["train", "eval"], ) def test_hardnet(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "HardNet" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/mlpmixer/test_mlpmixer.py b/tests/models/mlpmixer/test_mlpmixer.py index 12607df0..bcaa97c4 100644 --- a/tests/models/mlpmixer/test_mlpmixer.py +++ b/tests/models/mlpmixer/test_mlpmixer.py @@ -34,7 +34,8 @@ def _load_inputs(self): ["train", "eval"], ) def test_mlpmixer(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "MLPMixer" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/mnist/test_mnist.py b/tests/models/mnist/test_mnist.py index b3572654..040c11c6 100644 --- a/tests/models/mnist/test_mnist.py +++ b/tests/models/mnist/test_mnist.py @@ -60,7 +60,8 @@ def _load_inputs(self): ["train", "eval"], ) def test_mnist_train(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "Mnist" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/openpose/test_openpose_v2.py b/tests/models/openpose/test_openpose_v2.py index 93cb0299..4f90b936 100644 --- a/tests/models/openpose/test_openpose_v2.py +++ b/tests/models/openpose/test_openpose_v2.py @@ -50,7 +50,8 @@ def _load_inputs(self): ["train", "eval"], ) def test_openpose_v2(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "OpenPose V2" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/resnet/test_resnet.py b/tests/models/resnet/test_resnet.py index cfaa3202..54c157fd 100644 --- a/tests/models/resnet/test_resnet.py +++ b/tests/models/resnet/test_resnet.py @@ -24,7 +24,8 @@ def _load_inputs(self): ["train", "eval"], ) def test_resnet(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "ResNet18" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/resnet50/test_resnet50.py b/tests/models/resnet50/test_resnet50.py index 33e96189..194b81d6 100644 --- a/tests/models/resnet50/test_resnet50.py +++ b/tests/models/resnet50/test_resnet50.py @@ -37,7 +37,8 @@ def _load_inputs(self): ["train", "eval"], ) def test_resnet(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "ResNet50" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/segformer/test_segformer.py b/tests/models/segformer/test_segformer.py index b7131672..2a25225b 100644 --- a/tests/models/segformer/test_segformer.py +++ b/tests/models/segformer/test_segformer.py @@ -45,7 +45,8 @@ def get_results_train(self, model, inputs, outputs): ["train", "eval"], ) def test_segformer(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "SegFormer" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/speecht5_tts/test_speecht5_tts.py b/tests/models/speecht5_tts/test_speecht5_tts.py index 001486ff..d6ecb9cb 100644 --- a/tests/models/speecht5_tts/test_speecht5_tts.py +++ b/tests/models/speecht5_tts/test_speecht5_tts.py @@ -48,7 +48,7 @@ def set_model_eval(self, model): ["eval"], ) def test_speecht5_tts(record_property, mode): - pytest.skip("crashes in lowering to stable hlo.") + pytest.skip() # crashes in lowering to stable hlo model_name = "speecht5-tts" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/unet/test_unet.py b/tests/models/unet/test_unet.py index 2c963eb9..3bb689d3 100644 --- a/tests/models/unet/test_unet.py +++ b/tests/models/unet/test_unet.py @@ -46,7 +46,8 @@ def _load_inputs(self): ["train", "eval"], ) def test_unet(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "U-Net" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/unet_brain/test_unet_brain.py b/tests/models/unet_brain/test_unet_brain.py index 316ba282..5a6bd4fa 100644 --- a/tests/models/unet_brain/test_unet_brain.py +++ b/tests/models/unet_brain/test_unet_brain.py @@ -56,7 +56,8 @@ def _load_inputs(self): ["train", "eval"], ) def test_unet_brain(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "Unet-brain" record_property("model_name", model_name) record_property("mode", mode) diff --git a/tests/models/unet_carvana/test_unet_carvana.py b/tests/models/unet_carvana/test_unet_carvana.py index 93afafe5..7ce2312e 100644 --- a/tests/models/unet_carvana/test_unet_carvana.py +++ b/tests/models/unet_carvana/test_unet_carvana.py @@ -32,7 +32,8 @@ def _load_inputs(self): ["train", "eval"], ) def test_unet_carvana(record_property, mode): - pytest.skip("module has train variant.") + if mode == "train": + pytest.skip() model_name = "Unet-carvana" record_property("model_name", model_name) record_property("mode", mode)