Skip to content

Commit

Permalink
Update Nilearn requirement to 0.10.3 (PennLINC#396)
Browse files Browse the repository at this point in the history
* Try fixing nilearn parcellation.

* Update pyproject.toml
  • Loading branch information
tsalo authored Mar 28, 2024
1 parent f84987e commit 12e1580
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 53 deletions.
140 changes: 87 additions & 53 deletions aslprep/interfaces/parcellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pandas as pd
from nilearn.maskers import NiftiLabelsMasker
from nipype import logging
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
File,
Expand All @@ -14,7 +15,7 @@
)
from nipype.utils.filemanip import fname_presuffix

from aslprep import config
LOGGER = logging.getLogger("nipype.interface")


class _ParcellateCBFInputSpec(BaseInterfaceInputSpec):
Expand All @@ -35,8 +36,8 @@ class _ParcellateCBFInputSpec(BaseInterfaceInputSpec):


class _ParcellateCBFOutputSpec(TraitedSpec):
timeseries = File(exists=True, mandatory=True, desc="Parcellated time series file.")
coverage = File(exists=True, mandatory=True, desc="Parcel-wise coverage file.")
timeseries = File(exists=True, desc="Parcellated time series file.")
coverage = File(exists=True, desc="Parcel-wise coverage file.")


class ParcellateCBF(SimpleInterface):
Expand All @@ -51,65 +52,50 @@ class ParcellateCBF(SimpleInterface):
output_spec = _ParcellateCBFOutputSpec

def _run_interface(self, runtime):
in_file = self.inputs.in_file
mask = self.inputs.mask
atlas = self.inputs.atlas
atlas_labels = self.inputs.atlas_labels
min_coverage = self.inputs.min_coverage

node_labels_df = pd.read_table(atlas_labels, index_col="index")
node_labels_df.sort_index(inplace=True) # ensure index is in order

# Explicitly remove label corresponding to background (index=0), if present.
if 0 in node_labels_df.index:
config.loggers.interface.warning(
"Index value of 0 found in atlas labels file. "
"Will assume this describes the background and ignore it."
)
node_labels_df = node_labels_df.drop(index=[0])
node_labels_df = pd.read_table(self.inputs.atlas_labels, index_col="index")

# Fix any nonsequential values or mismatch between atlas and DataFrame.
atlas_img, node_labels_df = _sanitize_nifti_atlas(atlas, node_labels_df)
node_labels = node_labels_df["label"].tolist()

self._results["timeseries"] = fname_presuffix(
"timeseries.tsv",
newpath=runtime.cwd,
use_ext=True,
)
self._results["coverage"] = fname_presuffix(
"coverage.tsv",
newpath=runtime.cwd,
use_ext=True,
)
# prepend "background" to node labels to satisfy NiftiLabelsMasker
# The background "label" won't be present in the output timeseries.
masker_labels = ["background"] + node_labels

# Before anything, we need to measure coverage
atlas_img = nb.load(atlas)
atlas_data = atlas_img.get_fdata()
atlas_data_bin = (atlas_data > 0).astype(np.float32)
atlas_img_bin = nb.Nifti1Image(atlas_data_bin, atlas_img.affine, atlas_img.header)
atlas_img_bin = nb.Nifti1Image(
(atlas_img.get_fdata() > 0).astype(np.uint8),
atlas_img.affine,
atlas_img.header,
)

sum_masker_masked = NiftiLabelsMasker(
labels_img=atlas,
labels=node_labels,
labels_img=atlas_img,
labels=masker_labels,
background_label=0,
mask_img=mask,
smoothing_fwhm=None,
standardize=False,
strategy="sum",
resampling_target=None, # they should be in the same space/resolution already
keep_masked_labels=True,
)
sum_masker_unmasked = NiftiLabelsMasker(
labels_img=atlas,
labels=node_labels,
labels_img=atlas_img,
labels=masker_labels,
background_label=0,
smoothing_fwhm=None,
standardize=False,
strategy="sum",
resampling_target=None, # they should be in the same space/resolution already
keep_masked_labels=True,
)
n_voxels_in_masked_parcels = sum_masker_masked.fit_transform(atlas_img_bin)
n_voxels_in_parcels = sum_masker_unmasked.fit_transform(atlas_img_bin)
parcel_coverage = np.squeeze(n_voxels_in_masked_parcels / n_voxels_in_parcels)
coverage_thresholded = parcel_coverage < min_coverage
del sum_masker_masked, sum_masker_unmasked, n_voxels_in_masked_parcels, n_voxels_in_parcels

n_nodes = len(node_labels)
n_found_nodes = coverage_thresholded.size
Expand All @@ -122,48 +108,48 @@ def _run_interface(self, runtime):
)

