From 6d2d41a17ad133b01bcdf6d324ca8abe80dc559a Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 27 May 2024 14:41:02 +0200 Subject: [PATCH 01/10] chore: Remove storing_device option from FUGW mapping class --- src/fugw/mappings/dense.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/fugw/mappings/dense.py b/src/fugw/mappings/dense.py index 0115c1e..3fc4eb1 100644 --- a/src/fugw/mappings/dense.py +++ b/src/fugw/mappings/dense.py @@ -27,7 +27,6 @@ def fit( solver_params={}, callback_bcd=None, device="auto", - storing_device="cpu", verbose=False, ): """ @@ -101,8 +100,6 @@ def fit( device: "auto" or torch.device if "auto": use first available gpu if it's available, cpu otherwise. - storing_device: torch.device, default="cpu" - Device on which to store the computed transport plan. verbose: bool, optional, defaults to False Log solving process. @@ -258,7 +255,7 @@ def fit( ) # Store variables of interest in model - self.pi = res["pi"].detach().to(device=storing_device) + self.pi = res["pi"].detach() self.loss = res["loss"] self.loss_steps = res["loss_steps"] self.loss_times = res["loss_times"] From 38b6ec6530f0a601110b972b23439f54349ce813 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 27 May 2024 14:41:15 +0200 Subject: [PATCH 02/10] Remove storing_device from SparseMapping --- src/fugw/mappings/sparse.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/fugw/mappings/sparse.py b/src/fugw/mappings/sparse.py index cffd9eb..2f1e042 100644 --- a/src/fugw/mappings/sparse.py +++ b/src/fugw/mappings/sparse.py @@ -33,7 +33,6 @@ def fit( solver_params={}, callback_bcd=None, device="auto", - storing_device="cpu", verbose=False, ): """ @@ -107,8 +106,6 @@ def fit( device: "auto" or torch.device if "auto": use first available gpu if it's available, cpu otherwise. - storing_device: torch.device, default="cpu" - Device on which to store the computed transport plan. verbose: bool, optional, defaults to False Log solving process. @@ -295,7 +292,7 @@ def fit( verbose=verbose, ) - self.pi = res["pi"].to_sparse_coo().detach().to(device=storing_device) + self.pi = res["pi"].to_sparse_coo().detach() self.loss = res["loss"] self.loss_val = res["loss_val"] self.loss_steps = res["loss_steps"] From 59a4c3c19c26fcf582c3ac1ad3568c7fa0b8af21 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 27 May 2024 14:43:16 +0200 Subject: [PATCH 03/10] Remove storing_device option from FUGW barycenter mapping class --- src/fugw/mappings/barycenter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fugw/mappings/barycenter.py b/src/fugw/mappings/barycenter.py index 7b48aff..35c455c 100644 --- a/src/fugw/mappings/barycenter.py +++ b/src/fugw/mappings/barycenter.py @@ -155,7 +155,6 @@ def compute_all_ot_plans( solver=solver, solver_params=solver_params, device=device, - storing_device=device, verbose=verbose, ) From 601e4b718caf02bb109a46ad1f267fb3548d4dc9 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 27 May 2024 14:43:48 +0200 Subject: [PATCH 04/10] Add storing_device option to save_mapping utils --- src/fugw/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/fugw/utils.py b/src/fugw/utils.py index 67f066f..5d8686e 100644 --- a/src/fugw/utils.py +++ b/src/fugw/utils.py @@ -252,7 +252,7 @@ def init_plan_dense( return plan -def save_mapping(mapping, fname): +def save_mapping(mapping, fname, storing_device="cpu"): """Save mapping in pickle file, separating hyperparams and weights. Parameters @@ -261,7 +261,10 @@ def save_mapping(mapping, fname): FUGW mapping to save fname: str or pathlib.Path Path to pickle file to save + storing_device: torch.device, default="cpu" + Device on which to store the computed transport plan. """ + mapping.pi = mapping.pi.to(storing_device) with open(fname, "wb") as f: # Dump hyperparams first pickle.dump(mapping, f) From cbc87fd462fb780be0856033bee215cdb7cf643c Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 27 May 2024 14:44:36 +0200 Subject: [PATCH 05/10] Add storing_device option to load_mapping --- src/fugw/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/fugw/utils.py b/src/fugw/utils.py index 5d8686e..9cc9cde 100644 --- a/src/fugw/utils.py +++ b/src/fugw/utils.py @@ -272,7 +272,7 @@ def save_mapping(mapping, fname, storing_device="cpu"): pickle.dump(mapping.pi, f) -def load_mapping(fname, load_weights=True): +def load_mapping(fname, load_weights=True, storing_device="cpu"): """Load mapping from pickle file, optionally loading weights. Parameters @@ -281,6 +281,8 @@ def load_mapping(fname, load_weights=True): Path to pickle file to load load_weights: bool, optional, defaults to True If True, load mapping weights from pickle file. + storing_device: torch.device, default="cpu" + Device on which to store the computed transport plan. Returns ------- @@ -290,5 +292,6 @@ def load_mapping(fname, load_weights=True): mapping = pickle.load(f) if load_weights: mapping.pi = pickle.load(f) + mapping.pi = mapping.pi.to(storing_device) return mapping From 64846aa1bae741efa908b58cfd93206db908611f Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 27 May 2024 14:44:55 +0200 Subject: [PATCH 06/10] Add device to unit tests --- tests/test_utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 9020d23..10879e3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -117,12 +117,16 @@ def test_saving_and_loading(device, return_numpy, solver): with TemporaryDirectory() as tmpdir: fname = tmpdir + "/mapping.pkl" - save_mapping(fugw, fname) + save_mapping(fugw, fname, device) - mapping_without_weights = load_mapping(fname, load_weights=False) + mapping_without_weights = load_mapping( + fname, load_weights=False, storing_device=device + ) assert mapping_without_weights.pi is None - mapping_with_weights = load_mapping(fname, load_weights=True) + mapping_with_weights = load_mapping( + fname, load_weights=True, storing_device=device + ) assert mapping_with_weights.pi.shape == ( n_voxels_source, n_voxels_target, From 49977acd429d3d30c5863181a0771f53dd1292e5 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 27 May 2024 14:51:41 +0200 Subject: [PATCH 07/10] chore: Remove storing_device option from save_mapping --- src/fugw/utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/fugw/utils.py b/src/fugw/utils.py index 9cc9cde..1b5f41a 100644 --- a/src/fugw/utils.py +++ b/src/fugw/utils.py @@ -252,7 +252,7 @@ def init_plan_dense( return plan -def save_mapping(mapping, fname, storing_device="cpu"): +def save_mapping(mapping, fname): """Save mapping in pickle file, separating hyperparams and weights. Parameters @@ -261,10 +261,8 @@ def save_mapping(mapping, fname, storing_device="cpu"): FUGW mapping to save fname: str or pathlib.Path Path to pickle file to save - storing_device: torch.device, default="cpu" - Device on which to store the computed transport plan. """ - mapping.pi = mapping.pi.to(storing_device) + mapping.pi = mapping.pi.to("cpu") with open(fname, "wb") as f: # Dump hyperparams first pickle.dump(mapping, f) @@ -272,7 +270,7 @@ def save_mapping(mapping, fname, storing_device="cpu"): pickle.dump(mapping.pi, f) -def load_mapping(fname, load_weights=True, storing_device="cpu"): +def load_mapping(fname, load_weights=True, device="cpu"): """Load mapping from pickle file, optionally loading weights. Parameters @@ -281,7 +279,7 @@ def load_mapping(fname, load_weights=True, storing_device="cpu"): Path to pickle file to load load_weights: bool, optional, defaults to True If True, load mapping weights from pickle file. - storing_device: torch.device, default="cpu" + device: torch.device, default="cpu" Device on which to store the computed transport plan. Returns @@ -292,6 +290,6 @@ def load_mapping(fname, load_weights=True, storing_device="cpu"): mapping = pickle.load(f) if load_weights: mapping.pi = pickle.load(f) - mapping.pi = mapping.pi.to(storing_device) + mapping.pi = mapping.pi.to(device) return mapping From 35fe1196af12a6d01895cdcc5bc25768383ed095 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 27 May 2024 14:51:56 +0200 Subject: [PATCH 08/10] chore: Refactor save_mapping and load_mapping functions --- tests/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 10879e3..2410bf1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -117,15 +117,15 @@ def test_saving_and_loading(device, return_numpy, solver): with TemporaryDirectory() as tmpdir: fname = tmpdir + "/mapping.pkl" - save_mapping(fugw, fname, device) + save_mapping(fugw, fname) mapping_without_weights = load_mapping( - fname, load_weights=False, storing_device=device + fname, load_weights=False, device=device ) assert mapping_without_weights.pi is None mapping_with_weights = load_mapping( - fname, load_weights=True, storing_device=device + fname, load_weights=True, device=device ) assert mapping_with_weights.pi.shape == ( n_voxels_source, From 1aab9a01918913f8a274244830d42b35faa199eb Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 27 May 2024 14:53:09 +0200 Subject: [PATCH 09/10] Add comment --- src/fugw/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fugw/utils.py b/src/fugw/utils.py index 1b5f41a..bb72f60 100644 --- a/src/fugw/utils.py +++ b/src/fugw/utils.py @@ -262,6 +262,7 @@ def save_mapping(mapping, fname): fname: str or pathlib.Path Path to pickle file to save """ + # Move mapping weights to CPU before saving mapping.pi = mapping.pi.to("cpu") with open(fname, "wb") as f: # Dump hyperparams first From 1beb96844e9a9532568d24f104bfe37305ea641e Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 27 May 2024 15:24:29 +0200 Subject: [PATCH 10/10] chore: Send coarse mapping to CPU for numpy operations --- src/fugw/scripts/coarse_to_fine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/fugw/scripts/coarse_to_fine.py b/src/fugw/scripts/coarse_to_fine.py index 8f1379e..2406f27 100644 --- a/src/fugw/scripts/coarse_to_fine.py +++ b/src/fugw/scripts/coarse_to_fine.py @@ -419,6 +419,9 @@ def fit( verbose=verbose, ) + # Send coarse mapping to cpu to handle numpy operations + coarse_mapping.pi = coarse_mapping.pi.cpu() + # 2. Build sparsity mask # Select best pairs of source and target vertices from coarse alignment