From f76fef85f573c18b0f3752c78ccfc6976d1e4dc8 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 19 Sep 2024 23:50:31 +0200 Subject: [PATCH] Refactor test_fugw_barycenter to use normalized geometry embedding --- tests/mappings/test_sparse_barycenter.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/mappings/test_sparse_barycenter.py b/tests/mappings/test_sparse_barycenter.py index d067a15..b7cb8c7 100644 --- a/tests/mappings/test_sparse_barycenter.py +++ b/tests/mappings/test_sparse_barycenter.py @@ -19,7 +19,7 @@ "device, callback", product(devices, callbacks), ) -def test_fugw_barycenter(device, callback): +def test_fugw_sparse_barycenter(device, callback): np.random.seed(0) n_subjects = 4 n_voxels = 100 @@ -39,7 +39,10 @@ def test_fugw_barycenter(device, callback): weights_list.append(weights) features_list.append(features) - fugw_barycenter = FUGWSparseBarycenter() + geometry_embedding_normalized = ( + geometry_embedding / geometry_embedding.norm() + ) + fugw_sparse_barycenter = FUGWSparseBarycenter() # Fit the barycenter ( @@ -47,10 +50,10 @@ def test_fugw_barycenter(device, callback): barycenter_features, plans, losses_each_bar_step, - ) = fugw_barycenter.fit( + ) = fugw_sparse_barycenter.fit( weights_list, features_list, - geometry_embedding, + geometry_embedding_normalized, mesh_sample=mesh_sample, coarse_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5}, fine_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5},