Skip to content

Commit

Permalink
Merge pull request #64 from pbarbarant/feat/add_callbacks_barycenter
Browse files Browse the repository at this point in the history
Feat/add callbacks barycenter
  • Loading branch information
pbarbarant authored Jul 19, 2024
2 parents 86aeed2 + 6d753cd commit 679feaf
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 7 deletions.
13 changes: 12 additions & 1 deletion src/fugw/mappings/barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def compute_all_ot_plans(
barycenter_geometry,
solver,
solver_params,
callback_barycenter,
device,
verbose,
):
Expand Down Expand Up @@ -182,6 +183,7 @@ def fit(
solver_params={},
nits_barycenter=5,
device="auto",
callback_barycenter=None,
verbose=False,
):
"""Compute barycentric features and geometry
Expand Down Expand Up @@ -209,6 +211,11 @@ def fit(
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
callback_barycenter: callable or None
Callback function called at the end of each barycenter step.
It will be called with the following arguments:
- locals (dictionary containing all local variables)
Returns
-------
Expand Down Expand Up @@ -270,7 +277,7 @@ def fit(
duals = None
losses_each_bar_step = []

for _ in range(nits_barycenter):
for idx in range(nits_barycenter):
# Transport all elements
plans, losses = self.compute_all_ot_plans(
plans,
Expand All @@ -283,6 +290,7 @@ def fit(
barycenter_geometry,
solver,
solver_params,
callback_barycenter,
device,
verbose,
)
Expand All @@ -298,6 +306,9 @@ def fit(
plans, weights_list, geometry_list, self.force_psd, device
)

if callback_barycenter is not None:
callback_barycenter(locals())

return (
barycenter_weights,
barycenter_features,
Expand Down
11 changes: 10 additions & 1 deletion src/fugw/mappings/sparse_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def fit(
mesh_sample=None,
nits_barycenter=5,
device="auto",
callback_barycenter=None,
verbose=False,
):
"""Compute barycentric features and geometry
Expand Down Expand Up @@ -192,6 +193,11 @@ def fit(
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
callback_barycenter: callable or None
Callback function called at the end of each barycenter step.
It will be called with the following arguments:
- locals (dictionary containing all local variables)
Returns
-------
Expand Down Expand Up @@ -248,7 +254,7 @@ def fit(
mask = None
losses_each_bar_step = []

for _ in range(nits_barycenter):
for idx in range(nits_barycenter):
# Transport all elements
plans, losses = self.compute_all_ot_plans(
plans,
Expand All @@ -274,6 +280,9 @@ def fit(
plans, weights_list, features_list, device
)

if callback_barycenter is not None:
callback_barycenter(locals())

return (
barycenter_weights,
barycenter_features,
Expand Down
17 changes: 14 additions & 3 deletions tests/mappings/test_barycenter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import product

import numpy as np
import pytest
import torch
Expand All @@ -9,9 +11,14 @@
if torch.cuda.is_available():
devices.append(torch.device("cuda:0"))

callbacks = [None, lambda x: x["plans"]]


@pytest.mark.parametrize("device", devices)
def test_fugw_barycenter(device):
@pytest.mark.parametrize(
"device, callback",
product(devices, callbacks),
)
def test_fugw_barycenter(device, callback):
np.random.seed(0)
n_subjects = 4
n_voxels = 100
Expand All @@ -32,5 +39,9 @@ def test_fugw_barycenter(device):

fugw_barycenter = FUGWBarycenter()
fugw_barycenter.fit(
weights_list, features_list, geometry_list, device=device
weights_list,
features_list,
geometry_list,
device=device,
callback_barycenter=callback,
)
11 changes: 9 additions & 2 deletions tests/mappings/test_sparse_barycenter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import product

import numpy as np
import pytest
import torch
Expand All @@ -9,10 +11,15 @@
if torch.cuda.is_available():
devices.append(torch.device("cuda:0"))

callbacks = [None, lambda x: x["plans"]]


@pytest.mark.skip_if_no_mkl
@pytest.mark.parametrize("device", devices)
def test_fugw_barycenter(device):
@pytest.mark.parametrize(
"device, callback",
product(devices, callbacks),
)
def test_fugw_barycenter(device, callback):
np.random.seed(0)
n_subjects = 4
n_voxels = 100
Expand Down

0 comments on commit 679feaf

Please sign in to comment.