Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add coarse-to-fine barycenter estimation #45

Merged
merged 53 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
e4d16fe
Add sparse barycenter script
pbarbarant Apr 17, 2024
c563c77
Add FUGWSparseBarycenter class to mappings/__init__.py
pbarbarant Apr 17, 2024
3695757
Update fit function in coarse_to_fine.py to include init_plan parameter
pbarbarant Apr 17, 2024
853141f
Remove callback in fit fn
pbarbarant Apr 17, 2024
4fa8542
Fix division error with sparse vectors in sparse barycenter
pbarbarant Apr 17, 2024
b137c47
Add test for FUGWSparseBarycenter in test_sparse_barycenter.py
pbarbarant Apr 17, 2024
b98eb10
Add sparse barycenter script
pbarbarant Apr 17, 2024
1e32dc6
Add FUGWSparseBarycenter class to mappings/__init__.py
pbarbarant Apr 17, 2024
df47bd6
Update fit function in coarse_to_fine.py to include init_plan parameter
pbarbarant Apr 17, 2024
23f23d2
Remove callback in fit fn
pbarbarant Apr 17, 2024
d854efd
Fix division error with sparse vectors in sparse barycenter
pbarbarant Apr 17, 2024
1b05384
Add test for FUGWSparseBarycenter in test_sparse_barycenter.py
pbarbarant Apr 17, 2024
952920e
Merge branch 'feat_coarse_to_fine_bary' of github.com:pbarbarant/fugw…
pbarbarant Apr 18, 2024
4f4d7dc
Update fit function in coarse_to_fine.py to include mask parameter
pbarbarant Apr 19, 2024
6fc8006
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant Apr 19, 2024
a07507a
Update fit function in coarse_to_fine.py to optimize mask computation
pbarbarant Apr 19, 2024
6ee4535
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant May 2, 2024
1e30527
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant May 2, 2024
b1e0af9
Update test_fugw_barycenter function in test_sparse_barycenter.py
pbarbarant May 2, 2024
005a6d9
Add sparse barycenter script
pbarbarant Apr 17, 2024
77a7d85
Add FUGWSparseBarycenter class to mappings/__init__.py
pbarbarant Apr 17, 2024
3c60c64
Update fit function in coarse_to_fine.py to include init_plan parameter
pbarbarant Apr 17, 2024
513bfb5
Remove callback in fit fn
pbarbarant Apr 17, 2024
5257488
Fix division error with sparse vectors in sparse barycenter
pbarbarant Apr 17, 2024
ecacff0
Add test for FUGWSparseBarycenter in test_sparse_barycenter.py
pbarbarant Apr 17, 2024
b7fe9db
Update fit function in coarse_to_fine.py to include mask parameter
pbarbarant Apr 19, 2024
a371bcd
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant Apr 19, 2024
14c1027
Update fit function in coarse_to_fine.py to optimize mask computation
pbarbarant Apr 19, 2024
692905e
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant May 2, 2024
1c78b9d
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant May 2, 2024
d9e5359
Update test_fugw_barycenter function in test_sparse_barycenter.py
pbarbarant May 2, 2024
160b787
Merge branch 'alexisthual:main' into feat_coarse_to_fine_bary
pbarbarant May 6, 2024
748df35
Merge branch 'feat_coarse_to_fine_bary' of github.com:pbarbarant/fugw…
pbarbarant May 6, 2024
ac7fa1f
Fix failing mac os mkl test
pbarbarant May 6, 2024
e5c8f7f
Update fit function in coarse_to_fine.py to include mask parameter
pbarbarant May 6, 2024
b770dd3
Update test_coarse_to_fine to take into account mask parameter
pbarbarant May 6, 2024
70badb1
Update fit function in coarse_to_fine.py to remove storing_device par…
pbarbarant May 6, 2024
2dfd979
Update storing_device assignment in FUGWSparse class
pbarbarant May 7, 2024
0b6bf86
Update fit function in coarse_to_fine.py to include storing_device pa…
pbarbarant May 7, 2024
a67237a
Update test_fugw_barycenter function in test_sparse_barycenter.py
pbarbarant May 7, 2024
fd2d8d6
Refactor FUGWSparseBarycenter class to remove unused variables
pbarbarant May 7, 2024
c346ef0
Update pytest configuration to ignore sparse CSR tensor warning
pbarbarant May 7, 2024
ee557b3
Update test_sparse_barycenter.py with assertions
pbarbarant May 7, 2024
9297776
Refactor test_sparse_barycenter.py with assertions and type checks
pbarbarant May 7, 2024
6015c8e
Refactor FUGWSparseBarycenter class to handle NaN values in fine plan
pbarbarant May 22, 2024
038fec6
Add nan val check for barycenter features
pbarbarant May 22, 2024
4809deb
Add nan check in dense barycenters
pbarbarant May 22, 2024
f84926b
Refactor test_sparse_barycenter.py to include nan check for barycente…
pbarbarant May 22, 2024
5249d8f
Refactor FUGWBarycenter class to remove unused variables
pbarbarant May 22, 2024
0929035
Add small reg to avoid zero-div
pbarbarant May 22, 2024
470822f
Add reg to dense barycenters to avoid NaNs
pbarbarant May 22, 2024
ac75b42
Fix docstring indentation
pbarbarant May 23, 2024
ef00241
Fix docstrings and flake8
May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/fugw/mappings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .barycenter import FUGWBarycenter
from .sparse_barycenter import FUGWSparseBarycenter
from .dense import FUGW
from .sparse import FUGWSparse

