diff --git a/nireports/interfaces/mosaic.py b/nireports/interfaces/mosaic.py index 4dbae397..d4859db0 100644 --- a/nireports/interfaces/mosaic.py +++ b/nireports/interfaces/mosaic.py @@ -112,6 +112,29 @@ def _run_interface(self, runtime): class _PlotMosaicInputSpec(_PlotBaseInputSpec): bbox_mask_file = File(exists=True, desc="brain mask") only_noise = traits.Bool(False, desc="plot only noise") + main_view = traits.Enum( + "axial", + "sagittal", + "coronal", + default="axial", + usedefault=True, + ) + addon_view1 = traits.Enum( + "sagittal", + "axial", + "coronal", + None, + default="sagittal", + usedefault=True, + ) + addon_view2 = traits.Enum( + None, + "axial", + "sagittal", + "coronal", + default=None, + usedefault=True, + ) class _PlotMosaicOutputSpec(TraitedSpec): @@ -144,6 +167,11 @@ def _run_interface(self, runtime): bbox_mask_file=mask, cmap=self.inputs.cmap, annotate=self.inputs.annotate, + views=( + self.inputs.main_view, + self.inputs.addon_view1, + self.inputs.addon_view2, + ) ) self._results["out_file"] = str((Path(runtime.cwd) / self.inputs.out_file).resolve()) return runtime diff --git a/nireports/reportlets/mosaic.py b/nireports/reportlets/mosaic.py index 05df2444..be52f901 100644 --- a/nireports/reportlets/mosaic.py +++ b/nireports/reportlets/mosaic.py @@ -495,9 +495,18 @@ def plot_mosaic( plot_sagittal=True, fig=None, zmax=128, + views=("axial", "sagittal", None), ): + """Plot a mosaic of 2D cuts.""" - if isinstance(img, (str, bytes)): + VIEW_AXES_ORDER = (2, 1, 0) + + # Error with inconsistent views input + print(views) + if views[0] is None or ((views[1] is None) and (views[2] is not None)): + raise RuntimeError("First view must not be None") + + if not hasattr(img, "shape"): nii = nb.as_closest_canonical(nb.load(img)) img_data = nii.get_fdata() zooms = nii.header.get_zooms() @@ -506,20 +515,43 @@ def plot_mosaic( zooms = [1.0, 1.0, 1.0] out_file = "mosaic.svg" + if views[1] is None and plot_sagittal: + views = (views[0], "sagittal", None) + + # Select the axis through which we cut the planes + axes_order = [ + ["sagittal", "coronal", "axial"].index(views[0]), + ["sagittal", "coronal", "axial"].index(views[1] or "sagittal"), + ] + + # If 3D, complete last axis + if img_data.ndim > 3: + raise RuntimeError("Dataset has more than three dimensions") + elif img_data.ndim == 3: + axes_order += list(set(range(3)) - set(axes_order)) + # Remove extra dimensions - img_data = np.squeeze(img_data) + img_data = np.moveaxis( + np.squeeze(img_data), + axes_order, + VIEW_AXES_ORDER[:len(axes_order)], + ) - if img_data.shape[2] > zmax and bbox_mask_file is None: + # Create mask for bounding box + if bbox_mask_file is not None: + bbox_data = np.moveaxis( + nb.as_closest_canonical(nb.load(bbox_mask_file)).get_fdata(), + axes_order, + VIEW_AXES_ORDER[:len(axes_order)], + ) + img_data = _bbox(img_data, bbox_data) + elif img_data.shape[-1] > zmax: lowthres = np.percentile(img_data, 5) mask_file = np.ones_like(img_data) mask_file[img_data <= lowthres] = 0 img_data = _bbox(img_data, mask_file) - if bbox_mask_file is not None: - bbox_data = nb.as_closest_canonical(nb.load(bbox_mask_file)).get_fdata() - img_data = _bbox(img_data, bbox_data) - - z_vals = np.array(list(range(0, img_data.shape[2]))) + z_vals = np.arange(0, img_data.shape[-1], dtype=int) # Reduce the number of slices shown if len(z_vals) > zmax: @@ -539,12 +571,15 @@ def plot_mosaic( z_vals = z_vals[::2] n_images = len(z_vals) - nrows = math.ceil(n_images / ncols) - if plot_sagittal: - nrows += 1 + extra_rows = sum(bool(v) for v in views[1:]) + nrows = math.ceil(n_images / ncols) + extra_rows if overlay_mask: - overlay_data = nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata() + overlay_data = np.moveaxis( + nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata(), + axes_order, + VIEW_AXES_ORDER[:len(axes_order)], + ) # create figures if fig is None: @@ -556,20 +591,22 @@ def plot_mosaic( if not vmax: vmax = est_vmax + slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[0]] naxis = 1 for z_val in z_vals: ax = fig.add_subplot(nrows, ncols, naxis) if overlay_mask: ax.set_rasterized(True) + plot_slice( img_data[:, :, z_val], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, - spacing=zooms[:2], - label="%d" % z_val, + spacing=slice_spacing, + label=f"{z_val:d}", annotate=annotate, ) @@ -586,31 +623,49 @@ def plot_mosaic( vmax=1, cmap=msk_cmap, ax=ax, - spacing=zooms[:2], + spacing=slice_spacing, ) naxis += 1 - if plot_sagittal: - naxis = ncols * (nrows - 1) + 1 + if views[1] is not None: + slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[1]] + naxis = ncols * (nrows - extra_rows) + 1 + step = max(int(img_data.shape[-2] / (ncols + 1)), 1) + start = step + stop = img_data.shape[-2] - step + + for slice_val in list(range(start, stop, step))[:ncols]: + ax = fig.add_subplot(nrows, ncols, naxis) - step = int(img_data.shape[0] / (ncols + 1)) + plot_slice( + img_data[:, slice_val, :], + vmin=vmin, + vmax=vmax, + cmap=cmap, + ax=ax, + label=f"{slice_val:d}", + spacing=slice_spacing, + ) + naxis += 1 + + if views[1] is not None and views[2] is not None: + slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[2]] + naxis = ncols * (nrows - extra_rows) + 1 + step = max(int(img_data.shape[0] / (ncols + 1)), 1) start = step stop = img_data.shape[0] - step - if step == 0: - step = 1 - - for x_val in list(range(start, stop, step))[:ncols]: + for slice_val in list(range(start, stop, step))[:ncols]: ax = fig.add_subplot(nrows, ncols, naxis) plot_slice( - img_data[x_val, ...], + img_data[slice_val, ...], vmin=vmin, vmax=vmax, cmap=cmap, ax=ax, - label="%d" % x_val, - spacing=[zooms[0], zooms[2]], + label=f"{slice_val:d}", + spacing=slice_spacing, ) naxis += 1 diff --git a/nireports/tests/test_reportlets.py b/nireports/tests/test_reportlets.py index 024172e4..e86f9a05 100644 --- a/nireports/tests/test_reportlets.py +++ b/nireports/tests/test_reportlets.py @@ -23,6 +23,8 @@ """Test reportlets module.""" import os from pathlib import Path +from itertools import permutations +from functools import partial import nibabel as nb import numpy as np @@ -32,6 +34,7 @@ from nireports.reportlets.modality.func import fMRIPlot from nireports.reportlets.nuisance import plot_carpet from nireports.reportlets.surface import cifti_surfaces_plot +from nireports.reportlets.mosaic import plot_mosaic from nireports.reportlets.xca import compcor_variance_plot, plot_melodic_components from nireports.tools.timeseries import cifti_timeseries as _cifti_timeseries from nireports.tools.timeseries import get_tr as _get_tr @@ -321,3 +324,37 @@ def test_nifti_carpetplot(tmp_path, testdata_path, outdir): output_file=outdir / "carpetplot_nifti.svg" if outdir is not None else None, drop_trs=0, ) + + +_views = ( + list(permutations(("axial", "sagittal", "coronal", None), 3)) + + [(v, None, None) for v in ("axial", "sagittal", "coronal")] +) + + +@pytest.mark.parametrize("views", _views) +@pytest.mark.parametrize("plot_sagittal", (True, False)) +@pytest.mark.parametrize("only_plot_noise", (True, False)) +def test_mriqc_plot_mosaic(tmp_path, testdata_path, outdir, views, plot_sagittal, only_plot_noise): + """Exercise the generation of mosaics.""" + + out_file = ( + outdir / f"mosaic_{'_'.join(views)}_{plot_sagittal:d}_{only_plot_noise:d}.svg" + ) if outdir is not None else None + + testfunc = partial( + plot_mosaic, + testdata_path / "testSpatialNormalizationRPTMovingWarpedImage.nii.gz", + views=views, + out_file=out_file, + title=( + f"A mosaic plotting example: views={views}, plot_sagittal={plot_sagittal}", + f"only_plot_noise={only_plot_noise}" + ), + ) + + if views[0] is None or ((views[1] is None) and (views[2] is not None)): + with pytest.raises(RuntimeError): + testfunc() + else: + testfunc()