Skip to content

Commit

Permalink
Multi segment handling in ensure_time_bins
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Jan 30, 2025
1 parent cbe5471 commit a0d658d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 20 deletions.
37 changes: 29 additions & 8 deletions src/spikeinterface/sortingcomponents/motion/motion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,10 +587,14 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
Going from centers to edges is done by taking midpoints and padding with the
left and rightmost centers.
To handle multi segment, this function is working both:
* array/array input
* list[array]/list[array] input
Parameters
----------
time_bin_centers_s : None or np.array
time_bin_edges_s : None or np.array
time_bin_centers_s : None or np.array or list[np.array]
time_bin_edges_s : None or np.array or list[np.array]
Returns
-------
Expand All @@ -600,17 +604,34 @@ def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.")

if time_bin_centers_s is None:
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])
if isinstance(time_bin_edges_s, list):
# multi segment cas
time_bin_centers_s = []
for be in time_bin_edges_s:
bc, _ = ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=be)
time_bin_centers_s.append(bc)
else:
# simple segment
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])

if time_bin_edges_s is None:
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]]
if time_bin_centers_s.size > 2:
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1])
if isinstance(time_bin_centers_s, list):
# multi segment cas
time_bin_edges_s = []
for bc in time_bin_centers_s:
_, be = ensure_time_bins(time_bin_centers_s=bc, time_bin_edges_s=None)
time_bin_edges_s.append(be)
else:
# simple segment
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]]
if time_bin_centers_s.size > 2:
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1])

return time_bin_centers_s, time_bin_edges_s



def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None):
return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1]
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,26 @@
interpolate_motion_on_traces,
)
from spikeinterface.sortingcomponents.tests.common import make_dataset

from spikeinterface.core import generate_ground_truth_recording

def make_fake_motion(rec):
# make a fake motion object
duration = rec.get_total_duration()

locs = rec.get_channel_locations()
temporal_bins = np.arange(0.5, duration - 0.49, 0.5)
spatial_bins = np.arange(locs[:, 1].min(), locs[:, 1].max(), 100)
displacement = np.zeros((temporal_bins.size, spatial_bins.size))
displacement[:, :] = np.linspace(-30, 30, temporal_bins.size)[:, None]

motion = Motion([displacement], [temporal_bins], spatial_bins, direction="y")
displacement = []
temporal_bins = []
for segment_index in range(rec.get_num_segments()):
duration = rec.get_duration(segment_index=segment_index)
seg_time_bins = np.arange(0.5, duration - 0.49, 0.5)
seg_disp = np.zeros((seg_time_bins.size, spatial_bins.size))
seg_disp[:, :] = np.linspace(-30, 30, seg_time_bins.size)[:, None]

temporal_bins.append(seg_time_bins)
displacement.append(seg_disp)

motion = Motion(displacement, temporal_bins, spatial_bins, direction="y")

return motion

Expand Down Expand Up @@ -176,7 +184,27 @@ def test_cross_band_interpolation():


def test_InterpolateMotionRecording():
rec, sorting = make_dataset()
# rec, sorting = make_dataset()

# 2 segments
rec, sorting = generate_ground_truth_recording(
durations=[30.0],
sampling_frequency=30000.0,
num_channels=32,
num_units=10,
generate_probe_kwargs=dict(
num_columns=2,
xpitch=20,
ypitch=20,
contact_shapes="circle",
contact_shape_params={"radius": 6},
),
generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
seed=2205,
)


motion = make_fake_motion(rec)

rec2 = InterpolateMotionRecording(rec, motion, border_mode="force_extrapolate")
Expand All @@ -187,15 +215,19 @@ def test_InterpolateMotionRecording():

rec2 = InterpolateMotionRecording(rec, motion, border_mode="remove_channels")
assert rec2.channel_ids.size == 24
for ch_id in (0, 1, 14, 15, 16, 17, 30, 31):
for ch_id in ("0", "1", "14", "15", "16", "17", "30", "31"):
assert ch_id not in rec2.channel_ids

traces = rec2.get_traces(segment_index=0, start_frame=0, end_frame=30000)
assert traces.shape == (30000, 24)

traces = rec2.get_traces(segment_index=0, start_frame=0, end_frame=30000, channel_ids=[3, 4])
traces = rec2.get_traces(segment_index=0, start_frame=0, end_frame=30000, channel_ids=["3", "4"])
assert traces.shape == (30000, 2)

# test dump.load when multi segments
rec2.dump("rec_motion_interp.pickle")
rec3 = sc.load("rec_motion_interp.pickle")

# import matplotlib.pyplot as plt
# import spikeinterface.widgets as sw
# fig, ax = plt.subplots()
Expand All @@ -207,7 +239,7 @@ def test_InterpolateMotionRecording():

if __name__ == "__main__":
# test_correct_motion_on_peaks()
test_interpolate_motion_on_traces()
# test_interpolate_motion_on_traces()
# test_interpolation_simple()
# test_InterpolateMotionRecording()
test_cross_band_interpolation()
test_InterpolateMotionRecording()
# test_cross_band_interpolation()

0 comments on commit a0d658d

Please sign in to comment.