From 2b0c23c4565bf78a695a07a45dc450db95947f19 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 27 Aug 2024 16:50:15 +0200 Subject: [PATCH] (wip) sync pipeline and use SI 0.101.0 --- pyproject.toml | 6 +- .../curation/curation.py | 8 +-- src/spikeinterface_pipelines/pipeline.py | 4 +- .../postprocessing/postprocessing.py | 4 +- .../preprocessing/params.py | 20 ++++++ .../preprocessing/preprocessing.py | 71 +++++++++++++++---- .../spikesorting/params.py | 6 +- .../spikesorting/spikesorting.py | 2 +- .../visualization/visualization.py | 6 +- 9 files changed, 97 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4c3267f..6054ab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,10 +8,10 @@ authors = [ { name = "Jeremy Magland", email = "jmagland@flatironinstitute.org" }, { name = "Luiz Tauffer", email = "luiz.tauffer@catalystneuro.com" }, ] -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [ - "spikeinterface[full,widgets]>=0.100.0", - "neo>=0.12.0", + "spikeinterface[full,widgets]>=0.101.0", + "neo>=0.13.0", "pydantic>=2.4.2", "sortingview>=0.13.1", "kachery_cloud>=0.4.7", diff --git a/src/spikeinterface_pipelines/curation/curation.py b/src/spikeinterface_pipelines/curation/curation.py index 5c60f06..e10d699 100644 --- a/src/spikeinterface_pipelines/curation/curation.py +++ b/src/spikeinterface_pipelines/curation/curation.py @@ -1,4 +1,4 @@ -from __future__ import annotations + from pathlib import Path import re @@ -12,7 +12,7 @@ def curate( - waveform_extractor: si.WaveformExtractor, + waveform_extractor: si.SortingAnalyzer, curation_params: CurationParams = CurationParams(), scratch_folder: Path = Path("./scratch/"), results_folder: Path = Path("./results/curation/"), @@ -23,7 +23,7 @@ def curate( Parameters ---------- - waveform_extractor: si.WaveformExtractor + waveform_extractor: si.SortingAnalyzer The input waveform extractor curation_params: CurationParams Curation parameters @@ -39,7 +39,7 @@ def curate( """ # get quality metrics if not waveform_extractor.has_extension("quality_metrics"): - logger.info(f"[Curation] \tQuality metrics not found in WaveformExtractor.") + logger.info(f"[Curation] \tQuality metrics not found in SortingAnalyzer.") return qm = waveform_extractor.load_extension("quality_metrics").get_data() diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 15ef910..838623d 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -1,4 +1,4 @@ -from __future__ import annotations + from pathlib import Path from typing import Tuple import spikeinterface as si @@ -47,7 +47,7 @@ def run_pipeline( ) -> Tuple[ si.BaseRecording | None, si.BaseSorting | None, - si.WaveformExtractor | None, + si.SortingAnalyzer | None, si.BaseSorting | None, dict | None, ]: diff --git a/src/spikeinterface_pipelines/postprocessing/postprocessing.py b/src/spikeinterface_pipelines/postprocessing/postprocessing.py index cb5ef26..f2e3699 100644 --- a/src/spikeinterface_pipelines/postprocessing/postprocessing.py +++ b/src/spikeinterface_pipelines/postprocessing/postprocessing.py @@ -20,7 +20,7 @@ def postprocess( postprocessing_params: PostprocessingParams = PostprocessingParams(), scratch_folder: Path = Path("./scratch/"), results_folder: Path = Path("./results/postprocessing/"), -) -> si.WaveformExtractor: +) -> si.SortingAnalyzer: """ Postprocess preprocessed and spike sorting output @@ -37,7 +37,7 @@ def postprocess( Returns ------- - si.WaveformExtractor + si.SortingAnalyzer The waveform extractor """ diff --git a/src/spikeinterface_pipelines/preprocessing/params.py b/src/spikeinterface_pipelines/preprocessing/params.py index 7edeb39..2eaf545 100644 --- a/src/spikeinterface_pipelines/preprocessing/params.py +++ b/src/spikeinterface_pipelines/preprocessing/params.py @@ -9,10 +9,18 @@ class PreprocessingStrategy(str, Enum): destripe = "destripe" +class FilterType(str, Enum): + highpass = "highpass" + bandpass = "bandpass" + class HighpassFilter(BaseModel): freq_min: float = Field(default=300.0, description="Minimum frequency for the highpass filter") margin_ms: float = Field(default=5.0, description="Margin in milliseconds") +class BandpassFilter(BaseModel): + freq_min: float = Field(default=300.0, description="Minimum frequency for the highpass filter") + freq_max: float = Field(default=6000.0, description="Maximum frequency for the highpass filter") + margin_ms: float = Field(default=5.0, description="Margin in milliseconds") class PhaseShift(BaseModel): margin_ms: float = Field(default=100.0, description="Margin in milliseconds for phase shift") @@ -145,6 +153,7 @@ class MCInterpolateMotionKwargs(BaseModel): ) +# TODO: add dredge and use dredge_Fast as default class MCNonrigidAccurate(BaseModel): detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="") localize_peaks_kwargs: MCLocalizeMonopolarTriangulation = Field( @@ -165,6 +174,15 @@ class MCNonrigidFastAndAccurate(BaseModel): interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="") +class MCNonrigidFastAndAccurate(BaseModel): + detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="") + localize_peaks_kwargs: MCLocalizeGridConvolution = Field(default=MCLocalizeGridConvolution(), description="") + estimate_motion_kwargs: MCEstimateMotionDecentralized = Field( + default=MCEstimateMotionDecentralized(), description="" + ) + interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="") + + class MCRigidFast(BaseModel): detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="") localize_peaks_kwargs: MCLocalizeCenterOfMass = Field(default=MCLocalizeCenterOfMass(), description="") @@ -208,7 +226,9 @@ class MotionCorrection(BaseModel): # Preprocessing params --------------------------------------------------------------- class PreprocessingParams(BaseModel): preprocessing_strategy: PreprocessingStrategy = Field(default="cmr", description="Strategy for preprocessing") + filter_type: FilterType = Field(default="highpass", description="Type of filter") highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter") + bandpass_filter: BandpassFilter = Field(default=BandpassFilter(), description="Bandpass filter") phase_shift: PhaseShift = Field(default=PhaseShift(), description="Phase shift") common_reference: CommonReference = Field(default=CommonReference(), description="Common reference") highpass_spatial_filter: HighpassSpatialFilter = Field( diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index 4bfe791..b525b6e 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -56,11 +56,16 @@ def preprocess( logger.info("[Preprocessing] \tSkipping phase shift: 'inter_sample_shift' property not found") # Highpass filter - recording_hp_full = spre.highpass_filter(recording, **preprocessing_params.highpass_filter.model_dump()) + if preprocessing_params.filter_type == "highpass": + logger.info("[Preprocessing] \tHighpass filter") + recording_filt_full = spre.highpass_filter(recording, **preprocessing_params.highpass_filter.model_dump()) + else: + logger.info("[Preprocessing] \tBandpass filter") + recording_filt_full = spre.bandpass_filter(recording, **preprocessing_params.bandpass_filter.model_dump()) # Detect and remove bad channels _, channel_labels = spre.detect_bad_channels( - recording_hp_full, **preprocessing_params.detect_bad_channels.model_dump() + recording_filt_full, **preprocessing_params.detect_bad_channels.model_dump() ) dead_channel_mask = channel_labels == "dead" noise_channel_mask = channel_labels == "noise" @@ -68,9 +73,9 @@ def preprocess( logger.info( f"[Preprocessing] \tBad channel detection found: {np.sum(dead_channel_mask)} dead, {np.sum(noise_channel_mask)} noise, {np.sum(out_channel_mask)} out channels" ) - dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask] - noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask] - out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] + dead_channel_ids = recording_filt_full.channel_ids[dead_channel_mask] + noise_channel_ids = recording_filt_full.channel_ids[noise_channel_mask] + out_channel_ids = recording_filt_full.channel_ids[out_channel_mask] all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids)) max_bad_channel_fraction_to_remove = preprocessing_params.max_bad_channel_fraction_to_remove @@ -78,14 +83,15 @@ def preprocess( logger.info( f"[Preprocessing] \tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " ) - logger.info("[Preprocessing] \tSkipping further processing for this recording.") - return recording_hp_full + if preprocessing_params.remove_bad_channels: + logger.info("[Preprocessing] \tSkipping further processing for this recording.") + return recording_filt_full if preprocessing_params.remove_out_channels: logger.info(f"[Preprocessing] \tRemoving {len(out_channel_ids)} out channels") - recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) + recording_rm_out = recording_filt_full.remove_channels(out_channel_ids) else: - recording_rm_out = recording_hp_full + recording_rm_out = recording_filt_full bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) @@ -95,10 +101,17 @@ def preprocess( recording_rm_out, **preprocessing_params.common_reference.model_dump() ) else: - recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) - recording_processed = spre.highpass_spatial_filter( - recording_interp, **preprocessing_params.highpass_spatial_filter.model_dump() - ) + # protection against short probes + try: + recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) + recording_processed = spre.highpass_spatial_filter( + recording_interp, **preprocessing_params.highpass_spatial_filter.model_dump() + ) + except: + recording_processed = recording_rm_out + logger.warning( + "[Preprocessing] \tInterpolation failed. Skipping interpolation and highpass spatial filter." + ) if preprocessing_params.remove_bad_channels: logger.info( @@ -114,8 +127,19 @@ def preprocess( ) logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}") motion_folder = results_folder / "motion_correction" + # interpolation requires float + recording_processed_f = spre.astype(recording_processed, "float32") + + # fix for multi-segment + concat_motion = False + if recording_processed.get_num_segments() > 1: + recording_processed_c = si.concatenate_recordings([recording_processed_f]) + concat_motion = True + else: + recording_processed_c = recording_processed_f + recording_corrected = spre.correct_motion( - recording_processed, + recording_processed_c, preset=preset, folder=motion_folder, verbose=False, @@ -124,6 +148,25 @@ def preprocess( estimate_motion_kwargs=motion_correction_kwargs.estimate_motion_kwargs.model_dump(), interpolate_motion_kwargs=motion_correction_kwargs.interpolate_motion_kwargs.model_dump(), ) + + # split segments back + if concat_motion: + rec_corrected_list = [] + for segment_index in range(recording_processed.get_num_segments()): + num_samples = recording_processed.get_num_samples(segment_index) + if segment_index == 0: + start_frame = 0 + else: + start_frame = recording_processed.get_num_samples(segment_index - 1) + end_frame = start_frame + num_samples + rec_split_corrected = recording_corrected.frame_slice( + start_frame=start_frame, + end_frame=end_frame + ) + rec_corrected_list.append(rec_split_corrected) + # append all segments + recording_corrected = si.append_recordings(rec_corrected_list) + if preprocessing_params.motion_correction.strategy == "apply": logger.info("[Preprocessing] \tApplying motion correction") recording_processed = recording_corrected diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index 48a04a4..12a589d 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -6,9 +6,10 @@ class SorterName(str, Enum): kilosort25 = "kilosort2_5" kilosort3 = "kilosort3" + kilosort4 = "kilosort4" mountainsort5 = "mountainsort5" - # spykingcircus2 = "spykingcircus2" ironclust = "ironclust" + # spykingcircus2 = "spykingcircus2" class Kilosort25Model(BaseModel): @@ -57,6 +58,9 @@ class Kilosort3Model(BaseModel): model_config = ConfigDict(extra="forbid") pass +class Kilosort4Model(BaseModel): + model_config = ConfigDict(extra="forbid") + pass class MountainSort5Model(BaseModel): model_config = ConfigDict(extra="forbid") diff --git a/src/spikeinterface_pipelines/spikesorting/spikesorting.py b/src/spikeinterface_pipelines/spikesorting/spikesorting.py index c58ba26..2fb0160 100644 --- a/src/spikeinterface_pipelines/spikesorting/spikesorting.py +++ b/src/spikeinterface_pipelines/spikesorting/spikesorting.py @@ -1,4 +1,4 @@ -from __future__ import annotations + import shutil import numpy as np from pathlib import Path diff --git a/src/spikeinterface_pipelines/visualization/visualization.py b/src/spikeinterface_pipelines/visualization/visualization.py index 07995c1..13d929a 100644 --- a/src/spikeinterface_pipelines/visualization/visualization.py +++ b/src/spikeinterface_pipelines/visualization/visualization.py @@ -1,4 +1,4 @@ -from __future__ import annotations + from pathlib import Path import numpy as np @@ -25,7 +25,7 @@ def visualize( recording: si.BaseRecording, sorting_curated: si.BaseSorting | None = None, - waveform_extractor: si.WaveformExtractor | None = None, + waveform_extractor: si.SortingAnalyzer | None = None, visualization_params: VisualizationParams = VisualizationParams(), scratch_folder: Path = Path("./scratch/"), results_folder: Path = Path("./results/visualization/"), @@ -39,7 +39,7 @@ def visualize( The input processed recording sorting_curated: si.BaseSorting | None The input curated sorting. If None, only the recording visualization will be generated. - waveform_extractor: si.WaveformExtractor | None + waveform_extractor: si.SortingAnalyzer | None The input waveform extractor from postprocessing. If None, only the recording visualization will be generated. visualization_params: VisualizationParams The visualization parameters