Skip to content

Commit

Permalink
Merge pull request #13 from alexisthual/feat/efficient_sparsity_mask_…
Browse files Browse the repository at this point in the history
…computatin

Compute coarse_to_fine sparsity mask from matrix product
  • Loading branch information
alexisthual authored Mar 8, 2023
2 parents c605800 + 8902523 commit 3111413
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 136 deletions.
6 changes: 3 additions & 3 deletions doc/themes/custom.css
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ div[class^="highlight-"].sphx-glr-script-out {

.sphx-glr-script-out .highlight {
border-radius: 0px;
border-left: 4px solid var(--sd-color-primary);
border-left: 4px solid var(--color-background-border);
}

.sphx-glr-script-out .highlight pre {
Expand All @@ -74,11 +74,11 @@ button.copybtn {
}

.highlight-default .highlight {
border-left: 4px solid var(--color-background-border);
border-left: 4px solid var(--sd-color-primary);
}

.highlight-primary .highlight {
border-left: 4px solid var(--sd-color-primary);
border-left: 4px solid var(--color-background-border);
}

table.plotting-table {
Expand Down
10 changes: 7 additions & 3 deletions examples/00_basics/plot_2_2_coarse_to_fine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@
# Parametrize step 2 (selection of pairs of indices present in
# fine-grained's sparsity mask)
coarse_pairs_selection_method="topk",
source_selection_radius=1 / source_d_max,
target_selection_radius=1 / target_d_max,
source_selection_radius=0.5 / source_d_max,
target_selection_radius=0.5 / target_d_max,
# Parametrize step 3 (fine-grained alignment)
fine_mapping=fine_mapping,
fine_mapping_solver=fine_mapping_solver,
Expand Down Expand Up @@ -216,7 +216,11 @@
).permute(2, 0, 1)
pi_normalized = pi / pi.sum(dim=1).reshape(-1, 1)
line_segments = LineCollection(
segments, alpha=pi_normalized.flatten(), colors="black", lw=1, zorder=1
segments,
alpha=pi_normalized.flatten().nan_to_num(),
colors="black",
lw=1,
zorder=1,
)
ax.add_collection(line_segments)

Expand Down
6 changes: 3 additions & 3 deletions examples/01_brain_alignment/plot_1_aligning_brain_dense.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# %%
"""
====================================================
Align brain surfaces of 2 individuals with fMRI data
====================================================
===================================================================
Align low-resolution brain surfaces of 2 individuals with fMRI data
===================================================================
In this example, we align 2 low-resolution left hemispheres
using 4 fMRI feature maps (z-score contrast maps).
Expand Down
78 changes: 24 additions & 54 deletions examples/01_brain_alignment/plot_2_aligning_brain_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,6 @@
source_imgs_paths = brain_data["cmaps"][0 : len(contrasts)]
target_imgs_paths = brain_data["cmaps"][len(contrasts) : 2 * len(contrasts)]

# %%
# Here is what the first contrast map of the source subject looks like
# (the following figure is interactive):

contrast_index = 0
plotting.view_img(
source_imgs_paths[contrast_index],
brain_data["anats"][0],
title=f"Contrast {contrast_index} (source subject)",
opacity=0.5,
)

# %%
# Computing feature arrays
# ------------------------
Expand Down Expand Up @@ -358,50 +346,28 @@ def plot_surface_map(
# derive a scalable way to derive which vertices are selected.


def vertices_in_sparcity_mask(embeddings, sample, radius):
n_vertices = embeddings.shape[0]
elligible_vertices = [
torch.tensor(
np.argwhere(
np.linalg.norm(
embeddings - embeddings[i],
axis=1,
)
<= radius
)
)
for i in sample
]

rows = torch.concat(
[
i * torch.ones(len(elligible_vertices[i]))
for i in range(len(elligible_vertices))
]
).type(torch.int)
cols = torch.concat(elligible_vertices).flatten().type(torch.int)
values = torch.ones_like(rows)

vertex_within_radius = torch.sparse_coo_tensor(
torch.stack([rows, cols]),
values,
size=(n_vertices, n_vertices),
)
selected_vertices = torch.sparse.sum(
vertex_within_radius, dim=0
).to_dense()

return selected_vertices


source_selection_radius = 7
selected_source_vertices = vertices_in_sparcity_mask(
source_geometry_embeddings, source_sample, source_selection_radius
n_neighbourhoods_per_vertex_source = (
torch.sparse.sum(
coarse_to_fine.get_neighbourhood_matrix(
source_geometry_embeddings, source_sample, source_selection_radius
),
dim=1,
)
.to_dense()
.numpy()
)

target_selection_radius = 7
selected_target_vertices = vertices_in_sparcity_mask(
target_geometry_embeddings, target_sample, target_selection_radius
n_neighbourhoods_per_vertex_target = (
torch.sparse.sum(
coarse_to_fine.get_neighbourhood_matrix(
target_geometry_embeddings, target_sample, target_selection_radius
),
dim=1,
)
.to_dense()
.numpy()
)

# %%
Expand All @@ -410,12 +376,16 @@ def vertices_in_sparcity_mask(embeddings, sample, radius):
# sampled, light blue for vertices which are within radius-distance
# of a sampled vertex. Vertices which won't be selected appear in white.
# The following figure is interactive.
# **Note that, because embeddings are not very precise for short distances,
# vertices that are very close to sampled vertices can actually
# be absent from the mask**. In order to limit this effect, the radius
# should generally be set to a high enough value.

source_vertices_in_mask = np.zeros(source_features.shape[1])
source_vertices_in_mask[selected_source_vertices > 0] = 1
source_vertices_in_mask[n_neighbourhoods_per_vertex_source > 0] = 1
source_vertices_in_mask[source_sample] = 2
target_vertices_in_mask = np.zeros(target_features.shape[1])
target_vertices_in_mask[selected_target_vertices > 0] = 1
target_vertices_in_mask[n_neighbourhoods_per_vertex_target > 0] = 1
target_vertices_in_mask[target_sample] = 2

# Generate figure with 2 subplots
Expand Down
Loading

0 comments on commit 3111413

Please sign in to comment.