Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes #96

Merged
merged 7 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"tbb>=2021.11.0; platform_system != 'Darwin'",
"pynapple>=0.5.1",
"lxml",
"spatial_maps"
]

[project.urls]
Expand Down
80 changes: 41 additions & 39 deletions src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from functools import partial

import warnings
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -164,26 +163,26 @@ def __init__(
step=0.01,
description="Bin size:",
)
self.flip_y_axis = widgets.Checkbox(
value=False,
description="Flip y-axis",
)
spatial_series_label = widgets.Label("Spatial Series:")
top_panel = widgets.VBox(
[
self.unit_list,
self.unit_name_text,
self.unit_info_text,
widgets.HBox(
[
spatial_series_label,
self.spatial_series_selector,
]
),
widgets.HBox([self.smoothing_slider, self.bin_size_slider]),
widgets.HBox([spatial_series_label, self.spatial_series_selector]),
widgets.HBox([self.smoothing_slider, self.bin_size_slider, self.flip_y_axis]),
]
)
self.controls = dict(
unit_index=self.unit_list,
spatial_series_selector=self.spatial_series_selector,
smoothing_slider=self.smoothing_slider,
bin_size_slider=self.bin_size_slider,
flip_y_axis=self.flip_y_axis,
)
self.rate_maps, self.extent = None, None
self.compute_rate_maps()
Expand Down Expand Up @@ -222,15 +221,7 @@ def get_spatial_series(self):

def compute_rate_maps(self):
import pynapple as nap
try:
from spatial_maps import SpatialMap
HAVE_SPATIAL_MAPS = True
except:
warnings.warn(
"spatial_maps not installed. Please install it to compute rate maps:\n"
">>> pip install git+https://github.com/CINPLA/spatial-maps.git"
)
HAVE_SPATIAL_MAPS = False
from spatial_maps import SpatialMap

spatial_series = self.spatial_series[self.spatial_series_selector.value]
x, y = spatial_series.data[:].T
Expand Down Expand Up @@ -265,30 +256,36 @@ def compute_rate_maps(self):
self.nap_position = nap_position
self.nap_units = nap_units

if HAVE_SPATIAL_MAPS:
sm = SpatialMap(
bin_size=self.bin_size_slider.value,
smoothing=self.smoothing_slider.value,
)
rate_maps = []
for unit_index in self.units.id.data:
rate_map = sm.rate_map(x, y, t, unit_spike_times[unit_index])
rate_maps.append(rate_map)
self.rate_maps = np.array(rate_maps)
else:
self.rate_maps = None
sm = SpatialMap(
bin_size=self.bin_size_slider.value,
smoothing=self.smoothing_slider.value,
)
rate_maps = []
for unit_index in self.units.id.data:
rate_map = sm.rate_map(x, y, t, unit_spike_times[unit_index])
rate_maps.append(rate_map)
self.rate_maps = np.array(rate_maps)

def on_spatial_series_change(self, change):
self.compute_rate_maps()
self.show_unit_rate_maps(self.unit_list.value)

def on_bin_size_change(self, change):
self.compute_rate_maps()
self.show_unit_rate_maps(self.unit_list.value)

def on_smoothing_change(self, change):
self.compute_rate_maps()
self.show_unit_rate_maps(self.unit_list.value)

