Skip to content

Commit

Permalink
Merge pull request #68 from pbarbarant/feat/fugw-datasets
Browse files Browse the repository at this point in the history
[FEAT] Add caching options for geometry
  • Loading branch information
pbarbarant authored Sep 16, 2024
2 parents 776fa1d + d4b3e36 commit e96384b
Show file tree
Hide file tree
Showing 8 changed files with 539 additions and 4 deletions.
7 changes: 7 additions & 0 deletions src/fugw/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .surf_geometry import fetch_surf_geometry
from .vol_geometry import fetch_vol_geometry

__all__ = [
"fetch_surf_geometry",
"fetch_vol_geometry",
]
165 changes: 165 additions & 0 deletions src/fugw/datasets/surf_geometry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import gdist
import numpy as np

from typing import Tuple

from joblib import Memory
from nilearn import surface, datasets
from scipy.spatial import distance_matrix

from fugw.scripts import coarse_to_fine, lmds


# Create a Memory object to handle the caching
fugw_data = "~/fugw_data"
memory = Memory(fugw_data, verbose=0)


def _check_mesh(mesh: str) -> None:
"""Check if the mesh is valid."""
valid_meshes = ["pial_left", "pial_right", "infl_left", "infl_right"]
if mesh not in valid_meshes:
raise ValueError(
f"Unknown mesh {mesh}. Valid meshes include {valid_meshes}."
)


def _check_resolution(resolution: str) -> None:
"""Check if the resolution is valid."""
valid_resolutions = [
"fsaverage3",
"fsaverage4",
"fsaverage5",
"fsaverage6",
"fsaverage7",
"fsaverage",
]
if resolution not in valid_resolutions:
raise ValueError(
f"Unknown resolution {resolution}. Valid resolutions include"
f" {valid_resolutions}."
)


@memory.cache
def _fetch_geometry_full_rank(
mesh: str, resolution: str, method: str = "geodesic"
) -> Tuple[np.ndarray, float]:
"""Returns the normalized full-rank distance matrix for the
given mesh and the maximum distance between two points in the mesh.
"""
mesh_path = datasets.fetch_surf_fsaverage(mesh=resolution)[mesh]
(coordinates, triangles) = surface.load_surf_mesh(mesh_path)
if method == "geodesic":
# Return geodesic distance matrix
geometry = gdist.local_gdist_matrix(
coordinates.astype(np.float64), triangles.astype(np.int32)
).toarray()

elif method == "euclidean":
# Return euclidean distance matrix
geometry = distance_matrix(coordinates, coordinates)

# Normalize the distance matrix
d_max = geometry.max()
geometry = geometry / d_max

return geometry, d_max


@memory.cache
def _fetch_geometry_low_rank(
mesh: str,
resolution: str,
method: str = "geodesic",
rank: int = 3,
n_landmarks: int = 100,
n_jobs: int = 2,
verbose: bool = True,
) -> Tuple[np.ndarray, float]:
"""Returns the normalized low-rank distance matrix for the
given mesh and the maximum distance between two points in the mesh.
"""
if method == "euclidean":
raise NotImplementedError(
"Low-rank embedding is not implemented for L2 distance matrices."
)

mesh_path = datasets.fetch_surf_fsaverage(mesh=resolution)[mesh]
(coordinates, triangles) = surface.load_surf_mesh(mesh_path)
geometry_embedding = lmds.compute_lmds_mesh(
coordinates,
triangles,
n_landmarks=n_landmarks,
k=rank,
n_jobs=n_jobs,
verbose=verbose,
)
(
geometry_embedding_normalized,
d_max,
) = coarse_to_fine.random_normalizing(geometry_embedding)

return geometry_embedding_normalized.cpu().numpy(), d_max


def fetch_surf_geometry(
mesh: str,
resolution: str,
method: str = "geodesic",
rank: int = -1,
n_landmarks: int = 100,
n_jobs: int = 2,
verbose: bool = True,
) -> Tuple[np.ndarray, float]:
"""Returns either the normalized full-rank or low-rank embedding
of the distance matrix for the given mesh and the maximum distance
between two points in the mesh.
Parameters
----------
mesh : str
Input mesh name. Valid meshes include "pial_left", "pial_right",
"infl_left", and "infl_right".
resolution : str
Input resolution name. Valid resolutions include "fsaverage3",
"fsaverage4", "fsaverage5", "fsaverage6", "fsaverage7", and
"fsaverage".
method : str, optional
Method used to compute distances, either "geodesic" or "euclidean",
by default "geodesic".
rank : int, optional
Dimension of embedding, -1 for full-rank embedding
and rank < n_vertices for low-rank embedding, by default -1
n_landmarks : int, optional
Number of vertices to sample on mesh to approximate embedding,
by default 100
n_jobs : int, optional,
Relative tolerance used to check intermediate results, by default 2
verbose : bool, optional
Enable logging, by default True
Returns
-------
Tuple[np.ndarray, float]
Full-rank or low-rank embedding of the distance matrix of size
(n_vertices, n_vertices) or (n_vertices, rank) and the maximum
distance encountered in the mesh.
"""
_check_mesh(mesh)
_check_resolution(resolution)

if rank == -1:
return _fetch_geometry_full_rank(
mesh=mesh, resolution=resolution, method=method
)
else:
return _fetch_geometry_low_rank(
mesh=mesh,
resolution=resolution,
method=method,
rank=rank,
n_landmarks=n_landmarks,
n_jobs=n_jobs,
verbose=verbose,
)
164 changes: 164 additions & 0 deletions src/fugw/datasets/vol_geometry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import numpy as np

from typing import Tuple, Any

