Skip to content

Commit

Permalink
Merge pull request #76 from pbarbarant/fix/fix-barycenter-tests
Browse files Browse the repository at this point in the history
[BUGFIX] Fixes CI + other tests
  • Loading branch information
pbarbarant authored Sep 23, 2024
2 parents e96384b + 37476c2 commit dc8b819
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
11 changes: 7 additions & 4 deletions src/fugw/scripts/coarse_to_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,13 @@ def compute_sparsity_mask(
device = torch.device("cuda", 0)
else:
device = torch.device("cpu")

# Convert coarse plan to numpy
coarse_plan = coarse_mapping.pi.to("cpu").numpy()
if method == "quantile":
# Method 1: keep first percentile
threshold = np.percentile(coarse_mapping.pi, 99.95)
rows, cols = np.nonzero(coarse_mapping.pi > threshold)
threshold = np.percentile(coarse_plan, 99.95)
rows, cols = np.nonzero(coarse_plan > threshold)

elif method == "topk":
# Method 2: keep topk indices per line and per column
Expand All @@ -328,12 +331,12 @@ def compute_sparsity_mask(
rows = np.concatenate(
[
np.arange(source_sample.shape[0]),
np.argmax(coarse_mapping.pi, axis=0),
np.argmax(coarse_plan, axis=0),
]
)
cols = np.concatenate(
[
np.argmax(coarse_mapping.pi, axis=1),
np.argmax(coarse_plan, axis=1),
np.arange(target_sample.shape[0]),
]
)
Expand Down
24 changes: 23 additions & 1 deletion tests/mappings/test_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
product(devices, callbacks),
)
def test_fugw_barycenter(device, callback):
"""Tests the FUGW barycenter fitting on toy data."""
np.random.seed(0)
n_subjects = 4
n_voxels = 100
n_features = 10
nits_barycenter = 3

# Generate random training data for n subjects
features_list = []
Expand All @@ -38,10 +40,30 @@ def test_fugw_barycenter(device, callback):
geometry_list.append(geometry)

fugw_barycenter = FUGWBarycenter()
fugw_barycenter.fit(

# Fit the barycenter
(
barycenter_weights,
barycenter_features,
barycenter_geometry,
plans,
_,
losses_each_bar_step,
) = fugw_barycenter.fit(
weights_list,
features_list,
geometry_list,
solver_params={"nits_bcd": 2, "nits_uot": 5},
nits_barycenter=nits_barycenter,
device=device,
callback_barycenter=callback,
)

assert isinstance(barycenter_weights, torch.Tensor)
assert barycenter_weights.shape == (n_voxels,)
assert isinstance(barycenter_features, torch.Tensor)
assert barycenter_features.shape == (n_features, n_voxels)
assert isinstance(barycenter_geometry, torch.Tensor)
assert barycenter_geometry.shape == (n_voxels, n_voxels)
assert len(plans) == n_subjects
assert len(losses_each_bar_step) == nits_barycenter
13 changes: 9 additions & 4 deletions tests/mappings/test_sparse_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
"device, callback",
product(devices, callbacks),
)
def test_fugw_barycenter(device, callback):
def test_fugw_sparse_barycenter(device, callback):
"""Tests the FUGW sparse barycenter fitting on toy data."""
np.random.seed(0)
n_subjects = 4
n_voxels = 100
Expand All @@ -39,23 +40,27 @@ def test_fugw_barycenter(device, callback):
weights_list.append(weights)
features_list.append(features)

fugw_barycenter = FUGWSparseBarycenter()
geometry_embedding_normalized = (
geometry_embedding / geometry_embedding.norm()
)
fugw_sparse_barycenter = FUGWSparseBarycenter()

# Fit the barycenter
(
barycenter_weights,
barycenter_features,
plans,
losses_each_bar_step,
) = fugw_barycenter.fit(
) = fugw_sparse_barycenter.fit(
weights_list,
features_list,
geometry_embedding,
geometry_embedding_normalized,
mesh_sample=mesh_sample,
coarse_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5},
fine_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5},
nits_barycenter=nits_barycenter,
device=device,
callback_barycenter=callback,
)

assert isinstance(barycenter_weights, torch.Tensor)
Expand Down

0 comments on commit dc8b819

Please sign in to comment.