def show_unit_rate_maps(
self, unit_index=None, spatial_series_selector=None, smoothing_slider=None, bin_size_slider=None, axs=None
self,
unit_index=None,
spatial_series_selector=None,
smoothing_slider=None,
bin_size_slider=None,
flip_y_axis=None,
axs=None,
):
"""
Shows unit rate maps.
Expand All @@ -305,21 +302,26 @@ def show_unit_rate_maps(
figsize = (10, 7)

if axs is None:
fig, axs = plt.subplots(figsize=figsize, ncols=2)
fig, axs = plt.subplots(figsize=figsize, ncols=2, sharex=True, sharey=True)
if hasattr(fig, "canvas"):
fig.canvas.header_visible = False
else:
legend_kwargs.update(bbox_to_anchor=(1.01, 1))
if self.rate_maps is not None:
axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin="lower", aspect="auto", extent=self.extent)
axs[0].set_xlabel("x")
axs[0].set_ylabel("y")
else:
axs[0].set_title("Rate maps not computed (spatial_maps not installed)")
origin = "lower" if self.flip_y_axis.value else "upper"
axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin=origin, aspect="auto", extent=self.extent)
axs[0].set_xlabel("x")
axs[0].set_ylabel("y")

axs[1].plot(self.nap_position["y"], self.nap_position["x"], color="grey")
tracking_x = self.nap_position["y"]
tracking_y = self.nap_position["x"]
spk_pos = self.nap_units[unit_index].value_from(self.nap_position)
axs[1].plot(spk_pos["y"], spk_pos["x"], "o", color="red", markersize=5, alpha=0.5)
spike_pos_x = spk_pos["y"]
spike_pos_y = spk_pos["x"]
if not self.flip_y_axis.value:
tracking_y = 1 - tracking_y
spike_pos_y = 1 - spike_pos_y
axs[1].plot(tracking_x, tracking_y, color="grey")
axs[1].plot(spike_pos_x, spike_pos_y, "o", color="red", markersize=5, alpha=0.5)
axs[1].set_xlabel("x")
axs[1].set_ylabel("y")
fig.tight_layout()
Expand Down
6 changes: 4 additions & 2 deletions src/expipe_plugin_cinpla/scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,14 @@ def compute_and_set_unit_groups(sorting, recording):
unit_groups = recording.get_channel_groups()[np.array(list(extremum_channel_indices.values()))]
sorting.set_property("group", unit_groups)
else:
unit_groups = sorting.get_property("group").astype("str")
unit_groups = sorting.get_property("group").astype("U20")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mariapfj this was the cause of the last issue. In case there are only merges/splits in Phy, all groups are nan so casting to string makes a 3-character object. When setting a value to tetrode*, it became tet...

# if there are units without group, we need to compute them
unit_ids_without_group = np.array(sorting.unit_ids)[np.where(unit_groups == "nan")[0]]
if len(unit_ids_without_group) > 0:
sorting_no_group = sorting.select_units(unit_ids=unit_ids_without_group)
we_mem = si.extract_waveforms(recording, sorting_no_group, folder=None, mode="memory", sparse=False)
we_mem = si.extract_waveforms(
recording, sorting_no_group, folder=None, mode="memory", sparse=False, progress_bar=False
)
extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index")
unit_groups[sorting.ids_to_indices(unit_ids_without_group)] = recording.get_channel_groups()[
np.array(list(extremum_channel_indices.values()))
Expand Down
1 change: 1 addition & 0 deletions src/expipe_plugin_cinpla/tools/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import expipe
import numpy as np
import spatial_maps as sp

from expipe_plugin_cinpla.data_loader import (
get_channel_groups,
Expand Down
25 changes: 21 additions & 4 deletions src/expipe_plugin_cinpla/widgets/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ def __init__(self, project):
load_from_phy = ipywidgets.Button(description="Load from Phy", layout={"width": "initial"})
load_from_phy.style.button_color = "pink"

restore_phy = ipywidgets.ToggleButton(
self.restore_phy_clicked = False
restore_phy = ipywidgets.Button(
value=False,
description="Restore",
description="Restore (click twice to restore)",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Restore unsorted clusters",
Expand Down Expand Up @@ -205,6 +206,14 @@ def on_action(change):
sorters = [p.name for p in si_path.iterdir() if p.is_dir()]
sorter_list.options = sorters
if len(sorter_list.value) == 1:
units_raw = self.sorting_curator.load_raw_units(sorter_list.value[0])
if units_raw is not None:
w = nwb2widget(units_raw, custom_raw_unit_vis)
units_viewers["raw"] = w
units_main = self.sorting_curator.load_main_units()
if units_main is not None:
w = nwb2widget(units_main, custom_main_unit_vis)
units_viewers["main"] = w
if strategy.value == "Sortingview":
sv_visualization_link.value = self.sorting_curator.get_sortingview_link(sorter_list.value[0])
elif strategy.value == "Phy":
Expand Down Expand Up @@ -313,7 +322,15 @@ def on_restore_phy(change):
if len(sorter_list.value) > 1:
print("Select one spike sorting output at a time")
else:
self.sorting_curator.restore_phy(sorter_list.value[0])
if self.restore_phy_clicked:
self.sorting_curator.restore_phy(sorter_list.value[0])
self.restore_phy_clicked = False
restore_phy.description = "Restore (click twice to restore)"
restore_phy.button_style = "primary"
else:
self.restore_phy_clicked = True
restore_phy.description = "Restore (click again to confirm)"
restore_phy.button_style = "danger"

@self.output.capture()
def on_apply_qm_curation(change):
Expand All @@ -335,7 +352,7 @@ def on_save_to_nwb(change):

actions_list.observe(on_action)
load_from_phy.on_click(on_load_phy)
restore_phy.observe(on_restore_phy)
restore_phy.on_click(on_restore_phy)
run_save.on_click(on_save_to_nwb)
strategy.observe(on_change_strategy)
sorter_list.observe(on_sorter)
Expand Down
Loading