Skip to content

Commit

Permalink
add auxiliar method, not sort by default
Browse files Browse the repository at this point in the history
  • Loading branch information
h-mayorquin committed Aug 14, 2024
1 parent c0cc715 commit 532d7f7
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
28 changes: 15 additions & 13 deletions src/pynwb/ndx_binned_spikes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,21 @@ def __init__(self, **kwargs):
name = kwargs.pop("name")
super().__init__(name=name)

# # Sort the data by the timestamps
# timestamps = kwargs["timestamps"]
# event_indices = kwargs["event_indices"]
# data = kwargs["data"]
timestamps = kwargs["timestamps"]
event_indices = kwargs["event_indices"]
data = kwargs["data"]

# sorted_indices = np.argsort(timestamps)
# data = data[:, sorted_indices, :]
# timestamps = timestamps[sorted_indices]
# event_indices = event_indices[sorted_indices]
assert data.shape[1] == timestamps.shape[0], "The number of timestamps must match the second axis of data."
assert event_indices.shape[0] == timestamps.shape[0], "The number of timestamps must match the event_indices."

# kwargs["data"] = data
# kwargs["timestamps"] = timestamps
# kwargs["event_indices"] = event_indices
# Assert timestamps are monotonically increasing
if not np.all(np.diff(kwargs["timestamps"]) >= 0):
error_msg = (
"The timestamps must be monotonically increasing and the data and event_indices "
"must be sorted by timestamps. Use the `sort_data_by_timestamps` method to do this "
"automatically before passing the data to the constructor."
)
raise ValueError(error_msg)

for key in kwargs:
setattr(self, key, kwargs[key])
Expand All @@ -220,12 +222,12 @@ def get_timestamps_for_stimuli(self, event_index):
return timestamps

@staticmethod
def sort_data_by_time(
def sort_data_by_timestamps(
data: np.ndarray,
timestamps: np.ndarray,
event_indices: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:

sorted_indices = np.argsort(timestamps)
data = data[:, sorted_indices, :]
timestamps = timestamps[sorted_indices]
Expand Down
26 changes: 22 additions & 4 deletions src/pynwb/tests/test_aggregated_binned_aligned_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,22 @@ def setUp(self):
def test_constructor(self):
"""Test that the constructor for AggregatedBinnedAlignedSpikes sets values as expected."""

data, timestamps, event_indices = AggregatedBinnedAlignedSpikes.sort_data_by_time(
with self.assertRaises(ValueError):
AggregatedBinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=self.data,
timestamps=self.timestamps,
event_indices=self.event_indices,
)


data, timestamps, event_indices = AggregatedBinnedAlignedSpikes.sort_data_by_timestamps(
self.data,
self.timestamps,
self.event_indices,
)

aggregated_binnned_align_spikes = AggregatedBinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
Expand All @@ -104,12 +115,19 @@ def test_constructor(self):

def test_get_single_event_data_methods(self):


data, timestamps, event_indices = AggregatedBinnedAlignedSpikes.sort_data_by_timestamps(
self.data,
self.timestamps,
self.event_indices,
)

aggregated_binnned_align_spikes = AggregatedBinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=self.data,
timestamps=self.timestamps,
event_indices=self.event_indices,
data=data,
timestamps=timestamps,
event_indices=event_indices,
)

data_for_stimuli_1 = aggregated_binnned_align_spikes.get_data_for_stimuli(event_index=0)
Expand Down

0 comments on commit 532d7f7

Please sign in to comment.