Skip to content

Commit

Permalink
Merge pull request #96 from CINPLA/protect-phy-and-flip
Browse files Browse the repository at this point in the history
Various fixes
  • Loading branch information
alejoe91 authored Nov 21, 2024
2 parents b74e709 + f8963fa commit d031d06
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 45 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"tbb>=2021.11.0; platform_system != 'Darwin'",
"pynapple>=0.5.1",
"lxml",
"spatial_maps"
]

[project.urls]
Expand Down
80 changes: 41 additions & 39 deletions src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from functools import partial

import warnings
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -164,26 +163,26 @@ def __init__(
step=0.01,
description="Bin size:",
)
self.flip_y_axis = widgets.Checkbox(
value=False,
description="Flip y-axis",
)
spatial_series_label = widgets.Label("Spatial Series:")
top_panel = widgets.VBox(
[
self.unit_list,
self.unit_name_text,
self.unit_info_text,
widgets.HBox(
[
spatial_series_label,
self.spatial_series_selector,
]
),
widgets.HBox([self.smoothing_slider, self.bin_size_slider]),
widgets.HBox([spatial_series_label, self.spatial_series_selector]),
widgets.HBox([self.smoothing_slider, self.bin_size_slider, self.flip_y_axis]),
]
)
self.controls = dict(
unit_index=self.unit_list,
spatial_series_selector=self.spatial_series_selector,
smoothing_slider=self.smoothing_slider,
bin_size_slider=self.bin_size_slider,
flip_y_axis=self.flip_y_axis,
)
self.rate_maps, self.extent = None, None
self.compute_rate_maps()
Expand Down Expand Up @@ -222,15 +221,7 @@ def get_spatial_series(self):

def compute_rate_maps(self):
import pynapple as nap
try:
from spatial_maps import SpatialMap
HAVE_SPATIAL_MAPS = True
except:
warnings.warn(
"spatial_maps not installed. Please install it to compute rate maps:\n"
">>> pip install git+https://github.com/CINPLA/spatial-maps.git"
)
HAVE_SPATIAL_MAPS = False
from spatial_maps import SpatialMap

spatial_series = self.spatial_series[self.spatial_series_selector.value]
x, y = spatial_series.data[:].T
Expand Down Expand Up @@ -265,30 +256,36 @@ def compute_rate_maps(self):
self.nap_position = nap_position
self.nap_units = nap_units

if HAVE_SPATIAL_MAPS:
sm = SpatialMap(
bin_size=self.bin_size_slider.value,
smoothing=self.smoothing_slider.value,
)
rate_maps = []
for unit_index in self.units.id.data:
rate_map = sm.rate_map(x, y, t, unit_spike_times[unit_index])
rate_maps.append(rate_map)
self.rate_maps = np.array(rate_maps)
else:
self.rate_maps = None
sm = SpatialMap(
bin_size=self.bin_size_slider.value,
smoothing=self.smoothing_slider.value,
)
rate_maps = []
for unit_index in self.units.id.data:
rate_map = sm.rate_map(x, y, t, unit_spike_times[unit_index])
rate_maps.append(rate_map)
self.rate_maps = np.array(rate_maps)

def on_spatial_series_change(self, change):
self.compute_rate_maps()
self.show_unit_rate_maps(self.unit_list.value)

def on_bin_size_change(self, change):
self.compute_rate_maps()
self.show_unit_rate_maps(self.unit_list.value)

def on_smoothing_change(self, change):
self.compute_rate_maps()
self.show_unit_rate_maps(self.unit_list.value)