from joblib import Memory
from nilearn import datasets, masking
from scipy.spatial import distance_matrix

from fugw.scripts import coarse_to_fine, lmds


# Create a Memory object to handle the caching
fugw_data = "~/fugw_data"
memory = Memory(fugw_data, verbose=0)


def _check_masker(mask: str) -> None:
"""Check if the mask is valid."""
valid_masks = ["mni152_gm_mask", "mni152_brain_mask"]
if mask not in valid_masks:
raise ValueError(
f"Unknown mask {mask}. Valid masks include {valid_masks}."
)


def _compute_connected_segmentation(mask_img: Any) -> np.ndarray:
"""Compute the connected segmentation of the mask from a 3D Nifti image."""
return (
masking.compute_background_mask(mask_img, connected=True).get_fdata()
> 0
)


@memory.cache
def _fetch_geometry_full_rank(
mask: str, resolution: int, method: str = "euclidean"
) -> Tuple[np.ndarray, float]:
"""Returns the normalized full-rank distance matrix for the
given mesh and the maximum distance between two points in the volume.
"""
if mask == "mni152_gm_mask":
mask_img = datasets.load_mni152_gm_mask(resolution=resolution)
elif mask == "mni152_brain_mask":
mask_img = datasets.load_mni152_brain_mask(resolution=resolution)
segmentation = _compute_connected_segmentation(mask_img)

if method == "geodesic":
raise NotImplementedError(
"Geodesic distance computation is not implemented for volume data"
" in the full-rank setting."
)

elif method == "euclidean":
# Return euclidean distance matrix
coordinates = np.array(np.where(segmentation)).T
geometry = distance_matrix(coordinates, coordinates)

# Normalize the distance matrix
d_max = geometry.max()
geometry = geometry / d_max

return geometry, d_max


@memory.cache
def _fetch_geometry_low_rank(
mask: str,
resolution: str,
method: str = "euclidean",
rank: int = 3,
n_landmarks: int = 100,
n_jobs: int = 2,
verbose: bool = True,
) -> Tuple[np.ndarray, float]:
"""Returns the normalized low-rank distance matrix for the
given mesh and the maximum distance between two points in the mesh.
"""
if mask == "mni152_gm_mask":
mask_img = datasets.load_mni152_gm_mask(resolution=resolution)
elif mask == "mni152_brain_mask":
mask_img = datasets.load_mni152_brain_mask(resolution=resolution)
segmentation = _compute_connected_segmentation(mask_img)

# Get the anisotropy of the 3D mask
anisotropy = np.abs(mask_img.header.get_zooms()[:3]).tolist()

geometry_embedding = lmds.compute_lmds_volume(
segmentation,
method=method,
k=rank,
n_landmarks=n_landmarks,
anisotropy=anisotropy,
n_jobs=n_jobs,
verbose=verbose,
).nan_to_num()

(
geometry_embedding_normalized,
d_max,
) = coarse_to_fine.random_normalizing(geometry_embedding)

return (
geometry_embedding_normalized.cpu().numpy(),
d_max,
)


def fetch_vol_geometry(
mask: str,
resolution: int,
method: str = "euclidean",
rank: int = -1,
n_landmarks: int = 100,
n_jobs: int = 2,
verbose: bool = True,
) -> Tuple[np.ndarray, float]:
"""Returns either the normalized full-rank or low-rank embedding
of the distance matrix for the given mesh and the maximum distance
between two points in the volume.
Parameters
----------
mask : str
Input mask name. Valid masks include "mni152_gm_mask" and
"mni152_brain_mask".
resolution : int
Input resolution name.
method : str, optional
Method used to compute distances, either "geodesic" or "euclidean",
by default "geodesic".
rank : int, optional
Dimension of embedding, -1 for full-rank embedding
and rank < n_vertices for low-rank embedding, by default -1
n_landmarks : int, optional
Number of vertices to sample on mesh to approximate embedding,
by default 100
n_jobs : int, optional,
Relative tolerance used to check intermediate results, by default 2
verbose : bool, optional
Enable logging, by default True
Returns
-------
Tuple[np.ndarray, float]
Full-rank or low-rank embedding of the distance matrix of size
(n_vertices, n_vertices) or (n_vertices, rank) and the maximum
distance encountered in the volume.
"""
_check_masker(mask)

if rank == -1:
return _fetch_geometry_full_rank(
mask=mask, resolution=resolution, method=method
)
else:
return _fetch_geometry_low_rank(
mask=mask,
resolution=resolution,
method=method,
rank=rank,
n_landmarks=n_landmarks,
n_jobs=n_jobs,
verbose=verbose,
)
6 changes: 6 additions & 0 deletions src/fugw/mappings/barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ def update_barycenter_features(plans, weights_list, features_list, device):
else:
barycenter_features += acc

# Normalize barycenter features
min_val = barycenter_features.min(dim=0, keepdim=True).values
max_val = barycenter_features.max(dim=0, keepdim=True).values
barycenter_features = (
2 * (barycenter_features - min_val) / (max_val - min_val) - 1
)
return barycenter_features.T

@staticmethod
Expand Down
6 changes: 6 additions & 0 deletions src/fugw/mappings/sparse_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ def update_barycenter_features(plans, weights_list, features_list, device):
else:
barycenter_features += acc

# Normalize barycenter features
min_val = barycenter_features.min(dim=0, keepdim=True).values
max_val = barycenter_features.max(dim=0, keepdim=True).values
barycenter_features = (
2 * (barycenter_features - min_val) / (max_val - min_val) - 1
)
return barycenter_features.T

@staticmethod
Expand Down
Loading

0 comments on commit e96384b

Please sign in to comment.