Skip to content

Commit

Permalink
Merge pull request #310 from bendichter/validate_read_probe_dicts
Browse files Browse the repository at this point in the history
add json validation to tests
  • Loading branch information
alejoe91 authored Jan 7, 2025
2 parents 1020965 + 11e1242 commit e39afea
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 7 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ package-dir = {"probeinterface" = "src/probeinterface"}
[project.optional-dependencies]

test = [
"jsonschema",
"pytest",
"pytest-cov",
"matplotlib",
Expand Down
8 changes: 6 additions & 2 deletions resources/probe.json.schema
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
"annotations": {
"type": "object",
"properties": {
"name": { "type": "string" },
"model_name": { "type": "string" },
"manufacturer": { "type": "string" }
},
"required": ["name", "manufacturer"],
"required": ["model_name", "manufacturer"],
"additionalProperties": true
},
"contact_annotations": {
Expand Down Expand Up @@ -101,6 +101,10 @@
"shank_ids": {
"type": "array",
"items": { "type": "string" }
},
"device_channel_indices": {
"type": "array",
"items": { "type": "integer" }
}
},
"required": [
Expand Down
1 change: 0 additions & 1 deletion src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
from typing import Optional
from pathlib import Path
import json

from .shank import Shank

Expand Down
13 changes: 13 additions & 0 deletions src/probeinterface/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import json
from pathlib import Path

from probeinterface import __version__ as version
import jsonschema

json_schema_file = Path(__file__).absolute().parent.parent.parent / "resources" / "probe.json.schema"
schema = json.load(open(json_schema_file, "r"))


def validate_probe_dict(probe_dict):
instance = dict(specification="probeinterface", version=version, probes=[probe_dict])
jsonschema.validate(instance=instance, schema=schema)
13 changes: 13 additions & 0 deletions tests/test_io/test_3brain.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
import glob
from pathlib import Path
import numpy as np

import pytest

from probeinterface import read_3brain

from probeinterface.testing import validate_probe_dict


data_path = Path(__file__).absolute().parent.parent / "data" / "3brain"
brw_files = glob.glob(str(data_path / "*.brw"))


@pytest.mark.parametrize("file_", brw_files)
def test_valid_probe_dict(file_: str):
probe = read_3brain(data_path / file_)
probe_dict = probe.to_dict(array_as_list=True)
probe_dict["annotations"].update(model_name="placeholder")
validate_probe_dict(probe_dict)


def test_3brain():
Expand Down
11 changes: 11 additions & 0 deletions tests/test_io/test_imro.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import glob
from pathlib import Path

import pytest
import numpy as np

from probeinterface import read_imro, write_imro
from probeinterface.testing import validate_probe_dict

data_path = Path(__file__).absolute().parent.parent / "data" / "imro"
imro_files = glob.glob(str(data_path / "*.imro"))

imro_files.pop(imro_files.index(str(data_path / "test_non_standard.imro")))


@pytest.mark.parametrize("file_", imro_files)
def test_valid_probe_dict(file_: str):
probe = read_imro(data_path / file_)
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_reading_multishank_imro(tmp_path):
Expand Down
9 changes: 9 additions & 0 deletions tests/test_io/test_maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,19 @@
import pytest

from probeinterface import read_maxwell
from probeinterface.testing import validate_probe_dict

data_path = Path(__file__).absolute().parent.parent / "data" / "maxwell"


def test_valid_probe_dict():
file_ = "data.raw.h5"
probe = read_maxwell(data_path / file_)
probe_dict = probe.to_dict(array_as_list=True)
probe_dict["annotations"].update(model_name="placeholder")
validate_probe_dict(probe_dict)


def test_maxwell():
"""Basic file taken from the ephys data repository and provided by Alessio Buccino"""

Expand Down
43 changes: 41 additions & 2 deletions tests/test_io/test_openephys.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from pathlib import Path

import numpy as np
import glob

import pytest

from probeinterface import read_openephys
from probeinterface.testing import validate_probe_dict

data_path = Path(__file__).absolute().parent.parent / "data" / "openephys"


def test_NP2_OE_1_0():
# NP2 1-shank
probeA = read_openephys(data_path / "OE_1.0_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeA")
probe_dict = probeA.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert probeA.get_shank_count() == 1
assert "2.0" in probeA.model_name
assert probeA.get_contact_count() == 384
Expand All @@ -20,13 +24,17 @@ def test_NP2_OE_1_0():
def test_NP2():
# NP2
probe = read_openephys(data_path / "OE_Neuropix-PXI" / "settings.xml")
probe_dict = probe.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert probe.get_shank_count() == 1
assert "2.0 - Single Shank" in probe.model_name


def test_NP2_four_shank():
# NP2
probe = read_openephys(data_path / "OE_Neuropix-PXI-NP2-4shank" / "settings.xml")
probe_dict = probe.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
# on this case, only shanks 2-3 are used
assert probe.get_shank_count() == 2
assert "2.0 - Four Shank" in probe.model_name
Expand All @@ -38,6 +46,8 @@ def test_NP_Ultra():
data_path / "OE_Neuropix-PXI-NP-Ultra" / "settings.xml",
probe_name="ProbeA",
)
probe_dict = probeA.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert "Ultra" in probeA.model_name
assert probeA.get_shank_count() == 1
assert probeA.get_contact_count() == 384
Expand All @@ -46,6 +56,8 @@ def test_NP_Ultra():
data_path / "OE_Neuropix-PXI-NP-Ultra" / "settings.xml",
probe_name="ProbeB",
)
probe_dict = probeB.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert "Ultra" in probeB.model_name
assert probeB.get_shank_count() == 1
assert probeB.get_contact_count() == 384
Expand All @@ -54,6 +66,8 @@ def test_NP_Ultra():
data_path / "OE_Neuropix-PXI-NP-Ultra" / "settings.xml",
probe_name="ProbeF",
)
probe_dict = probeF.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert "Ultra" in probeF.model_name
assert probeF.get_shank_count() == 1
assert probeF.get_contact_count() == 384
Expand All @@ -62,6 +76,8 @@ def test_NP_Ultra():
data_path / "OE_Neuropix-PXI-NP-Ultra" / "settings.xml",
probe_name="ProbeD",
)
probe_dict = probeD.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert "Ultra" in probeD.model_name and "Type 2" in probeD.model_name
assert probeD.get_shank_count() == 1
assert probeD.get_contact_count() == 384
Expand All @@ -72,12 +88,16 @@ def test_NP_Ultra():
def test_NP1_subset():
# NP1 - 200 channels selected by recording_state in Record Node
probe_ap = read_openephys(data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-AP")
probe_dict = probe_ap.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)

