diff --git a/README.md b/README.md index ba465db..cc4f199 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ data = np.array( [2, 7, 4, 1], # Bin counts around the third timestamp ], ], + dtype="uint64", ) event_timestamps = np.array([0.25, 5.0, 12.25]) # The timestamps to which we align the counts @@ -249,7 +250,7 @@ data_for_first_stimuli = np.array( ], ) -# Also two units and 4 bins but this event appeared three times +# Also two units and 4 bins but this condition occurred three times data_for_second_stimuli = np.array( [ # Unit 1 diff --git a/src/pynwb/ndx_binned_spikes/__init__.py b/src/pynwb/ndx_binned_spikes/__init__.py index 53dd29c..bf83b3d 100644 --- a/src/pynwb/ndx_binned_spikes/__init__.py +++ b/src/pynwb/ndx_binned_spikes/__init__.py @@ -185,6 +185,25 @@ def sort_data_by_event_timestamps( return data, event_timestamps, condition_indices + @property + def number_of_units(self): + return self.data.shape[0] + + @property + def number_of_events(self): + return self.data.shape[1] + + @property + def number_of_bins(self): + return self.data.shape[2] + + + @property + def number_of_conditions(self): + if self.has_multiple_conditions: + return np.unique(self.condition_indices).size + else: + return 1 # Remove these functions from the package del load_namespaces, get_class diff --git a/src/pynwb/ndx_binned_spikes/testing/mock.py b/src/pynwb/ndx_binned_spikes/testing/mock.py index 939b928..798a6b5 100644 --- a/src/pynwb/ndx_binned_spikes/testing/mock.py +++ b/src/pynwb/ndx_binned_spikes/testing/mock.py @@ -99,7 +99,7 @@ def mock_BinnedAlignedSpikes( number_of_units, number_of_events, number_of_bins = data.shape else: rng = np.random.default_rng(seed=seed) - data = rng.integers(low=0, high=100, size=(number_of_units, number_of_events, number_of_bins)) + data = rng.integers(low=0, high=100, size=(number_of_units, number_of_events, number_of_bins), dtype="uint64") # Assert data shapes assertion_msg = ( @@ -121,8 +121,8 @@ def mock_BinnedAlignedSpikes( number_of_conditions < number_of_events ), "The number of conditions should be less than the number of events." - condition_indices = np.zeros(number_of_events, dtype=int) - all_indices = np.arange(number_of_conditions, dtype=int) + condition_indices = np.zeros(number_of_events, dtype="uint64") + all_indices = np.arange(number_of_conditions, dtype='uint64') # Ensure all conditions indices appear at least once condition_indices[:number_of_conditions] = rng.choice(all_indices, size=number_of_conditions, replace=False) diff --git a/src/pynwb/tests/test_binned_aligned_spikes.py b/src/pynwb/tests/test_binned_aligned_spikes.py index 2582987..e2322be 100644 --- a/src/pynwb/tests/test_binned_aligned_spikes.py +++ b/src/pynwb/tests/test_binned_aligned_spikes.py @@ -50,14 +50,15 @@ def test_constructor(self): np.testing.assert_array_equal(binned_aligned_spikes.data, self.data) np.testing.assert_array_equal(binned_aligned_spikes.event_timestamps, self.event_timestamps) + self.assertEqual(binned_aligned_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds) self.assertEqual( binned_aligned_spikes.milliseconds_from_event_to_first_bin, self.milliseconds_from_event_to_first_bin ) - self.assertEqual(binned_aligned_spikes.data.shape[0], self.number_of_units) - self.assertEqual(binned_aligned_spikes.data.shape[1], self.number_of_events) - self.assertEqual(binned_aligned_spikes.data.shape[2], self.number_of_bins) + self.assertEqual(binned_aligned_spikes.number_of_units, self.number_of_units) + self.assertEqual(binned_aligned_spikes.number_of_events, self.number_of_events) + self.assertEqual(binned_aligned_spikes.number_of_bins, self.number_of_bins) def test_constructor_units_region(self): @@ -121,7 +122,7 @@ def setUp(self): self.bin_width_in_milliseconds = 20.0 self.milliseconds_from_event_to_first_bin = -100.0 - # Two units in total and 4 bins, and event with two timestamps + # Two units in total and 4 bins, and condition with two timestamps self.data_for_first_condition = np.array( [ # Unit 1 data @@ -135,9 +136,11 @@ def setUp(self): [12, 13, 14, 15], # Bin counts around the second timestamp ], ], + dtype="uint64", + ) - # Also two units and 4 bins but this event appeared three times + # Also two units and 4 bins but this condition appeared three times self.data_for_second_condition = np.array( [ # Unit 1 data @@ -152,7 +155,8 @@ def setUp(self): [16, 17, 18, 19], # Bin counts around the second timestamp [20, 21, 22, 23], # Bin counts around the third timestamp ], - ] + ], + dtype="uint64", ) self.timestamps_first_condition = [5.0, 15.0] @@ -167,7 +171,7 @@ def setUp(self): self.event_timestamps = np.concatenate([self.timestamps_first_condition, self.timestamps_second_condition]) self.sorted_indices = np.argsort(self.event_timestamps) - + self.condition_labels = ["first", "second"] def test_constructor(self): @@ -189,7 +193,7 @@ def test_constructor(self): self.condition_indices, ) - aggregated_binnned_align_spikes = BinnedAlignedSpikes( + binnned_align_spikes = BinnedAlignedSpikes( bin_width_in_milliseconds=self.bin_width_in_milliseconds, milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, data=data, @@ -198,27 +202,23 @@ def test_constructor(self): condition_labels=self.condition_labels, ) - np.testing.assert_array_equal(aggregated_binnned_align_spikes.data, self.data[:, self.sorted_indices, :]) + np.testing.assert_array_equal(binnned_align_spikes.data, self.data[:, self.sorted_indices, :]) np.testing.assert_array_equal( - aggregated_binnned_align_spikes.condition_indices, self.condition_indices[self.sorted_indices] + binnned_align_spikes.condition_indices, self.condition_indices[self.sorted_indices] ) - np.testing.assert_array_equal( - aggregated_binnned_align_spikes.event_timestamps, self.event_timestamps[self.sorted_indices] - ) - - np.testing.assert_array_equal( - aggregated_binnned_align_spikes.condition_labels, self.condition_labels - ) - - self.assertEqual(aggregated_binnned_align_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds) + np.testing.assert_array_equal(binnned_align_spikes.event_timestamps, self.event_timestamps[self.sorted_indices]) + + np.testing.assert_array_equal(binnned_align_spikes.condition_labels, self.condition_labels) + + self.assertEqual(binnned_align_spikes.bin_width_in_milliseconds, self.bin_width_in_milliseconds) self.assertEqual( - aggregated_binnned_align_spikes.milliseconds_from_event_to_first_bin, + binnned_align_spikes.milliseconds_from_event_to_first_bin, self.milliseconds_from_event_to_first_bin, ) - self.assertEqual(aggregated_binnned_align_spikes.data.shape[0], self.number_of_units) - self.assertEqual(aggregated_binnned_align_spikes.data.shape[1], self.number_of_events) - self.assertEqual(aggregated_binnned_align_spikes.data.shape[2], self.number_of_bins) + self.assertEqual(binnned_align_spikes.number_of_units, self.number_of_units) + self.assertEqual(binnned_align_spikes.number_of_events, self.number_of_events) + self.assertEqual(binnned_align_spikes.number_of_bins, self.number_of_bins) def test_get_single_condition_data_methods(self): @@ -228,7 +228,7 @@ def test_get_single_condition_data_methods(self): self.condition_indices, ) - aggregated_binnned_align_spikes = BinnedAlignedSpikes( + binnned_align_spikes = BinnedAlignedSpikes( bin_width_in_milliseconds=self.bin_width_in_milliseconds, milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin, data=data, @@ -236,16 +236,16 @@ def test_get_single_condition_data_methods(self): condition_indices=condition_indices, ) - data_condition1 = aggregated_binnned_align_spikes.get_data_for_condition(condition_index=0) + data_condition1 = binnned_align_spikes.get_data_for_condition(condition_index=0) np.testing.assert_allclose(data_condition1, self.data_for_first_condition) - data_condition2 = aggregated_binnned_align_spikes.get_data_for_condition(condition_index=1) + data_condition2 = binnned_align_spikes.get_data_for_condition(condition_index=1) np.testing.assert_allclose(data_condition2, self.data_for_second_condition) - timestamps_condition1 = aggregated_binnned_align_spikes.get_event_timestamps_for_condition(condition_index=0) + timestamps_condition1 = binnned_align_spikes.get_event_timestamps_for_condition(condition_index=0) np.testing.assert_allclose(timestamps_condition1, self.timestamps_first_condition) - timestamps_condition2 = aggregated_binnned_align_spikes.get_event_timestamps_for_condition(condition_index=1) + timestamps_condition2 = binnned_align_spikes.get_event_timestamps_for_condition(condition_index=1) np.testing.assert_allclose(timestamps_condition2, self.timestamps_second_condition) @@ -265,9 +265,20 @@ def test_roundtrip_acquisition(self): Add a BinnedAlignedSpikes to an NWBFile, write it to file, read the file and test that the BinnedAlignedSpikes from the file matches the original BinnedAlignedSpikes. """ - # Testing here - self.binned_aligned_spikes = mock_BinnedAlignedSpikes(number_of_conditions=3, condition_labels=["a", "b", "c"]) + number_of_units = 5 + number_of_bins = 10 + number_of_events = 100 + number_of_conditions = 3 + condition_labels = ["a", "b", "c"] + + self.binned_aligned_spikes = mock_BinnedAlignedSpikes( + number_of_units=number_of_units, + number_of_bins=number_of_bins, + number_of_events=number_of_events, + number_of_conditions=number_of_conditions, + condition_labels=condition_labels, + ) self.nwbfile.add_acquisition(self.binned_aligned_spikes) @@ -276,8 +287,13 @@ def test_roundtrip_acquisition(self): with NWBHDF5IO(self.path, mode="r", load_namespaces=True) as io: read_nwbfile = io.read() - read_container = read_nwbfile.acquisition["BinnedAlignedSpikes"] - self.assertContainerEqual(self.binned_aligned_spikes, read_container) + read_binned_aligned_spikes = read_nwbfile.acquisition["BinnedAlignedSpikes"] + self.assertContainerEqual(self.binned_aligned_spikes, read_binned_aligned_spikes) + + assert read_binned_aligned_spikes.number_of_units == number_of_units + assert read_binned_aligned_spikes.number_of_bins == number_of_bins + assert read_binned_aligned_spikes.number_of_events == number_of_events + assert read_binned_aligned_spikes.number_of_conditions == number_of_conditions def test_roundtrip_processing_module(self): self.binned_aligned_spikes = mock_BinnedAlignedSpikes() @@ -312,3 +328,5 @@ def test_roundtrip_with_units_table(self): read_nwbfile = io.read() read_container = read_nwbfile.acquisition["BinnedAlignedSpikes"] self.assertContainerEqual(binned_aligned_spikes_with_region, read_container) + +