Skip to content

Commit

Permalink
(wip) sync pipeline and use SI 0.101.0
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Aug 27, 2024
1 parent 5e02a30 commit 2b0c23c
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 30 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ authors = [
{ name = "Jeremy Magland", email = "[email protected]" },
{ name = "Luiz Tauffer", email = "[email protected]" },
]
requires-python = ">=3.8"
requires-python = ">=3.9"
dependencies = [
"spikeinterface[full,widgets]>=0.100.0",
"neo>=0.12.0",
"spikeinterface[full,widgets]>=0.101.0",
"neo>=0.13.0",
"pydantic>=2.4.2",
"sortingview>=0.13.1",
"kachery_cloud>=0.4.7",
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface_pipelines/curation/curation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

from pathlib import Path
import re

Expand All @@ -12,7 +12,7 @@


def curate(
waveform_extractor: si.WaveformExtractor,
waveform_extractor: si.SortingAnalyzer,
curation_params: CurationParams = CurationParams(),
scratch_folder: Path = Path("./scratch/"),
results_folder: Path = Path("./results/curation/"),
Expand All @@ -23,7 +23,7 @@ def curate(
Parameters
----------
waveform_extractor: si.WaveformExtractor
waveform_extractor: si.SortingAnalyzer
The input waveform extractor
curation_params: CurationParams
Curation parameters
Expand All @@ -39,7 +39,7 @@ def curate(
"""
# get quality metrics
if not waveform_extractor.has_extension("quality_metrics"):
logger.info(f"[Curation] \tQuality metrics not found in WaveformExtractor.")
logger.info(f"[Curation] \tQuality metrics not found in SortingAnalyzer.")
return

qm = waveform_extractor.load_extension("quality_metrics").get_data()
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface_pipelines/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

from pathlib import Path
from typing import Tuple
import spikeinterface as si
Expand Down Expand Up @@ -47,7 +47,7 @@ def run_pipeline(
) -> Tuple[
si.BaseRecording | None,
si.BaseSorting | None,
si.WaveformExtractor | None,
si.SortingAnalyzer | None,
si.BaseSorting | None,
dict | None,
]:
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface_pipelines/postprocessing/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def postprocess(
postprocessing_params: PostprocessingParams = PostprocessingParams(),
scratch_folder: Path = Path("./scratch/"),
results_folder: Path = Path("./results/postprocessing/"),
) -> si.WaveformExtractor:
) -> si.SortingAnalyzer:
"""
Postprocess preprocessed and spike sorting output
Expand All @@ -37,7 +37,7 @@ def postprocess(
Returns
-------
si.WaveformExtractor
si.SortingAnalyzer
The waveform extractor
"""

Expand Down
20 changes: 20 additions & 0 deletions src/spikeinterface_pipelines/preprocessing/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@ class PreprocessingStrategy(str, Enum):
destripe = "destripe"


class FilterType(str, Enum):
highpass = "highpass"
bandpass = "bandpass"

class HighpassFilter(BaseModel):
freq_min: float = Field(default=300.0, description="Minimum frequency for the highpass filter")
margin_ms: float = Field(default=5.0, description="Margin in milliseconds")

class BandpassFilter(BaseModel):
freq_min: float = Field(default=300.0, description="Minimum frequency for the highpass filter")
freq_max: float = Field(default=6000.0, description="Maximum frequency for the highpass filter")
margin_ms: float = Field(default=5.0, description="Margin in milliseconds")

class PhaseShift(BaseModel):
margin_ms: float = Field(default=100.0, description="Margin in milliseconds for phase shift")
Expand Down Expand Up @@ -145,6 +153,7 @@ class MCInterpolateMotionKwargs(BaseModel):
)


# TODO: add dredge and use dredge_Fast as default
class MCNonrigidAccurate(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeMonopolarTriangulation = Field(
Expand All @@ -165,6 +174,15 @@ class MCNonrigidFastAndAccurate(BaseModel):
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="")


class MCNonrigidFastAndAccurate(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeGridConvolution = Field(default=MCLocalizeGridConvolution(), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(
default=MCEstimateMotionDecentralized(), description=""
)
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="")


class MCRigidFast(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeCenterOfMass = Field(default=MCLocalizeCenterOfMass(), description="")
Expand Down Expand Up @@ -208,7 +226,9 @@ class MotionCorrection(BaseModel):
# Preprocessing params ---------------------------------------------------------------
class PreprocessingParams(BaseModel):
preprocessing_strategy: PreprocessingStrategy = Field(default="cmr", description="Strategy for preprocessing")
filter_type: FilterType = Field(default="highpass", description="Type of filter")
highpass_filter: HighpassFilter = Field(default=HighpassFilter(), description="Highpass filter")
bandpass_filter: BandpassFilter = Field(default=BandpassFilter(), description="Bandpass filter")
phase_shift: PhaseShift = Field(default=PhaseShift(), description="Phase shift")
common_reference: CommonReference = Field(default=CommonReference(), description="Common reference")
highpass_spatial_filter: HighpassSpatialFilter = Field(
Expand Down
71 changes: 57 additions & 14 deletions src/spikeinterface_pipelines/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,36 +56,42 @@ def preprocess(
logger.info("[Preprocessing] \tSkipping phase shift: 'inter_sample_shift' property not found")

# Highpass filter
recording_hp_full = spre.highpass_filter(recording, **preprocessing_params.highpass_filter.model_dump())
if preprocessing_params.filter_type == "highpass":
logger.info("[Preprocessing] \tHighpass filter")
recording_filt_full = spre.highpass_filter(recording, **preprocessing_params.highpass_filter.model_dump())
else:
logger.info("[Preprocessing] \tBandpass filter")
recording_filt_full = spre.bandpass_filter(recording, **preprocessing_params.bandpass_filter.model_dump())

# Detect and remove bad channels
_, channel_labels = spre.detect_bad_channels(
recording_hp_full, **preprocessing_params.detect_bad_channels.model_dump()
recording_filt_full, **preprocessing_params.detect_bad_channels.model_dump()
)
dead_channel_mask = channel_labels == "dead"
noise_channel_mask = channel_labels == "noise"
out_channel_mask = channel_labels == "out"
logger.info(
f"[Preprocessing] \tBad channel detection found: {np.sum(dead_channel_mask)} dead, {np.sum(noise_channel_mask)} noise, {np.sum(out_channel_mask)} out channels"
)
dead_channel_ids = recording_hp_full.channel_ids[dead_channel_mask]
noise_channel_ids = recording_hp_full.channel_ids[noise_channel_mask]
out_channel_ids = recording_hp_full.channel_ids[out_channel_mask]
dead_channel_ids = recording_filt_full.channel_ids[dead_channel_mask]
noise_channel_ids = recording_filt_full.channel_ids[noise_channel_mask]
out_channel_ids = recording_filt_full.channel_ids[out_channel_mask]
all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids))

max_bad_channel_fraction_to_remove = preprocessing_params.max_bad_channel_fraction_to_remove
if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()):
logger.info(
f"[Preprocessing] \tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). "
)
logger.info("[Preprocessing] \tSkipping further processing for this recording.")
return recording_hp_full
if preprocessing_params.remove_bad_channels:
logger.info("[Preprocessing] \tSkipping further processing for this recording.")
return recording_filt_full

if preprocessing_params.remove_out_channels:
logger.info(f"[Preprocessing] \tRemoving {len(out_channel_ids)} out channels")
recording_rm_out = recording_hp_full.remove_channels(out_channel_ids)
recording_rm_out = recording_filt_full.remove_channels(out_channel_ids)
else:
recording_rm_out = recording_hp_full
recording_rm_out = recording_filt_full

bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids))

Expand All @@ -95,10 +101,17 @@ def preprocess(
recording_rm_out, **preprocessing_params.common_reference.model_dump()
)
else:
recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids)
recording_processed = spre.highpass_spatial_filter(
recording_interp, **preprocessing_params.highpass_spatial_filter.model_dump()
)
# protection against short probes
try:
recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids)
recording_processed = spre.highpass_spatial_filter(
recording_interp, **preprocessing_params.highpass_spatial_filter.model_dump()
)
except:
recording_processed = recording_rm_out
logger.warning(
"[Preprocessing] \tInterpolation failed. Skipping interpolation and highpass spatial filter."
)

if preprocessing_params.remove_bad_channels:
logger.info(
Expand All @@ -114,8 +127,19 @@ def preprocess(
)
logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}")
motion_folder = results_folder / "motion_correction"
# interpolation requires float
recording_processed_f = spre.astype(recording_processed, "float32")

# fix for multi-segment
concat_motion = False
if recording_processed.get_num_segments() > 1:
recording_processed_c = si.concatenate_recordings([recording_processed_f])
concat_motion = True
else:
recording_processed_c = recording_processed_f

recording_corrected = spre.correct_motion(
recording_processed,
recording_processed_c,
preset=preset,
folder=motion_folder,
verbose=False,
Expand All @@ -124,6 +148,25 @@ def preprocess(
estimate_motion_kwargs=motion_correction_kwargs.estimate_motion_kwargs.model_dump(),
interpolate_motion_kwargs=motion_correction_kwargs.interpolate_motion_kwargs.model_dump(),
)

# split segments back
if concat_motion:
rec_corrected_list = []
for segment_index in range(recording_processed.get_num_segments()):
num_samples = recording_processed.get_num_samples(segment_index)
if segment_index == 0:
start_frame = 0
else:
start_frame = recording_processed.get_num_samples(segment_index - 1)
end_frame = start_frame + num_samples
rec_split_corrected = recording_corrected.frame_slice(
start_frame=start_frame,
end_frame=end_frame
)
rec_corrected_list.append(rec_split_corrected)
# append all segments
recording_corrected = si.append_recordings(rec_corrected_list)

if preprocessing_params.motion_correction.strategy == "apply":
logger.info("[Preprocessing] \tApplying motion correction")
recording_processed = recording_corrected
Expand Down
6 changes: 5 additions & 1 deletion src/spikeinterface_pipelines/spikesorting/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
class SorterName(str, Enum):
kilosort25 = "kilosort2_5"
kilosort3 = "kilosort3"
kilosort4 = "kilosort4"
mountainsort5 = "mountainsort5"
# spykingcircus2 = "spykingcircus2"
ironclust = "ironclust"
# spykingcircus2 = "spykingcircus2"


class Kilosort25Model(BaseModel):
Expand Down Expand Up @@ -57,6 +58,9 @@ class Kilosort3Model(BaseModel):
model_config = ConfigDict(extra="forbid")
pass

class Kilosort4Model(BaseModel):
model_config = ConfigDict(extra="forbid")
pass

class MountainSort5Model(BaseModel):
model_config = ConfigDict(extra="forbid")
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface_pipelines/spikesorting/spikesorting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

import shutil
import numpy as np
from pathlib import Path
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface_pipelines/visualization/visualization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations


from pathlib import Path
import numpy as np
Expand All @@ -25,7 +25,7 @@
def visualize(
recording: si.BaseRecording,
sorting_curated: si.BaseSorting | None = None,
waveform_extractor: si.WaveformExtractor | None = None,
waveform_extractor: si.SortingAnalyzer | None = None,
visualization_params: VisualizationParams = VisualizationParams(),
scratch_folder: Path = Path("./scratch/"),
results_folder: Path = Path("./results/visualization/"),
Expand All @@ -39,7 +39,7 @@ def visualize(
The input processed recording
sorting_curated: si.BaseSorting | None
The input curated sorting. If None, only the recording visualization will be generated.
waveform_extractor: si.WaveformExtractor | None
waveform_extractor: si.SortingAnalyzer | None
The input waveform extractor from postprocessing. If None, only the recording visualization will be generated.
visualization_params: VisualizationParams
The visualization parameters
Expand Down

0 comments on commit 2b0c23c

Please sign in to comment.