diff --git a/pyproject.toml b/pyproject.toml index c647e31..9886313 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "tbb>=2021.11.0; platform_system != 'Darwin'", "pynapple>=0.5.1", "lxml", + "spatial_maps" ] [project.urls] diff --git a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py index dcc9d8e..f921ae8 100644 --- a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py +++ b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from functools import partial -import warnings import ipywidgets as widgets import matplotlib.pyplot as plt import numpy as np @@ -164,19 +163,18 @@ def __init__( step=0.01, description="Bin size:", ) + self.flip_y_axis = widgets.Checkbox( + value=False, + description="Flip y-axis", + ) spatial_series_label = widgets.Label("Spatial Series:") top_panel = widgets.VBox( [ self.unit_list, self.unit_name_text, self.unit_info_text, - widgets.HBox( - [ - spatial_series_label, - self.spatial_series_selector, - ] - ), - widgets.HBox([self.smoothing_slider, self.bin_size_slider]), + widgets.HBox([spatial_series_label, self.spatial_series_selector]), + widgets.HBox([self.smoothing_slider, self.bin_size_slider, self.flip_y_axis]), ] ) self.controls = dict( @@ -184,6 +182,7 @@ def __init__( spatial_series_selector=self.spatial_series_selector, smoothing_slider=self.smoothing_slider, bin_size_slider=self.bin_size_slider, + flip_y_axis=self.flip_y_axis, ) self.rate_maps, self.extent = None, None self.compute_rate_maps() @@ -222,15 +221,7 @@ def get_spatial_series(self): def compute_rate_maps(self): import pynapple as nap - try: - from spatial_maps import SpatialMap - HAVE_SPATIAL_MAPS = True - except: - warnings.warn( - "spatial_maps not installed. Please install it to compute rate maps:\n" - ">>> pip install git+https://github.com/CINPLA/spatial-maps.git" - ) - HAVE_SPATIAL_MAPS = False + from spatial_maps import SpatialMap spatial_series = self.spatial_series[self.spatial_series_selector.value] x, y = spatial_series.data[:].T @@ -265,30 +256,36 @@ def compute_rate_maps(self): self.nap_position = nap_position self.nap_units = nap_units - if HAVE_SPATIAL_MAPS: - sm = SpatialMap( - bin_size=self.bin_size_slider.value, - smoothing=self.smoothing_slider.value, - ) - rate_maps = [] - for unit_index in self.units.id.data: - rate_map = sm.rate_map(x, y, t, unit_spike_times[unit_index]) - rate_maps.append(rate_map) - self.rate_maps = np.array(rate_maps) - else: - self.rate_maps = None + sm = SpatialMap( + bin_size=self.bin_size_slider.value, + smoothing=self.smoothing_slider.value, + ) + rate_maps = [] + for unit_index in self.units.id.data: + rate_map = sm.rate_map(x, y, t, unit_spike_times[unit_index]) + rate_maps.append(rate_map) + self.rate_maps = np.array(rate_maps) def on_spatial_series_change(self, change): self.compute_rate_maps() + self.show_unit_rate_maps(self.unit_list.value) def on_bin_size_change(self, change): self.compute_rate_maps() + self.show_unit_rate_maps(self.unit_list.value) def on_smoothing_change(self, change): self.compute_rate_maps() + self.show_unit_rate_maps(self.unit_list.value) def show_unit_rate_maps( - self, unit_index=None, spatial_series_selector=None, smoothing_slider=None, bin_size_slider=None, axs=None + self, + unit_index=None, + spatial_series_selector=None, + smoothing_slider=None, + bin_size_slider=None, + flip_y_axis=None, + axs=None, ): """ Shows unit rate maps. @@ -305,21 +302,26 @@ def show_unit_rate_maps( figsize = (10, 7) if axs is None: - fig, axs = plt.subplots(figsize=figsize, ncols=2) + fig, axs = plt.subplots(figsize=figsize, ncols=2, sharex=True, sharey=True) if hasattr(fig, "canvas"): fig.canvas.header_visible = False else: legend_kwargs.update(bbox_to_anchor=(1.01, 1)) - if self.rate_maps is not None: - axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin="lower", aspect="auto", extent=self.extent) - axs[0].set_xlabel("x") - axs[0].set_ylabel("y") - else: - axs[0].set_title("Rate maps not computed (spatial_maps not installed)") + origin = "lower" if self.flip_y_axis.value else "upper" + axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin=origin, aspect="auto", extent=self.extent) + axs[0].set_xlabel("x") + axs[0].set_ylabel("y") - axs[1].plot(self.nap_position["y"], self.nap_position["x"], color="grey") + tracking_x = self.nap_position["y"] + tracking_y = self.nap_position["x"] spk_pos = self.nap_units[unit_index].value_from(self.nap_position) - axs[1].plot(spk_pos["y"], spk_pos["x"], "o", color="red", markersize=5, alpha=0.5) + spike_pos_x = spk_pos["y"] + spike_pos_y = spk_pos["x"] + if not self.flip_y_axis.value: + tracking_y = 1 - tracking_y + spike_pos_y = 1 - spike_pos_y + axs[1].plot(tracking_x, tracking_y, color="grey") + axs[1].plot(spike_pos_x, spike_pos_y, "o", color="red", markersize=5, alpha=0.5) axs[1].set_xlabel("x") axs[1].set_ylabel("y") fig.tight_layout() diff --git a/src/expipe_plugin_cinpla/scripts/utils.py b/src/expipe_plugin_cinpla/scripts/utils.py index ab61a41..7399772 100644 --- a/src/expipe_plugin_cinpla/scripts/utils.py +++ b/src/expipe_plugin_cinpla/scripts/utils.py @@ -297,12 +297,14 @@ def compute_and_set_unit_groups(sorting, recording): unit_groups = recording.get_channel_groups()[np.array(list(extremum_channel_indices.values()))] sorting.set_property("group", unit_groups) else: - unit_groups = sorting.get_property("group").astype("str") + unit_groups = sorting.get_property("group").astype("U20") # if there are units without group, we need to compute them unit_ids_without_group = np.array(sorting.unit_ids)[np.where(unit_groups == "nan")[0]] if len(unit_ids_without_group) > 0: sorting_no_group = sorting.select_units(unit_ids=unit_ids_without_group) - we_mem = si.extract_waveforms(recording, sorting_no_group, folder=None, mode="memory", sparse=False) + we_mem = si.extract_waveforms( + recording, sorting_no_group, folder=None, mode="memory", sparse=False, progress_bar=False + ) extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index") unit_groups[sorting.ids_to_indices(unit_ids_without_group)] = recording.get_channel_groups()[ np.array(list(extremum_channel_indices.values())) diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py index 3347c94..65bbb3d 100644 --- a/src/expipe_plugin_cinpla/tools/data_processing.py +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -5,6 +5,7 @@ import expipe import numpy as np +import spatial_maps as sp from expipe_plugin_cinpla.data_loader import ( get_channel_groups, diff --git a/src/expipe_plugin_cinpla/widgets/curation.py b/src/expipe_plugin_cinpla/widgets/curation.py index b6b4a43..2ff9af0 100644 --- a/src/expipe_plugin_cinpla/widgets/curation.py +++ b/src/expipe_plugin_cinpla/widgets/curation.py @@ -104,9 +104,10 @@ def __init__(self, project): load_from_phy = ipywidgets.Button(description="Load from Phy", layout={"width": "initial"}) load_from_phy.style.button_color = "pink" - restore_phy = ipywidgets.ToggleButton( + self.restore_phy_clicked = False + restore_phy = ipywidgets.Button( value=False, - description="Restore", + description="Restore (click twice to restore)", disabled=False, button_style="", # 'success', 'info', 'warning', 'danger' or '' tooltip="Restore unsorted clusters", @@ -205,6 +206,14 @@ def on_action(change): sorters = [p.name for p in si_path.iterdir() if p.is_dir()] sorter_list.options = sorters if len(sorter_list.value) == 1: + units_raw = self.sorting_curator.load_raw_units(sorter_list.value[0]) + if units_raw is not None: + w = nwb2widget(units_raw, custom_raw_unit_vis) + units_viewers["raw"] = w + units_main = self.sorting_curator.load_main_units() + if units_main is not None: + w = nwb2widget(units_main, custom_main_unit_vis) + units_viewers["main"] = w if strategy.value == "Sortingview": sv_visualization_link.value = self.sorting_curator.get_sortingview_link(sorter_list.value[0]) elif strategy.value == "Phy": @@ -313,7 +322,15 @@ def on_restore_phy(change): if len(sorter_list.value) > 1: print("Select one spike sorting output at a time") else: - self.sorting_curator.restore_phy(sorter_list.value[0]) + if self.restore_phy_clicked: + self.sorting_curator.restore_phy(sorter_list.value[0]) + self.restore_phy_clicked = False + restore_phy.description = "Restore (click twice to restore)" + restore_phy.button_style = "primary" + else: + self.restore_phy_clicked = True + restore_phy.description = "Restore (click again to confirm)" + restore_phy.button_style = "danger" @self.output.capture() def on_apply_qm_curation(change): @@ -335,7 +352,7 @@ def on_save_to_nwb(change): actions_list.observe(on_action) load_from_phy.on_click(on_load_phy) - restore_phy.observe(on_restore_phy) + restore_phy.on_click(on_restore_phy) run_save.on_click(on_save_to_nwb) strategy.observe(on_change_strategy) sorter_list.observe(on_sorter)