Skip to content

Commit

Permalink
Add comparison to POT
Browse files Browse the repository at this point in the history
  • Loading branch information
Pierre-Louis Barbarant committed Oct 12, 2024
1 parent 2bc7c75 commit 585f7d6
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions tests/mappings/test_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest
import torch
import ot

from fugw.mappings import FUGWBarycenter
from fugw.utils import _init_mock_distribution
Expand Down Expand Up @@ -121,3 +122,66 @@ def test_identity_case(alpha):
# by the number of voxels
for plan in plans:
assert torch.allclose(plan, torch.eye(n_voxels) / n_voxels)


@pytest.mark.parametrize(
"alpha",
alphas,
)
def test_fgw_barycenter(alpha):
"""Tests the FUGW barycenter in the case rho=inf and compare with POT."""
torch.manual_seed(0)
n_subjects = 3
n_features = 1
n_voxels = 5
nits_barycenter = 2

geometry = _init_mock_distribution(
n_features, n_voxels, should_normalize=True
)[2]
geometry_list = [geometry for _ in range(n_subjects)]
weights_list = [torch.ones(n_voxels) / n_voxels] * n_subjects
features_list = [torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])] * n_subjects

fugw_barycenter = FUGWBarycenter(
alpha=alpha,
rho=float("inf"),
eps=1e-6,
)

fugw_barycenter = FUGWBarycenter(alpha=alpha, eps=1e-6, rho=float("inf"))
(
fugw_bary_weights,
fugw_bary_features,
fugw_bary_geometry,
_,
_,
_,
) = fugw_barycenter.fit(
weights_list,
features_list,
geometry_list,
solver_params={"nits_bcd": 5, "nits_uot": 100},
nits_barycenter=nits_barycenter,
device=torch.device("cpu"),
init_barycenter_geometry=geometry_list[0],
init_barycenter_features=features_list[0],
)

# Compare the barycenter with the one obtained with POT
pot_bary_features, pot_bary_geometry, log = ot.gromov.fgw_barycenters(
n_voxels,
[features.T for features in features_list],
geometry_list,
weights_list,
alpha=1 - alpha,
log=True,
fixed_structure=True,
init_C=geometry_list[0],
)

assert torch.allclose(fugw_bary_geometry, pot_bary_geometry)
assert torch.allclose(fugw_bary_weights, log["p"])

if alpha != 1.0:
assert torch.allclose(fugw_bary_features, pot_bary_features.T)

0 comments on commit 585f7d6

Please sign in to comment.