From b9dbf5e41b1963db3d63308e8bafbf0da123c0c7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 21 Nov 2024 15:59:29 +0100 Subject: [PATCH 1/7] Flip y axis and restore phy confirmation --- .../nwbutils/nwbwidgetsunitviewer.py | 12 +++++++++- src/expipe_plugin_cinpla/widgets/curation.py | 22 +++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py index dcc9d8e..4f170a1 100644 --- a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py +++ b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py @@ -164,6 +164,10 @@ def __init__( step=0.01, description="Bin size:", ) + self.flip_y_axis = widgets.Checkbox( + value=False, + description="Flip y-axis", + ) spatial_series_label = widgets.Label("Spatial Series:") top_panel = widgets.VBox( [ @@ -174,6 +178,7 @@ def __init__( [ spatial_series_label, self.spatial_series_selector, + self.flip_y_axis ] ), widgets.HBox([self.smoothing_slider, self.bin_size_slider]), @@ -311,13 +316,18 @@ def show_unit_rate_maps( else: legend_kwargs.update(bbox_to_anchor=(1.01, 1)) if self.rate_maps is not None: + origin = "lower" if self.flip_y_axis.value else "upper" axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin="lower", aspect="auto", extent=self.extent) axs[0].set_xlabel("x") axs[0].set_ylabel("y") else: axs[0].set_title("Rate maps not computed (spatial_maps not installed)") - axs[1].plot(self.nap_position["y"], self.nap_position["x"], color="grey") + tracking_x = self.nap_position["y"] + tracking_y = self.nap_position["x"] + if self.flip_y_axis.value: + tracking_y = 1 - tracking_y + axs[1].plot(tracking_x, tracking_y, color="grey") spk_pos = self.nap_units[unit_index].value_from(self.nap_position) axs[1].plot(spk_pos["y"], spk_pos["x"], "o", color="red", markersize=5, alpha=0.5) axs[1].set_xlabel("x") diff --git a/src/expipe_plugin_cinpla/widgets/curation.py b/src/expipe_plugin_cinpla/widgets/curation.py index b6b4a43..65a97c3 100644 --- a/src/expipe_plugin_cinpla/widgets/curation.py +++ b/src/expipe_plugin_cinpla/widgets/curation.py @@ -104,9 +104,10 @@ def __init__(self, project): load_from_phy = ipywidgets.Button(description="Load from Phy", layout={"width": "initial"}) load_from_phy.style.button_color = "pink" + self.restore_phy_clicked = False restore_phy = ipywidgets.ToggleButton( value=False, - description="Restore", + description="Restore (click twice to restore)", disabled=False, button_style="", # 'success', 'info', 'warning', 'danger' or '' tooltip="Restore unsorted clusters", @@ -205,6 +206,14 @@ def on_action(change): sorters = [p.name for p in si_path.iterdir() if p.is_dir()] sorter_list.options = sorters if len(sorter_list.value) == 1: + units_raw = self.sorting_curator.load_raw_units(sorter_list.value[0]) + if units_raw is not None: + w = nwb2widget(units_raw, custom_raw_unit_vis) + units_viewers["raw"] = w + units_main = self.sorting_curator.load_main_units() + if units_main is not None: + w = nwb2widget(units_main, custom_main_unit_vis) + units_viewers["main"] = w if strategy.value == "Sortingview": sv_visualization_link.value = self.sorting_curator.get_sortingview_link(sorter_list.value[0]) elif strategy.value == "Phy": @@ -212,6 +221,7 @@ def on_action(change): units_dropdown.value = "Raw" on_choose_units(None) + def on_sorter(change): required_values_filled(actions_list) if len(sorter_list.value) > 1: @@ -313,7 +323,15 @@ def on_restore_phy(change): if len(sorter_list.value) > 1: print("Select one spike sorting output at a time") else: - self.sorting_curator.restore_phy(sorter_list.value[0]) + if self.restore_phy_clicked: + self.sorting_curator.restore_phy(sorter_list.value[0]) + self.restore_phy_clicked = False + restore_phy.description = "Restore (click twice to restore)" + restore_phy.button_style = "primary" + else: + self.restore_phy_clicked = True + restore_phy.description = "Restore (click again to confirm)" + restore_phy.button_style = "danger" @self.output.capture() def on_apply_qm_curation(change): From 0144a72f5bab9e38213c35ad24107e950a76c6f1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 21 Nov 2024 16:06:57 +0100 Subject: [PATCH 2/7] Formatting --- .../nwbutils/nwbwidgetsunitviewer.py | 10 +++------- src/expipe_plugin_cinpla/widgets/curation.py | 1 - 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py index 4f170a1..e285878 100644 --- a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py +++ b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py @@ -174,13 +174,7 @@ def __init__( self.unit_list, self.unit_name_text, self.unit_info_text, - widgets.HBox( - [ - spatial_series_label, - self.spatial_series_selector, - self.flip_y_axis - ] - ), + widgets.HBox([spatial_series_label, self.spatial_series_selector, self.flip_y_axis]), widgets.HBox([self.smoothing_slider, self.bin_size_slider]), ] ) @@ -227,8 +221,10 @@ def get_spatial_series(self): def compute_rate_maps(self): import pynapple as nap + try: from spatial_maps import SpatialMap + HAVE_SPATIAL_MAPS = True except: warnings.warn( diff --git a/src/expipe_plugin_cinpla/widgets/curation.py b/src/expipe_plugin_cinpla/widgets/curation.py index 65a97c3..57e35fc 100644 --- a/src/expipe_plugin_cinpla/widgets/curation.py +++ b/src/expipe_plugin_cinpla/widgets/curation.py @@ -221,7 +221,6 @@ def on_action(change): units_dropdown.value = "Raw" on_choose_units(None) - def on_sorter(change): required_values_filled(actions_list) if len(sorter_list.value) > 1: From cbc80d7e748293d37235514fb6daa5ff3224a5c4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 21 Nov 2024 16:19:47 +0100 Subject: [PATCH 3/7] Add spatial_maps dependency --- pyproject.toml | 1 + .../nwbutils/nwbwidgetsunitviewer.py | 44 ++++++------------- src/expipe_plugin_cinpla/widgets/curation.py | 2 +- 3 files changed, 16 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c647e31..9886313 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "tbb>=2021.11.0; platform_system != 'Darwin'", "pynapple>=0.5.1", "lxml", + "spatial_maps" ] [project.urls] diff --git a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py index e285878..e07704e 100644 --- a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py +++ b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py @@ -221,17 +221,7 @@ def get_spatial_series(self): def compute_rate_maps(self): import pynapple as nap - - try: - from spatial_maps import SpatialMap - - HAVE_SPATIAL_MAPS = True - except: - warnings.warn( - "spatial_maps not installed. Please install it to compute rate maps:\n" - ">>> pip install git+https://github.com/CINPLA/spatial-maps.git" - ) - HAVE_SPATIAL_MAPS = False + from spatial_maps import SpatialMap spatial_series = self.spatial_series[self.spatial_series_selector.value] x, y = spatial_series.data[:].T @@ -266,18 +256,15 @@ def compute_rate_maps(self): self.nap_position = nap_position self.nap_units = nap_units - if HAVE_SPATIAL_MAPS: - sm = SpatialMap( - bin_size=self.bin_size_slider.value, - smoothing=self.smoothing_slider.value, - ) - rate_maps = [] - for unit_index in self.units.id.data: - rate_map = sm.rate_map(x, y, t, unit_spike_times[unit_index]) - rate_maps.append(rate_map) - self.rate_maps = np.array(rate_maps) - else: - self.rate_maps = None + sm = SpatialMap( + bin_size=self.bin_size_slider.value, + smoothing=self.smoothing_slider.value, + ) + rate_maps = [] + for unit_index in self.units.id.data: + rate_map = sm.rate_map(x, y, t, unit_spike_times[unit_index]) + rate_maps.append(rate_map) + self.rate_maps = np.array(rate_maps) def on_spatial_series_change(self, change): self.compute_rate_maps() @@ -311,13 +298,10 @@ def show_unit_rate_maps( fig.canvas.header_visible = False else: legend_kwargs.update(bbox_to_anchor=(1.01, 1)) - if self.rate_maps is not None: - origin = "lower" if self.flip_y_axis.value else "upper" - axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin="lower", aspect="auto", extent=self.extent) - axs[0].set_xlabel("x") - axs[0].set_ylabel("y") - else: - axs[0].set_title("Rate maps not computed (spatial_maps not installed)") + origin = "lower" if self.flip_y_axis.value else "upper" + axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin="lower", aspect="auto", extent=self.extent) + axs[0].set_xlabel("x") + axs[0].set_ylabel("y") tracking_x = self.nap_position["y"] tracking_y = self.nap_position["x"] diff --git a/src/expipe_plugin_cinpla/widgets/curation.py b/src/expipe_plugin_cinpla/widgets/curation.py index 57e35fc..3b7eeb8 100644 --- a/src/expipe_plugin_cinpla/widgets/curation.py +++ b/src/expipe_plugin_cinpla/widgets/curation.py @@ -105,7 +105,7 @@ def __init__(self, project): load_from_phy.style.button_color = "pink" self.restore_phy_clicked = False - restore_phy = ipywidgets.ToggleButton( + restore_phy = ipywidgets.Button( value=False, description="Restore (click twice to restore)", disabled=False, From 45dcc46b8ceb6eca30f654cbdacf5fa8dd584f64 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 21 Nov 2024 23:19:29 +0100 Subject: [PATCH 4/7] fix bug with all nans and flip y --- .../nwbutils/nwbwidgetsunitviewer.py | 26 ++++++++++++++----- src/expipe_plugin_cinpla/scripts/utils.py | 6 +++-- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py index e07704e..00dc198 100644 --- a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py +++ b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from functools import partial -import warnings import ipywidgets as widgets import matplotlib.pyplot as plt import numpy as np @@ -183,6 +182,7 @@ def __init__( spatial_series_selector=self.spatial_series_selector, smoothing_slider=self.smoothing_slider, bin_size_slider=self.bin_size_slider, + flip_y_axis=self.flip_y_axis, ) self.rate_maps, self.extent = None, None self.compute_rate_maps() @@ -268,15 +268,24 @@ def compute_rate_maps(self): def on_spatial_series_change(self, change): self.compute_rate_maps() + self.show_unit_rate_maps(self.unit_list.value) def on_bin_size_change(self, change): self.compute_rate_maps() + self.show_unit_rate_maps(self.unit_list.value) def on_smoothing_change(self, change): self.compute_rate_maps() + self.show_unit_rate_maps(self.unit_list.value) def show_unit_rate_maps( - self, unit_index=None, spatial_series_selector=None, smoothing_slider=None, bin_size_slider=None, axs=None + self, + unit_index=None, + spatial_series_selector=None, + smoothing_slider=None, + bin_size_slider=None, + flip_y_axis=None, + axs=None, ): """ Shows unit rate maps. @@ -293,23 +302,26 @@ def show_unit_rate_maps( figsize = (10, 7) if axs is None: - fig, axs = plt.subplots(figsize=figsize, ncols=2) + fig, axs = plt.subplots(figsize=figsize, ncols=2, sharex=True, sharey=True) if hasattr(fig, "canvas"): fig.canvas.header_visible = False else: legend_kwargs.update(bbox_to_anchor=(1.01, 1)) origin = "lower" if self.flip_y_axis.value else "upper" - axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin="lower", aspect="auto", extent=self.extent) + axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin=origin, aspect="auto", extent=self.extent) axs[0].set_xlabel("x") axs[0].set_ylabel("y") tracking_x = self.nap_position["y"] tracking_y = self.nap_position["x"] - if self.flip_y_axis.value: + spk_pos = self.nap_units[unit_index].value_from(self.nap_position) + spike_pos_x = spk_pos["y"] + spike_pos_y = spk_pos["x"] + if not self.flip_y_axis.value: tracking_y = 1 - tracking_y + spike_pos_y = 1 - spike_pos_y axs[1].plot(tracking_x, tracking_y, color="grey") - spk_pos = self.nap_units[unit_index].value_from(self.nap_position) - axs[1].plot(spk_pos["y"], spk_pos["x"], "o", color="red", markersize=5, alpha=0.5) + axs[1].plot(spike_pos_x, spike_pos_y, "o", color="red", markersize=5, alpha=0.5) axs[1].set_xlabel("x") axs[1].set_ylabel("y") fig.tight_layout() diff --git a/src/expipe_plugin_cinpla/scripts/utils.py b/src/expipe_plugin_cinpla/scripts/utils.py index ab61a41..7399772 100644 --- a/src/expipe_plugin_cinpla/scripts/utils.py +++ b/src/expipe_plugin_cinpla/scripts/utils.py @@ -297,12 +297,14 @@ def compute_and_set_unit_groups(sorting, recording): unit_groups = recording.get_channel_groups()[np.array(list(extremum_channel_indices.values()))] sorting.set_property("group", unit_groups) else: - unit_groups = sorting.get_property("group").astype("str") + unit_groups = sorting.get_property("group").astype("U20") # if there are units without group, we need to compute them unit_ids_without_group = np.array(sorting.unit_ids)[np.where(unit_groups == "nan")[0]] if len(unit_ids_without_group) > 0: sorting_no_group = sorting.select_units(unit_ids=unit_ids_without_group) - we_mem = si.extract_waveforms(recording, sorting_no_group, folder=None, mode="memory", sparse=False) + we_mem = si.extract_waveforms( + recording, sorting_no_group, folder=None, mode="memory", sparse=False, progress_bar=False + ) extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index") unit_groups[sorting.ids_to_indices(unit_ids_without_group)] = recording.get_channel_groups()[ np.array(list(extremum_channel_indices.values())) From 568e7ab666b1c135b50874f93cacc410aa1e7d75 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 21 Nov 2024 23:21:43 +0100 Subject: [PATCH 5/7] small change --- src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py index 00dc198..f921ae8 100644 --- a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py +++ b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py @@ -173,8 +173,8 @@ def __init__( self.unit_list, self.unit_name_text, self.unit_info_text, - widgets.HBox([spatial_series_label, self.spatial_series_selector, self.flip_y_axis]), - widgets.HBox([self.smoothing_slider, self.bin_size_slider]), + widgets.HBox([spatial_series_label, self.spatial_series_selector]), + widgets.HBox([self.smoothing_slider, self.bin_size_slider, self.flip_y_axis]), ] ) self.controls = dict( From 9bba1e976489b484f8aae95c6e3acad88622aef2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 21 Nov 2024 23:31:02 +0100 Subject: [PATCH 6/7] fix ruff --- src/expipe_plugin_cinpla/tools/data_processing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py index 3347c94..65bbb3d 100644 --- a/src/expipe_plugin_cinpla/tools/data_processing.py +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -5,6 +5,7 @@ import expipe import numpy as np +import spatial_maps as sp from expipe_plugin_cinpla.data_loader import ( get_channel_groups, From f8963fa33699e1b576dc24eab368ece603a2359e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 21 Nov 2024 23:35:03 +0100 Subject: [PATCH 7/7] fix phy restore button --- src/expipe_plugin_cinpla/widgets/curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/expipe_plugin_cinpla/widgets/curation.py b/src/expipe_plugin_cinpla/widgets/curation.py index 3b7eeb8..2ff9af0 100644 --- a/src/expipe_plugin_cinpla/widgets/curation.py +++ b/src/expipe_plugin_cinpla/widgets/curation.py @@ -352,7 +352,7 @@ def on_save_to_nwb(change): actions_list.observe(on_action) load_from_phy.on_click(on_load_phy) - restore_phy.observe(on_restore_phy) + restore_phy.on_click(on_restore_phy) run_save.on_click(on_save_to_nwb) strategy.observe(on_change_strategy) sorter_list.observe(on_sorter)