Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convinience properties #19

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ data = np.array(
[2, 7, 4, 1], # Bin counts around the third timestamp
],
],
dtype="uint64",
)

event_timestamps = np.array([0.25, 5.0, 12.25]) # The timestamps to which we align the counts
Expand Down Expand Up @@ -249,7 +250,7 @@ data_for_first_stimuli = np.array(
],
)

# Also two units and 4 bins but this event appeared three times
# Also two units and 4 bins but this condition occurred three times
data_for_second_stimuli = np.array(
[
# Unit 1
Expand Down
19 changes: 19 additions & 0 deletions src/pynwb/ndx_binned_spikes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,25 @@ def sort_data_by_event_timestamps(

return data, event_timestamps, condition_indices

@property
def number_of_units(self):
return self.data.shape[0]

@property
def number_of_events(self):
return self.data.shape[1]

@property
def number_of_bins(self):
return self.data.shape[2]


@property
def number_of_conditions(self):
if self.has_multiple_conditions:
return np.unique(self.condition_indices).size
else:
return 1

# Remove these functions from the package
del load_namespaces, get_class
6 changes: 3 additions & 3 deletions src/pynwb/ndx_binned_spikes/testing/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def mock_BinnedAlignedSpikes(
number_of_units, number_of_events, number_of_bins = data.shape
else:
rng = np.random.default_rng(seed=seed)
data = rng.integers(low=0, high=100, size=(number_of_units, number_of_events, number_of_bins))
data = rng.integers(low=0, high=100, size=(number_of_units, number_of_events, number_of_bins), dtype="uint64")

# Assert data shapes
assertion_msg = (
Expand All @@ -121,8 +121,8 @@ def mock_BinnedAlignedSpikes(
number_of_conditions < number_of_events
), "The number of conditions should be less than the number of events."

condition_indices = np.zeros(number_of_events, dtype=int)
all_indices = np.arange(number_of_conditions, dtype=int)
condition_indices = np.zeros(number_of_events, dtype="uint64")
all_indices = np.arange(number_of_conditions, dtype='uint64')

# Ensure all conditions indices appear at least once
condition_indices[:number_of_conditions] = rng.choice(all_indices, size=number_of_conditions, replace=False)
Expand Down
82 changes: 50 additions & 32 deletions src/pynwb/tests/test_binned_aligned_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ def test_constructor(self):

np.testing.assert_array_equal(binned_aligned_spikes.data, self.data)
np.testing.assert_array_equal(binned_aligned_spikes.event_timestamps, self.event_timestamps)

self.assertEqual(binned_aligned_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds)
self.assertEqual(
binned_aligned_spikes.milliseconds_from_event_to_first_bin, self.milliseconds_from_event_to_first_bin
)

self.assertEqual(binned_aligned_spikes.data.shape[0], self.number_of_units)
self.assertEqual(binned_aligned_spikes.data.shape[1], self.number_of_events)
self.assertEqual(binned_aligned_spikes.data.shape[2], self.number_of_bins)
self.assertEqual(binned_aligned_spikes.number_of_units, self.number_of_units)
self.assertEqual(binned_aligned_spikes.number_of_events, self.number_of_events)
self.assertEqual(binned_aligned_spikes.number_of_bins, self.number_of_bins)

def test_constructor_units_region(self):

Expand Down Expand Up @@ -121,7 +122,7 @@ def setUp(self):
self.bin_width_in_milliseconds = 20.0
self.milliseconds_from_event_to_first_bin = -100.0

# Two units in total and 4 bins, and event with two timestamps
# Two units in total and 4 bins, and condition with two timestamps
self.data_for_first_condition = np.array(
[
# Unit 1 data
Expand All @@ -135,9 +136,11 @@ def setUp(self):
[12, 13, 14, 15], # Bin counts around the second timestamp
],
],
dtype="uint64",

)

# Also two units and 4 bins but this event appeared three times
# Also two units and 4 bins but this condition appeared three times
self.data_for_second_condition = np.array(
[
# Unit 1 data
Expand All @@ -152,7 +155,8 @@ def setUp(self):
[16, 17, 18, 19], # Bin counts around the second timestamp
[20, 21, 22, 23], # Bin counts around the third timestamp
],
]
],
dtype="uint64",
)

self.timestamps_first_condition = [5.0, 15.0]
Expand All @@ -167,7 +171,7 @@ def setUp(self):
self.event_timestamps = np.concatenate([self.timestamps_first_condition, self.timestamps_second_condition])

self.sorted_indices = np.argsort(self.event_timestamps)

self.condition_labels = ["first", "second"]

def test_constructor(self):
Expand All @@ -189,7 +193,7 @@ def test_constructor(self):
self.condition_indices,
)

aggregated_binnned_align_spikes = BinnedAlignedSpikes(
binnned_align_spikes = BinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=data,
Expand All @@ -198,27 +202,23 @@ def test_constructor(self):
condition_labels=self.condition_labels,
)

np.testing.assert_array_equal(aggregated_binnned_align_spikes.data, self.data[:, self.sorted_indices, :])
np.testing.assert_array_equal(binnned_align_spikes.data, self.data[:, self.sorted_indices, :])
np.testing.assert_array_equal(
aggregated_binnned_align_spikes.condition_indices, self.condition_indices[self.sorted_indices]
binnned_align_spikes.condition_indices, self.condition_indices[self.sorted_indices]
)
np.testing.assert_array_equal(
aggregated_binnned_align_spikes.event_timestamps, self.event_timestamps[self.sorted_indices]
)

np.testing.assert_array_equal(
aggregated_binnned_align_spikes.condition_labels, self.condition_labels
)

self.assertEqual(aggregated_binnned_align_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds)
np.testing.assert_array_equal(binnned_align_spikes.event_timestamps, self.event_timestamps[self.sorted_indices])

np.testing.assert_array_equal(binnned_align_spikes.condition_labels, self.condition_labels)

self.assertEqual(binnned_align_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds)
self.assertEqual(
aggregated_binnned_align_spikes.milliseconds_from_event_to_first_bin,
binnned_align_spikes.milliseconds_from_event_to_first_bin,
self.milliseconds_from_event_to_first_bin,
)

self.assertEqual(aggregated_binnned_align_spikes.data.shape[0], self.number_of_units)
self.assertEqual(aggregated_binnned_align_spikes.data.shape[1], self.number_of_events)
self.assertEqual(aggregated_binnned_align_spikes.data.shape[2], self.number_of_bins)
self.assertEqual(binnned_align_spikes.number_of_units, self.number_of_units)
self.assertEqual(binnned_align_spikes.number_of_events, self.number_of_events)
self.assertEqual(binnned_align_spikes.number_of_bins, self.number_of_bins)

def test_get_single_condition_data_methods(self):

Expand All @@ -228,24 +228,24 @@ def test_get_single_condition_data_methods(self):
self.condition_indices,
)

