diff --git a/src/spikeinterface_pipelines/curation/curation.py b/src/spikeinterface_pipelines/curation/curation.py index c02c775..5c60f06 100644 --- a/src/spikeinterface_pipelines/curation/curation.py +++ b/src/spikeinterface_pipelines/curation/curation.py @@ -31,7 +31,7 @@ def curate( Path to the scratch folder results_folder: Path Path to the results folder - + Returns ------- si.BaseSorting | None diff --git a/src/spikeinterface_pipelines/curation/params.py b/src/spikeinterface_pipelines/curation/params.py index 2d17206..a66e116 100644 --- a/src/spikeinterface_pipelines/curation/params.py +++ b/src/spikeinterface_pipelines/curation/params.py @@ -5,10 +5,11 @@ class CurationParams(BaseModel): """ Curation parameters. """ + curation_query: str = Field( default="isi_violations_ratio < 0.5 and amplitude_cutoff < 0.1 and presence_ratio > 0.8", description=( "Query to select units to keep after curation. " "Default is 'isi_violations_ratio < 0.5 and amplitude_cutoff < 0.1 and presence_ratio > 0.8'." - ) - ) \ No newline at end of file + ), + ) diff --git a/src/spikeinterface_pipelines/pipeline.py b/src/spikeinterface_pipelines/pipeline.py index 0f492b0..a080d4a 100644 --- a/src/spikeinterface_pipelines/pipeline.py +++ b/src/spikeinterface_pipelines/pipeline.py @@ -54,8 +54,7 @@ def run_pipeline( preprocessing_params = PreprocessingParams(**preprocessing_params) if isinstance(spikesorting_params, dict): spikesorting_params = SpikeSortingParams( - sorter_name=spikesorting_params['sorter_name'], - sorter_kwargs=spikesorting_params['sorter_kwargs'] + sorter_name=spikesorting_params["sorter_name"], sorter_kwargs=spikesorting_params["sorter_kwargs"] ) if isinstance(postprocessing_params, dict): postprocessing_params = PostprocessingParams(**postprocessing_params) @@ -84,6 +83,7 @@ def run_pipeline( # Spike Sorting if run_spikesorting: + # TODO: turn off sorter motion correction if motion correction is already done sorting = spikesort( recording=recording_preprocessed, scratch_folder=scratch_folder, @@ -126,7 +126,6 @@ def run_pipeline( waveform_extractor = None sorting_curated = None - # Visualization visualization_output = None if run_visualization: diff --git a/src/spikeinterface_pipelines/postprocessing/params.py b/src/spikeinterface_pipelines/postprocessing/params.py index ecb970c..93c0911 100644 --- a/src/spikeinterface_pipelines/postprocessing/params.py +++ b/src/spikeinterface_pipelines/postprocessing/params.py @@ -97,8 +97,16 @@ class QMParams(BaseModel): class QualityMetrics(BaseModel): qm_params: QMParams = Field(default=QMParams(), description="Quality metric parameters.") metric_names: List[str] = Field( - default=["presence_ratio", "snr", "isi_violation", "rp_violation", "sliding_rp_violation", "amplitude_cutoff", "amplitude_median"], - description="List of metric names to compute. If None, all available metrics are computed." + default=[ + "presence_ratio", + "snr", + "isi_violation", + "rp_violation", + "sliding_rp_violation", + "amplitude_cutoff", + "amplitude_median", + ], + description="List of metric names to compute. If None, all available metrics are computed.", ) n_jobs: int = Field(default=1, description="Number of jobs.") diff --git a/src/spikeinterface_pipelines/preprocessing/params.py b/src/spikeinterface_pipelines/preprocessing/params.py index f7f4876..7edeb39 100644 --- a/src/spikeinterface_pipelines/preprocessing/params.py +++ b/src/spikeinterface_pipelines/preprocessing/params.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field from typing import Optional, Union, List, Literal from enum import Enum +import numpy as np class PreprocessingStrategy(str, Enum): @@ -52,7 +53,9 @@ class MCDetectKwargs(BaseModel): class MCLocalizeCenterOfMass(BaseModel): radius_um: float = Field(default=75.0, description="Radius in um for channel sparsity.") - feature: str = Field(default="ptp", description="'ptp', 'mean', 'energy' or 'peak_voltage'. Feature to consider for computation") + feature: str = Field( + default="ptp", description="'ptp', 'mean', 'energy' or 'peak_voltage'. Feature to consider for computation" + ) class MCLocalizeMonopolarTriangulation(BaseModel): @@ -60,17 +63,25 @@ class MCLocalizeMonopolarTriangulation(BaseModel): max_distance_um: float = Field(default=150.0, description="Boundary for distance estimation.") optimizer: str = Field(default="minimize_with_log_penality", description="") enforce_decrease: bool = Field(default=True, description="Enforce spatial decreasingness for PTP vectors") - feature: str = Field(default="ptp", description="'ptp', 'energy' or 'peak_voltage'. The available features to consider for estimating the position via monopolar triangulation are peak-to-peak amplitudes (ptp, default), energy ('energy', as L2 norm) or voltages at the center of the waveform (peak_voltage)") + feature: str = Field( + default="ptp", + description="'ptp', 'energy' or 'peak_voltage'. The available features to consider for estimating the position via monopolar triangulation are peak-to-peak amplitudes (ptp, default), energy ('energy', as L2 norm) or voltages at the center of the waveform (peak_voltage)", + ) class MCLocalizeGridConvolution(BaseModel): radius_um: float = Field(default=40.0, description="Radius in um for channel sparsity.") upsampling_um: float = Field(default=5.0, description="Upsampling resolution for the grid of templates.") - sigma_um: List[float] = Field(default=[5.0, 25.0, 5], description="Spatial decays of the fake templates.") + weight_method: dict = Field( + default={"mode": "gaussian_2d", "sigma_list_um": np.linspace(5, 25, 5)}, description="Weighting strategy." + ) sigma_ms: float = Field(default=0.25, description="The temporal decay of the fake templates.") margin_um: float = Field(default=30.0, description="The margin for the grid of fake templates.") - percentile: float = Field(default=10.0, description="The percentage in [0, 100] of the best scalar products kept to estimate the position.") - sparsity_threshold: float = Field(default=0.01, description="The sparsity threshold (in [0, 1]) below which weights should be considered as 0.") + percentile: float = Field( + default=10.0, + description="The percentage in [0, 100] of the best scalar products kept to estimate the position.", + ) + prototype: Optional[list] = Field(default=None, description="Fake waveforms for the templates.") class MCEstimateMotionDecentralized(BaseModel): @@ -117,45 +128,81 @@ class MCEstimateMotionIterativeTemplate(BaseModel): class MCInterpolateMotionKwargs(BaseModel): - direction: int = Field(default=1, description="0 | 1 | 2. Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z).") - border_mode: str = Field(default="remove_channels", description="'remove_channels' | 'force_extrapolate' | 'force_zeros'. Control how channels are handled on border.") - spatial_interpolation_method: str = Field(default="idw", description="The spatial interpolation method used to interpolate the channel locations.") + direction: int = Field( + default=1, description="0 | 1 | 2. Dimension along which channel_locations are shifted (0 - x, 1 - y, 2 - z)." + ) + border_mode: str = Field( + default="remove_channels", + description="'remove_channels' | 'force_extrapolate' | 'force_zeros'. Control how channels are handled on border.", + ) + spatial_interpolation_method: str = Field( + default="idw", description="The spatial interpolation method used to interpolate the channel locations." + ) sigma_um: float = Field(default=20.0, description="Used in the 'kriging' formula") p: int = Field(default=1, description="Used in the 'kriging' formula") - num_closest: int = Field(default=3, description="Number of closest channels used by 'idw' method for interpolation.") + num_closest: int = Field( + default=3, description="Number of closest channels used by 'idw' method for interpolation." + ) class MCNonrigidAccurate(BaseModel): detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="") - localize_peaks_kwargs: MCLocalizeMonopolarTriangulation = Field(default=MCLocalizeMonopolarTriangulation(), description="") - estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(default=MCEstimateMotionDecentralized(), description="") + localize_peaks_kwargs: MCLocalizeMonopolarTriangulation = Field( + default=MCLocalizeMonopolarTriangulation(), description="" + ) + estimate_motion_kwargs: MCEstimateMotionDecentralized = Field( + default=MCEstimateMotionDecentralized(), description="" + ) + interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="") + + +class MCNonrigidFastAndAccurate(BaseModel): + detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="") + localize_peaks_kwargs: MCLocalizeGridConvolution = Field(default=MCLocalizeGridConvolution(), description="") + estimate_motion_kwargs: MCEstimateMotionDecentralized = Field( + default=MCEstimateMotionDecentralized(), description="" + ) interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="") class MCRigidFast(BaseModel): detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="") localize_peaks_kwargs: MCLocalizeCenterOfMass = Field(default=MCLocalizeCenterOfMass(), description="") - estimate_motion_kwargs: MCEstimateMotionDecentralized = Field(default=MCEstimateMotionDecentralized(bin_duration_s=10.0, rigid=True), description="") + estimate_motion_kwargs: MCEstimateMotionDecentralized = Field( + default=MCEstimateMotionDecentralized(bin_duration_s=10.0, rigid=True), description="" + ) interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(), description="") class MCKilosortLike(BaseModel): detect_kwargs: MCDetectKwargs = Field(default=MCDetectKwargs(), description="") localize_peaks_kwargs: MCLocalizeGridConvolution = Field(default=MCLocalizeGridConvolution(), description="") - estimate_motion_kwargs: MCEstimateMotionIterativeTemplate = Field(default=MCEstimateMotionIterativeTemplate(), description="") - interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field(default=MCInterpolateMotionKwargs(border_mode="force_extrapolate", spatial_interpolation_method="kriging"), description="") + estimate_motion_kwargs: MCEstimateMotionIterativeTemplate = Field( + default=MCEstimateMotionIterativeTemplate(), description="" + ) + interpolate_motion_kwargs: MCInterpolateMotionKwargs = Field( + default=MCInterpolateMotionKwargs(border_mode="force_extrapolate", spatial_interpolation_method="kriging"), + description="", + ) class MCPreset(str, Enum): nonrigid_accurate = "nonrigid_accurate" + nonrigid_fast_and_accurate = "nonrigid_fast_and_accurate" rigid_fast = "rigid_fast" kilosort_like = "kilosort_like" class MotionCorrection(BaseModel): - strategy: Literal["skip", "compute", "apply"] = Field(default="compute", description="What strategy to use for motion correction") - preset: MCPreset = Field(default=MCPreset.nonrigid_accurate.value, description="Preset for motion correction") - motion_kwargs: Union[MCNonrigidAccurate, MCRigidFast, MCKilosortLike] = Field(default=MCNonrigidAccurate(), description="Motion correction parameters") + strategy: Literal["skip", "compute", "apply"] = Field( + default="compute", description="What strategy to use for motion correction" + ) + preset: MCPreset = Field( + default=MCPreset.nonrigid_fast_and_accurate.value, description="Preset for motion correction" + ) + motion_kwargs: Union[MCNonrigidAccurate, MCNonrigidFastAndAccurate, MCRigidFast, MCKilosortLike] = Field( + default=MCNonrigidFastAndAccurate(), description="Motion correction parameters" + ) # Preprocessing params --------------------------------------------------------------- diff --git a/src/spikeinterface_pipelines/preprocessing/preprocessing.py b/src/spikeinterface_pipelines/preprocessing/preprocessing.py index e11bb9c..4bfe791 100644 --- a/src/spikeinterface_pipelines/preprocessing/preprocessing.py +++ b/src/spikeinterface_pipelines/preprocessing/preprocessing.py @@ -6,11 +6,18 @@ import spikeinterface.preprocessing as spre from ..logger import logger -from .params import PreprocessingParams, MCNonrigidAccurate, MCRigidFast, MCKilosortLike +from .params import PreprocessingParams, MCNonrigidAccurate, MCNonrigidFastAndAccurate, MCRigidFast, MCKilosortLike warnings.filterwarnings("ignore") +_motion_correction_presets_to_params = dict( + nonrigid_accurate=MCNonrigidAccurate, + nonrigid_fast_and_accurate=MCNonrigidFastAndAccurate, + rigid_fast=MCKilosortLike, + kilosort_like=MCKilosortLike, +) + def preprocess( recording: si.BaseRecording, @@ -102,12 +109,9 @@ def preprocess( # Motion correction if preprocessing_params.motion_correction.strategy != "skip": preset = preprocessing_params.motion_correction.preset - if preset == "nonrigid_accurate": - motion_correction_kwargs = MCNonrigidAccurate(**preprocessing_params.motion_correction.motion_kwargs.model_dump()) - elif preset == "rigid_fast": - motion_correction_kwargs = MCRigidFast(**preprocessing_params.motion_correction.motion_kwargs.model_dump()) - elif preset == "kilosort_like": - motion_correction_kwargs = MCKilosortLike(**preprocessing_params.motion_correction.motion_kwargs.model_dump()) + motion_correction_kwargs = _motion_correction_presets_to_params[preset]( + **preprocessing_params.motion_correction.motion_kwargs.model_dump() + ) logger.info(f"[Preprocessing] \tComputing motion correction with preset: {preset}") motion_folder = results_folder / "motion_correction" recording_corrected = spre.correct_motion( diff --git a/src/spikeinterface_pipelines/spikesorting/params.py b/src/spikeinterface_pipelines/spikesorting/params.py index 2ad7295..239fe35 100644 --- a/src/spikeinterface_pipelines/spikesorting/params.py +++ b/src/spikeinterface_pipelines/spikesorting/params.py @@ -11,7 +11,7 @@ class SorterName(str, Enum): class Kilosort25Model(BaseModel): - model_config = ConfigDict(extra='forbid') + model_config = ConfigDict(extra="forbid") detect_threshold: float = Field(default=6, description="Threshold for spike detection") projection_threshold: List[float] = Field(default=[10, 4], description="Threshold on projections") preclust_threshold: float = Field( @@ -29,7 +29,10 @@ class Kilosort25Model(BaseModel): sig: float = Field(default=20, description="spatial smoothness constant for registration") freq_min: float = Field(default=150, description="High-pass filter cutoff frequency") sigmaMask: float = Field(default=30, description="Spatial constant in um for computing residual variance of spike") - lam: float = Field(default=10.0, description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)") + lam: float = Field( + default=10.0, + description="The importance of the amplitude penalty (like in Kilosort1: 0 means not used, 10 is average, 50 is a lot)", + ) nPCs: int = Field(default=3, description="Number of PCA dimensions") ntbuff: int = Field(default=64, description="Samples of symmetrical buffer for whitening and spike detection") nfilt_factor: int = Field(default=4, description="Max number of clusters per good channel (even temporary ones) 4") @@ -50,22 +53,18 @@ class Kilosort25Model(BaseModel): class Kilosort3Model(BaseModel): - model_config = ConfigDict(extra='forbid') + model_config = ConfigDict(extra="forbid") pass class IronClustModel(BaseModel): - model_config = ConfigDict(extra='forbid') + model_config = ConfigDict(extra="forbid") pass class MountainSort5Model(BaseModel): - model_config = ConfigDict(extra='forbid') - scheme: str = Field( - default='2', - description="Sorting scheme", - json_schema_extra={'options': ["1", "2", "3"]} - ) + model_config = ConfigDict(extra="forbid") + scheme: str = Field(default="2", description="Sorting scheme", json_schema_extra={"options": ["1", "2", "3"]}) detect_threshold: float = Field(default=5.5, description="Threshold for spike detection") detect_sign: int = Field(default=-1, description="Sign of the peak") detect_time_radius_msec: float = Field(default=0.5, description="Time radius in milliseconds") @@ -77,9 +76,13 @@ class MountainSort5Model(BaseModel): scheme1_detect_channel_radius: int = Field(default=150, description="Scheme 1 detect channel radius") scheme2_phase1_detect_channel_radius: int = Field(default=200, description="Scheme 2 phase 1 detect channel radius") scheme2_detect_channel_radius: int = Field(default=50, description="Scheme 2 detect channel radius") - scheme2_max_num_snippets_per_training_batch: int = Field(default=200, description="Scheme 2 max number of snippets per training batch") + scheme2_max_num_snippets_per_training_batch: int = Field( + default=200, description="Scheme 2 max number of snippets per training batch" + ) scheme2_training_duration_sec: int = Field(default=300, description="Scheme 2 training duration in seconds") - scheme2_training_recording_sampling_mode: str = Field(default='uniform', description="Scheme 2 training recording sampling mode") + scheme2_training_recording_sampling_mode: str = Field( + default="uniform", description="Scheme 2 training recording sampling mode" + ) scheme3_block_duration_sec: int = Field(default=1800, description="Scheme 3 block duration in seconds") freq_min: int = Field(default=300, description="High-pass filter cutoff frequency") freq_max: int = Field(default=6000, description="Low-pass filter cutoff frequency") @@ -90,7 +93,8 @@ class MountainSort5Model(BaseModel): class SpikeSortingParams(BaseModel): sorter_name: SorterName = Field(description="Name of the sorter to use.") sorter_kwargs: Union[Kilosort25Model, Kilosort3Model, IronClustModel, MountainSort5Model] = Field( - description="Sorter specific kwargs.", - union_mode='left_to_right' + description="Sorter specific kwargs.", union_mode="left_to_right" + ) + spikesort_by_group: bool = Field( + default=False, description="If True, spike sorting is run for each group separately." ) - spikesort_by_group: bool = Field(default=False, description="If True, spike sorting is run for each group separately.") diff --git a/src/spikeinterface_pipelines/visualization/__init__.py b/src/spikeinterface_pipelines/visualization/__init__.py index 49d093d..9aae88e 100644 --- a/src/spikeinterface_pipelines/visualization/__init__.py +++ b/src/spikeinterface_pipelines/visualization/__init__.py @@ -1,2 +1,2 @@ from .visualization import visualize -from .params import VisualizationParams \ No newline at end of file +from .params import VisualizationParams diff --git a/src/spikeinterface_pipelines/visualization/params.py b/src/spikeinterface_pipelines/visualization/params.py index 19b920b..a6b1e9d 100644 --- a/src/spikeinterface_pipelines/visualization/params.py +++ b/src/spikeinterface_pipelines/visualization/params.py @@ -9,6 +9,7 @@ class TracesParams(BaseModel): """ Traces parameters. """ + n_snippets_per_segment: int = Field(default=2, description="Number of snippets per segment to visualize.") snippet_duration_s: float = Field(default=0.5, description="Duration of each snippet in seconds.") skip: bool = Field(default=False, description="Skip traces visualization.") @@ -18,6 +19,7 @@ class DetectionParams(BaseModel): """ Detection parameters. """ + peak_sign: Literal["neg", "pos", "both"] = Field(default="neg", description="Peak sign for peak detection.") detect_threshold: float = Field(default=5.0, description="Threshold for peak detection.") exclude_sweep_ms: float = Field(default=0.1, description="Exclude sweep in ms around peak detection.") @@ -27,6 +29,7 @@ class LocalizationParams(BaseModel): """ Localization parameters. """ + ms_before: float = Field(default=0.1, description="Time before peak in ms.") ms_after: float = Field(default=0.3, description="Time after peak in ms.") radius_um: float = Field(default=100.0, description="Radius in um for sparsifying waveforms before localization.") @@ -36,17 +39,18 @@ class DriftParams(BaseModel): """ Drift parameters. """ + detection: DetectionParams = Field( default=DetectionParams(), - description="Detection parameters (only used if spike localization was not performed in postprocessing)" + description="Detection parameters (only used if spike localization was not performed in postprocessing)", ) localization: LocalizationParams = Field( default=LocalizationParams(), - description="Localization parameters (only used if spike localization was not performed in postprocessing)" + description="Localization parameters (only used if spike localization was not performed in postprocessing)", ) decimation_factor: int = Field( default=30, - description="The decimation factor for drift visualization. E.g. 30 means that 1 out of 30 spikes is plotted." + description="The decimation factor for drift visualization. E.g. 30 means that 1 out of 30 spikes is plotted.", ) alpha: float = Field(default=0.15, description="Alpha for scatter plot.") vmin: float = Field(default=-200, description="Min value for colormap.") @@ -59,50 +63,35 @@ class SortingSummaryVisualizationParams(BaseModel): """ Sorting summary visualization parameters. """ + unit_table_properties: list = Field( - default=["default_qc"], - description="List of properties to show in the unit table." - ) - curation: bool = Field( - default=True, - description="Whether to show curation buttons." + default=["default_qc"], description="List of properties to show in the unit table." ) + curation: bool = Field(default=True, description="Whether to show curation buttons.") label_choices: list = Field( - default=["SUA", "MUA", "noise"], - description="List of labels to choose from (if `curation=True`)" - ) - label: str = Field( - default="Sorting summary from SI pipelines", - description="Label for the sorting summary." + default=["SUA", "MUA", "noise"], description="List of labels to choose from (if `curation=True`)" ) + label: str = Field(default="Sorting summary from SI pipelines", description="Label for the sorting summary.") class RecordingVisualizationParams(BaseModel): """ Recording visualization parameters. """ - timeseries: TracesParams = Field( - default=TracesParams(), - description="Traces visualization parameters." - ) - drift: DriftParams = Field( - default=DriftParams(), - description="Drift visualization parameters." - ) - label: str = Field( - default="Recording visualization from SI pipelines", - description="Label for the recording." - ) + + timeseries: TracesParams = Field(default=TracesParams(), description="Traces visualization parameters.") + drift: DriftParams = Field(default=DriftParams(), description="Drift visualization parameters.") + label: str = Field(default="Recording visualization from SI pipelines", description="Label for the recording.") + class VisualizationParams(BaseModel): """ Visualization parameters. """ + recording: RecordingVisualizationParams = Field( - default=RecordingVisualizationParams(), - description="Recording visualization parameters." + default=RecordingVisualizationParams(), description="Recording visualization parameters." ) sorting_summary: SortingSummaryVisualizationParams = Field( - default=SortingSummaryVisualizationParams(), - description="Sorting summary visualization parameters." + default=SortingSummaryVisualizationParams(), description="Sorting summary visualization parameters." ) diff --git a/src/spikeinterface_pipelines/visualization/visualization.py b/src/spikeinterface_pipelines/visualization/visualization.py index 9e0fa33..f228cff 100644 --- a/src/spikeinterface_pipelines/visualization/visualization.py +++ b/src/spikeinterface_pipelines/visualization/visualization.py @@ -1,4 +1,3 @@ - from __future__ import annotations from pathlib import Path @@ -31,7 +30,7 @@ def visualize( scratch_folder: Path = Path("./scratch/"), results_folder: Path = Path("./results/visualization/"), ) -> dict | None: - """ + """ Generate visualization of preprocessing, spikesorting and curation results. Parameters @@ -52,7 +51,7 @@ def visualize( Returns ------- visualization_output: dict - The visualization output dictionary + The visualization output dictionary """ logger.info("[Visualization] \tRunning Visualization stage") visualization_output = {} @@ -68,9 +67,7 @@ def visualize( # Recording visualization cmap = plt.get_cmap(recording_params["drift"]["cmap"]) - norm = Normalize( - vmin=recording_params["drift"]["vmin"], vmax=recording_params["drift"]["vmax"], clip=True - ) + norm = Normalize(vmin=recording_params["drift"]["vmin"], vmax=recording_params["drift"]["vmax"], clip=True) decimation_factor = recording_params["drift"]["decimation_factor"] alpha = recording_params["drift"]["alpha"] @@ -143,9 +140,7 @@ def visualize( fig_drift.savefig(fig_drift_folder / f"drift.png", dpi=300) # make a sorting view View - v_drift = svv.TabLayoutItem( - label=f"Drift map", view=svv.Image(image_path=str(fig_drift_folder / f"drift.png")) - ) + v_drift = svv.TabLayoutItem(label=f"Drift map", view=svv.Image(image_path=str(fig_drift_folder / f"drift.png"))) # timeseries if not recording_params["timeseries"]["skip"]: @@ -212,8 +207,7 @@ def visualize( for prop in unit_table_properties: if prop not in waveform_extractor.sorting.get_property_keys(): logger.info( - f"[Visualization] \tProperty {prop} not found in sorting object. " - "Not adding to unit table" + f"[Visualization] \tProperty {prop} not found in sorting object. " "Not adding to unit table" ) unit_table_properties.remove(prop) v_sorting = sw.plot_sorting_summary( @@ -222,14 +216,12 @@ def visualize( curation=sorting_summary_params["unit_table_properties"], label_choices=sorting_summary_params["label_choices"], backend="sortingview", - generate_url=False + generate_url=False, ).view try: # pre-generate gh for curation - url = v_sorting.url( - label=sorting_summary_params["label"] - ) + url = v_sorting.url(label=sorting_summary_params["label"]) print(f"\n{url}\n") visualization_output["sorting_summary"] = url except Exception as e: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index c991d30..a6243f2 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -223,15 +223,15 @@ def test_pipeline(tmp_path, generate_recording): recording, sorting, waveform_extractor = _generate_gt_recording() print("TEST PREPROCESSING") - test_preprocessing(tmp_folder, (recording, sorting)) + test_preprocessing(tmp_folder, (recording, sorting, waveform_extractor)) print("TEST SPIKESORTING") - test_spikesorting(tmp_folder, (recording, sorting)) + test_spikesorting(tmp_folder, (recording, sorting, waveform_extractor)) print("TEST POSTPROCESSING") - test_postprocessing(tmp_folder, (recording, sorting)) + test_postprocessing(tmp_folder, (recording, sorting, waveform_extractor)) print("TEST CURATION") test_curation(tmp_folder, (recording, sorting, waveform_extractor)) print("TEST VISUALIZATION") test_visualization(tmp_folder, (recording, sorting, waveform_extractor)) print("TEST PIPELINE") - test_pipeline(tmp_folder, (recording, sorting)) + test_pipeline(tmp_folder, (recording, sorting, waveform_extractor))