Skip to content

Commit

Permalink
Merge pull request alexisthual#66 from pbarbarant/fix/accelerate_spar…
Browse files Browse the repository at this point in the history
…se_barycenter

Fix/accelerate sparse barycenter
  • Loading branch information
pbarbarant authored Sep 11, 2024
2 parents cf73a1d + edf98ad commit 2f13445
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 98 deletions.
9 changes: 8 additions & 1 deletion src/fugw/mappings/barycenter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from fugw.mappings.dense import FUGW
from fugw.utils import _make_tensor
from fugw.utils import _make_tensor, console


class FUGWBarycenter:
Expand Down Expand Up @@ -131,6 +131,9 @@ def compute_all_ot_plans(
for i, (features, weights) in enumerate(
zip(features_list, weights_list)
):
if verbose:
console.log(f"Updating mapping {i + 1} / {len(weights_list)}")

if len(geometry_list) == 1 and len(weights_list) > 1:
G = geometry_list[0]
else:
Expand Down Expand Up @@ -277,6 +280,10 @@ def fit(
losses_each_bar_step = []

for idx in range(nits_barycenter):
if verbose:
console.log(
f"Barycenter iterations {idx + 1} / {nits_barycenter}"
)
# Transport all elements
plans, losses = self.compute_all_ot_plans(
plans,
Expand Down
24 changes: 16 additions & 8 deletions src/fugw/mappings/sparse_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fugw.mappings.dense import FUGW
from fugw.mappings.sparse import FUGWSparse
from fugw.scripts import coarse_to_fine
from fugw.utils import _make_tensor
from fugw.utils import _make_tensor, console


class FUGWSparseBarycenter:
Expand Down Expand Up @@ -81,7 +81,7 @@ def compute_all_ot_plans(
coarse_mapping_solver_params,
fine_mapping_solver_params,
selection_radius,
mask,
sparsity_mask,
device,
verbose,
):
Expand All @@ -91,6 +91,9 @@ def compute_all_ot_plans(
for i, (features, weights) in enumerate(
zip(features_list, weights_list)
):
if verbose:
console.log(f"Updating mapping {i + 1} / {len(weights_list)}")

coarse_mapping = FUGW(
alpha=self.alpha_coarse,
rho=self.rho_coarse,
Expand All @@ -105,7 +108,7 @@ def compute_all_ot_plans(
reg_mode=self.reg_mode,
)

_, _, mask = coarse_to_fine.fit(
_, _, sparsity_mask = coarse_to_fine.fit(
source_features=features,
target_features=barycenter_features,
source_geometry_embeddings=geometry_embedding,
Expand All @@ -124,7 +127,7 @@ def compute_all_ot_plans(
fine_mapping_solver=solver,
fine_mapping_solver_params=fine_mapping_solver_params,
init_plan=plans[i] if plans is not None else None,
mask=mask,
sparsity_mask=sparsity_mask,
device=device,
verbose=verbose,
)
Expand All @@ -140,7 +143,7 @@ def compute_all_ot_plans(
)
)

return new_plans, new_losses
return new_plans, new_losses, sparsity_mask

def fit(
self,
Expand Down Expand Up @@ -251,12 +254,17 @@ def fit(
)

plans = None
mask = None
sparsity_mask = None
losses_each_bar_step = []

for idx in range(nits_barycenter):
if verbose:
console.log(
f"Barycenter iterations {idx + 1} / {nits_barycenter}"
)

# Transport all elements
plans, losses = self.compute_all_ot_plans(
plans, losses, sparsity_mask = self.compute_all_ot_plans(
plans,
weights_list,
features_list,
Expand All @@ -268,7 +276,7 @@ def fit(
coarse_mapping_solver_params,
fine_mapping_solver_params,
self.selection_radius,
mask,
sparsity_mask,
device,
verbose,
)
Expand Down
223 changes: 137 additions & 86 deletions src/fugw/scripts/coarse_to_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,90 @@ def get_neighbourhood_matrix(embeddings, sample, radius):
return neighbourhood_matrix


def compute_sparsity_mask(
coarse_mapping,
source_sample,
target_sample,
source_geometry_embeddings,
target_geometry_embeddings,
source_selection_radius=1,
target_selection_radius=1,
method="topk",
):
"""
Compute sparsity mask from coarse mapping.
Parameters
----------
coarse_mapping: fugw.FUGW
Dense mapping between subsampled meshes
source_sample: torch.Tensor of size (n1)
Indices of sampled points on source distribution
target_sample: torch.Tensor of size (m1)
Indices of sampled points on target distribution
source_geometry_embeddings: torch.Tensor of size (n, k)
Embeddings approximating the geodesic distance between
source vertices
target_geometry_embeddings: torch.Tensor of size (m, k)
Embeddings approximating the geodesic distance between
target vertices
source_selection_radius: float
Radius used to determine the neighbourhood
of source vertices when defining sparsity mask
target_selection_radius: float
Radius used to determine the neighbourhood
of target vertices when defining sparsity mask
method: "topk" or "quantile"
Method used to select pairs of source and target features
whose neighbourhoods will be used to define
the sparsity mask of the solution
Returns
-------
sparsity_mask: torch.sparse_coo_tensor of size (n, m)
Sparsity mask used to initialize the fine mapping.
"""
if method == "quantile":
# Method 1: keep first percentile
threshold = np.percentile(coarse_mapping.pi, 99.95)
rows, cols = np.nonzero(coarse_mapping.pi > threshold)

elif method == "topk":
# Method 2: keep topk indices per line and per column
# (this should be preferred as it will keep vertices
# which are particularly unbalanced)
rows = np.concatenate(
[
np.arange(source_sample.shape[0]),
np.argmax(coarse_mapping.pi, axis=0),
]
)
cols = np.concatenate(
[
np.argmax(coarse_mapping.pi, axis=1),
np.arange(target_sample.shape[0]),
]
)

# Compute mask as a matrix product between:
# a. neighbourhood matrices that encode
# which vertex is close to which sampled point
N_source = get_neighbourhood_matrix(
source_geometry_embeddings, source_sample, source_selection_radius
)
N_target = get_neighbourhood_matrix(
target_geometry_embeddings, target_sample, target_selection_radius
)
# b. cluster matrices that encode
# which sampled point belongs to which cluster
C_source = get_cluster_matrix(rows, source_sample.shape[0])
C_target = get_cluster_matrix(cols, target_sample.shape[0])

sparsity_mask = (N_source @ C_source) @ (N_target @ C_target).T

return sparsity_mask


def fit(
coarse_mapping=None,
coarse_mapping_solver="mm",
Expand All @@ -285,7 +369,7 @@ def fit(
source_weights=None,
target_weights=None,
init_plan=None,
mask=None,
sparsity_mask=None,
device="auto",
verbose=False,
):
Expand Down Expand Up @@ -353,7 +437,7 @@ def fit(
init_plan: torch.sparse_coo_tensor or None
Initial transport plan to use when fitting the fine mapping.
If None, a random plan will be used.
sparsity_mask: torch.Tensor of size(n, m) or None
sparsity_mask: sparse coo/csr matrix of size(n, m) or None
Sparsity mask to use when fitting the fine mapping.
If None, a mask will be computed from the coarse mapping.
device: "auto" or torch.device
Expand All @@ -374,29 +458,12 @@ def fit(
Sparsity mask used to fit the fine mapping.
"""
# 0. Parse input tensors
source_sample = _make_tensor(source_sample, dtype=torch.int64)
target_sample = _make_tensor(target_sample, dtype=torch.int64)

source_features = _make_tensor(source_features)
target_features = _make_tensor(target_features)

source_geometry_embeddings = _make_tensor(source_geometry_embeddings)
target_geometry_embeddings = _make_tensor(target_geometry_embeddings)

# Compute anatomical kernels
source_geometry_kernel = torch.cdist(
source_geometry_embeddings[source_sample],
source_geometry_embeddings[source_sample],
p=2,
)
source_geometry_kernel /= source_geometry_kernel.max()
target_geometry_kernel = torch.cdist(
target_geometry_embeddings[target_sample],
target_geometry_embeddings[target_sample],
p=2,
)
target_geometry_kernel /= target_geometry_kernel.max()

# Sampled weights
if source_weights is None:
n = source_features.shape[1]
Expand All @@ -405,82 +472,66 @@ def fit(
m = target_features.shape[1]
target_weights = torch.ones(m) / m

source_weights_sampled = _make_tensor(source_weights)[source_sample]
source_weights_sampled = (
source_weights_sampled / source_weights_sampled.sum()
)
target_weights_sampled = _make_tensor(target_weights)[target_sample]
target_weights_sampled = (
target_weights_sampled / target_weights_sampled.sum()
)

# 1. Fit coarse mapping
coarse_mapping.fit(
source_features[:, source_sample],
target_features[:, target_sample],
source_geometry=source_geometry_kernel,
target_geometry=target_geometry_kernel,
source_weights=source_weights_sampled,
target_weights=target_weights_sampled,
solver=coarse_mapping_solver,
solver_params=coarse_mapping_solver_params,
callback_bcd=coarse_callback_bcd,
device=device,
verbose=verbose,
)

# Send coarse mapping to cpu to handle numpy operations
coarse_mapping.pi = coarse_mapping.pi.cpu()

# 2. Build sparsity mask
if sparsity_mask is None:
source_sample = _make_tensor(source_sample, dtype=torch.int64)
target_sample = _make_tensor(target_sample, dtype=torch.int64)

# Select best pairs of source and target vertices from coarse alignment
if coarse_pairs_selection_method == "quantile":
# Method 1: keep first percentile
quantile = 99.95

threshold = np.percentile(coarse_mapping.pi, quantile)
rows, cols = np.nonzero(coarse_mapping.pi > threshold)

elif coarse_pairs_selection_method == "topk":
# Method 2: keep topk indices per line and per column
# (this should be preferred as it will keep vertices
# which are particularly unbalanced)
rows = np.concatenate(
[
np.arange(source_sample.shape[0]),
np.argmax(coarse_mapping.pi, axis=0),
]
source_weights_sampled = _make_tensor(source_weights)[source_sample]
source_weights_sampled = (
source_weights_sampled / source_weights_sampled.sum()
)
cols = np.concatenate(
[
np.argmax(coarse_mapping.pi, axis=1),
np.arange(target_sample.shape[0]),
]
target_weights_sampled = _make_tensor(target_weights)[target_sample]
target_weights_sampled = (
target_weights_sampled / target_weights_sampled.sum()
)

if mask is None:
# Compute mask as a matrix product between:
# a. neighbourhood matrices that encode
# which vertex is close to which sampled point
N_source = get_neighbourhood_matrix(
source_geometry_embeddings, source_sample, source_selection_radius
# Compute anatomical kernels
source_geometry_kernel = torch.cdist(
source_geometry_embeddings[source_sample],
source_geometry_embeddings[source_sample],
p=2,
)
N_target = get_neighbourhood_matrix(
target_geometry_embeddings, target_sample, target_selection_radius
source_geometry_kernel /= source_geometry_kernel.max()
target_geometry_kernel = torch.cdist(
target_geometry_embeddings[target_sample],
target_geometry_embeddings[target_sample],
p=2,
)
target_geometry_kernel /= target_geometry_kernel.max()

# 1. Fit coarse mapping
coarse_mapping.fit(
source_features[:, source_sample],
target_features[:, target_sample],
source_geometry=source_geometry_kernel,
target_geometry=target_geometry_kernel,
source_weights=source_weights_sampled,
target_weights=target_weights_sampled,
solver=coarse_mapping_solver,
solver_params=coarse_mapping_solver_params,
callback_bcd=coarse_callback_bcd,
device=device,
verbose=verbose,
)
# b. cluster matrices that encode
# which sampled point belongs to which cluster
C_source = get_cluster_matrix(rows, source_sample.shape[0])
C_target = get_cluster_matrix(cols, target_sample.shape[0])

mask = (N_source @ C_source) @ (N_target @ C_target).T
# 2. Build sparsity mask
sparsity_mask = compute_sparsity_mask(
coarse_mapping,
source_sample,
target_sample,
source_geometry_embeddings,
target_geometry_embeddings,
source_selection_radius=source_selection_radius,
target_selection_radius=target_selection_radius,
method=coarse_pairs_selection_method,
)

# Define init plan from spasity mask
# Define init plan from sparsity mask
if init_plan is None:
init_plan = torch.sparse_coo_tensor(
mask.indices(),
torch.ones_like(mask.values()) / mask.values().shape[0],
sparsity_mask.indices(),
torch.ones_like(sparsity_mask.values())
/ sparsity_mask.values().shape[0],
(
source_geometry_embeddings.shape[0],
target_geometry_embeddings.shape[0],
Expand All @@ -503,4 +554,4 @@ def fit(
callback_bcd=fine_callback_bcd,
)

return source_sample, target_sample, mask
return source_sample, target_sample, sparsity_mask
Loading

0 comments on commit 2f13445

Please sign in to comment.