diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 957d43c..a080d4a 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -54,8 +54,7 @@ def run_pipeline( preprocessing_params = PreprocessingParams(**preprocessing_params) if isinstance(spikesorting_params, dict): spikesorting_params = SpikeSortingParams( - sorter_name=spikesorting_params['sorter_name'], - sorter_kwargs=spikesorting_params['sorter_kwargs'] + sorter_name=spikesorting_params["sorter_name"], sorter_kwargs=spikesorting_params["sorter_kwargs"] ) if isinstance(postprocessing_params, dict): postprocessing_params = PostprocessingParams(**postprocessing_params) diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index 9e3e792..239fe35 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -11,7 +11,7 @@ class SorterName(str, Enum): class Kilosort25Model(BaseModel): - model_config = ConfigDict(extra='forbid') + model_config = ConfigDict(extra="forbid") detect_threshold: float = Field(default=6, description="Threshold for spike detection") projection_threshold: List[float] = Field(default=[10, 4], description="Threshold on projections") preclust_threshold: float = Field( @@ -53,22 +53,18 @@ class Kilosort25Model(BaseModel): class Kilosort3Model(BaseModel): - model_config = ConfigDict(extra='forbid') + model_config = ConfigDict(extra="forbid") pass class IronClustModel(BaseModel): - model_config = ConfigDict(extra='forbid') + model_config = ConfigDict(extra="forbid") pass class MountainSort5Model(BaseModel): - model_config = ConfigDict(extra='forbid') - scheme: str = Field( - default='2', - description="Sorting scheme", - json_schema_extra={'options': ["1", "2", "3"]} - ) + model_config = ConfigDict(extra="forbid") + scheme: str = Field(default="2", description="Sorting scheme", json_schema_extra={"options": ["1", "2", "3"]}) detect_threshold: float = Field(default=5.5, description="Threshold for spike detection") detect_sign: int = Field(default=-1, description="Sign of the peak") detect_time_radius_msec: float = Field(default=0.5, description="Time radius in milliseconds") @@ -80,9 +76,13 @@ class MountainSort5Model(BaseModel): scheme1_detect_channel_radius: int = Field(default=150, description="Scheme 1 detect channel radius") scheme2_phase1_detect_channel_radius: int = Field(default=200, description="Scheme 2 phase 1 detect channel radius") scheme2_detect_channel_radius: int = Field(default=50, description="Scheme 2 detect channel radius") - scheme2_max_num_snippets_per_training_batch: int = Field(default=200, description="Scheme 2 max number of snippets per training batch") + scheme2_max_num_snippets_per_training_batch: int = Field( + default=200, description="Scheme 2 max number of snippets per training batch" + ) scheme2_training_duration_sec: int = Field(default=300, description="Scheme 2 training duration in seconds") - scheme2_training_recording_sampling_mode: str = Field(default='uniform', description="Scheme 2 training recording sampling mode") + scheme2_training_recording_sampling_mode: str = Field( + default="uniform", description="Scheme 2 training recording sampling mode" + ) scheme3_block_duration_sec: int = Field(default=1800, description="Scheme 3 block duration in seconds") freq_min: int = Field(default=300, description="High-pass filter cutoff frequency") freq_max: int = Field(default=6000, description="Low-pass filter cutoff frequency") @@ -93,7 +93,8 @@ class MountainSort5Model(BaseModel): class SpikeSortingParams(BaseModel): sorter_name: SorterName = Field(description="Name of the sorter to use.") sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field( - description="Sorter specific kwargs.", - union_mode='left_to_right' + description="Sorter specific kwargs.", union_mode="left_to_right" + ) + spikesort_by_group: bool = Field( + default=False, description="If True, spike sorting is run for each group separately." ) - spikesort_by_group: bool = Field(default=False, description="If True, spike sorting is run for each group separately.")