-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEAT] Add coarse-to-fine barycenter estimation #45
Merged
pbarbarant
merged 53 commits into
alexisthual:main
from
pbarbarant:feat_coarse_to_fine_bary
May 27, 2024
Merged
Changes from 16 commits
Commits
Show all changes
53 commits
Select commit
Hold shift + click to select a range
e4d16fe
Add sparse barycenter script
pbarbarant c563c77
Add FUGWSparseBarycenter class to mappings/__init__.py
pbarbarant 3695757
Update fit function in coarse_to_fine.py to include init_plan parameter
pbarbarant 853141f
Remove callback in fit fn
pbarbarant 4fa8542
Fix division error with sparse vectors in sparse barycenter
pbarbarant b137c47
Add test for FUGWSparseBarycenter in test_sparse_barycenter.py
pbarbarant b98eb10
Add sparse barycenter script
pbarbarant 1e32dc6
Add FUGWSparseBarycenter class to mappings/__init__.py
pbarbarant df47bd6
Update fit function in coarse_to_fine.py to include init_plan parameter
pbarbarant 23f23d2
Remove callback in fit fn
pbarbarant d854efd
Fix division error with sparse vectors in sparse barycenter
pbarbarant 1b05384
Add test for FUGWSparseBarycenter in test_sparse_barycenter.py
pbarbarant 952920e
Merge branch 'feat_coarse_to_fine_bary' of github.com:pbarbarant/fugw…
pbarbarant 4f4d7dc
Update fit function in coarse_to_fine.py to include mask parameter
pbarbarant 6fc8006
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant a07507a
Update fit function in coarse_to_fine.py to optimize mask computation
pbarbarant 6ee4535
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant 1e30527
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant b1e0af9
Update test_fugw_barycenter function in test_sparse_barycenter.py
pbarbarant 005a6d9
Add sparse barycenter script
pbarbarant 77a7d85
Add FUGWSparseBarycenter class to mappings/__init__.py
pbarbarant 3c60c64
Update fit function in coarse_to_fine.py to include init_plan parameter
pbarbarant 513bfb5
Remove callback in fit fn
pbarbarant 5257488
Fix division error with sparse vectors in sparse barycenter
pbarbarant ecacff0
Add test for FUGWSparseBarycenter in test_sparse_barycenter.py
pbarbarant b7fe9db
Update fit function in coarse_to_fine.py to include mask parameter
pbarbarant a371bcd
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant 14c1027
Update fit function in coarse_to_fine.py to optimize mask computation
pbarbarant 692905e
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant 1c78b9d
Update selection_radius parameter in FUGWSparseBarycenter class
pbarbarant d9e5359
Update test_fugw_barycenter function in test_sparse_barycenter.py
pbarbarant 160b787
Merge branch 'alexisthual:main' into feat_coarse_to_fine_bary
pbarbarant 748df35
Merge branch 'feat_coarse_to_fine_bary' of github.com:pbarbarant/fugw…
pbarbarant ac7fa1f
Fix failing mac os mkl test
pbarbarant e5c8f7f
Update fit function in coarse_to_fine.py to include mask parameter
pbarbarant b770dd3
Update test_coarse_to_fine to take into account mask parameter
pbarbarant 70badb1
Update fit function in coarse_to_fine.py to remove storing_device par…
pbarbarant 2dfd979
Update storing_device assignment in FUGWSparse class
pbarbarant 0b6bf86
Update fit function in coarse_to_fine.py to include storing_device pa…
pbarbarant a67237a
Update test_fugw_barycenter function in test_sparse_barycenter.py
pbarbarant fd2d8d6
Refactor FUGWSparseBarycenter class to remove unused variables
pbarbarant c346ef0
Update pytest configuration to ignore sparse CSR tensor warning
pbarbarant ee557b3
Update test_sparse_barycenter.py with assertions
pbarbarant 9297776
Refactor test_sparse_barycenter.py with assertions and type checks
pbarbarant 6015c8e
Refactor FUGWSparseBarycenter class to handle NaN values in fine plan
pbarbarant 038fec6
Add nan val check for barycenter features
pbarbarant 4809deb
Add nan check in dense barycenters
pbarbarant f84926b
Refactor test_sparse_barycenter.py to include nan check for barycente…
pbarbarant 5249d8f
Refactor FUGWBarycenter class to remove unused variables
pbarbarant 0929035
Add small reg to avoid zero-div
pbarbarant 470822f
Add reg to dense barycenters to avoid NaNs
pbarbarant ac75b42
Fix docstring indentation
pbarbarant ef00241
Fix docstrings and flake8
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .barycenter import FUGWBarycenter | ||
from .sparse_barycenter import FUGWSparseBarycenter | ||
from .dense import FUGW | ||
from .sparse import FUGWSparse | ||
|
||
__all__ = ["FUGW", "FUGWBarycenter", "FUGWSparse"] | ||
__all__ = ["FUGW", "FUGWBarycenter", "FUGWSparse", "FUGWSparseBarycenter"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,285 @@ | ||
import torch | ||
|
||
from fugw.mappings.dense import FUGW | ||
from fugw.mappings.sparse import FUGWSparse | ||
from fugw.scripts import coarse_to_fine | ||
from fugw.utils import _make_tensor | ||
|
||
|
||
class FUGWSparseBarycenter: | ||
"""FUGW sparse barycenters""" | ||
|
||
def __init__( | ||
self, | ||
alpha_coarse=0.5, | ||
alpha_fine=0.5, | ||
rho_coarse=1, | ||
rho_fine=1e-2, | ||
eps_coarse=1e-4, | ||
eps_fine=1e-4, | ||
selection_radius=0.1, | ||
reg_mode="joint", | ||
force_psd=False, | ||
learn_geometry=False, | ||
): | ||
# Save model arguments | ||
self.alpha_coarse = alpha_coarse | ||
self.alpha_fine = alpha_fine | ||
self.rho_coarse = rho_coarse | ||
self.rho_fine = rho_fine | ||
self.eps_coarse = eps_coarse | ||
self.eps_fine = eps_fine | ||
self.reg_mode = reg_mode | ||
self.force_psd = force_psd | ||
self.learn_geometry = learn_geometry | ||
self.selection_radius = selection_radius | ||
|
||
@staticmethod | ||
def update_barycenter_features(plans, weights_list, features_list, device): | ||
barycenter_features = 0 | ||
for i, (pi, weights, features) in enumerate( | ||
zip(plans, weights_list, features_list) | ||
): | ||
w = _make_tensor(weights, device=device) | ||
f = _make_tensor(features, device=device) | ||
|
||
if features is not None: | ||
pi_sum = torch.sparse.sum(pi, dim=0).to_dense() | ||
acc = w * pi.T @ f.T / pi_sum.unsqueeze(1) | ||
|
||
if i == 0: | ||
barycenter_features = acc | ||
else: | ||
barycenter_features += acc | ||
|
||
return barycenter_features.T | ||
|
||
@staticmethod | ||
def get_dim(C): | ||
if isinstance(C, tuple): | ||
return C[0].shape[0] | ||
elif torch.is_tensor(C): | ||
return C.shape[0] | ||
|
||
@staticmethod | ||
def get_device_dtype(C): | ||
if isinstance(C, tuple): | ||
return C[0].device, C[0].dtype | ||
elif torch.is_tensor(C): | ||
return C.device, C.dtype | ||
|
||
def compute_all_ot_plans( | ||
self, | ||
plans, | ||
weights_list, | ||
features_list, | ||
geometry_list, | ||
barycenter_weights, | ||
barycenter_features, | ||
barycenter_geometry_embedding, | ||
mesh_sample, | ||
solver, | ||
coarse_mapping_solver_params, | ||
fine_mapping_solver_params, | ||
selection_radius, | ||
mask, | ||
device, | ||
verbose, | ||
): | ||
new_plans = [] | ||
new_losses = [] | ||
|
||
for i, (features, weights) in enumerate( | ||
zip(features_list, weights_list) | ||
): | ||
if len(geometry_list) == 1 and len(weights_list) > 1: | ||
G = geometry_list[0] | ||
else: | ||
G = geometry_list[i] | ||
|
||
coarse_mapping = FUGW( | ||
alpha=self.alpha_coarse, | ||
rho=self.rho_coarse, | ||
eps=self.eps_coarse, | ||
reg_mode=self.reg_mode, | ||
) | ||
|
||
fine_mapping = FUGWSparse( | ||
alpha=self.alpha_fine, | ||
rho=self.rho_fine, | ||
eps=self.eps_fine, | ||
reg_mode=self.reg_mode, | ||
) | ||
|
||
_, _, mask = coarse_to_fine.fit( | ||
source_features=features, | ||
target_features=barycenter_features, | ||
source_geometry_embeddings=G, | ||
target_geometry_embeddings=barycenter_geometry_embedding, | ||
source_sample=mesh_sample, | ||
target_sample=mesh_sample, | ||
coarse_mapping=coarse_mapping, | ||
source_weights=weights, | ||
target_weights=barycenter_weights, | ||
coarse_mapping_solver=solver, | ||
coarse_mapping_solver_params=coarse_mapping_solver_params, | ||
coarse_pairs_selection_method="topk", | ||
source_selection_radius=selection_radius, | ||
target_selection_radius=selection_radius, | ||
fine_mapping=fine_mapping, | ||
fine_mapping_solver=solver, | ||
fine_mapping_solver_params=fine_mapping_solver_params, | ||
init_plan=plans[i] if plans is not None else None, | ||
mask=mask, | ||
device=device, | ||
verbose=verbose, | ||
) | ||
|
||
new_plans.append(fine_mapping.pi) | ||
new_losses.append( | ||
( | ||
fine_mapping.loss, | ||
fine_mapping.loss_steps, | ||
fine_mapping.loss_times, | ||
) | ||
) | ||
|
||
return new_plans, new_losses | ||
|
||
def fit( | ||
self, | ||
weights_list, | ||
features_list, | ||
geometry_list, | ||
barycenter_size=None, | ||
init_barycenter_weights=None, | ||
init_barycenter_features=None, | ||
init_barycenter_geometry=None, | ||
solver="sinkhorn", | ||
coarse_mapping_solver_params={}, | ||
fine_mapping_solver_params={}, | ||
mesh_sample=None, | ||
nits_barycenter=5, | ||
device="auto", | ||
verbose=False, | ||
): | ||
"""Compute barycentric features and geometry | ||
minimizing FUGW loss to list of distributions given as input. | ||
In this documentation, we refer to a single distribution as | ||
an a subject's or an individual's distribution. | ||
|
||
Parameters | ||
---------- | ||
weights_list (list of np.array): List of weights. Different individuals | ||
can have weights with different sizes. | ||
features_list (list of np.array): List of features. Individuals should | ||
have the same number of features n_features. | ||
geometry_list (list of np.array or np.array): List of kernel matrices | ||
or just one kernel matrix if it's shared across individuals | ||
barycenter_size (int, optional): Size of computed | ||
pbarbarant marked this conversation as resolved.
Show resolved
Hide resolved
|
||
barycentric features and geometry. Defaults to None. | ||
init_barycenter_weights (np.array, optional): Distribution weights | ||
of barycentric points. If None, points will have uniform | ||
weights. Defaults to None. | ||
mesh_sample (np.array, optional): Sample points on which to compute | ||
the barycenter. Defaults to None. | ||
init_barycenter_features (np.array, optional): np.array of size | ||
(barycenter_size, n_features). Defaults to None. | ||
init_barycenter_geometry (np.array, optional): np.array of size | ||
(barycenter_size, barycenter_size). Defaults to None. | ||
device: "auto" or torch.device | ||
if "auto": use first available gpu if it's available, | ||
cpu otherwise. | ||
|
||
Returns | ||
------- | ||
barycenter_weights: np.array of size (barycenter_size) | ||
barycenter_features: np.array of size (barycenter_size, n_features) | ||
barycenter_geometry: np.array of size | ||
(barycenter_size, barycenter_size) | ||
plans: list of arrays | ||
duals: list of (array, array) | ||
losses_each_bar_step: list such that l[s][i] | ||
is a tuple containing: | ||
- loss | ||
- loss_steps | ||
- loss_times | ||
for individual i at barycenter computation step s | ||
""" | ||
if device == "auto": | ||
if torch.cuda.is_available(): | ||
device = torch.device("cuda", 0) | ||
else: | ||
device = torch.device("cpu") | ||
|
||
if barycenter_size is None: | ||
barycenter_size = weights_list[0].shape[0] | ||
|
||
# Initialize barycenter weights, features and geometry | ||
if init_barycenter_weights is None: | ||
barycenter_weights = ( | ||
torch.ones(barycenter_size) / barycenter_size | ||
).to(device) | ||
else: | ||
barycenter_weights = _make_tensor( | ||
init_barycenter_weights, device=device | ||
) | ||
|
||
if init_barycenter_features is None: | ||
barycenter_features = torch.ones( | ||
(features_list[0].shape[0], barycenter_size) | ||
).to(device) | ||
barycenter_features = barycenter_features / torch.norm( | ||
barycenter_features, dim=1 | ||
).reshape(-1, 1) | ||
else: | ||
barycenter_features = _make_tensor( | ||
init_barycenter_features, device=device | ||
) | ||
|
||
if init_barycenter_geometry is None: | ||
barycenter_geometry_embedding = geometry_list[0] | ||
else: | ||
barycenter_geometry_embedding = _make_tensor( | ||
init_barycenter_geometry, device=device | ||
) | ||
|
||
plans = None | ||
duals = None | ||
mask = None | ||
losses_each_bar_step = [] | ||
|
||
for _ in range(nits_barycenter): | ||
# Transport all elements | ||
plans, losses = self.compute_all_ot_plans( | ||
plans, | ||
weights_list, | ||
features_list, | ||
geometry_list, | ||
barycenter_weights, | ||
barycenter_features, | ||
barycenter_geometry_embedding, | ||
mesh_sample, | ||
solver, | ||
coarse_mapping_solver_params, | ||
fine_mapping_solver_params, | ||
self.selection_radius, | ||
mask, | ||
device, | ||
verbose, | ||
) | ||
|
||
losses_each_bar_step.append(losses) | ||
|
||
# Update barycenter features and geometry | ||
barycenter_features = self.update_barycenter_features( | ||
plans, weights_list, features_list, device | ||
) | ||
|
||
return ( | ||
barycenter_weights, | ||
barycenter_features, | ||
plans, | ||
duals, | ||
losses_each_bar_step, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -284,6 +284,8 @@ def fit( | |
target_geometry_embeddings=None, | ||
source_weights=None, | ||
target_weights=None, | ||
init_plan=None, | ||
mask=None, | ||
device="auto", | ||
verbose=False, | ||
): | ||
|
@@ -348,6 +350,12 @@ def fit( | |
Distribution weights of target nodes. | ||
Should sum to 1. If None, each node's weight | ||
will be set to 1 / m. | ||
init_plan: torch.sparse_coo_tensor or None | ||
Initial transport plan to use when fitting the fine mapping. | ||
If None, a random plan will be used. | ||
mask: torch.Tensor or None | ||
Sparsity mask to use when fitting the fine mapping. | ||
If None, a mask will be computed from the coarse mapping. | ||
Comment on lines
+356
to
+358
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feel free to ignore this comment, but would rather name this variable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point |
||
device: "auto" or torch.device | ||
if "auto": use first available gpu if it's available, | ||
cpu otherwise. | ||
|
@@ -446,31 +454,33 @@ def fit( | |
] | ||
) | ||
|
||
# Compute mask as a matrix product between: | ||
# a. neighbourhood matrices that encode | ||
# which vertex is close to which sampled point | ||
N_source = get_neighbourhood_matrix( | ||
source_geometry_embeddings, source_sample, source_selection_radius | ||
) | ||
N_target = get_neighbourhood_matrix( | ||
target_geometry_embeddings, target_sample, target_selection_radius | ||
) | ||
# b. cluster matrices that encode | ||
# which sampled point belongs to which cluster | ||
C_source = get_cluster_matrix(rows, source_sample.shape[0]) | ||
C_target = get_cluster_matrix(cols, target_sample.shape[0]) | ||
if mask is None: | ||
# Compute mask as a matrix product between: | ||
# a. neighbourhood matrices that encode | ||
# which vertex is close to which sampled point | ||
N_source = get_neighbourhood_matrix( | ||
source_geometry_embeddings, source_sample, source_selection_radius | ||
) | ||
N_target = get_neighbourhood_matrix( | ||
target_geometry_embeddings, target_sample, target_selection_radius | ||
) | ||
# b. cluster matrices that encode | ||
# which sampled point belongs to which cluster | ||
C_source = get_cluster_matrix(rows, source_sample.shape[0]) | ||
C_target = get_cluster_matrix(cols, target_sample.shape[0]) | ||
|
||
mask = (N_source @ C_source) @ (N_target @ C_target).T | ||
mask = (N_source @ C_source) @ (N_target @ C_target).T | ||
|
||
# Define init plan from spasity mask | ||
init_plan = torch.sparse_coo_tensor( | ||
mask.indices(), | ||
torch.ones_like(mask.values()) / mask.values().shape[0], | ||
( | ||
source_geometry_embeddings.shape[0], | ||
target_geometry_embeddings.shape[0], | ||
), | ||
).coalesce() | ||
if init_plan is None: | ||
init_plan = torch.sparse_coo_tensor( | ||
mask.indices(), | ||
torch.ones_like(mask.values()) / mask.values().shape[0], | ||
( | ||
source_geometry_embeddings.shape[0], | ||
target_geometry_embeddings.shape[0], | ||
), | ||
).coalesce() | ||
|
||
# 3. Fit fine-grained mapping | ||
fine_mapping.fit( | ||
|
@@ -486,6 +496,7 @@ def fit( | |
solver=fine_mapping_solver, | ||
solver_params=fine_mapping_solver_params, | ||
callback_bcd=fine_callback_bcd, | ||
storing_device=device, | ||
) | ||
|
||
return source_sample, target_sample | ||
return source_sample, target_sample, mask |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be wrong, but I feel that you might need different values for
mesh_sample
across individuals if their data lie on different meshes. The barycenter might need its ownmesh_sample
as well?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are definitely right. However we agreed with @bthirion to keep the same geometry/mesh as for now. Moreover I'm afraid of the large memory/speed slowdown that we could get if we individualize the geometry of each individual. Maybe a
NotImplementedError
would be the best fit ? I'll open an issue...