-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
76 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,64 @@ | ||
from typing import Any | ||
|
||
import numpy as np | ||
from pydantic import Field, model_validator | ||
from pydantic_zarr.v2 import ArraySpec, GroupSpec | ||
|
||
from ome_zarr_models._v05.base import BaseGroupv05, BaseOMEAttrs | ||
from ome_zarr_models.v04.labels import LabelsAttrs | ||
from ome_zarr_models.base import BaseAttrs | ||
|
||
__all__ = ["Labels", "LabelsAttrs"] | ||
|
||
|
||
VALID_DTYPES: list[np.dtype[Any]] = [ | ||
np.dtype(x) | ||
for x in [ | ||
np.uint8, | ||
np.int8, | ||
np.uint16, | ||
np.int16, | ||
np.uint32, | ||
np.int32, | ||
np.uint64, | ||
np.int64, | ||
] | ||
] | ||
|
||
|
||
def _check_valid_dtypes(labels: "Labels") -> "Labels": | ||
for label_path in labels.attributes.ome.labels: | ||
if label_path not in labels.members: | ||
raise ValueError(f"Label path '{label_path}' not found in zarr group") | ||
else: | ||
spec = labels.members[label_path] | ||
if isinstance(spec, GroupSpec): | ||
raise ValueError( | ||
f"Label path '{label_path}' points to a group, not an array" | ||
) | ||
|
||
dtype = np.dtype(spec.dtype) | ||
if dtype not in VALID_DTYPES: | ||
raise ValueError( | ||
f"Data type of labels at '{label_path}' is not valid. " | ||
f"Got {dtype}, should be one of {[str(x) for x in VALID_DTYPES]}." | ||
) | ||
|
||
return labels | ||
|
||
|
||
class LabelsAttrs(BaseAttrs): | ||
""" | ||
Attributes for an OME-Zarr labels dataset. | ||
""" | ||
|
||
labels: list[str] = Field( | ||
..., description="List of paths to labels arrays within a labels dataset." | ||
) | ||
|
||
|
||
class Labels(GroupSpec[BaseOMEAttrs[LabelsAttrs], ArraySpec | GroupSpec], BaseGroupv05): # type: ignore[misc] | ||
""" | ||
An OME-Zarr labels dataset. | ||
""" | ||
|
||
_check_valid_dtypes = model_validator(mode="after")(_check_valid_dtypes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,33 @@ | ||
import re | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from ome_zarr_models._v05.labels import Labels, LabelsAttrs | ||
from tests.v05.conftest import json_to_zarr_group | ||
|
||
|
||
def test_labels() -> None: | ||
zarr_group = json_to_zarr_group(json_fname="labels_example.json") | ||
zarr_group.create_dataset("cell_space_segmentation", shape=(1, 1), dtype=np.int64) | ||
ome_group = Labels.from_zarr(zarr_group) | ||
assert ome_group.attributes.ome == LabelsAttrs( | ||
labels=["cell_space_segmentation"], version="0.5" | ||
) | ||
|
||
|
||
def test_labels_invalid_dtype() -> None: | ||
""" | ||
Check that an invalid data type raises an error. | ||
""" | ||
zarr_group = json_to_zarr_group(json_fname="labels_example.json") | ||
zarr_group.create_dataset("cell_space_segmentation", shape=(1, 1), dtype=np.float64) | ||
with pytest.raises( | ||
ValueError, | ||
match=re.escape( | ||
"Data type of labels at 'cell_space_segmentation' is not valid. " | ||
"Got float64, should be one of " | ||
"['uint8', 'int8', 'uint16', 'int16', 'uint32', 'int32', 'uint64', 'int64']" | ||
), | ||
): | ||
Labels.from_zarr(zarr_group) |