Skip to content

Commit

Permalink
Merge pull request #7 from catalystneuro/add_init_validator
Browse files Browse the repository at this point in the history
Add init validation
  • Loading branch information
h-mayorquin authored Mar 21, 2024
2 parents 20adc71 + 583edd8 commit 9055d08
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 3 deletions.
84 changes: 83 additions & 1 deletion src/pynwb/ndx_binned_spikes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import os
import numpy as np

from pynwb import load_namespaces, get_class
from pynwb import register_class
from pynwb.core import NWBDataInterface
from hdmf.utils import docval, popargs_to_dict

try:
from importlib.resources import files
Expand All @@ -18,7 +23,84 @@
# Load the namespace
load_namespaces(str(__spec_path))

BinnedAlignedSpikes = get_class("BinnedAlignedSpikes", "ndx-binned-spikes")
# BinnedAlignedSpikes = get_class("BinnedAlignedSpikes", "ndx-binned-spikes")


@register_class(neurodata_type="BinnedAlignedSpikes", namespace="ndx-binned-spikes") #noqa
class BinnedAlignedSpikes(NWBDataInterface):
__nwbfields__ = (
"name",
"bin_width_in_milliseconds",
"milliseconds_from_event_to_first_bin",
"data",
"event_timestamps",
"units",
)

DEFAULT_NAME = "BinnedAlignedSpikes"

@docval(
{
"name": "name",
"type": str,
"doc": "The name of this container",
"default": DEFAULT_NAME,
},
{
"name": "bin_width_in_milliseconds",
"type": float,
"doc": "The length in milliseconds of the bins",
},
{
"name": "milliseconds_from_event_to_first_bin",
"type": float,
"doc": (
"The time in milliseconds from the event (e.g. a stimuli or the beginning of a trial),"
"to the first bin. Note that this is a negative number if the first bin is before the event."
),
"default": 0.0,
},
{
"name": "data",
"type": "array_data",
"shape": [(None, None, None)],
"doc": "The source of the data",
},
{
"name": "event_timestamps",
"type": "array_data",
"doc": "The timestamps at which the event occurred.",
"shape": (None,),
},
{
"name": "units",
"type": ("DynamicTableRegion"),
"doc": "A reference to the Units table region that contains the units of the data.",
"default": None,
},
)
def __init__(self, **kwargs):

keys_to_set = ("bin_width_in_milliseconds", "milliseconds_from_event_to_first_bin", "units")
args_to_set = popargs_to_dict(keys_to_set, kwargs)

keys_to_process = ("data", "event_timestamps")
args_to_process = popargs_to_dict(keys_to_process, kwargs)
super().__init__(**kwargs)

# Set the values
for key, val in args_to_set.items():
setattr(self, key, val)

# Post-process / post_init
data = args_to_process["data"]
event_timestamps = args_to_process["event_timestamps"]

if data.shape[1] != event_timestamps.shape[0]:
raise ValueError("The number of event timestamps must match the number of event repetitions in the data.")

self.fields["data"] = data
self.fields["event_timestamps"] = event_timestamps


# Remove these functions from the package
Expand Down
14 changes: 12 additions & 2 deletions src/pynwb/tests/test_binned_aligned_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def setUp(self):
self.number_of_event_repetitions = 4
self.bin_width_in_milliseconds = 20.0
self.milliseconds_from_event_to_first_bin = -100.0
rng = np.random.default_rng(seed=0)
self.rng = np.random.default_rng(seed=0)

self.data = rng.integers(
self.data = self.rng.integers(
low=0,
high=100,
size=(
Expand Down Expand Up @@ -99,6 +99,16 @@ def test_constructor_units_region(self):
expected_names = [unit_name_a, unit_name_c]
self.assertListEqual(unit_table_names, expected_names)

def test_constructor_inconsistent_timestamps_and_data_error(self):
shorter_timestamps = self.event_timestamps[:-1]

with self.assertRaises(ValueError):
BinnedAlignedSpikes(
bin_width_in_milliseconds=self.bin_width_in_milliseconds,
milliseconds_from_event_to_first_bin=self.milliseconds_from_event_to_first_bin,
data=self.data,
event_timestamps=shorter_timestamps,
)

class TestBinnedAlignedSpikesSimpleRoundtrip(TestCase):
"""Simple roundtrip test for BinnedAlignedSpikes."""
Expand Down

0 comments on commit 9055d08

Please sign in to comment.