def show_unit_rate_maps(
self, unit_index=None, spatial_series_selector=None, smoothing_slider=None, bin_size_slider=None, axs=None
self,
unit_index=None,
spatial_series_selector=None,
smoothing_slider=None,
bin_size_slider=None,
flip_y_axis=None,
axs=None,
):
"""
Shows unit rate maps.
Expand All @@ -305,21 +302,26 @@ def show_unit_rate_maps(
figsize = (10, 7)

if axs is None:
fig, axs = plt.subplots(figsize=figsize, ncols=2)
fig, axs = plt.subplots(figsize=figsize, ncols=2, sharex=True, sharey=True)
if hasattr(fig, "canvas"):
fig.canvas.header_visible = False
else:
legend_kwargs.update(bbox_to_anchor=(1.01, 1))
if self.rate_maps is not None:
axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin="lower", aspect="auto", extent=self.extent)
axs[0].set_xlabel("x")
axs[0].set_ylabel("y")
else:
axs[0].set_title("Rate maps not computed (spatial_maps not installed)")
origin = "lower" if self.flip_y_axis.value else "upper"
axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin=origin, aspect="auto", extent=self.extent)
axs[0].set_xlabel("x")
axs[0].set_ylabel("y")

axs[1].plot(self.nap_position["y"], self.nap_position["x"], color="grey")
tracking_x = self.nap_position["y"]
tracking_y = self.nap_position["x"]
spk_pos = self.nap_units[unit_index].value_from(self.nap_position)
axs[1].plot(spk_pos["y"], spk_pos["x"], "o", color="red", markersize=5, alpha=0.5)
spike_pos_x = spk_pos["y"]
spike_pos_y = spk_pos["x"]
if not self.flip_y_axis.value:
tracking_y = 1 - tracking_y
spike_pos_y = 1 - spike_pos_y
axs[1].plot(tracking_x, tracking_y, color="grey")
axs[1].plot(spike_pos_x, spike_pos_y, "o", color="red", markersize=5, alpha=0.5)
axs[1].set_xlabel("x")
axs[1].set_ylabel("y")
fig.tight_layout()
Expand Down
6 changes: 4 additions & 2 deletions src/expipe_plugin_cinpla/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,14 @@ def compute_and_set_unit_groups(sorting, recording):
unit_groups = recording.get_channel_groups()[np.array(list(extremum_channel_indices.values()))]
sorting.set_property("group", unit_groups)
else:
unit_groups = sorting.get_property("group").astype("str")
unit_groups = sorting.get_property("group").astype("U20")
# if there are units without group, we need to compute them
unit_ids_without_group = np.array(sorting.unit_ids)[np.where(unit_groups == "nan")[0]]
if len(unit_ids_without_group) > 0:
sorting_no_group = sorting.select_units(unit_ids=unit_ids_without_group)
we_mem = si.extract_waveforms(recording, sorting_no_group, folder=None, mode="memory", sparse=False)
we_mem = si.extract_waveforms(
recording, sorting_no_group, folder=None, mode="memory", sparse=False, progress_bar=False
)
extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index")
unit_groups[sorting.ids_to_indices(unit_ids_without_group)] = recording.get_channel_groups()[
np.array(list(extremum_channel_indices.values()))
Expand Down
1 change: 1 addition & 0 deletions src/expipe_plugin_cinpla/tools/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import expipe
import numpy as np
import spatial_maps as sp

from expipe_plugin_cinpla.data_loader import (
get_channel_groups,
Expand Down
25 changes: 21 additions & 4 deletions src/expipe_plugin_cinpla/widgets/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def __init__(self, project):
load_from_phy = ipywidgets.Button(description="Load from Phy", layout={"width": "initial"})
load_from_phy.style.button_color = "pink"

restore_phy = ipywidgets.ToggleButton(
self.restore_phy_clicked = False
restore_phy = ipywidgets.Button(
value=False,
description="Restore",
description="Restore (click twice to restore)",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Restore unsorted clusters",
Expand Down Expand Up @@ -205,6 +206,14 @@ def on_action(change):
sorters = [p.name for p in si_path.iterdir() if p.is_dir()]
sorter_list.options = sorters
if len(sorter_list.value) == 1:
units_raw = self.sorting_curator.load_raw_units(sorter_list.value[0])
if units_raw is not None:
w = nwb2widget(units_raw, custom_raw_unit_vis)
units_viewers["raw"] = w
units_main = self.sorting_curator.load_main_units()
if units_main is not None:
w = nwb2widget(units_main, custom_main_unit_vis)
units_viewers["main"] = w
if strategy.value == "Sortingview":
sv_visualization_link.value = self.sorting_curator.get_sortingview_link(sorter_list.value[0])
elif strategy.value == "Phy":
Expand Down Expand Up @@ -313,7 +322,15 @@ def on_restore_phy(change):
if len(sorter_list.value) > 1:
print("Select one spike sorting output at a time")
else:
self.sorting_curator.restore_phy(sorter_list.value[0])
if self.restore_phy_clicked:
self.sorting_curator.restore_phy(sorter_list.value[0])
self.restore_phy_clicked = False
restore_phy.description = "Restore (click twice to restore)"
restore_phy.button_style = "primary"
else:
self.restore_phy_clicked = True
restore_phy.description = "Restore (click again to confirm)"
restore_phy.button_style = "danger"

@self.output.capture()
def on_apply_qm_curation(change):
Expand All @@ -335,7 +352,7 @@ def on_save_to_nwb(change):

actions_list.observe(on_action)
load_from_phy.on_click(on_load_phy)
restore_phy.observe(on_restore_phy)
restore_phy.on_click(on_restore_phy)
run_save.on_click(on_save_to_nwb)
strategy.observe(on_change_strategy)
sorter_list.observe(on_sorter)
Expand Down

0 comments on commit d031d06

Please sign in to comment.