-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #68 from pbarbarant/feat/fugw-datasets
[FEAT] Add caching options for geometry
- Loading branch information
Showing
8 changed files
with
539 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .surf_geometry import fetch_surf_geometry | ||
from .vol_geometry import fetch_vol_geometry | ||
|
||
__all__ = [ | ||
"fetch_surf_geometry", | ||
"fetch_vol_geometry", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import gdist | ||
import numpy as np | ||
|
||
from typing import Tuple | ||
|
||
from joblib import Memory | ||
from nilearn import surface, datasets | ||
from scipy.spatial import distance_matrix | ||
|
||
from fugw.scripts import coarse_to_fine, lmds | ||
|
||
|
||
# Create a Memory object to handle the caching | ||
fugw_data = "~/fugw_data" | ||
memory = Memory(fugw_data, verbose=0) | ||
|
||
|
||
def _check_mesh(mesh: str) -> None: | ||
"""Check if the mesh is valid.""" | ||
valid_meshes = ["pial_left", "pial_right", "infl_left", "infl_right"] | ||
if mesh not in valid_meshes: | ||
raise ValueError( | ||
f"Unknown mesh {mesh}. Valid meshes include {valid_meshes}." | ||
) | ||
|
||
|
||
def _check_resolution(resolution: str) -> None: | ||
"""Check if the resolution is valid.""" | ||
valid_resolutions = [ | ||
"fsaverage3", | ||
"fsaverage4", | ||
"fsaverage5", | ||
"fsaverage6", | ||
"fsaverage7", | ||
"fsaverage", | ||
] | ||
if resolution not in valid_resolutions: | ||
raise ValueError( | ||
f"Unknown resolution {resolution}. Valid resolutions include" | ||
f" {valid_resolutions}." | ||
) | ||
|
||
|
||
@memory.cache | ||
def _fetch_geometry_full_rank( | ||
mesh: str, resolution: str, method: str = "geodesic" | ||
) -> Tuple[np.ndarray, float]: | ||
"""Returns the normalized full-rank distance matrix for the | ||
given mesh and the maximum distance between two points in the mesh. | ||
""" | ||
mesh_path = datasets.fetch_surf_fsaverage(mesh=resolution)[mesh] | ||
(coordinates, triangles) = surface.load_surf_mesh(mesh_path) | ||
if method == "geodesic": | ||
# Return geodesic distance matrix | ||
geometry = gdist.local_gdist_matrix( | ||
coordinates.astype(np.float64), triangles.astype(np.int32) | ||
).toarray() | ||
|
||
elif method == "euclidean": | ||
# Return euclidean distance matrix | ||
geometry = distance_matrix(coordinates, coordinates) | ||
|
||
# Normalize the distance matrix | ||
d_max = geometry.max() | ||
geometry = geometry / d_max | ||
|
||
return geometry, d_max | ||
|
||
|
||
@memory.cache | ||
def _fetch_geometry_low_rank( | ||
mesh: str, | ||
resolution: str, | ||
method: str = "geodesic", | ||
rank: int = 3, | ||
n_landmarks: int = 100, | ||
n_jobs: int = 2, | ||
verbose: bool = True, | ||
) -> Tuple[np.ndarray, float]: | ||
"""Returns the normalized low-rank distance matrix for the | ||
given mesh and the maximum distance between two points in the mesh. | ||
""" | ||
if method == "euclidean": | ||
raise NotImplementedError( | ||
"Low-rank embedding is not implemented for L2 distance matrices." | ||
) | ||
|
||
mesh_path = datasets.fetch_surf_fsaverage(mesh=resolution)[mesh] | ||
(coordinates, triangles) = surface.load_surf_mesh(mesh_path) | ||
geometry_embedding = lmds.compute_lmds_mesh( | ||
coordinates, | ||
triangles, | ||
n_landmarks=n_landmarks, | ||
k=rank, | ||
n_jobs=n_jobs, | ||
verbose=verbose, | ||
) | ||
( | ||
geometry_embedding_normalized, | ||
d_max, | ||
) = coarse_to_fine.random_normalizing(geometry_embedding) | ||
|
||
return geometry_embedding_normalized.cpu().numpy(), d_max | ||
|
||
|
||
def fetch_surf_geometry( | ||
mesh: str, | ||
resolution: str, | ||
method: str = "geodesic", | ||
rank: int = -1, | ||
n_landmarks: int = 100, | ||
n_jobs: int = 2, | ||
verbose: bool = True, | ||
) -> Tuple[np.ndarray, float]: | ||
"""Returns either the normalized full-rank or low-rank embedding | ||
of the distance matrix for the given mesh and the maximum distance | ||
between two points in the mesh. | ||
Parameters | ||
---------- | ||
mesh : str | ||
Input mesh name. Valid meshes include "pial_left", "pial_right", | ||
"infl_left", and "infl_right". | ||
resolution : str | ||
Input resolution name. Valid resolutions include "fsaverage3", | ||
"fsaverage4", "fsaverage5", "fsaverage6", "fsaverage7", and | ||
"fsaverage". | ||
method : str, optional | ||
Method used to compute distances, either "geodesic" or "euclidean", | ||
by default "geodesic". | ||
rank : int, optional | ||
Dimension of embedding, -1 for full-rank embedding | ||
and rank < n_vertices for low-rank embedding, by default -1 | ||
n_landmarks : int, optional | ||
Number of vertices to sample on mesh to approximate embedding, | ||
by default 100 | ||
n_jobs : int, optional, | ||
Relative tolerance used to check intermediate results, by default 2 | ||
verbose : bool, optional | ||
Enable logging, by default True | ||
Returns | ||
------- | ||
Tuple[np.ndarray, float] | ||
Full-rank or low-rank embedding of the distance matrix of size | ||
(n_vertices, n_vertices) or (n_vertices, rank) and the maximum | ||
distance encountered in the mesh. | ||
""" | ||
_check_mesh(mesh) | ||
_check_resolution(resolution) | ||
|
||
if rank == -1: | ||
return _fetch_geometry_full_rank( | ||
mesh=mesh, resolution=resolution, method=method | ||
) | ||
else: | ||
return _fetch_geometry_low_rank( | ||
mesh=mesh, | ||
resolution=resolution, | ||
method=method, | ||
rank=rank, | ||
n_landmarks=n_landmarks, | ||
n_jobs=n_jobs, | ||
verbose=verbose, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import numpy as np | ||
|
||
from typing import Tuple, Any | ||
|
||
from joblib import Memory | ||
from nilearn import datasets, masking | ||
from scipy.spatial import distance_matrix | ||
|
||
from fugw.scripts import coarse_to_fine, lmds | ||
|
||
|
||
# Create a Memory object to handle the caching | ||
fugw_data = "~/fugw_data" | ||
memory = Memory(fugw_data, verbose=0) | ||
|
||
|
||
def _check_masker(mask: str) -> None: | ||
"""Check if the mask is valid.""" | ||
valid_masks = ["mni152_gm_mask", "mni152_brain_mask"] | ||
if mask not in valid_masks: | ||
raise ValueError( | ||
f"Unknown mask {mask}. Valid masks include {valid_masks}." | ||
) | ||
|
||
|
||
def _compute_connected_segmentation(mask_img: Any) -> np.ndarray: | ||
"""Compute the connected segmentation of the mask from a 3D Nifti image.""" | ||
return ( | ||
masking.compute_background_mask(mask_img, connected=True).get_fdata() | ||
> 0 | ||
) | ||
|
||
|
||
@memory.cache | ||
def _fetch_geometry_full_rank( | ||
mask: str, resolution: int, method: str = "euclidean" | ||
) -> Tuple[np.ndarray, float]: | ||
"""Returns the normalized full-rank distance matrix for the | ||
given mesh and the maximum distance between two points in the volume. | ||
""" | ||
if mask == "mni152_gm_mask": | ||
mask_img = datasets.load_mni152_gm_mask(resolution=resolution) | ||
elif mask == "mni152_brain_mask": | ||
mask_img = datasets.load_mni152_brain_mask(resolution=resolution) | ||
segmentation = _compute_connected_segmentation(mask_img) | ||
|
||
if method == "geodesic": | ||
raise NotImplementedError( | ||
"Geodesic distance computation is not implemented for volume data" | ||
" in the full-rank setting." | ||
) | ||
|
||
elif method == "euclidean": | ||
# Return euclidean distance matrix | ||
coordinates = np.array(np.where(segmentation)).T | ||
geometry = distance_matrix(coordinates, coordinates) | ||
|
||
# Normalize the distance matrix | ||
d_max = geometry.max() | ||
geometry = geometry / d_max | ||
|
||
return geometry, d_max | ||
|
||
|
||
@memory.cache | ||
def _fetch_geometry_low_rank( | ||
mask: str, | ||
resolution: str, | ||
method: str = "euclidean", | ||
rank: int = 3, | ||
n_landmarks: int = 100, | ||
n_jobs: int = 2, | ||
verbose: bool = True, | ||
) -> Tuple[np.ndarray, float]: | ||
"""Returns the normalized low-rank distance matrix for the | ||
given mesh and the maximum distance between two points in the mesh. | ||
""" | ||
if mask == "mni152_gm_mask": | ||
mask_img = datasets.load_mni152_gm_mask(resolution=resolution) | ||
elif mask == "mni152_brain_mask": | ||
mask_img = datasets.load_mni152_brain_mask(resolution=resolution) | ||
segmentation = _compute_connected_segmentation(mask_img) | ||
|
||
# Get the anisotropy of the 3D mask | ||
anisotropy = np.abs(mask_img.header.get_zooms()[:3]).tolist() | ||
|
||
geometry_embedding = lmds.compute_lmds_volume( | ||
segmentation, | ||
method=method, | ||
k=rank, | ||
n_landmarks=n_landmarks, | ||
anisotropy=anisotropy, | ||
n_jobs=n_jobs, | ||
verbose=verbose, | ||
).nan_to_num() | ||
|
||
( | ||
geometry_embedding_normalized, | ||
d_max, | ||
) = coarse_to_fine.random_normalizing(geometry_embedding) | ||
|
||
return ( | ||
geometry_embedding_normalized.cpu().numpy(), | ||
d_max, | ||
) | ||
|
||
|
||
def fetch_vol_geometry( | ||
mask: str, | ||
resolution: int, | ||
method: str = "euclidean", | ||
rank: int = -1, | ||
n_landmarks: int = 100, | ||
n_jobs: int = 2, | ||
verbose: bool = True, | ||
) -> Tuple[np.ndarray, float]: | ||
"""Returns either the normalized full-rank or low-rank embedding | ||
of the distance matrix for the given mesh and the maximum distance | ||
between two points in the volume. | ||
Parameters | ||
---------- | ||
mask : str | ||
Input mask name. Valid masks include "mni152_gm_mask" and | ||
"mni152_brain_mask". | ||
resolution : int | ||
Input resolution name. | ||
method : str, optional | ||
Method used to compute distances, either "geodesic" or "euclidean", | ||
by default "geodesic". | ||
rank : int, optional | ||
Dimension of embedding, -1 for full-rank embedding | ||
and rank < n_vertices for low-rank embedding, by default -1 | ||
n_landmarks : int, optional | ||
Number of vertices to sample on mesh to approximate embedding, | ||
by default 100 | ||
n_jobs : int, optional, | ||
Relative tolerance used to check intermediate results, by default 2 | ||
verbose : bool, optional | ||
Enable logging, by default True | ||
Returns | ||
------- | ||
Tuple[np.ndarray, float] | ||
Full-rank or low-rank embedding of the distance matrix of size | ||
(n_vertices, n_vertices) or (n_vertices, rank) and the maximum | ||
distance encountered in the volume. | ||
""" | ||
_check_masker(mask) | ||
|
||
if rank == -1: | ||
return _fetch_geometry_full_rank( | ||
mask=mask, resolution=resolution, method=method | ||
) | ||
else: | ||
return _fetch_geometry_low_rank( | ||
mask=mask, | ||
resolution=resolution, | ||
method=method, | ||
rank=rank, | ||
n_landmarks=n_landmarks, | ||
n_jobs=n_jobs, | ||
verbose=verbose, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.