assert probe_ap.get_shank_count() == 1
assert "1.0" in probe_ap.model_name
assert probe_ap.get_contact_count() == 200

probe_lf = read_openephys(data_path / "OE_Neuropix-PXI-subset" / "settings.xml", stream_name="ProbeA-LFP")
probe_dict = probe_lf.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)

assert probe_lf.get_shank_count() == 1
assert "1.0" in probe_lf.model_name
Expand All @@ -92,6 +112,8 @@ def test_NP1_subset():
def test_multiple_probes():
# multiple probes
probeA = read_openephys(data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeA")
probe_dict = probeA.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)

assert probeA.get_shank_count() == 1
assert "1.0" in probeA.model_name
Expand All @@ -100,17 +122,23 @@ def test_multiple_probes():
data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml",
stream_name="RecordNode#ProbeB",
)
probe_dict = probeB.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)

assert probeB.get_shank_count() == 1

probeC = read_openephys(
data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml",
serial_number="20403311714",
)
probe_dict = probeC.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)

assert probeC.get_shank_count() == 1

probeD = read_openephys(data_path / "OE_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="ProbeD")
probe_dict = probeD.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)

assert probeD.get_shank_count() == 1

Expand Down Expand Up @@ -146,11 +174,16 @@ def test_multiple_probes_enabled():
probe = read_openephys(
data_path / "OE_6.7_enabled_disabled_Neuropix-PXI" / "settings_enabled-enabled.xml", probe_name="ProbeA"
)
probe_dict = probe.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)