__all__ = ["FUGW", "FUGWBarycenter", "FUGWSparse"]
__all__ = ["FUGW", "FUGWBarycenter", "FUGWSparse", "FUGWSparseBarycenter"]
285 changes: 285 additions & 0 deletions src/fugw/mappings/sparse_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
import torch

from fugw.mappings.dense import FUGW
from fugw.mappings.sparse import FUGWSparse
from fugw.scripts import coarse_to_fine
from fugw.utils import _make_tensor


class FUGWSparseBarycenter:
"""FUGW sparse barycenters"""

def __init__(
self,
alpha_coarse=0.5,
alpha_fine=0.5,
rho_coarse=1,
rho_fine=1e-2,
eps_coarse=1e-4,
eps_fine=1e-4,
selection_radius=0.1,
reg_mode="joint",
force_psd=False,
learn_geometry=False,
):
# Save model arguments
self.alpha_coarse = alpha_coarse
self.alpha_fine = alpha_fine
self.rho_coarse = rho_coarse
self.rho_fine = rho_fine
self.eps_coarse = eps_coarse
self.eps_fine = eps_fine
self.reg_mode = reg_mode
self.force_psd = force_psd
self.learn_geometry = learn_geometry
self.selection_radius = selection_radius

@staticmethod
def update_barycenter_features(plans, weights_list, features_list, device):
barycenter_features = 0
for i, (pi, weights, features) in enumerate(
zip(plans, weights_list, features_list)
):
w = _make_tensor(weights, device=device)
f = _make_tensor(features, device=device)

if features is not None:
pi_sum = torch.sparse.sum(pi, dim=0).to_dense()
acc = w * pi.T @ f.T / pi_sum.unsqueeze(1)

if i == 0:
barycenter_features = acc
else:
barycenter_features += acc

return barycenter_features.T

@staticmethod
def get_dim(C):
if isinstance(C, tuple):
return C[0].shape[0]
elif torch.is_tensor(C):
return C.shape[0]

@staticmethod
def get_device_dtype(C):
if isinstance(C, tuple):
return C[0].device, C[0].dtype
elif torch.is_tensor(C):
return C.device, C.dtype

def compute_all_ot_plans(
self,
plans,
weights_list,
features_list,
geometry_list,
barycenter_weights,
barycenter_features,
barycenter_geometry_embedding,
mesh_sample,
solver,
coarse_mapping_solver_params,
fine_mapping_solver_params,
selection_radius,
mask,
device,
verbose,
):
new_plans = []
new_losses = []

for i, (features, weights) in enumerate(
zip(features_list, weights_list)
):
if len(geometry_list) == 1 and len(weights_list) > 1:
G = geometry_list[0]
else:
G = geometry_list[i]

coarse_mapping = FUGW(
alpha=self.alpha_coarse,
rho=self.rho_coarse,
eps=self.eps_coarse,
reg_mode=self.reg_mode,
)