if n_found_nodes != n_nodes:
config.loggers.interface.warning(
LOGGER.warning(
f"{n_nodes - n_found_nodes}/{n_nodes} of parcels not found in atlas file."
)

if n_bad_nodes:
config.loggers.interface.warning(
f"{n_bad_nodes}/{n_nodes} of parcels have 0% coverage."
)
LOGGER.warning(f"{n_bad_nodes}/{n_nodes} of parcels have 0% coverage.")

if n_poor_parcels:
config.loggers.interface.warning(
LOGGER.warning(
f"{n_poor_parcels}/{n_nodes} of parcels have <50% coverage. "
"These parcels' time series will be replaced with zeros."
)

if n_partial_parcels:
config.loggers.interface.warning(
LOGGER.warning(
f"{n_partial_parcels}/{n_nodes} of parcels have at least one uncovered "
"voxel, but have enough good voxels to be useable. "
"The bad voxels will be ignored and the parcels' time series will be "
"calculated from the remaining voxels."
)

masker = NiftiLabelsMasker(
labels_img=atlas,
labels=node_labels,
labels_img=atlas_img,
labels=masker_labels,
background_label=0,
mask_img=mask,
smoothing_fwhm=None,
standardize=False,
resampling_target=None, # they should be in the same space/resolution already
keep_masked_labels=True,
)

# Use nilearn for time_series
timeseries_arr = masker.fit_transform(in_file)
# Use nilearn to parcellate the file
timeseries_arr = masker.fit_transform(self.inputs.in_file)
assert timeseries_arr.shape[1] == n_found_nodes
masker_labels = masker.labels_[:]
del masker

# Apply the coverage mask
timeseries_arr[:, coverage_thresholded] = np.nan

# Region indices in the atlas may not be sequential, so we map them to sequential ints.
seq_mapper = {idx: i for i, idx in enumerate(node_labels_df.index.tolist())}
seq_mapper = {idx: i for i, idx in enumerate(node_labels_df["sanitized_index"].tolist())}

if n_found_nodes != n_nodes: # parcels lost by warping/downsampling atlas
# Fill in any missing nodes in the timeseries array with NaNs.
Expand All @@ -173,24 +159,72 @@ def _run_interface(self, runtime):
dtype=timeseries_arr.dtype,
)
for col in range(timeseries_arr.shape[1]):
label_col = seq_mapper[masker.labels_[col]]
label_col = seq_mapper[masker_labels[col]]
new_timeseries_arr[:, label_col] = timeseries_arr[:, col]

timeseries_arr = new_timeseries_arr
del new_timeseries_arr

# Fill in any missing nodes in the coverage array with zero.
new_parcel_coverage = np.zeros(n_nodes, dtype=parcel_coverage.dtype)
for row in range(parcel_coverage.shape[0]):
label_row = seq_mapper[masker.labels_[row]]
label_row = seq_mapper[masker_labels[row]]
new_parcel_coverage[label_row] = parcel_coverage[row]

parcel_coverage = new_parcel_coverage
del new_parcel_coverage

# The time series file is tab-delimited, with node names included in the first row.
self._results["timeseries"] = fname_presuffix(
"timeseries.tsv",
newpath=runtime.cwd,
use_ext=True,
)
timeseries_df = pd.DataFrame(data=timeseries_arr, columns=node_labels)
coverage_df = pd.DataFrame(data=parcel_coverage, index=node_labels, columns=["coverage"])

timeseries_df.to_csv(self._results["timeseries"], sep="\t", na_rep="n/a", index=False)
coverage_df.to_csv(self._results["coverage"], sep="\t", index_label="Node")

# Save out the coverage tsv
coverage_df = pd.DataFrame(
data=parcel_coverage.astype(np.float32),
index=node_labels,
columns=["coverage"],
)
self._results["coverage"] = fname_presuffix(
"coverage.tsv",
newpath=runtime.cwd,
use_ext=True,
)
coverage_df.to_csv(self._results["coverage"], sep="\t", na_rep="n/a", index_label="Node")

return runtime


def _sanitize_nifti_atlas(atlas, df):
atlas_img = nb.load(atlas)
atlas_data = atlas_img.get_fdata()
atlas_data = atlas_data.astype(np.int16)

# Check that all labels in the DataFrame are present in the NIfTI file, and vice versa.
if 0 in df.index:
df = df.drop(index=[0])

df.sort_index(inplace=True) # ensure index is in order
expected_values = df.index.values

found_values = np.unique(atlas_data)
found_values = found_values[found_values != 0] # drop the background value
if not np.all(np.isin(found_values, expected_values)):
raise ValueError("Atlas file contains values that are not present in the DataFrame.")

# Map the labels in the DataFrame to sequential values.
label_mapper = {value: i + 1 for i, value in enumerate(expected_values)}
df["sanitized_index"] = [label_mapper[i] for i in df.index.values]

# Map the values in the atlas image to sequential values.
new_atlas_data = np.zeros(atlas_data.shape, dtype=np.int16)
for old_value, new_value in label_mapper.items():
new_atlas_data[atlas_data == old_value] = new_value

new_atlas_img = nb.Nifti1Image(new_atlas_data, atlas_img.affine, atlas_img.header)

return new_atlas_img, df
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies = [
"looseversion",
"networkx ~= 3.2.1", # nipype needs networkx, but 3+ isn"t compatible with nipype 1.8.5
"nibabel >= 4.0.1",
"nilearn ~= 0.10.3",
"nipype >= 1.8.5",
"nitransforms >= 21.0.0",
"niworkflows ~= 1.10.0",
Expand Down

0 comments on commit 12e1580

Please sign in to comment.