Skip to content

Commit

Permalink
Test for data type of labels images
Browse files Browse the repository at this point in the history
  • Loading branch information
dstansby committed Feb 4, 2025
1 parent b541589 commit e8efcfa
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 1 deletion.
54 changes: 53 additions & 1 deletion src/ome_zarr_models/_v05/labels.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,64 @@
from typing import Any

import numpy as np
from pydantic import Field, model_validator
from pydantic_zarr.v2 import ArraySpec, GroupSpec

from ome_zarr_models._v05.base import BaseGroupv05, BaseOMEAttrs
from ome_zarr_models.v04.labels import LabelsAttrs
from ome_zarr_models.base import BaseAttrs

__all__ = ["Labels", "LabelsAttrs"]


VALID_DTYPES: list[np.dtype[Any]] = [
np.dtype(x)
for x in [
np.uint8,
np.int8,
np.uint16,
np.int16,
np.uint32,
np.int32,
np.uint64,
np.int64,
]
]


def _check_valid_dtypes(labels: "Labels") -> "Labels":
for label_path in labels.attributes.ome.labels:
if label_path not in labels.members:
raise ValueError(f"Label path '{label_path}' not found in zarr group")
else:
spec = labels.members[label_path]
if isinstance(spec, GroupSpec):
raise ValueError(
f"Label path '{label_path}' points to a group, not an array"
)

dtype = np.dtype(spec.dtype)
if dtype not in VALID_DTYPES:
raise ValueError(
f"Data type of labels at '{label_path}' is not valid. "
f"Got {dtype}, should be one of {[str(x) for x in VALID_DTYPES]}."
)

return labels


class LabelsAttrs(BaseAttrs):
"""
Attributes for an OME-Zarr labels dataset.
"""

labels: list[str] = Field(
..., description="List of paths to labels arrays within a labels dataset."
)


class Labels(GroupSpec[BaseOMEAttrs[LabelsAttrs], ArraySpec | GroupSpec], BaseGroupv05): # type: ignore[misc]
"""
An OME-Zarr labels dataset.
"""

_check_valid_dtypes = model_validator(mode="after")(_check_valid_dtypes)
23 changes: 23 additions & 0 deletions tests/v05/test_labels.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
import re

import numpy as np
import pytest

from ome_zarr_models._v05.labels import Labels, LabelsAttrs
from tests.v05.conftest import json_to_zarr_group


def test_labels() -> None:
zarr_group = json_to_zarr_group(json_fname="labels_example.json")
zarr_group.create_dataset("cell_space_segmentation", shape=(1, 1), dtype=np.int64)
ome_group = Labels.from_zarr(zarr_group)
assert ome_group.attributes.ome == LabelsAttrs(
labels=["cell_space_segmentation"], version="0.5"
)


def test_labels_invalid_dtype() -> None:
"""
Check that an invalid data type raises an error.
"""
zarr_group = json_to_zarr_group(json_fname="labels_example.json")
zarr_group.create_dataset("cell_space_segmentation", shape=(1, 1), dtype=np.float64)
with pytest.raises(
ValueError,
match=re.escape(
"Data type of labels at 'cell_space_segmentation' is not valid. "
"Got float64, should be one of "
"['uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64']"
),
):
Labels.from_zarr(zarr_group)

0 comments on commit e8efcfa

Please sign in to comment.