Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] Fix dense barycenter calculations #79

Merged
merged 12 commits into from
Oct 16, 2024
30 changes: 13 additions & 17 deletions src/fugw/mappings/barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
self,
alpha=0.5,
rho=1,
eps=1e-2,
eps=1e-4,
reg_mode="joint",
force_psd=False,
learn_geometry=False,
Expand Down Expand Up @@ -79,26 +79,19 @@ def update_barycenter_geometry(
return barycenter_geometry

@staticmethod
def update_barycenter_features(plans, weights_list, features_list, device):
for i, (pi, weights, features) in enumerate(
zip(plans, weights_list, features_list)
):
w = _make_tensor(weights, device=device)
def update_barycenter_features(plans, features_list, device):
for i, (pi, features) in enumerate(zip(plans, features_list)):
# Use uniform weights across subjects
weight = 1 / len(features_list)
f = _make_tensor(features, device=device)
if features is not None:
acc = w * pi.T @ f.T / (pi.sum(0).reshape(-1, 1) + 1e-16)
acc = weight * pi.T @ f.T / (pi.sum(0).reshape(-1, 1) + 1e-16)

if i == 0:
barycenter_features = acc
else:
barycenter_features += acc

# Normalize barycenter features
min_val = barycenter_features.min(dim=0, keepdim=True).values
max_val = barycenter_features.max(dim=0, keepdim=True).values
barycenter_features = (
2 * (barycenter_features - min_val) / (max_val - min_val) - 1
)
return barycenter_features.T

@staticmethod
Expand Down Expand Up @@ -127,7 +120,6 @@ def compute_all_ot_plans(
barycenter_geometry,
solver,
solver_params,
callback_barycenter,
device,
verbose,
):
Expand Down Expand Up @@ -271,7 +263,12 @@ def fit(
init_barycenter_features, device=device
)

if init_barycenter_geometry is None:
if init_barycenter_geometry is None and self.learn_geometry is False:
raise ValueError(
"In the fixed support case, init_barycenter_geometry must be"
" provided."
)
elif init_barycenter_geometry is None and self.learn_geometry is True:
barycenter_geometry = (
torch.ones((barycenter_size, barycenter_size)).to(device)
/ barycenter_size
Expand Down Expand Up @@ -302,7 +299,6 @@ def fit(
barycenter_geometry,
solver,
solver_params,
callback_barycenter,
device,
verbose,
)
Expand All @@ -311,7 +307,7 @@ def fit(

# Update barycenter features and geometry
barycenter_features = self.update_barycenter_features(
plans, weights_list, features_list, device
plans, features_list, device
)
if self.learn_geometry:
barycenter_geometry = self.update_barycenter_geometry(
Expand Down
13 changes: 5 additions & 8 deletions src/fugw/mappings/sparse_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,16 @@ def __init__(
self.selection_radius = selection_radius

@staticmethod
def update_barycenter_features(plans, weights_list, features_list, device):
for i, (pi, weights, features) in enumerate(
zip(plans, weights_list, features_list)
):
w = _make_tensor(weights, device=device)
def update_barycenter_features(plans, features_list, device):
for i, (pi, features) in enumerate(zip(plans, features_list)):
f = _make_tensor(features, device=device)

weight = 1 / len(features_list)
if features is not None:
pi_sum = (
torch.sparse.sum(pi, dim=0).to_dense().reshape(-1, 1)
+ 1e-16
)
acc = w * pi.T @ f.T / pi_sum
acc = weight * pi.T @ f.T / pi_sum

if i == 0:
barycenter_features = acc
Expand Down Expand Up @@ -291,7 +288,7 @@ def fit(

# Update barycenter features and geometry
barycenter_features = self.update_barycenter_features(
plans, weights_list, features_list, device
plans, features_list, device
)

if callback_barycenter is not None:
Expand Down
118 changes: 118 additions & 0 deletions tests/mappings/test_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pytest
import torch
import ot

from fugw.mappings import FUGWBarycenter
from fugw.utils import _init_mock_distribution
Expand All @@ -12,6 +13,7 @@
devices.append(torch.device("cuda:0"))

callbacks = [None, lambda x: x["plans"]]
alphas = [0.0, 0.5, 1.0]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -57,6 +59,7 @@ def test_fugw_barycenter(device, callback):
nits_barycenter=nits_barycenter,
device=device,
callback_barycenter=callback,
init_barycenter_geometry=geometry_list[0],
)

assert isinstance(barycenter_weights, torch.Tensor)
Expand All @@ -67,3 +70,118 @@ def test_fugw_barycenter(device, callback):
assert barycenter_geometry.shape == (n_voxels, n_voxels)
assert len(plans) == n_subjects
assert len(losses_each_bar_step) == nits_barycenter


@pytest.mark.parametrize(
"alpha",
alphas,
)
def test_identity_case(alpha):
"""Test the case where all subjects are the same."""
torch.manual_seed(0)
n_subjects = 3
n_features = 10
n_voxels = 100
nits_barycenter = 2

geometry = _init_mock_distribution(n_features, n_voxels)[2]
features = torch.rand(n_features, n_voxels)

geometry_list = [geometry for _ in range(n_subjects)]
features_list = [features for _ in range(n_subjects)]
weights_list = [torch.ones(n_voxels) / n_voxels for _ in range(n_subjects)]

fugw_barycenter = FUGWBarycenter(alpha=alpha, eps=1e-6, rho=float("inf"))
(
barycenter_weights,
barycenter_features,
barycenter_geometry,
plans,
_,
_,
) = fugw_barycenter.fit(
weights_list,
features_list,
geometry_list,
solver_params={"nits_bcd": 5, "nits_uot": 100},
nits_barycenter=nits_barycenter,
device=torch.device("cpu"),
init_barycenter_geometry=geometry_list[0],
init_barycenter_features=features_list[0],
)

# Check that the barycenter is the same as the input
assert torch.allclose(barycenter_weights, torch.ones(n_voxels) / n_voxels)
assert torch.allclose(barycenter_geometry, geometry_list[0])

# In the case alpha=1.0, the features can be permuted
# since the GW distance is invariant under isometries
if alpha != 1.0:
assert torch.allclose(barycenter_features, features)
# Check that all the plans are the identity matrix divided
# by the number of voxels
for plan in plans:
assert torch.allclose(plan, torch.eye(n_voxels) / n_voxels)


@pytest.mark.parametrize(
"alpha",
alphas,
)
def test_fgw_barycenter(alpha):
"""Tests the FUGW barycenter in the case rho=inf and compare with POT."""
torch.manual_seed(0)
n_subjects = 3
n_features = 1
n_voxels = 5
nits_barycenter = 2

geometry = _init_mock_distribution(
n_features, n_voxels, should_normalize=True
)[2]
geometry_list = [geometry for _ in range(n_subjects)]
weights_list = [torch.ones(n_voxels) / n_voxels] * n_subjects
features_list = [torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])] * n_subjects

fugw_barycenter = FUGWBarycenter(
alpha=alpha,
rho=float("inf"),
eps=1e-6,
)

fugw_barycenter = FUGWBarycenter(alpha=alpha, eps=1e-6, rho=float("inf"))
(
fugw_bary_weights,
fugw_bary_features,
fugw_bary_geometry,
_,
_,
_,
) = fugw_barycenter.fit(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AN old traidiotn of sklearn is that fit() return the object itself.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to merge this PR first and open another one to fix this.

weights_list,
features_list,
geometry_list,
solver_params={"nits_bcd": 5, "nits_uot": 100},
nits_barycenter=nits_barycenter,
device=torch.device("cpu"),
init_barycenter_geometry=geometry_list[0],
init_barycenter_features=features_list[0],
)

# Compare the barycenter with the one obtained with POT
pot_bary_features, pot_bary_geometry, log = ot.gromov.fgw_barycenters(
n_voxels,
[features.T for features in features_list],
geometry_list,
weights_list,
alpha=1 - alpha,
log=True,
fixed_structure=True,
init_C=geometry_list[0],
)

assert torch.allclose(fugw_bary_geometry, pot_bary_geometry)
assert torch.allclose(fugw_bary_weights, log["p"])

if alpha != 1.0:
assert torch.allclose(fugw_bary_features, pot_bary_features.T)