aggregated_binnned_align_spikes = BinnedAlignedSpikes(
binnned_align_spikes = BinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=data,
event_timestamps=event_timestamps,
condition_indices=condition_indices,
)

data_condition1 = aggregated_binnned_align_spikes.get_data_for_condition(condition_index=0)
data_condition1 = binnned_align_spikes.get_data_for_condition(condition_index=0)
np.testing.assert_allclose(data_condition1, self.data_for_first_condition)

data_condition2 = aggregated_binnned_align_spikes.get_data_for_condition(condition_index=1)
data_condition2 = binnned_align_spikes.get_data_for_condition(condition_index=1)
np.testing.assert_allclose(data_condition2, self.data_for_second_condition)

timestamps_condition1 = aggregated_binnned_align_spikes.get_event_timestamps_for_condition(condition_index=0)
timestamps_condition1 = binnned_align_spikes.get_event_timestamps_for_condition(condition_index=0)
np.testing.assert_allclose(timestamps_condition1, self.timestamps_first_condition)

timestamps_condition2 = aggregated_binnned_align_spikes.get_event_timestamps_for_condition(condition_index=1)
timestamps_condition2 = binnned_align_spikes.get_event_timestamps_for_condition(condition_index=1)
np.testing.assert_allclose(timestamps_condition2, self.timestamps_second_condition)


Expand All @@ -265,9 +265,20 @@ def test_roundtrip_acquisition(self):
Add a BinnedAlignedSpikes to an NWBFile, write it to file, read the file
and test that the BinnedAlignedSpikes from the file matches the original BinnedAlignedSpikes.
"""

# Testing here
self.binned_aligned_spikes = mock_BinnedAlignedSpikes(number_of_conditions=3, condition_labels=["a", "b", "c"])
number_of_units = 5
number_of_bins = 10
number_of_events = 100
number_of_conditions = 3
condition_labels = ["a", "b", "c"]

self.binned_aligned_spikes = mock_BinnedAlignedSpikes(
number_of_units=number_of_units,
number_of_bins=number_of_bins,
number_of_events=number_of_events,
number_of_conditions=number_of_conditions,
condition_labels=condition_labels,
)

self.nwbfile.add_acquisition(self.binned_aligned_spikes)

Expand All @@ -276,8 +287,13 @@ def test_roundtrip_acquisition(self):

with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io:
read_nwbfile = io.read()
read_container = read_nwbfile.acquisition["BinnedAlignedSpikes"]
self.assertContainerEqual(self.binned_aligned_spikes, read_container)
read_binned_aligned_spikes = read_nwbfile.acquisition["BinnedAlignedSpikes"]
self.assertContainerEqual(self.binned_aligned_spikes, read_binned_aligned_spikes)

assert read_binned_aligned_spikes.number_of_units == number_of_units
assert read_binned_aligned_spikes.number_of_bins == number_of_bins
assert read_binned_aligned_spikes.number_of_events == number_of_events
assert read_binned_aligned_spikes.number_of_conditions == number_of_conditions

def test_roundtrip_processing_module(self):
self.binned_aligned_spikes = mock_BinnedAlignedSpikes()
Expand Down Expand Up @@ -312,3 +328,5 @@ def test_roundtrip_with_units_table(self):
read_nwbfile = io.read()
read_container = read_nwbfile.acquisition["BinnedAlignedSpikes"]
self.assertContainerEqual(binned_aligned_spikes_with_region, read_container)


Loading