diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f4bfe33..ce044cd5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ### Features * Added a seed to dummy generators [#361](https://github.com/catalystneuro/roiextractors/pull/361) * Added depth_slice for VolumetricImagingExtractors [PR #363](https://github.com/catalystneuro/roiextractors/pull/363) +* Added MinianSegmentationExtractor: [PR #368](https://github.com/catalystneuro/roiextractors/pull/368) ### Fixes * Added specific error message for single-frame scanimage data [PR #360](https://github.com/catalystneuro/roiextractors/pull/360) diff --git a/src/roiextractors/extractorlist.py b/src/roiextractors/extractorlist.py index f93c6c29..c1a2e897 100644 --- a/src/roiextractors/extractorlist.py +++ b/src/roiextractors/extractorlist.py @@ -28,6 +28,7 @@ from .extractors.inscopixextractors import InscopixImagingExtractor from .extractors.memmapextractors import NumpyMemmapImagingExtractor from .extractors.memmapextractors import MemmapImagingExtractor +from .extractors.minian import MinianSegmentationExtractor from .extractors.miniscopeimagingextractor import MiniscopeImagingExtractor from .multisegmentationextractor import MultiSegmentationExtractor from .multiimagingextractor import MultiImagingExtractor @@ -62,6 +63,7 @@ ExtractSegmentationExtractor, SimaSegmentationExtractor, CaimanSegmentationExtractor, + MinianSegmentationExtractor, ] imaging_extractor_dict = {imaging_class.extractor_name: imaging_class for imaging_class in imaging_extractor_full_list} diff --git a/src/roiextractors/extractors/minian/__init__.py b/src/roiextractors/extractors/minian/__init__.py new file mode 100644 index 00000000..4839a679 --- /dev/null +++ b/src/roiextractors/extractors/minian/__init__.py @@ -0,0 +1,14 @@ +"""A Segmentation Extractor for Minian. + +Modules +------- +miniansegmentationextractor + A Segmentation Extractor for Minian. + +Classes +------- +MinianSegmentationExtractor + A class for extracting segmentation from Minian output. +""" + +from .miniansegmentationextractor import MinianSegmentationExtractor diff --git a/src/roiextractors/extractors/minian/miniansegmentationextractor.py b/src/roiextractors/extractors/minian/miniansegmentationextractor.py new file mode 100644 index 00000000..6c1e59c0 --- /dev/null +++ b/src/roiextractors/extractors/minian/miniansegmentationextractor.py @@ -0,0 +1,204 @@ +"""A SegmentationExtractor for Minian. + +Classes +------- +MinianSegmentationExtractor + A class for extracting segmentation from Minian output. +""" + +from pathlib import Path + +import zarr +import warnings +import numpy as np +import pandas as pd + +from ...extraction_tools import PathType +from ...segmentationextractor import SegmentationExtractor + + +class MinianSegmentationExtractor(SegmentationExtractor): + """A SegmentationExtractor for Minian. + + This class inherits from the SegmentationExtractor class, having all + its functionality specifically applied to the dataset output from + the 'Minian' ROI segmentation method. + + Users can extract key information such as ROI traces, image masks, + and timestamps from the output of the Minian pipeline. + + Key features: + - Extracts fluorescence traces (denoised, baseline, neuropil, deconvolved) for each ROI. + - Retrieves ROI masks and background components. + - Provides access to timestamps corresponding to calcium traces. + - Retrieves maximum projection image. + + Parameters + ---------- + folder_path: str + Path to the folder containing Minian .zarr output files. + + """ + + extractor_name = "MinianSegmentation" + is_writable = True + mode = "file" + + def __init__(self, folder_path: PathType): + """Initialize a MinianSegmentationExtractor instance. + + Parameters + ---------- + folder_path: str + The location of the folder containing minian .zarr output. + """ + SegmentationExtractor.__init__(self) + self.folder_path = folder_path + self._roi_response_denoised = self._read_trace_from_zarr_filed(field="C") + self._roi_response_baseline = self._read_trace_from_zarr_filed(field="b0") + self._roi_response_neuropil = self._read_trace_from_zarr_filed(field="f") + self._roi_response_deconvolved = self._read_trace_from_zarr_filed(field="S") + self._image_maximum_projection = np.array(self._read_zarr_group("/max_proj.zarr/max_proj")) + self._image_masks = self._read_roi_image_mask_from_zarr_filed() + self._background_image_masks = self._read_background_image_mask_from_zarr_filed() + self._times = self._read_timestamps_from_csv() + + def _read_zarr_group(self, zarr_group=""): + """Read the zarr. + + Returns + ------- + zarr.open + The zarr object specified by self.folder_path. + """ + if zarr_group not in zarr.open(self.folder_path, mode="r"): + warnings.warn(f"Group '{zarr_group}' not found in the Zarr store.", UserWarning) + return None + else: + return zarr.open(str(self.folder_path) + f"/{zarr_group}", "r") + + def _read_roi_image_mask_from_zarr_filed(self): + """Read the image masks from the zarr output. + + Returns + ------- + image_masks: numpy.ndarray + The image masks for each ROI. + """ + dataset = self._read_zarr_group("/A.zarr") + if dataset is None or "A" not in dataset: + return None + else: + return np.transpose(dataset["A"], (1, 2, 0)) + + def _read_background_image_mask_from_zarr_filed(self): + """Read the image masks from the zarr output. + + Returns + ------- + image_masks: numpy.ndarray + The image masks for each background components. + """ + dataset = self._read_zarr_group("/b.zarr") + if dataset is None or "b" not in dataset: + return None + else: + return np.expand_dims(dataset["b"], axis=2) + + def _read_trace_from_zarr_filed(self, field): + """Read the traces specified by the field from the zarr object. + + Parameters + ---------- + field: str + The field to read from the zarr object. + + Returns + ------- + trace: numpy.ndarray + The traces specified by the field. + """ + dataset = self._read_zarr_group(f"/{field}.zarr") + + if dataset is None or field not in dataset: + return None + elif dataset[field].ndim == 2: + return np.transpose(dataset[field]) + elif dataset[field].ndim == 1: + return np.expand_dims(dataset[field], axis=1) + + def _read_timestamps_from_csv(self): + """Extract timestamps corresponding to frame numbers of the stored denoised trace + + Returns + ------- + np.ndarray + The timestamps of the denoised trace. + """ + csv_file = self.folder_path / "timeStamps.csv" + df = pd.read_csv(csv_file) + frame_numbers = self._read_zarr_group("/C.zarr/frame") + filtered_df = df[df["Frame Number"].isin(frame_numbers)] * 1e-3 + + return filtered_df["Time Stamp (ms)"].to_numpy() + + def get_image_size(self): + dataset = self._read_zarr_group("/A.zarr") + height = dataset["height"].shape[0] + width = dataset["width"].shape[0] + return (height, width) + + def get_accepted_list(self) -> list: + """Get a list of accepted ROI ids. + + Returns + ------- + accepted_list: list + List of accepted ROI ids. + """ + return list(range(self.get_num_rois())) + + def get_rejected_list(self) -> list: + """Get a list of rejected ROI ids. + + Returns + ------- + rejected_list: list + List of rejected ROI ids. + """ + return list() + + def get_roi_ids(self) -> list: + dataset = self._read_zarr_group("/A.zarr") + return list(dataset["unit_id"]) + + def get_traces_dict(self) -> dict: + """Get traces as a dictionary with key as the name of the ROiResponseSeries. + + Returns + ------- + _roi_response_dict: dict + dictionary with key, values representing different types of RoiResponseSeries: + Raw Fluorescence, DeltaFOverF, Denoised, Neuropil, Deconvolved, Background, etc. + """ + return dict( + denoised=self._roi_response_denoised, + baseline=self._roi_response_baseline, + neuropil=self._roi_response_neuropil, + deconvolved=self._roi_response_deconvolved, + ) + + def get_images_dict(self) -> dict: + """Get images as a dictionary with key as the name of the ROIResponseSeries. + + Returns + ------- + _roi_image_dict: dict + dictionary with key, values representing different types of Images used in segmentation: + Mean, Correlation image + """ + return dict( + mean=self._image_mean, + correlation=self._image_correlation, + maximum_projection=self._image_maximum_projection, + ) diff --git a/tests/test_miniansegmentationextractor.py b/tests/test_miniansegmentationextractor.py new file mode 100644 index 00000000..14e349c3 --- /dev/null +++ b/tests/test_miniansegmentationextractor.py @@ -0,0 +1,114 @@ +import shutil +import tempfile +from pathlib import Path + +import numpy as np +import zarr +from hdmf.testing import TestCase +from numpy.testing import assert_array_equal + +from roiextractors import MinianSegmentationExtractor +from tests.setup_paths import OPHYS_DATA_PATH + + +class TestMinianSegmentationExtractor(TestCase): + @classmethod + def setUpClass(cls): + folder_path = str(OPHYS_DATA_PATH / "segmentation_datasets" / "minian") + + cls.folder_path = Path(folder_path) + extractor = MinianSegmentationExtractor(folder_path=cls.folder_path) + cls.extractor = extractor + + cls.test_dir = Path(tempfile.mkdtemp()) + + # denoised traces + dataset = zarr.open(folder_path + "/C.zarr") + cls.denoised_traces = np.transpose(dataset["C"]) + cls.num_frames = len(dataset["frame"][:]) + # deconvolved traces + dataset = zarr.open(folder_path + "/S.zarr") + cls.deconvolved_traces = np.transpose(dataset["S"]) + # baseline traces + dataset = zarr.open(folder_path + "/b0.zarr") + cls.baseline_traces = np.transpose(dataset["b0"]) + # neuropil trace + dataset = zarr.open(folder_path + "/f.zarr") + cls.neuropil_trace = np.expand_dims(dataset["f"], axis=1) + + # ROIs masks + dataset = zarr.open(folder_path + "/A.zarr") + cls.image_masks = np.transpose(dataset["A"], (1, 2, 0)) + cls.image_size = (dataset["height"].shape[0], dataset["width"].shape[0]) + cls.num_rois = dataset["unit_id"].shape[0] + # background mask + dataset = zarr.open(folder_path + "/b.zarr") + cls.background_image_mask = np.expand_dims(dataset["b"], axis=2) + # summary image: maximum projection + cls.maximum_projection_image = np.array(zarr.open(folder_path + "/max_proj.zarr/max_proj")) + + @classmethod + def tearDownClass(cls): + # remove the temporary directory and its contents + shutil.rmtree(cls.test_dir) + + def test_incomplete_extractor_load(self): + """Check extractor can be initialized when not all traces are available.""" + # temporary directory for testing assertion when some of the files are missing + folders_to_copy = [ + "A.zarr", + "C.zarr", + "b0.zarr", + "b.zarr", + "f.zarr", + "max_proj.zarr", + ".zgroup", + "timeStamps.csv", + ] + self.test_dir.mkdir(exist_ok=True) + + for folder in folders_to_copy: + src = Path(self.folder_path) / folder + dst = self.test_dir / folder + if src.is_dir(): + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + shutil.copy(src, dst) + + extractor = MinianSegmentationExtractor(folder_path=self.test_dir) + traces_dict = extractor.get_traces_dict() + self.assertEqual(traces_dict["deconvolved"], None) + + def test_image_size(self): + self.assertEqual(self.extractor.get_image_size(), self.image_size) + + def test_num_frames(self): + self.assertEqual(self.extractor.get_num_frames(), self.num_frames) + + def test_frame_to_time(self): + self.assertEqual(self.extractor.frame_to_time(frames=[0]), 0.328) + + def test_num_channels(self): + self.assertEqual(self.extractor.get_num_channels(), 1) + + def test_num_rois(self): + self.assertEqual(self.extractor.get_num_rois(), self.num_rois) + + def test_extractor_denoised_traces(self): + assert_array_equal(self.extractor.get_traces(name="denoised"), self.denoised_traces) + + def test_extractor_neuropil_trace(self): + assert_array_equal(self.extractor.get_traces(name="neuropil"), self.neuropil_trace) + + def test_extractor_image_masks(self): + """Test that the image masks are correctly extracted.""" + assert_array_equal(self.extractor.get_roi_image_masks(), self.image_masks) + + def test_extractor_background_image_masks(self): + """Test that the image masks are correctly extracted.""" + assert_array_equal(self.extractor.get_background_image_masks(), self.background_image_mask) + + def test_maximum_projection_image(self): + """Test that the mean image is correctly loaded from the extractor.""" + images_dict = self.extractor.get_images_dict() + assert_array_equal(images_dict["maximum_projection"], self.maximum_projection_image)