From da47b4c213e19890c958517d0054bb930d9bc6b5 Mon Sep 17 00:00:00 2001 From: Antoine Collas <22830806+antoinecollas@users.noreply.github.com> Date: Tue, 16 Apr 2024 11:52:08 +0200 Subject: [PATCH 1/6] fix storing device --- src/fugw/mappings/barycenter.py | 1 + src/fugw/mappings/dense.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/fugw/mappings/barycenter.py b/src/fugw/mappings/barycenter.py index 35c455c3..7b48affe 100644 --- a/src/fugw/mappings/barycenter.py +++ b/src/fugw/mappings/barycenter.py @@ -155,6 +155,7 @@ def compute_all_ot_plans( solver=solver, solver_params=solver_params, device=device, + storing_device=device, verbose=verbose, ) diff --git a/src/fugw/mappings/dense.py b/src/fugw/mappings/dense.py index 7cf2a5be..0115c1e3 100644 --- a/src/fugw/mappings/dense.py +++ b/src/fugw/mappings/dense.py @@ -27,6 +27,7 @@ def fit( solver_params={}, callback_bcd=None, device="auto", + storing_device="cpu", verbose=False, ): """ @@ -100,6 +101,8 @@ def fit( device: "auto" or torch.device if "auto": use first available gpu if it's available, cpu otherwise. + storing_device: torch.device, default="cpu" + Device on which to store the computed transport plan. verbose: bool, optional, defaults to False Log solving process. @@ -255,7 +258,7 @@ def fit( ) # Store variables of interest in model - self.pi = res["pi"].detach().cpu() + self.pi = res["pi"].detach().to(device=storing_device) self.loss = res["loss"] self.loss_steps = res["loss_steps"] self.loss_times = res["loss_times"] From 925b5a0140509165f2cf2fa6ea25d8b43ed38101 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Wed, 17 Apr 2024 14:49:03 +0200 Subject: [PATCH 2/6] Add storing device option for sparse mappings --- src/fugw/mappings/sparse.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/fugw/mappings/sparse.py b/src/fugw/mappings/sparse.py index fbc76341..cffd9eba 100644 --- a/src/fugw/mappings/sparse.py +++ b/src/fugw/mappings/sparse.py @@ -33,6 +33,7 @@ def fit( solver_params={}, callback_bcd=None, device="auto", + storing_device="cpu", verbose=False, ): """ @@ -106,6 +107,8 @@ def fit( device: "auto" or torch.device if "auto": use first available gpu if it's available, cpu otherwise. + storing_device: torch.device, default="cpu" + Device on which to store the computed transport plan. verbose: bool, optional, defaults to False Log solving process. @@ -292,7 +295,7 @@ def fit( verbose=verbose, ) - self.pi = res["pi"].to_sparse_coo().detach().cpu() + self.pi = res["pi"].to_sparse_coo().detach().to(device=storing_device) self.loss = res["loss"] self.loss_val = res["loss_val"] self.loss_steps = res["loss_steps"] From 7f8107ce251007e5ffb1b46393fa9771d8b2723f Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 2 May 2024 17:00:04 +0200 Subject: [PATCH 3/6] Add conftest.py for skipping tests without MKL support --- tests/conftest.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..112f4111 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,17 @@ +import torch +import pytest + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "skip_if_no_mkl: Skip test if MKL support is not available." + ) + + +@pytest.fixture(autouse=True) +def check_mkl_availability(request): + if ( + "skip_if_no_mkl" in request.keywords + and not torch.backends.mkl.is_available() + ): + pytest.skip("Test requires MKL support which is not available.") From f079674f6233d552bcf9ca3fb7bdfe27911104f2 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 2 May 2024 17:00:27 +0200 Subject: [PATCH 4/6] Add __init__.py --- tests/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/__init__.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b From 754facd61ccdd9849f39cd9813a6381dc983cd9c Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 2 May 2024 17:00:33 +0200 Subject: [PATCH 5/6] Add skip_if_no_mkl marker to tests --- tests/mappings/test_sparse_mapping.py | 3 +++ tests/scripts/test_coarse_to_fine.py | 2 ++ tests/solvers/test_sparse_solver.py | 4 ++++ 3 files changed, 9 insertions(+) diff --git a/tests/mappings/test_sparse_mapping.py b/tests/mappings/test_sparse_mapping.py index 8619e924..9955dd82 100644 --- a/tests/mappings/test_sparse_mapping.py +++ b/tests/mappings/test_sparse_mapping.py @@ -27,6 +27,7 @@ callbacks = [None, lambda x: x["gamma"]] +@pytest.mark.skip_if_no_mkl @pytest.mark.parametrize( "device,return_numpy,solver,callback", product(devices, return_numpys, solvers, callbacks), @@ -95,6 +96,7 @@ def test_sparse_mapping(device, return_numpy, solver, callback): assert isinstance(target_features_on_source, torch.Tensor) +@pytest.mark.skip_if_no_mkl @pytest.mark.parametrize( "device,sparse_layout,return_numpy", product(devices, sparse_layouts, return_numpys), @@ -163,6 +165,7 @@ def test_fugw_sparse_with_init(device, sparse_layout, return_numpy): assert isinstance(target_features_on_source, torch.Tensor) +@pytest.mark.skip_if_no_mkl @pytest.mark.parametrize( "validation", ["None", "features", "geometries", "Both"] ) diff --git a/tests/scripts/test_coarse_to_fine.py b/tests/scripts/test_coarse_to_fine.py index 3da29571..6623023d 100644 --- a/tests/scripts/test_coarse_to_fine.py +++ b/tests/scripts/test_coarse_to_fine.py @@ -26,6 +26,7 @@ return_numpys = [False, True] +@pytest.mark.skip_if_no_mkl @pytest.mark.parametrize("return_numpy", product(return_numpys)) def test_random_normalizing(return_numpy): _, _, _, embeddings = _init_mock_distribution( @@ -53,6 +54,7 @@ def test_uniform_mesh_sampling(): assert np.unique(sample).shape == (n_samples,) +@pytest.mark.skip_if_no_mkl @pytest.mark.parametrize( "device,return_numpy", product(devices, return_numpys) ) diff --git a/tests/solvers/test_sparse_solver.py b/tests/solvers/test_sparse_solver.py index 3b666bc6..7e9b2481 100644 --- a/tests/solvers/test_sparse_solver.py +++ b/tests/solvers/test_sparse_solver.py @@ -18,6 +18,9 @@ # TODO: need to test sinkhorn + + +@pytest.mark.skip_if_no_mkl @pytest.mark.parametrize( "solver,device,callback,alpha", product(["sinkhorn", "mm", "ibpp"], devices, callbacks, alphas), @@ -143,6 +146,7 @@ def test_sparse_solvers(solver, device, callback, alpha): ) +@pytest.mark.skip_if_no_mkl @pytest.mark.parametrize( "validation,device", product(["None", "features", "geometries", "Both"], devices), From 04b8890d9adf7984fad4ed7a2abf682e9e56ece0 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 2 May 2024 17:18:18 +0200 Subject: [PATCH 6/6] Refactor test_sparse_solver.py to remove unnecessary white space --- tests/solvers/test_sparse_solver.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/solvers/test_sparse_solver.py b/tests/solvers/test_sparse_solver.py index 7e9b2481..95b54a49 100644 --- a/tests/solvers/test_sparse_solver.py +++ b/tests/solvers/test_sparse_solver.py @@ -18,8 +18,6 @@ # TODO: need to test sinkhorn - - @pytest.mark.skip_if_no_mkl @pytest.mark.parametrize( "solver,device,callback,alpha",