From 585f7d6d760c606a23e949888c096a6516ed9809 Mon Sep 17 00:00:00 2001 From: Pierre-Louis Barbarant Date: Sat, 12 Oct 2024 02:22:21 +0200 Subject: [PATCH] Add comparison to POT --- tests/mappings/test_barycenter.py | 64 +++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/mappings/test_barycenter.py b/tests/mappings/test_barycenter.py index cda4311..2f87f83 100644 --- a/tests/mappings/test_barycenter.py +++ b/tests/mappings/test_barycenter.py @@ -3,6 +3,7 @@ import numpy as np import pytest import torch +import ot from fugw.mappings import FUGWBarycenter from fugw.utils import _init_mock_distribution @@ -121,3 +122,66 @@ def test_identity_case(alpha): # by the number of voxels for plan in plans: assert torch.allclose(plan, torch.eye(n_voxels) / n_voxels) + + +@pytest.mark.parametrize( + "alpha", + alphas, +) +def test_fgw_barycenter(alpha): + """Tests the FUGW barycenter in the case rho=inf and compare with POT.""" + torch.manual_seed(0) + n_subjects = 3 + n_features = 1 + n_voxels = 5 + nits_barycenter = 2 + + geometry = _init_mock_distribution( + n_features, n_voxels, should_normalize=True + )[2] + geometry_list = [geometry for _ in range(n_subjects)] + weights_list = [torch.ones(n_voxels) / n_voxels] * n_subjects + features_list = [torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])] * n_subjects + + fugw_barycenter = FUGWBarycenter( + alpha=alpha, + rho=float("inf"), + eps=1e-6, + ) + + fugw_barycenter = FUGWBarycenter(alpha=alpha, eps=1e-6, rho=float("inf")) + ( + fugw_bary_weights, + fugw_bary_features, + fugw_bary_geometry, + _, + _, + _, + ) = fugw_barycenter.fit( + weights_list, + features_list, + geometry_list, + solver_params={"nits_bcd": 5, "nits_uot": 100}, + nits_barycenter=nits_barycenter, + device=torch.device("cpu"), + init_barycenter_geometry=geometry_list[0], + init_barycenter_features=features_list[0], + ) + + # Compare the barycenter with the one obtained with POT + pot_bary_features, pot_bary_geometry, log = ot.gromov.fgw_barycenters( + n_voxels, + [features.T for features in features_list], + geometry_list, + weights_list, + alpha=1 - alpha, + log=True, + fixed_structure=True, + init_C=geometry_list[0], + ) + + assert torch.allclose(fugw_bary_geometry, pot_bary_geometry) + assert torch.allclose(fugw_bary_weights, log["p"]) + + if alpha != 1.0: + assert torch.allclose(fugw_bary_features, pot_bary_features.T)