Skip to content

Commit

Permalink
Refactor test_fugw_barycenter to use normalized geometry embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
pbarbarant committed Sep 19, 2024
1 parent f80b5c2 commit f76fef8
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/mappings/test_sparse_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"device, callback",
product(devices, callbacks),
)
def test_fugw_barycenter(device, callback):
def test_fugw_sparse_barycenter(device, callback):
np.random.seed(0)
n_subjects = 4
n_voxels = 100
Expand All @@ -39,18 +39,21 @@ def test_fugw_barycenter(device, callback):
weights_list.append(weights)
features_list.append(features)

fugw_barycenter = FUGWSparseBarycenter()
geometry_embedding_normalized = (
geometry_embedding / geometry_embedding.norm()
)
fugw_sparse_barycenter = FUGWSparseBarycenter()

# Fit the barycenter
(
barycenter_weights,
barycenter_features,
plans,
losses_each_bar_step,
) = fugw_barycenter.fit(
) = fugw_sparse_barycenter.fit(
weights_list,
features_list,
geometry_embedding,
geometry_embedding_normalized,
mesh_sample=mesh_sample,
coarse_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5},
fine_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5},
Expand Down

0 comments on commit f76fef8

Please sign in to comment.