fine_mapping = FUGWSparse(
alpha=self.alpha_fine,
rho=self.rho_fine,
eps=self.eps_fine,
reg_mode=self.reg_mode,
)

_, _, mask = coarse_to_fine.fit(
source_features=features,
target_features=barycenter_features,
source_geometry_embeddings=G,
target_geometry_embeddings=barycenter_geometry_embedding,
source_sample=mesh_sample,
target_sample=mesh_sample,
Comment on lines +121 to +122
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be wrong, but I feel that you might need different values for mesh_sample across individuals if their data lie on different meshes. The barycenter might need its own mesh_sample as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are definitely right. However we agreed with @bthirion to keep the same geometry/mesh as for now. Moreover I'm afraid of the large memory/speed slowdown that we could get if we individualize the geometry of each individual. Maybe a NotImplementedError would be the best fit ? I'll open an issue...

coarse_mapping=coarse_mapping,
source_weights=weights,
target_weights=barycenter_weights,
coarse_mapping_solver=solver,
coarse_mapping_solver_params=coarse_mapping_solver_params,
coarse_pairs_selection_method="topk",
source_selection_radius=selection_radius,
target_selection_radius=selection_radius,
fine_mapping=fine_mapping,
fine_mapping_solver=solver,
fine_mapping_solver_params=fine_mapping_solver_params,
init_plan=plans[i] if plans is not None else None,
mask=mask,
device=device,
verbose=verbose,
)

new_plans.append(fine_mapping.pi)
new_losses.append(
(
fine_mapping.loss,
fine_mapping.loss_steps,
fine_mapping.loss_times,
)
)

return new_plans, new_losses

def fit(
self,
weights_list,
features_list,
geometry_list,
barycenter_size=None,
init_barycenter_weights=None,
init_barycenter_features=None,
init_barycenter_geometry=None,
solver="sinkhorn",
coarse_mapping_solver_params={},
fine_mapping_solver_params={},
mesh_sample=None,
nits_barycenter=5,
device="auto",
verbose=False,
):
"""Compute barycentric features and geometry
minimizing FUGW loss to list of distributions given as input.
In this documentation, we refer to a single distribution as
an a subject's or an individual's distribution.

Parameters
----------
weights_list (list of np.array): List of weights. Different individuals
can have weights with different sizes.
features_list (list of np.array): List of features. Individuals should
have the same number of features n_features.
geometry_list (list of np.array or np.array): List of kernel matrices
or just one kernel matrix if it's shared across individuals
barycenter_size (int, optional): Size of computed
pbarbarant marked this conversation as resolved.
Show resolved Hide resolved
barycentric features and geometry. Defaults to None.
init_barycenter_weights (np.array, optional): Distribution weights
of barycentric points. If None, points will have uniform
weights. Defaults to None.
mesh_sample (np.array, optional): Sample points on which to compute
the barycenter. Defaults to None.
init_barycenter_features (np.array, optional): np.array of size
(barycenter_size, n_features). Defaults to None.
init_barycenter_geometry (np.array, optional): np.array of size
(barycenter_size, barycenter_size). Defaults to None.
device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.

Returns
-------
barycenter_weights: np.array of size (barycenter_size)
barycenter_features: np.array of size (barycenter_size, n_features)
barycenter_geometry: np.array of size
(barycenter_size, barycenter_size)
plans: list of arrays
duals: list of (array, array)
losses_each_bar_step: list such that l[s][i]
is a tuple containing:
- loss
- loss_steps
- loss_times
for individual i at barycenter computation step s
"""
if device == "auto":
if torch.cuda.is_available():
device = torch.device("cuda", 0)
else:
device = torch.device("cpu")

if barycenter_size is None:
barycenter_size = weights_list[0].shape[0]

# Initialize barycenter weights, features and geometry
if init_barycenter_weights is None:
barycenter_weights = (
torch.ones(barycenter_size) / barycenter_size
).to(device)
else:
barycenter_weights = _make_tensor(
init_barycenter_weights, device=device
)

if init_barycenter_features is None:
barycenter_features = torch.ones(
(features_list[0].shape[0], barycenter_size)
).to(device)
barycenter_features = barycenter_features / torch.norm(
barycenter_features, dim=1
).reshape(-1, 1)
else:
barycenter_features = _make_tensor(
init_barycenter_features, device=device
)

