Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Fix save_mapping and load_mapping utils #54

Merged
merged 10 commits into from
Jul 19, 2024
1 change: 0 additions & 1 deletion src/fugw/mappings/barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def compute_all_ot_plans(
solver=solver,
solver_params=solver_params,
device=device,
storing_device=device,
verbose=verbose,
)

Expand Down
5 changes: 1 addition & 4 deletions src/fugw/mappings/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def fit(
solver_params={},
callback_bcd=None,
device="auto",
storing_device="cpu",
verbose=False,
):
"""
Expand Down Expand Up @@ -101,8 +100,6 @@ def fit(
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
storing_device: torch.device, default="cpu"
Device on which to store the computed transport plan.
verbose: bool, optional, defaults to False
Log solving process.

Expand Down Expand Up @@ -258,7 +255,7 @@ def fit(
)

# Store variables of interest in model
self.pi = res["pi"].detach().to(device=storing_device)
self.pi = res["pi"].detach()
self.loss = res["loss"]
self.loss_steps = res["loss_steps"]
self.loss_times = res["loss_times"]
Expand Down
5 changes: 1 addition & 4 deletions src/fugw/mappings/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def fit(
solver_params={},
callback_bcd=None,
device="auto",
storing_device="cpu",
verbose=False,
):
"""
Expand Down Expand Up @@ -107,8 +106,6 @@ def fit(
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
storing_device: torch.device, default="cpu"
Device on which to store the computed transport plan.
verbose: bool, optional, defaults to False
Log solving process.

Expand Down Expand Up @@ -295,7 +292,7 @@ def fit(
verbose=verbose,
)

self.pi = res["pi"].to_sparse_coo().detach().to(device=storing_device)
self.pi = res["pi"].to_sparse_coo().detach()
self.loss = res["loss"]
self.loss_val = res["loss_val"]
self.loss_steps = res["loss_steps"]
Expand Down
3 changes: 3 additions & 0 deletions src/fugw/scripts/coarse_to_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ def fit(
verbose=verbose,
)

# Send coarse mapping to cpu to handle numpy operations
coarse_mapping.pi = coarse_mapping.pi.cpu()

# 2. Build sparsity mask

# Select best pairs of source and target vertices from coarse alignment
Expand Down
7 changes: 6 additions & 1 deletion src/fugw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,16 @@ def save_mapping(mapping, fname):
fname: str or pathlib.Path
Path to pickle file to save
"""
# Move mapping weights to CPU before saving
mapping.pi = mapping.pi.to("cpu")
with open(fname, "wb") as f:
# Dump hyperparams first
pickle.dump(mapping, f)
# Dump mapping weights
pickle.dump(mapping.pi, f)


def load_mapping(fname, load_weights=True):
def load_mapping(fname, load_weights=True, device="cpu"):
"""Load mapping from pickle file, optionally loading weights.

Parameters
Expand All @@ -278,6 +280,8 @@ def load_mapping(fname, load_weights=True):
Path to pickle file to load
load_weights: bool, optional, defaults to True
If True, load mapping weights from pickle file.
device: torch.device, default="cpu"
Device on which to store the computed transport plan.

Returns
-------
Expand All @@ -287,5 +291,6 @@ def load_mapping(fname, load_weights=True):
mapping = pickle.load(f)
if load_weights:
mapping.pi = pickle.load(f)
mapping.pi = mapping.pi.to(device)

return mapping
8 changes: 6 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,14 @@ def test_saving_and_loading(device, return_numpy, solver):

save_mapping(fugw, fname)

mapping_without_weights = load_mapping(fname, load_weights=False)
mapping_without_weights = load_mapping(
fname, load_weights=False, device=device
)
assert mapping_without_weights.pi is None

mapping_with_weights = load_mapping(fname, load_weights=True)
mapping_with_weights = load_mapping(
fname, load_weights=True, device=device
)
assert mapping_with_weights.pi.shape == (
n_voxels_source,
n_voxels_target,
Expand Down
Loading