From 26c9ed52ea001e887e7e7884371dea3f3926f6f1 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Fri, 11 Oct 2024 21:54:49 +0200 Subject: [PATCH 01/12] Add a value error if no geometry init in the fixed support case --- src/fugw/mappings/barycenter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/fugw/mappings/barycenter.py b/src/fugw/mappings/barycenter.py index 08ae808..a097213 100644 --- a/src/fugw/mappings/barycenter.py +++ b/src/fugw/mappings/barycenter.py @@ -271,7 +271,12 @@ def fit( init_barycenter_features, device=device ) - if init_barycenter_geometry is None: + if init_barycenter_geometry is None and self.learn_geometry is False: + raise ValueError( + "In the fixed support case, init_barycenter_geometry must be" + " provided." + ) + elif init_barycenter_geometry is None and self.learn_geometry is True: barycenter_geometry = ( torch.ones((barycenter_size, barycenter_size)).to(device) / barycenter_size From 994f842c0864ffc95f3acf76c10d6dd48cb7ee02 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 12 Oct 2024 01:42:26 +0200 Subject: [PATCH 02/12] Add identity test for the barycenter --- tests/mappings/test_barycenter.py | 55 +++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/tests/mappings/test_barycenter.py b/tests/mappings/test_barycenter.py index 3410b84..c1e0407 100644 --- a/tests/mappings/test_barycenter.py +++ b/tests/mappings/test_barycenter.py @@ -12,6 +12,7 @@ devices.append(torch.device("cuda:0")) callbacks = [None, lambda x: x["plans"]] +alphas = [0.0, 0.5, 1.0] @pytest.mark.parametrize( @@ -67,3 +68,57 @@ def test_fugw_barycenter(device, callback): assert barycenter_geometry.shape == (n_voxels, n_voxels) assert len(plans) == n_subjects assert len(losses_each_bar_step) == nits_barycenter + + +@pytest.mark.parametrize( + "alpha", + alphas, +) +def test_identity_case(alpha): + """Test the case where all subjects are the same.""" + n_subjects = 3 + n_features = 10 + n_voxels = 5 + nits_barycenter = 2 + + geometry = _init_mock_distribution(n_features, n_voxels)[2] + # features = torch.rand(n_features, n_voxels) + features = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]) + + geometry_list = [geometry for _ in range(n_subjects)] + features_list = [features for _ in range(n_subjects)] + weights_list = [torch.ones(n_voxels) / n_voxels for _ in range(n_subjects)] + + fugw_barycenter = FUGWBarycenter(alpha=alpha, eps=1e-6, rho=float("inf")) + ( + barycenter_weights, + barycenter_features, + barycenter_geometry, + plans, + _, + _, + ) = fugw_barycenter.fit( + weights_list, + features_list, + geometry_list, + solver_params={"nits_bcd": 5, "nits_uot": 100}, + nits_barycenter=nits_barycenter, + device=torch.device("cpu"), + init_barycenter_geometry=geometry_list[0], + init_barycenter_features=features_list[0], + ) + + # Check that the barycenter is the same as the input + print(barycenter_features) + assert torch.allclose(barycenter_weights, torch.ones(n_voxels) / n_voxels) + assert torch.allclose(barycenter_geometry, geometry_list[0]) + + # In the case alpha=1.0, the features can be permuted + # since the GW distance is invariant under isometries + if alpha != 1.0: + assert torch.allclose(barycenter_features, features) + + # Check that all the plans are the identity matrix divided + # by the number of voxels + for plan in plans: + assert torch.allclose(plan, torch.eye(n_voxels) / n_voxels) From 41cd3cf3d602c09dde6c4bfd9bfcd759529576ac Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 12 Oct 2024 01:43:49 +0200 Subject: [PATCH 03/12] Reduce default eps value in FUGWBarycenter --- src/fugw/mappings/barycenter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fugw/mappings/barycenter.py b/src/fugw/mappings/barycenter.py index a097213..ed89ca5 100644 --- a/src/fugw/mappings/barycenter.py +++ b/src/fugw/mappings/barycenter.py @@ -11,7 +11,7 @@ def __init__( self, alpha=0.5, rho=1, - eps=1e-2, + eps=1e-4, reg_mode="joint", force_psd=False, learn_geometry=False, From 737bc5dd21812a3a2288be5227350b0546147f6e Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 12 Oct 2024 01:44:09 +0200 Subject: [PATCH 04/12] Fix update_barycenter_features method --- src/fugw/mappings/barycenter.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/fugw/mappings/barycenter.py b/src/fugw/mappings/barycenter.py index ed89ca5..faabcc8 100644 --- a/src/fugw/mappings/barycenter.py +++ b/src/fugw/mappings/barycenter.py @@ -79,11 +79,12 @@ def update_barycenter_geometry( return barycenter_geometry @staticmethod - def update_barycenter_features(plans, weights_list, features_list, device): - for i, (pi, weights, features) in enumerate( - zip(plans, weights_list, features_list) + def update_barycenter_features( + plans, subject_weights, features_list, device + ): + for i, (pi, w, features) in enumerate( + zip(plans, subject_weights, features_list) ): - w = _make_tensor(weights, device=device) f = _make_tensor(features, device=device) if features is not None: acc = w * pi.T @ f.T / (pi.sum(0).reshape(-1, 1) + 1e-16) @@ -93,12 +94,6 @@ def update_barycenter_features(plans, weights_list, features_list, device): else: barycenter_features += acc - # Normalize barycenter features - min_val = barycenter_features.min(dim=0, keepdim=True).values - max_val = barycenter_features.max(dim=0, keepdim=True).values - barycenter_features = ( - 2 * (barycenter_features - min_val) / (max_val - min_val) - 1 - ) return barycenter_features.T @staticmethod From be3df27e25b45e0b2351dda28a3eba34405614a9 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 12 Oct 2024 01:45:07 +0200 Subject: [PATCH 05/12] Add subject_weights parameter --- src/fugw/mappings/barycenter.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/fugw/mappings/barycenter.py b/src/fugw/mappings/barycenter.py index faabcc8..1497949 100644 --- a/src/fugw/mappings/barycenter.py +++ b/src/fugw/mappings/barycenter.py @@ -178,6 +178,7 @@ def fit( weights_list, features_list, geometry_list, + subject_weights=None, barycenter_size=None, init_barycenter_weights=None, init_barycenter_features=None, @@ -204,6 +205,9 @@ def fit( or just one kernel matrix if it's shared across individuals barycenter_size (int, optional): Size of computed barycentric features and geometry. Defaults to None. + subject_weights (list of float, optional): Weights of each individual. + If None, all individuals will have the same weight. + Defaults to None. init_barycenter_weights (np.array, optional): Distribution weights of barycentric points. If None, points will have uniform weights. Defaults to None. @@ -281,6 +285,9 @@ def fit( init_barycenter_geometry, device=device ) + if subject_weights is None: + subject_weights = [1 / len(weights_list)] * len(weights_list) + plans = None duals = None losses_each_bar_step = [] @@ -311,7 +318,7 @@ def fit( # Update barycenter features and geometry barycenter_features = self.update_barycenter_features( - plans, weights_list, features_list, device + plans, subject_weights, features_list, device ) if self.learn_geometry: barycenter_geometry = self.update_barycenter_geometry( From 5bf7c559f937d3b03630e583879985111de7f520 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 12 Oct 2024 01:50:10 +0200 Subject: [PATCH 06/12] Add init_barycenter_geometry parameter to test_fugw_barycenter method --- tests/mappings/test_barycenter.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/mappings/test_barycenter.py b/tests/mappings/test_barycenter.py index c1e0407..e0f4041 100644 --- a/tests/mappings/test_barycenter.py +++ b/tests/mappings/test_barycenter.py @@ -58,6 +58,7 @@ def test_fugw_barycenter(device, callback): nits_barycenter=nits_barycenter, device=device, callback_barycenter=callback, + init_barycenter_geometry=geometry_list[0], ) assert isinstance(barycenter_weights, torch.Tensor) @@ -76,14 +77,14 @@ def test_fugw_barycenter(device, callback): ) def test_identity_case(alpha): """Test the case where all subjects are the same.""" + torch.manual_seed(0) n_subjects = 3 n_features = 10 n_voxels = 5 nits_barycenter = 2 geometry = _init_mock_distribution(n_features, n_voxels)[2] - # features = torch.rand(n_features, n_voxels) - features = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5]]) + features = torch.rand(n_features, n_voxels) geometry_list = [geometry for _ in range(n_subjects)] features_list = [features for _ in range(n_subjects)] @@ -109,7 +110,6 @@ def test_identity_case(alpha): ) # Check that the barycenter is the same as the input - print(barycenter_features) assert torch.allclose(barycenter_weights, torch.ones(n_voxels) / n_voxels) assert torch.allclose(barycenter_geometry, geometry_list[0]) @@ -117,8 +117,7 @@ def test_identity_case(alpha): # since the GW distance is invariant under isometries if alpha != 1.0: assert torch.allclose(barycenter_features, features) - - # Check that all the plans are the identity matrix divided - # by the number of voxels - for plan in plans: - assert torch.allclose(plan, torch.eye(n_voxels) / n_voxels) + # Check that all the plans are the identity matrix divided + # by the number of voxels + for plan in plans: + assert torch.allclose(plan, torch.eye(n_voxels) / n_voxels) From 2bc7c755481c0d34ebad243b5aabb0f8946a2d13 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 12 Oct 2024 01:52:46 +0200 Subject: [PATCH 07/12] Increase the number of voxels in tests --- tests/mappings/test_barycenter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mappings/test_barycenter.py b/tests/mappings/test_barycenter.py index e0f4041..cda4311 100644 --- a/tests/mappings/test_barycenter.py +++ b/tests/mappings/test_barycenter.py @@ -80,7 +80,7 @@ def test_identity_case(alpha): torch.manual_seed(0) n_subjects = 3 n_features = 10 - n_voxels = 5 + n_voxels = 100 nits_barycenter = 2 geometry = _init_mock_distribution(n_features, n_voxels)[2] From 585f7d6d760c606a23e949888c096a6516ed9809 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 12 Oct 2024 02:22:21 +0200 Subject: [PATCH 08/12] Add comparison to POT --- tests/mappings/test_barycenter.py | 64 +++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/mappings/test_barycenter.py b/tests/mappings/test_barycenter.py index cda4311..2f87f83 100644 --- a/tests/mappings/test_barycenter.py +++ b/tests/mappings/test_barycenter.py @@ -3,6 +3,7 @@ import numpy as np import pytest import torch +import ot from fugw.mappings import FUGWBarycenter from fugw.utils import _init_mock_distribution @@ -121,3 +122,66 @@ def test_identity_case(alpha): # by the number of voxels for plan in plans: assert torch.allclose(plan, torch.eye(n_voxels) / n_voxels) + + +@pytest.mark.parametrize( + "alpha", + alphas, +) +def test_fgw_barycenter(alpha): + """Tests the FUGW barycenter in the case rho=inf and compare with POT.""" + torch.manual_seed(0) + n_subjects = 3 + n_features = 1 + n_voxels = 5 + nits_barycenter = 2 + + geometry = _init_mock_distribution( + n_features, n_voxels, should_normalize=True + )[2] + geometry_list = [geometry for _ in range(n_subjects)] + weights_list = [torch.ones(n_voxels) / n_voxels] * n_subjects + features_list = [torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])] * n_subjects + + fugw_barycenter = FUGWBarycenter( + alpha=alpha, + rho=float("inf"), + eps=1e-6, + ) + + fugw_barycenter = FUGWBarycenter(alpha=alpha, eps=1e-6, rho=float("inf")) + ( + fugw_bary_weights, + fugw_bary_features, + fugw_bary_geometry, + _, + _, + _, + ) = fugw_barycenter.fit( + weights_list, + features_list, + geometry_list, + solver_params={"nits_bcd": 5, "nits_uot": 100}, + nits_barycenter=nits_barycenter, + device=torch.device("cpu"), + init_barycenter_geometry=geometry_list[0], + init_barycenter_features=features_list[0], + ) + + # Compare the barycenter with the one obtained with POT + pot_bary_features, pot_bary_geometry, log = ot.gromov.fgw_barycenters( + n_voxels, + [features.T for features in features_list], + geometry_list, + weights_list, + alpha=1 - alpha, + log=True, + fixed_structure=True, + init_C=geometry_list[0], + ) + + assert torch.allclose(fugw_bary_geometry, pot_bary_geometry) + assert torch.allclose(fugw_bary_weights, log["p"]) + + if alpha != 1.0: + assert torch.allclose(fugw_bary_features, pot_bary_features.T) From 99fb4a9813e90efedc2f0b4996fe81b4ba892375 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sun, 13 Oct 2024 21:52:48 +0200 Subject: [PATCH 09/12] Refactor update_barycenter_features method and add subject_weights parameter --- src/fugw/mappings/sparse_barycenter.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/fugw/mappings/sparse_barycenter.py b/src/fugw/mappings/sparse_barycenter.py index 646d9c9..d0fa8a1 100644 --- a/src/fugw/mappings/sparse_barycenter.py +++ b/src/fugw/mappings/sparse_barycenter.py @@ -33,11 +33,12 @@ def __init__( self.selection_radius = selection_radius @staticmethod - def update_barycenter_features(plans, weights_list, features_list, device): - for i, (pi, weights, features) in enumerate( - zip(plans, weights_list, features_list) + def update_barycenter_features( + plans, subject_weights, features_list, device + ): + for i, (pi, w, features) in enumerate( + zip(plans, subject_weights, features_list) ): - w = _make_tensor(weights, device=device) f = _make_tensor(features, device=device) if features is not None: @@ -156,6 +157,7 @@ def fit( weights_list, features_list, geometry_embedding, + subject_weights=None, barycenter_size=None, init_barycenter_weights=None, init_barycenter_features=None, @@ -181,6 +183,9 @@ def fit( have the same number of features n_features. geometry_embedding (np.array or torch.Tensor): Common geometry embedding of all individuals and barycenter. + subject_weights (list of float, optional): Weights of each individual. + If None, all individuals will have the same weight. + Defaults to None. barycenter_size (int), optional: Size of computed barycentric features and geometry. Defaults to None. @@ -259,6 +264,9 @@ def fit( geometry_embedding, device=device ) + if subject_weights is None: + subject_weights = [1 / len(weights_list)] * len(weights_list) + plans = None sparsity_mask = None losses_each_bar_step = [] @@ -291,7 +299,7 @@ def fit( # Update barycenter features and geometry barycenter_features = self.update_barycenter_features( - plans, weights_list, features_list, device + plans, subject_weights, features_list, device ) if callback_barycenter is not None: From 225d13ac132ca4dbfaf4c71dc4af1f0456499795 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 14 Oct 2024 00:29:54 +0200 Subject: [PATCH 10/12] Remove unused callback call --- src/fugw/mappings/barycenter.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/fugw/mappings/barycenter.py b/src/fugw/mappings/barycenter.py index 1497949..f160297 100644 --- a/src/fugw/mappings/barycenter.py +++ b/src/fugw/mappings/barycenter.py @@ -122,7 +122,6 @@ def compute_all_ot_plans( barycenter_geometry, solver, solver_params, - callback_barycenter, device, verbose, ): @@ -309,7 +308,6 @@ def fit( barycenter_geometry, solver, solver_params, - callback_barycenter, device, verbose, ) From e1c6719070b7cf9dc8ff8ddec63ed5f54c9f43be Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 14 Oct 2024 00:31:28 +0200 Subject: [PATCH 11/12] Use uniform subject weights across individuals --- src/fugw/mappings/barycenter.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/fugw/mappings/barycenter.py b/src/fugw/mappings/barycenter.py index f160297..9b1e43c 100644 --- a/src/fugw/mappings/barycenter.py +++ b/src/fugw/mappings/barycenter.py @@ -79,15 +79,13 @@ def update_barycenter_geometry( return barycenter_geometry @staticmethod - def update_barycenter_features( - plans, subject_weights, features_list, device - ): - for i, (pi, w, features) in enumerate( - zip(plans, subject_weights, features_list) - ): + def update_barycenter_features(plans, features_list, device): + for i, (pi, features) in enumerate(zip(plans, features_list)): + # Use uniform weights across subjects + weight = 1 / len(features_list) f = _make_tensor(features, device=device) if features is not None: - acc = w * pi.T @ f.T / (pi.sum(0).reshape(-1, 1) + 1e-16) + acc = weight * pi.T @ f.T / (pi.sum(0).reshape(-1, 1) + 1e-16) if i == 0: barycenter_features = acc @@ -177,7 +175,6 @@ def fit( weights_list, features_list, geometry_list, - subject_weights=None, barycenter_size=None, init_barycenter_weights=None, init_barycenter_features=None, @@ -204,9 +201,6 @@ def fit( or just one kernel matrix if it's shared across individuals barycenter_size (int, optional): Size of computed barycentric features and geometry. Defaults to None. - subject_weights (list of float, optional): Weights of each individual. - If None, all individuals will have the same weight. - Defaults to None. init_barycenter_weights (np.array, optional): Distribution weights of barycentric points. If None, points will have uniform weights. Defaults to None. @@ -284,9 +278,6 @@ def fit( init_barycenter_geometry, device=device ) - if subject_weights is None: - subject_weights = [1 / len(weights_list)] * len(weights_list) - plans = None duals = None losses_each_bar_step = [] @@ -316,7 +307,7 @@ def fit( # Update barycenter features and geometry barycenter_features = self.update_barycenter_features( - plans, subject_weights, features_list, device + plans, features_list, device ) if self.learn_geometry: barycenter_geometry = self.update_barycenter_geometry( From b5fc56dce3a65106787237aea2c992a1858198f5 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Mon, 14 Oct 2024 00:31:50 +0200 Subject: [PATCH 12/12] Use unif weights for sparse barycenter --- src/fugw/mappings/sparse_barycenter.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/fugw/mappings/sparse_barycenter.py b/src/fugw/mappings/sparse_barycenter.py index d0fa8a1..619a430 100644 --- a/src/fugw/mappings/sparse_barycenter.py +++ b/src/fugw/mappings/sparse_barycenter.py @@ -33,20 +33,16 @@ def __init__( self.selection_radius = selection_radius @staticmethod - def update_barycenter_features( - plans, subject_weights, features_list, device - ): - for i, (pi, w, features) in enumerate( - zip(plans, subject_weights, features_list) - ): + def update_barycenter_features(plans, features_list, device): + for i, (pi, features) in enumerate(zip(plans, features_list)): f = _make_tensor(features, device=device) - + weight = 1 / len(features_list) if features is not None: pi_sum = ( torch.sparse.sum(pi, dim=0).to_dense().reshape(-1, 1) + 1e-16 ) - acc = w * pi.T @ f.T / pi_sum + acc = weight * pi.T @ f.T / pi_sum if i == 0: barycenter_features = acc @@ -157,7 +153,6 @@ def fit( weights_list, features_list, geometry_embedding, - subject_weights=None, barycenter_size=None, init_barycenter_weights=None, init_barycenter_features=None, @@ -183,9 +178,6 @@ def fit( have the same number of features n_features. geometry_embedding (np.array or torch.Tensor): Common geometry embedding of all individuals and barycenter. - subject_weights (list of float, optional): Weights of each individual. - If None, all individuals will have the same weight. - Defaults to None. barycenter_size (int), optional: Size of computed barycentric features and geometry. Defaults to None. @@ -264,9 +256,6 @@ def fit( geometry_embedding, device=device ) - if subject_weights is None: - subject_weights = [1 / len(weights_list)] * len(weights_list) - plans = None sparsity_mask = None losses_each_bar_step = [] @@ -299,7 +288,7 @@ def fit( # Update barycenter features and geometry barycenter_features = self.update_barycenter_features( - plans, subject_weights, features_list, device + plans, features_list, device ) if callback_barycenter is not None: