Skip to content

Commit

Permalink
Merge pull request #54 from pbarbarant/fix/storing-mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
pbarbarant authored Jul 19, 2024
2 parents 679feaf + 1beb968 commit b1063cd
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 12 deletions.
1 change: 0 additions & 1 deletion src/fugw/mappings/barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def compute_all_ot_plans(
solver=solver,
solver_params=solver_params,
device=device,
storing_device=device,
verbose=verbose,
)

Expand Down
5 changes: 1 addition & 4 deletions src/fugw/mappings/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def fit(
solver_params={},
callback_bcd=None,
device="auto",
storing_device="cpu",
verbose=False,
):
"""
Expand Down Expand Up @@ -101,8 +100,6 @@ def fit(
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
storing_device: torch.device, default="cpu"
Device on which to store the computed transport plan.
verbose: bool, optional, defaults to False
Log solving process.
Expand Down Expand Up @@ -258,7 +255,7 @@ def fit(
)

# Store variables of interest in model
self.pi = res["pi"].detach().to(device=storing_device)
self.pi = res["pi"].detach()
self.loss = res["loss"]
self.loss_steps = res["loss_steps"]
self.loss_times = res["loss_times"]
Expand Down
5 changes: 1 addition & 4 deletions src/fugw/mappings/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def fit(
solver_params={},
callback_bcd=None,
device="auto",
storing_device="cpu",
verbose=False,
):
"""
Expand Down Expand Up @@ -107,8 +106,6 @@ def fit(
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
storing_device: torch.device, default="cpu"
Device on which to store the computed transport plan.
verbose: bool, optional, defaults to False
Log solving process.
Expand Down Expand Up @@ -296,7 +293,7 @@ def fit(
verbose=verbose,
)

self.pi = res["pi"].to_sparse_coo().detach().to(device=storing_device)
self.pi = res["pi"].to_sparse_coo().detach()
self.loss = res["loss"]
self.loss_val = res["loss_val"]
self.loss_steps = res["loss_steps"]
Expand Down
3 changes: 3 additions & 0 deletions src/fugw/scripts/coarse_to_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,9 @@ def fit(
verbose=verbose,
)

# Send coarse mapping to cpu to handle numpy operations
coarse_mapping.pi = coarse_mapping.pi.cpu()

# 2. Build sparsity mask

# Select best pairs of source and target vertices from coarse alignment
Expand Down
7 changes: 6 additions & 1 deletion src/fugw/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,14 +262,16 @@ def save_mapping(mapping, fname):
fname: str or pathlib.Path
Path to pickle file to save
"""
# Move mapping weights to CPU before saving
mapping.pi = mapping.pi.to("cpu")
with open(fname, "wb") as f:
# Dump hyperparams first
pickle.dump(mapping, f)
# Dump mapping weights
pickle.dump(mapping.pi, f)


def load_mapping(fname, load_weights=True):
def load_mapping(fname, load_weights=True, device="cpu"):
"""Load mapping from pickle file, optionally loading weights.
Parameters
Expand All @@ -278,6 +280,8 @@ def load_mapping(fname, load_weights=True):
Path to pickle file to load
load_weights: bool, optional, defaults to True
If True, load mapping weights from pickle file.
device: torch.device, default="cpu"
Device on which to store the computed transport plan.
Returns
-------
Expand All @@ -287,5 +291,6 @@ def load_mapping(fname, load_weights=True):
mapping = pickle.load(f)
if load_weights:
mapping.pi = pickle.load(f)
mapping.pi = mapping.pi.to(device)

return mapping
8 changes: 6 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,14 @@ def test_saving_and_loading(device, return_numpy, solver):

save_mapping(fugw, fname)

mapping_without_weights = load_mapping(fname, load_weights=False)
mapping_without_weights = load_mapping(
fname, load_weights=False, device=device
)
assert mapping_without_weights.pi is None

mapping_with_weights = load_mapping(fname, load_weights=True)
mapping_with_weights = load_mapping(
fname, load_weights=True, device=device
)
assert mapping_with_weights.pi.shape == (
n_voxels_source,
n_voxels_target,
Expand Down

0 comments on commit b1063cd

Please sign in to comment.