Skip to content

Commit

Permalink
Merge pull request #94 from CINPLA/spatial-map
Browse files Browse the repository at this point in the history
Use `spatial-map` package for ratemaps
  • Loading branch information
alejoe91 authored Nov 21, 2024
2 parents 2d3823d + bd4c4ea commit b74e709
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 42 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ dependencies = [
"tbb>=2021.11.0; platform_system != 'Darwin'",
"pynapple>=0.5.1",
"lxml",
"spatial_maps@git+https://github.com/CINPLA/spatial-maps",
]

[project.urls]
Expand Down
10 changes: 0 additions & 10 deletions requirements.txt

This file was deleted.

107 changes: 78 additions & 29 deletions src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
from functools import partial

import warnings
import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -105,7 +106,6 @@ def show_unit_waveforms(units: "pynwb.mis.Units", unit_index=None, ax=None):
return ax


# TODO: use SpatialMaps instead
class UnitRateMapWidget(widgets.VBox):
def __init__(
self,
Expand All @@ -124,13 +124,13 @@ def __init__(
elif len(self.spatial_series) > 1:
self.spatial_series_selector = widgets.Dropdown(
options=list(self.spatial_series.keys()),
description="Spatial Series:",
description="",
layout=dict(width="200px", display="flex", justify_content="flex-start"),
)
else:
self.spatial_series_selector = widgets.Text(
value=list(self.spatial_series.keys())[0],
description="Spatial Series:",
description="",
layout=dict(width="200px", display="flex", justify_content="flex-start"),
disabled=True,
)
Expand All @@ -150,28 +150,42 @@ def __init__(
if "original_cluster_id" in self.units.colnames:
unit_info_text += " - Phy ID: "
self.unit_info_text = widgets.Label(unit_info_text, layout=dict(width="90%"))
self.num_bins_slider = widgets.IntSlider(
value=30,
min=5,
max=100,
step=1,
description="Bins:",
self.smoothing_slider = widgets.FloatSlider(
value=0.05,
min=0,
max=0.2,
step=0.01,
description="Smoothing:",
)
self.bin_size_slider = widgets.FloatSlider(
value=0.02,
min=0.01,
max=0.2,
step=0.01,
description="Bin size:",
)
spatial_series_label = widgets.Label("Spatial Series:")
top_panel = widgets.VBox(
[
self.unit_list,
self.unit_name_text,
self.unit_info_text,
self.spatial_series_selector,
self.num_bins_slider,
widgets.HBox(
[
spatial_series_label,
self.spatial_series_selector,
]
),
widgets.HBox([self.smoothing_slider, self.bin_size_slider]),
]
)
self.controls = dict(
unit_index=self.unit_list,
spatial_series_selector=self.spatial_series_selector,
num_bins_slider=self.num_bins_slider,
smoothing_slider=self.smoothing_slider,
bin_size_slider=self.bin_size_slider,
)
self.rate_maps, self.binsxy, self.extent = None, None, None
self.rate_maps, self.extent = None, None
self.compute_rate_maps()

out_fig = interactive_output(self.show_unit_rate_maps, self.controls)
Expand All @@ -182,7 +196,8 @@ def __init__(

self.unit_list.observe(self.on_unit_change, names="value")
self.spatial_series_selector.observe(self.on_spatial_series_change, names="value")
self.num_bins_slider.observe(self.on_num_bins_change, names="value")
self.smoothing_slider.observe(self.on_smoothing_change, names="value")
self.bin_size_slider.observe(self.on_bin_size_change, names="value")
self.on_unit_change(None)

def on_unit_change(self, change):
Expand All @@ -207,16 +222,33 @@ def get_spatial_series(self):

def compute_rate_maps(self):
import pynapple as nap
try:
from spatial_maps import SpatialMap
HAVE_SPATIAL_MAPS = True
except:
warnings.warn(
"spatial_maps not installed. Please install it to compute rate maps:\n"
">>> pip install git+https://github.com/CINPLA/spatial-maps.git"
)
HAVE_SPATIAL_MAPS = False

spatial_series = self.spatial_series[self.spatial_series_selector.value]
x, y = spatial_series.data[:].T
t = spatial_series.timestamps[:]

# Remove NaNs
mask = np.logical_not(np.isnan(spatial_series.data)).T
mask_and = np.logical_and(mask[0], mask[1])
# Remove NaNs and zeros
mask_nan = np.logical_not(np.isnan(spatial_series.data)).T
mask_nan = np.logical_and(mask_nan[0], mask_nan[1])
mask_zeros = np.logical_and(x != 0, y != 0)
mask = np.logical_and(mask_nan, mask_zeros)

x = x[mask]
y = y[mask]
t = t[mask]

nap_position = nap.TsdFrame(
d=spatial_series.data[mask_and],
t=spatial_series.timestamps[mask_and],
d=spatial_series.data[mask],
t=spatial_series.timestamps[mask],
columns=["x", "y"],
)
self.extent = (
Expand All @@ -230,17 +262,34 @@ def compute_rate_maps(self):
unit_names = self.units["unit_name"][:]
unit_spike_times = self.units["spike_times"][:]
nap_units = nap.TsGroup({i: np.array(unit_spike_times[i]) for i in range(len(unit_names))})
self.rate_maps, self.binsxy = nap.compute_2d_tuning_curves(nap_units, nap_position, self.num_bins_slider.value)
self.nap_position = nap_position
self.nap_units = nap_units

if HAVE_SPATIAL_MAPS:
sm = SpatialMap(
bin_size=self.bin_size_slider.value,
smoothing=self.smoothing_slider.value,
)
rate_maps = []
for unit_index in self.units.id.data:
rate_map = sm.rate_map(x, y, t, unit_spike_times[unit_index])
rate_maps.append(rate_map)
self.rate_maps = np.array(rate_maps)
else:
self.rate_maps = None

def on_spatial_series_change(self, change):
self.compute_rate_maps()

def on_num_bins_change(self, change):
def on_bin_size_change(self, change):
self.compute_rate_maps()

def show_unit_rate_maps(self, unit_index=None, spatial_series_selector=None, num_bins_slider=None, axs=None):
def on_smoothing_change(self, change):
self.compute_rate_maps()

def show_unit_rate_maps(
self, unit_index=None, spatial_series_selector=None, smoothing_slider=None, bin_size_slider=None, axs=None
):
"""
Shows unit rate maps.
Expand All @@ -251,20 +300,22 @@ def show_unit_rate_maps(self, unit_index=None, spatial_series_selector=None, num
"""
if unit_index is None:
return
if self.rate_maps is None:
return

legend_kwargs = dict()
figsize = (10, 7)

if axs is None:
fig, axs = plt.subplots(figsize=figsize, ncols=2)
if hasattr(fig, "canvas"):
fig.canvas.header_visible = False
else:
legend_kwargs.update(bbox_to_anchor=(1.01, 1))
axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin="lower", aspect="auto", extent=self.extent)
axs[0].set_xlabel("x")
axs[0].set_ylabel("y")
if self.rate_maps is not None:
axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin="lower", aspect="auto", extent=self.extent)
axs[0].set_xlabel("x")
axs[0].set_ylabel("y")
else:
axs[0].set_title("Rate maps not computed (spatial_maps not installed)")

axs[1].plot(self.nap_position["y"], self.nap_position["x"], color="grey")
spk_pos = self.nap_units[unit_index].value_from(self.nap_position)
Expand Down Expand Up @@ -300,6 +351,4 @@ def get_custom_spec():

custom_neurodata_vis_spec[Units] = units_view

# TODO: add Place Fields widget

return custom_neurodata_vis_spec
4 changes: 4 additions & 0 deletions src/expipe_plugin_cinpla/scripts/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def process_ecephys(
nwb_path = utils._get_data_path(action)
nwb_path_tmp = nwb_path.parent / "main_tmp.nwb"

# clean up tmp NWB file in case of crash
if nwb_path_tmp.is_file():
nwb_path_tmp.unlink()

si.set_global_job_kwargs(n_jobs=-1, progress_bar=False)

if overwrite:
Expand Down
1 change: 0 additions & 1 deletion src/expipe_plugin_cinpla/tools/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import expipe
import numpy as np
import spatial_maps as sp

from expipe_plugin_cinpla.data_loader import (
get_channel_groups,
Expand Down
2 changes: 1 addition & 1 deletion src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def plot_matches(self, channel_group=None, figsize=(10, 3)):
The figure size
"""
if channel_group is None:
ch_groups = self.identified_units.keys()
ch_groups = sorted(self.identified_units.keys())
else:
ch_groups = [channel_group]
for ch_group in ch_groups:
Expand Down

0 comments on commit b74e709

Please sign in to comment.