assert probe.get_shank_count() == 1

probe = read_openephys(
data_path / "OE_6.7_enabled_disabled_Neuropix-PXI" / "settings_enabled-enabled.xml", probe_name="ProbeB"
)
probe_dict = probe.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert probe.get_shank_count() == 4


Expand All @@ -159,7 +192,8 @@ def test_multiple_probes_disabled():
probe = read_openephys(
data_path / "OE_6.7_enabled_disabled_Neuropix-PXI" / "settings_enabled-disabled.xml", probe_name="ProbeA"
)

probe_dict = probe.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert probe.get_shank_count() == 1

# Fail as this is disabled:
Expand All @@ -173,6 +207,8 @@ def test_multiple_probes_disabled():

def test_np_opto_with_sync():
probe = read_openephys(data_path / "OE_Neuropix-PXI-opto-with-sync" / "settings.xml")
probe_dict = probe.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert probe.model_name == "Neuropixels Opto"
assert probe.get_shank_count() == 1
assert probe.get_contact_count() == 384
Expand All @@ -182,7 +218,8 @@ def test_older_than_06_format():
## Test with the open ephys < 0.6 format

probe = read_openephys(data_path / "OE_5_Neuropix-PXI-multi-probe" / "settings.xml", probe_name="100.0")

probe_dict = probe.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert probe.get_shank_count() == 4
assert "2.0 - Four Shank" in probe.model_name
ypos = probe.contact_positions[:, 1]
Expand All @@ -193,6 +230,8 @@ def test_multiple_signal_chains():
# tests that the probe information can be loaded even if the Neuropix-PXI plugin
# is not in the first signalchain
probe = read_openephys(data_path / "OE_Neuropix-PXI-multiple-signalchains" / "settings.xml")
probe_dict = probe.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert probe.model_name == "Neuropixels 1.0"


Expand Down
6 changes: 4 additions & 2 deletions tests/test_io/test_spikegadgets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pathlib import Path
from xml.etree import ElementTree

import pytest

from probeinterface import read_spikegadgets
from probeinterface.io import parse_spikegadgets_header
from probeinterface.testing import validate_probe_dict


data_path = Path(__file__).absolute().parent.parent / "data" / "spikegadgets"
test_file = "SpikeGadgets_test_data_2xNpix1.0_20240318_173658_header_only.rec"
Expand All @@ -22,6 +22,8 @@ def test_neuropixels_1_reader():
probe_group = read_spikegadgets(data_path / test_file, raise_error=False)
assert len(probe_group.probes) == 2
for probe in probe_group.probes:
probe_dict = probe.to_dict(array_as_list=True)
validate_probe_dict(probe_dict)
assert "1.0" in probe.model_name
assert probe.get_shank_count() == 1
assert probe.get_contact_count() == 384
Expand Down
9 changes: 9 additions & 0 deletions tests/test_io/test_spikeglx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
from pathlib import Path
import numpy as np

Expand All @@ -8,8 +9,16 @@
parse_spikeglx_meta,
get_saved_channel_indices_from_spikeglx_meta,
)
from probeinterface.testing import validate_probe_dict

data_path = Path(__file__).absolute().parent.parent / "data" / "spikeglx"
meta_files = glob.glob(str(data_path / "*.meta"))


@pytest.mark.parametrize("meta_file", meta_files)
def test_valid_probe_dict(meta_file: str):
probe = read_spikeglx(data_path / meta_file)
validate_probe_dict(probe.to_dict(array_as_list=True))


def test_parse_meta():
Expand Down

0 comments on commit e39afea

Please sign in to comment.