From f80b5c274e768df962cb27f77e1fd939512e2605 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 19 Sep 2024 16:22:48 +0200 Subject: [PATCH 1/5] Convert coarse plan to np --- src/fugw/scripts/coarse_to_fine.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/fugw/scripts/coarse_to_fine.py b/src/fugw/scripts/coarse_to_fine.py index 29906b7..f1a7371 100644 --- a/src/fugw/scripts/coarse_to_fine.py +++ b/src/fugw/scripts/coarse_to_fine.py @@ -316,10 +316,13 @@ def compute_sparsity_mask( device = torch.device("cuda", 0) else: device = torch.device("cpu") + + # Convert coarse plan to numpy + coarse_plan = coarse_mapping.pi.to("cpu").numpy() if method == "quantile": # Method 1: keep first percentile - threshold = np.percentile(coarse_mapping.pi, 99.95) - rows, cols = np.nonzero(coarse_mapping.pi > threshold) + threshold = np.percentile(coarse_plan, 99.95) + rows, cols = np.nonzero(coarse_plan > threshold) elif method == "topk": # Method 2: keep topk indices per line and per column @@ -328,12 +331,12 @@ def compute_sparsity_mask( rows = np.concatenate( [ np.arange(source_sample.shape[0]), - np.argmax(coarse_mapping.pi, axis=0), + np.argmax(coarse_plan, axis=0), ] ) cols = np.concatenate( [ - np.argmax(coarse_mapping.pi, axis=1), + np.argmax(coarse_plan, axis=1), np.arange(target_sample.shape[0]), ] ) From f76fef85f573c18b0f3752c78ccfc6976d1e4dc8 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Thu, 19 Sep 2024 23:50:31 +0200 Subject: [PATCH 2/5] Refactor test_fugw_barycenter to use normalized geometry embedding --- tests/mappings/test_sparse_barycenter.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/mappings/test_sparse_barycenter.py b/tests/mappings/test_sparse_barycenter.py index d067a15..b7cb8c7 100644 --- a/tests/mappings/test_sparse_barycenter.py +++ b/tests/mappings/test_sparse_barycenter.py @@ -19,7 +19,7 @@ "device, callback", product(devices, callbacks), ) -def test_fugw_barycenter(device, callback): +def test_fugw_sparse_barycenter(device, callback): np.random.seed(0) n_subjects = 4 n_voxels = 100 @@ -39,7 +39,10 @@ def test_fugw_barycenter(device, callback): weights_list.append(weights) features_list.append(features) - fugw_barycenter = FUGWSparseBarycenter() + geometry_embedding_normalized = ( + geometry_embedding / geometry_embedding.norm() + ) + fugw_sparse_barycenter = FUGWSparseBarycenter() # Fit the barycenter ( @@ -47,10 +50,10 @@ def test_fugw_barycenter(device, callback): barycenter_features, plans, losses_each_bar_step, - ) = fugw_barycenter.fit( + ) = fugw_sparse_barycenter.fit( weights_list, features_list, - geometry_embedding, + geometry_embedding_normalized, mesh_sample=mesh_sample, coarse_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5}, fine_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5}, From f387a09c526ba95a9c1fbf638c16219f60107bbb Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 23 Sep 2024 17:32:14 +0200 Subject: [PATCH 3/5] Add oneline docstring to barycenter test --- tests/mappings/test_barycenter.py | 1 + tests/mappings/test_sparse_barycenter.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/mappings/test_barycenter.py b/tests/mappings/test_barycenter.py index e9f1af7..229b4c5 100644 --- a/tests/mappings/test_barycenter.py +++ b/tests/mappings/test_barycenter.py @@ -19,6 +19,7 @@ product(devices, callbacks), ) def test_fugw_barycenter(device, callback): + """Tests the FUGW barycenter fitting on toy data.""" np.random.seed(0) n_subjects = 4 n_voxels = 100 diff --git a/tests/mappings/test_sparse_barycenter.py b/tests/mappings/test_sparse_barycenter.py index b7cb8c7..7f4f925 100644 --- a/tests/mappings/test_sparse_barycenter.py +++ b/tests/mappings/test_sparse_barycenter.py @@ -20,6 +20,7 @@ product(devices, callbacks), ) def test_fugw_sparse_barycenter(device, callback): + """Tests the FUGW sparse barycenter fitting on toy data.""" np.random.seed(0) n_subjects = 4 n_voxels = 100 From 36464eb68f01322bc06f64ac009c9a62c43a653d Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 23 Sep 2024 17:32:55 +0200 Subject: [PATCH 4/5] Add callback testing to sparse barycenter --- tests/mappings/test_sparse_barycenter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/mappings/test_sparse_barycenter.py b/tests/mappings/test_sparse_barycenter.py index 7f4f925..cb1a6b1 100644 --- a/tests/mappings/test_sparse_barycenter.py +++ b/tests/mappings/test_sparse_barycenter.py @@ -60,6 +60,7 @@ def test_fugw_sparse_barycenter(device, callback): fine_mapping_solver_params={"nits_bcd": 2, "nits_uot": 5}, nits_barycenter=nits_barycenter, device=device, + callback_barycenter=callback, ) assert isinstance(barycenter_weights, torch.Tensor) From 37476c26f03a2ba97f014f19265a503ef60f6b0d Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 23 Sep 2024 17:45:51 +0200 Subject: [PATCH 5/5] Add assertions to test_barycenter --- tests/mappings/test_barycenter.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/mappings/test_barycenter.py b/tests/mappings/test_barycenter.py index 229b4c5..3410b84 100644 --- a/tests/mappings/test_barycenter.py +++ b/tests/mappings/test_barycenter.py @@ -24,6 +24,7 @@ def test_fugw_barycenter(device, callback): n_subjects = 4 n_voxels = 100 n_features = 10 + nits_barycenter = 3 # Generate random training data for n subjects features_list = [] @@ -39,10 +40,30 @@ def test_fugw_barycenter(device, callback): geometry_list.append(geometry) fugw_barycenter = FUGWBarycenter() - fugw_barycenter.fit( + + # Fit the barycenter + ( + barycenter_weights, + barycenter_features, + barycenter_geometry, + plans, + _, + losses_each_bar_step, + ) = fugw_barycenter.fit( weights_list, features_list, geometry_list, + solver_params={"nits_bcd": 2, "nits_uot": 5}, + nits_barycenter=nits_barycenter, device=device, callback_barycenter=callback, ) + + assert isinstance(barycenter_weights, torch.Tensor) + assert barycenter_weights.shape == (n_voxels,) + assert isinstance(barycenter_features, torch.Tensor) + assert barycenter_features.shape == (n_features, n_voxels) + assert isinstance(barycenter_geometry, torch.Tensor) + assert barycenter_geometry.shape == (n_voxels, n_voxels) + assert len(plans) == n_subjects + assert len(losses_each_bar_step) == nits_barycenter