if init_barycenter_geometry is None:
barycenter_geometry_embedding = geometry_list[0]
else:
barycenter_geometry_embedding = _make_tensor(
init_barycenter_geometry, device=device
)

plans = None
duals = None
mask = None
losses_each_bar_step = []

for _ in range(nits_barycenter):
# Transport all elements
plans, losses = self.compute_all_ot_plans(
plans,
weights_list,
features_list,
geometry_list,
barycenter_weights,
barycenter_features,
barycenter_geometry_embedding,
mesh_sample,
solver,
coarse_mapping_solver_params,
fine_mapping_solver_params,
self.selection_radius,
mask,
device,
verbose,
)

losses_each_bar_step.append(losses)

# Update barycenter features and geometry
barycenter_features = self.update_barycenter_features(
plans, weights_list, features_list, device
)

return (
barycenter_weights,
barycenter_features,
plans,
duals,
losses_each_bar_step,
)
57 changes: 34 additions & 23 deletions src/fugw/scripts/coarse_to_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ def fit(
target_geometry_embeddings=None,
source_weights=None,
target_weights=None,
init_plan=None,
mask=None,
device="auto",
verbose=False,
):
Expand Down Expand Up @@ -348,6 +350,12 @@ def fit(
Distribution weights of target nodes.
Should sum to 1. If None, each node's weight
will be set to 1 / m.
init_plan: torch.sparse_coo_tensor or None
Initial transport plan to use when fitting the fine mapping.
If None, a random plan will be used.
mask: torch.Tensor or None
Sparsity mask to use when fitting the fine mapping.
If None, a mask will be computed from the coarse mapping.
Comment on lines +356 to +358
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to ignore this comment, but would rather name this variable sparsity_mask so that it's not confused with a nilearn masker.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point

device: "auto" or torch.device
if "auto": use first available gpu if it's available,
cpu otherwise.
Expand Down Expand Up @@ -446,31 +454,33 @@ def fit(
]
)

# Compute mask as a matrix product between:
# a. neighbourhood matrices that encode
# which vertex is close to which sampled point
N_source = get_neighbourhood_matrix(
source_geometry_embeddings, source_sample, source_selection_radius
)
N_target = get_neighbourhood_matrix(
target_geometry_embeddings, target_sample, target_selection_radius
)
# b. cluster matrices that encode
# which sampled point belongs to which cluster
C_source = get_cluster_matrix(rows, source_sample.shape[0])
C_target = get_cluster_matrix(cols, target_sample.shape[0])
if mask is None:
# Compute mask as a matrix product between:
# a. neighbourhood matrices that encode
# which vertex is close to which sampled point
N_source = get_neighbourhood_matrix(
source_geometry_embeddings, source_sample, source_selection_radius
)
N_target = get_neighbourhood_matrix(
target_geometry_embeddings, target_sample, target_selection_radius
)
# b. cluster matrices that encode
# which sampled point belongs to which cluster
C_source = get_cluster_matrix(rows, source_sample.shape[0])
C_target = get_cluster_matrix(cols, target_sample.shape[0])

mask = (N_source @ C_source) @ (N_target @ C_target).T
mask = (N_source @ C_source) @ (N_target @ C_target).T

# Define init plan from spasity mask
init_plan = torch.sparse_coo_tensor(
mask.indices(),
torch.ones_like(mask.values()) / mask.values().shape[0],
(
source_geometry_embeddings.shape[0],
target_geometry_embeddings.shape[0],
),
).coalesce()
if init_plan is None:
init_plan = torch.sparse_coo_tensor(
mask.indices(),
torch.ones_like(mask.values()) / mask.values().shape[0],
(
source_geometry_embeddings.shape[0],
target_geometry_embeddings.shape[0],
),
).coalesce()

# 3. Fit fine-grained mapping
fine_mapping.fit(
Expand All @@ -486,6 +496,7 @@ def fit(
solver=fine_mapping_solver,
solver_params=fine_mapping_solver_params,
callback_bcd=fine_callback_bcd,
storing_device=device,
)

return source_sample, target_sample
return source_sample, target_sample, mask
Loading
Loading