Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nonrigid_fast_and_accurate #13

Merged
merged 3 commits into from
Mar 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/spikeinterface_pipelines/curation/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def curate(
Path to the scratch folder
results_folder: Path
Path to the results folder

Returns
-------
si.BaseSorting | None
Expand Down
5 changes: 3 additions & 2 deletions src/spikeinterface_pipelines/curation/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ class CurationParams(BaseModel):
"""
Curation parameters.
"""

curation_query: str = Field(
default="isi_violations_ratio < 0.5 and amplitude_cutoff < 0.1 and presence_ratio > 0.8",
description=(
"Query to select units to keep after curation. "
"Default is 'isi_violations_ratio < 0.5 and amplitude_cutoff < 0.1 and presence_ratio > 0.8'."
)
)
),
)
5 changes: 2 additions & 3 deletions src/spikeinterface_pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def run_pipeline(
preprocessing_params = PreprocessingParams(**preprocessing_params)
if isinstance(spikesorting_params, dict):
spikesorting_params = SpikeSortingParams(
sorter_name=spikesorting_params['sorter_name'],
sorter_kwargs=spikesorting_params['sorter_kwargs']
sorter_name=spikesorting_params["sorter_name"], sorter_kwargs=spikesorting_params["sorter_kwargs"]
)
if isinstance(postprocessing_params, dict):
postprocessing_params = PostprocessingParams(**postprocessing_params)
Expand Down Expand Up @@ -84,6 +83,7 @@ def run_pipeline(

# Spike Sorting
if run_spikesorting:
# TODO: turn off sorter motion correction if motion correction is already done
sorting = spikesort(
recording=recording_preprocessed,
scratch_folder=scratch_folder,
Expand Down Expand Up @@ -126,7 +126,6 @@ def run_pipeline(
waveform_extractor = None
sorting_curated = None


# Visualization
visualization_output = None
if run_visualization:
Expand Down
12 changes: 10 additions & 2 deletions src/spikeinterface_pipelines/postprocessing/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,16 @@ class QMParams(BaseModel):
class QualityMetrics(BaseModel):
qm_params: QMParams = Field(default=QMParams(), description="Quality metric parameters.")
metric_names: List[str] = Field(
default=["presence_ratio", "snr", "isi_violation", "rp_violation", "sliding_rp_violation", "amplitude_cutoff", "amplitude_median"],
description="List of metric names to compute. If None, all available metrics are computed."
default=[
"presence_ratio",
"snr",
"isi_violation",
"rp_violation",
"sliding_rp_violation",
"amplitude_cutoff",
"amplitude_median",
],
description="List of metric names to compute. If None, all available metrics are computed.",
)
n_jobs: int = Field(default=1, description="Number of jobs.")

Expand Down
81 changes: 64 additions & 17 deletions src/spikeinterface_pipelines/preprocessing/params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field
from typing import Optional, Union, List, Literal
from enum import Enum
import numpy as np


class PreprocessingStrategy(str, Enum):
Expand Down Expand Up @@ -52,25 +53,35 @@ class MCDetectKwargs(BaseModel):

class MCLocalizeCenterOfMass(BaseModel):
radius_um: float = Field(default=75.0, description="Radius in um for channel sparsity.")
feature: str = Field(default="ptp", description="'ptp', 'mean', 'energy' or 'peak_voltage'. Feature to consider for computation")
feature: str = Field(
default="ptp", description="'ptp', 'mean', 'energy' or 'peak_voltage'. Feature to consider for computation"
)


class MCLocalizeMonopolarTriangulation(BaseModel):
radius_um: float = Field(default=75.0, description="Radius in um for channel sparsity.")
max_distance_um: float = Field(default=150.0, description="Boundary for distance estimation.")
optimizer: str = Field(default="minimize_with_log_penality", description="")
enforce_decrease: bool = Field(default=True, description="Enforce spatial decreasingness for PTP vectors")
feature: str = Field(default="ptp", description="'ptp', 'energy' or 'peak_voltage'. The available features to consider for estimating the position via monopolar triangulation are peak-to-peak amplitudes (ptp, default), energy ('energy', as L2 norm) or voltages at the center of the waveform (peak_voltage)")
feature: str = Field(
default="ptp",
description="'ptp', 'energy' or 'peak_voltage'. The available features to consider for estimating the position via monopolar triangulation are peak-to-peak amplitudes (ptp, default), energy ('energy', as L2 norm) or voltages at the center of the waveform (peak_voltage)",
)


class MCLocalizeGridConvolution(BaseModel):
radius_um: float = Field(default=40.0, description="Radius in um for channel sparsity.")
upsampling_um: float = Field(default=5.0, description="Upsampling resolution for the grid of templates.")
sigma_um: List[float] = Field(default=[5.0, 25.0, 5], description="Spatial decays of the fake templates.")
weight_method: dict = Field(
default={"mode": "gaussian_2d", "sigma_list_um": np.linspace(5, 25, 5)}, description="Weighting strategy."
)
sigma_ms: float = Field(default=0.25, description="The temporal decay of the fake templates.")
margin_um: float = Field(default=30.0, description="The margin for the grid of fake templates.")
percentile: float = Field(default=10.0, description="The percentage in [0, 100] of the best scalar products kept to estimate the position.")
sparsity_threshold: float = Field(default=0.01, description="The sparsity threshold (in [0, 1]) below which weights should be considered as 0.")
percentile: float = Field(
default=10.0,
description="The percentage in [0, 100] of the best scalar products kept to estimate the position.",
)
prototype: Optional[list] = Field(default=None, description="Fake waveforms for the templates.")


class MCEstimateMotionDecentralized(BaseModel):
Expand Down Expand Up @@ -117,45 +128,81 @@ class MCEstimateMotionIterativeTemplate(BaseModel):


class MCInterpolateMotionKwargs(BaseModel):
direction: int = Field(default=1, description="0 | 1 | 2. Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z).")
border_mode: str = Field(default="remove_channels", description="'remove_channels' | 'force_extrapolate' | 'force_zeros'. Control how channels are handled on border.")
spatial_interpolation_method: str = Field(default="idw", description="The spatial interpolation method used to interpolate the channel locations.")
direction: int = Field(
default=1, description="0 | 1 | 2. Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z)."
)
border_mode: str = Field(
default="remove_channels",
description="'remove_channels' | 'force_extrapolate' | 'force_zeros'. Control how channels are handled on border.",
)
spatial_interpolation_method: str = Field(
default="idw", description="The spatial interpolation method used to interpolate the channel locations."
)
sigma_um: float = Field(default=20.0, description="Used in the 'kriging' formula")
p: int = Field(default=1, description="Used in the 'kriging' formula")
num_closest: int = Field(default=3, description="Number of closest channels used by 'idw' method for interpolation.")
num_closest: int = Field(
default=3, description="Number of closest channels used by 'idw' method for interpolation."
)


class MCNonrigidAccurate(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeMonopolarTriangulation = Field(default=MCLocalizeMonopolarTriangulation(), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(default=MCEstimateMotionDecentralized(), description="")
localize_peaks_kwargs: MCLocalizeMonopolarTriangulation = Field(
default=MCLocalizeMonopolarTriangulation(), description=""
)
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(
default=MCEstimateMotionDecentralized(), description=""
)
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="")


class MCNonrigidFastAndAccurate(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeGridConvolution = Field(default=MCLocalizeGridConvolution(), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(
default=MCEstimateMotionDecentralized(), description=""
)
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="")


class MCRigidFast(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeCenterOfMass = Field(default=MCLocalizeCenterOfMass(), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(default=MCEstimateMotionDecentralized(bin_duration_s=10.0, rigid=True), description="")
estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(
default=MCEstimateMotionDecentralized(bin_duration_s=10.0, rigid=True), description=""
)
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="")


class MCKilosortLike(BaseModel):
detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="")
localize_peaks_kwargs: MCLocalizeGridConvolution = Field(default=MCLocalizeGridConvolution(), description="")
estimate_motion_kwargs: MCEstimateMotionIterativeTemplate = Field(default=MCEstimateMotionIterativeTemplate(), description="")
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(border_mode="force_extrapolate", spatial_interpolation_method="kriging"), description="")
estimate_motion_kwargs: MCEstimateMotionIterativeTemplate = Field(
default=MCEstimateMotionIterativeTemplate(), description=""
)
interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(
default=MCInterpolateMotionKwargs(border_mode="force_extrapolate", spatial_interpolation_method="kriging"),
description="",
)


class MCPreset(str, Enum):
nonrigid_accurate = "nonrigid_accurate"
nonrigid_fast_and_accurate = "nonrigid_fast_and_accurate"
rigid_fast = "rigid_fast"
kilosort_like = "kilosort_like"


class MotionCorrection(BaseModel):
strategy: Literal["skip", "compute", "apply"] = Field(default="compute", description="What strategy to use for motion correction")
preset: MCPreset = Field(default=MCPreset.nonrigid_accurate.value, description="Preset for motion correction")
motion_kwargs: Union[MCNonrigidAccurate, MCRigidFast, MCKilosortLike] = Field(default=MCNonrigidAccurate(), description="Motion correction parameters")
strategy: Literal["skip", "compute", "apply"] = Field(
default="compute", description="What strategy to use for motion correction"
)
preset: MCPreset = Field(
default=MCPreset.nonrigid_fast_and_accurate.value, description="Preset for motion correction"
)
motion_kwargs: Union[MCNonrigidAccurate, MCNonrigidFastAndAccurate, MCRigidFast, MCKilosortLike] = Field(
default=MCNonrigidFastAndAccurate(), description="Motion correction parameters"
)


# Preprocessing params ---------------------------------------------------------------
Expand Down
18 changes: 11 additions & 7 deletions src/spikeinterface_pipelines/preprocessing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@
import spikeinterface.preprocessing as spre

from ..logger import logger
from .params import PreprocessingParams, MCNonrigidAccurate, MCRigidFast, MCKilosortLike
from .params import PreprocessingParams, MCNonrigidAccurate, MCNonrigidFastAndAccurate, MCRigidFast, MCKilosortLike


warnings.filterwarnings("ignore")

_motion_correction_presets_to_params = dict(
nonrigid_accurate=MCNonrigidAccurate,
nonrigid_fast_and_accurate=MCNonrigidFastAndAccurate,
rigid_fast=MCKilosortLike,
kilosort_like=MCKilosortLike,
)


def preprocess(
recording: si.BaseRecording,
Expand Down Expand Up @@ -102,12 +109,9 @@ def preprocess(
# Motion correction
if preprocessing_params.motion_correction.strategy != "skip":
preset = preprocessing_params.motion_correction.preset
if preset == "nonrigid_accurate":
motion_correction_kwargs = MCNonrigidAccurate(**preprocessing_params.motion_correction.motion_kwargs.model_dump())
elif preset == "rigid_fast":
motion_correction_kwargs = MCRigidFast(**preprocessing_params.motion_correction.motion_kwargs.model_dump())
elif preset == "kilosort_like":
motion_correction_kwargs = MCKilosortLike(**preprocessing_params.motion_correction.motion_kwargs.model_dump())
motion_correction_kwargs = _motion_correction_presets_to_params[preset](
**preprocessing_params.motion_correction.motion_kwargs.model_dump()
)
logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}")
motion_folder = results_folder / "motion_correction"
recording_corrected = spre.correct_motion(
Expand Down
34 changes: 19 additions & 15 deletions src/spikeinterface_pipelines/spikesorting/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SorterName(str, Enum):


class Kilosort25Model(BaseModel):
model_config = ConfigDict(extra='forbid')
model_config = ConfigDict(extra="forbid")
detect_threshold: float = Field(default=6, description="Threshold for spike detection")
projection_threshold: List[float] = Field(default=[10, 4], description="Threshold on projections")
preclust_threshold: float = Field(
Expand All @@ -29,7 +29,10 @@ class Kilosort25Model(BaseModel):
sig: float = Field(default=20, description="spatial smoothness constant for registration")
freq_min: float = Field(default=150, description="High-pass filter cutoff frequency")
sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike")
lam: float = Field(default=10.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)")
lam: float = Field(
default=10.0,
description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)",
)
nPCs: int = Field(default=3, description="Number of PCA dimensions")
ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection")
nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4")
Expand All @@ -50,22 +53,18 @@ class Kilosort25Model(BaseModel):


class Kilosort3Model(BaseModel):
model_config = ConfigDict(extra='forbid')
model_config = ConfigDict(extra="forbid")
pass


class IronClustModel(BaseModel):
model_config = ConfigDict(extra='forbid')
model_config = ConfigDict(extra="forbid")
pass


class MountainSort5Model(BaseModel):
model_config = ConfigDict(extra='forbid')
scheme: str = Field(
default='2',
description="Sorting scheme",
json_schema_extra={'options': ["1", "2", "3"]}
)
model_config = ConfigDict(extra="forbid")
scheme: str = Field(default="2", description="Sorting scheme", json_schema_extra={"options": ["1", "2", "3"]})
detect_threshold: float = Field(default=5.5, description="Threshold for spike detection")
detect_sign: int = Field(default=-1, description="Sign of the peak")
detect_time_radius_msec: float = Field(default=0.5, description="Time radius in milliseconds")
Expand All @@ -77,9 +76,13 @@ class MountainSort5Model(BaseModel):
scheme1_detect_channel_radius: int = Field(default=150, description="Scheme 1 detect channel radius")
scheme2_phase1_detect_channel_radius: int = Field(default=200, description="Scheme 2 phase 1 detect channel radius")
scheme2_detect_channel_radius: int = Field(default=50, description="Scheme 2 detect channel radius")
scheme2_max_num_snippets_per_training_batch: int = Field(default=200, description="Scheme 2 max number of snippets per training batch")
scheme2_max_num_snippets_per_training_batch: int = Field(
default=200, description="Scheme 2 max number of snippets per training batch"
)
scheme2_training_duration_sec: int = Field(default=300, description="Scheme 2 training duration in seconds")
scheme2_training_recording_sampling_mode: str = Field(default='uniform', description="Scheme 2 training recording sampling mode")
scheme2_training_recording_sampling_mode: str = Field(
default="uniform", description="Scheme 2 training recording sampling mode"
)
scheme3_block_duration_sec: int = Field(default=1800, description="Scheme 3 block duration in seconds")
freq_min: int = Field(default=300, description="High-pass filter cutoff frequency")
freq_max: int = Field(default=6000, description="Low-pass filter cutoff frequency")
Expand All @@ -90,7 +93,8 @@ class MountainSort5Model(BaseModel):
class SpikeSortingParams(BaseModel):
sorter_name: SorterName = Field(description="Name of the sorter to use.")
sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field(
description="Sorter specific kwargs.",
union_mode='left_to_right'
description="Sorter specific kwargs.", union_mode="left_to_right"
)
spikesort_by_group: bool = Field(
default=False, description="If True, spike sorting is run for each group separately."
)
spikesort_by_group: bool = Field(default=False, description="If True, spike sorting is run for each group separately.")
2 changes: 1 addition & 1 deletion src/spikeinterface_pipelines/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .visualization import visualize
from .params import VisualizationParams
from .params import VisualizationParams
Loading
Loading