diff --git a/luxonis_ml/data/__init__.py b/luxonis_ml/data/__init__.py index 66ba4505..d7ea1263 100644 --- a/luxonis_ml/data/__init__.py +++ b/luxonis_ml/data/__init__.py @@ -7,6 +7,7 @@ from .datasets import ( DATASETS_REGISTRY, BaseDataset, + Category, DatasetIterator, LuxonisComponent, LuxonisDataset, @@ -43,6 +44,7 @@ def load_loader_plugins() -> None: # pragma: no cover "BucketType", "DatasetIterator", "DATASETS_REGISTRY", + "Category", "LOADERS_REGISTRY", "ImageType", "LuxonisComponent", diff --git a/luxonis_ml/data/datasets/__init__.py b/luxonis_ml/data/datasets/__init__.py index ecefc974..b4322c6c 100644 --- a/luxonis_ml/data/datasets/__init__.py +++ b/luxonis_ml/data/datasets/__init__.py @@ -2,6 +2,7 @@ Annotation, ArrayAnnotation, BBoxAnnotation, + Category, DatasetRecord, Detection, KeypointAnnotation, @@ -15,6 +16,7 @@ "BaseDataset", "DatasetIterator", "DatasetRecord", + "Category", "LuxonisDataset", "LuxonisComponent", "LuxonisSource", diff --git a/luxonis_ml/data/datasets/annotation.py b/luxonis_ml/data/datasets/annotation.py index 975e2f0b..4320bba0 100644 --- a/luxonis_ml/data/datasets/annotation.py +++ b/luxonis_ml/data/datasets/annotation.py @@ -9,8 +9,15 @@ import numpy as np import pycocotools.mask from PIL import Image, ImageDraw -from pydantic import Field, field_serializer, field_validator, model_validator +from pydantic import ( + Field, + GetCoreSchemaHandler, + field_serializer, + field_validator, + model_validator, +) from pydantic.types import FilePath, PositiveInt +from pydantic_core import core_schema from typeguard import check_type from typing_extensions import Annotated, Self, TypeAlias, override @@ -25,11 +32,19 @@ 1].""" +class Category(str): + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.is_instance_schema(cls) + + class Detection(BaseModelExtraForbid): class_name: Optional[str] = Field(None, alias="class") instance_id: int = -1 - metadata: Dict[str, Union[int, float, str]] = {} + metadata: Dict[str, Union[int, float, str, Category]] = {} boundingbox: Optional["BBoxAnnotation"] = None keypoints: Optional["KeypointAnnotation"] = None diff --git a/luxonis_ml/data/datasets/luxonis_dataset.py b/luxonis_ml/data/datasets/luxonis_dataset.py index 3d3526d0..4af357b1 100644 --- a/luxonis_ml/data/datasets/luxonis_dataset.py +++ b/luxonis_ml/data/datasets/luxonis_dataset.py @@ -6,6 +6,7 @@ from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from contextlib import suppress +from copy import deepcopy from functools import cached_property from pathlib import Path, PurePosixPath from typing import ( @@ -49,7 +50,7 @@ make_progress_bar, ) -from .annotation import DatasetRecord +from .annotation import Category, DatasetRecord from .base_dataset import BaseDataset, DatasetIterator from .source import LuxonisSource from .utils import find_filepath_uuid, get_dir, get_file @@ -86,6 +87,8 @@ class Metadata(TypedDict): classes: Dict[str, List[str]] tasks: Dict[str, List[str]] skeletons: Dict[str, Skeletons] + categorical_encodings: Dict[str, Dict[str, int]] + metadata_types: Dict[str, Literal["float", "int", "str", "Category"]] class LuxonisDataset(BaseDataset): @@ -158,7 +161,7 @@ def __init__( with FileLock( str(_lock_metadata) ): # DDP GCS training - multiple processes - self.metadata = cast( + self._metadata = cast( Metadata, defaultdict(dict, self._get_metadata()) ) @@ -166,7 +169,7 @@ def __init__( logger.warning( f"LDF versions do not match. The current `luxonis-ml` " f"installation supports LDF v{LDF_VERSION}, but the " - f"`{self.identifier}` dataset is in v{self.metadata['ldf_version']}. " + f"`{self.identifier}` dataset is in v{self._metadata['ldf_version']}. " "Internal migration will be performed. Note that some parts " "and new features might not work correctly unless you " "manually re-create the dataset using the latest version " @@ -174,17 +177,45 @@ def __init__( ) self.progress = make_progress_bar() + @property + def metadata(self) -> Metadata: + """Returns the metadata of the dataset. + + The metadata is a dictionary containing the following keys: + - source: L{LuxonisSource} + - ldf_version: str + - classes: Dict[task_name, List[class_name]] + - tasks: Dict[task_name, List[task_type]] + - skeletons: Dict[task_name, Skeletons] + - Skeletons is a dictionary with keys 'labels' and 'edges' + - labels: List[str] + - edges: List[Tuple[int, int]] + - categorical_encodings: Dict[task_name, Dict[metadata_name, Dict[metadata_value, int]]] + - Encodings for string metadata values + - Example:: + + { + "vehicle": { + "color": {"red": 0, "green": 1, "blue": 2}, + "brand": {"audi": 0, "bmw": 1, "mercedes": 2}, + } + } + + @type: L{Metadata} + """ + return deepcopy(self._metadata) + @cached_property def version(self) -> Version: return Version.parse( - self.metadata["ldf_version"], optional_minor_and_patch=True + self._metadata["ldf_version"], optional_minor_and_patch=True ) @property def source(self) -> LuxonisSource: - if "source" not in self.metadata: + if "source" not in self._metadata: raise ValueError("Source not found in metadata") - return LuxonisSource.from_document(self.metadata["source"]) + return LuxonisSource.from_document(self._metadata["source"]) @property @override @@ -272,11 +303,11 @@ def _save_df_offline(self, pl_df: pl.DataFrame) -> None: def _merge_metadata_with(self, other: "LuxonisDataset") -> None: """Merges relevant metadata from `other` into `self`.""" - for key, value in other.metadata.items(): - if key not in self.metadata: - self.metadata[key] = value + for key, value in other._metadata.items(): + if key not in self._metadata: + self._metadata[key] = value else: - existing_val = self.metadata[key] + existing_val = self._metadata[key] if isinstance(existing_val, dict) and isinstance(value, dict): if key == "classes": @@ -292,7 +323,7 @@ def _merge_metadata_with(self, other: "LuxonisDataset") -> None: else: existing_val.update(value) else: - self.metadata[key] = value + self._metadata[key] = value self._write_metadata() def clone( @@ -328,15 +359,15 @@ def clone( ) new_dataset._init_paths() - new_dataset.metadata = defaultdict(dict, self._get_metadata()) + new_dataset._metadata = defaultdict(dict, self._get_metadata()) - new_dataset.metadata["original_dataset"] = self.dataset_name + new_dataset._metadata["original_dataset"] = self.dataset_name if self.is_remote and push_to_cloud: new_dataset.sync_to_cloud() path = self.metadata_path / "metadata.json" - path.write_text(json.dumps(self.metadata, indent=4)) + path.write_text(json.dumps(self._metadata, indent=4)) return new_dataset @@ -645,7 +676,7 @@ def _write_index( def _write_metadata(self) -> None: path = self.metadata_path / "metadata.json" - path.write_text(json.dumps(self.metadata, indent=4)) + path.write_text(json.dumps(self._metadata, indent=4)) with suppress(shutil.SameFileError): self.fs.put_file(path, "metadata/metadata.json") @@ -691,6 +722,8 @@ def _get_metadata(self) -> Metadata: "classes": {}, "tasks": {}, "skeletons": {}, + "categorical_encodings": {}, + "metadata_types": {}, } def _migrate_metadata( @@ -734,25 +767,25 @@ def update_source(self, source: LuxonisSource) -> None: @param source: The new C{LuxonisSource} to replace the old one. """ - self.metadata["source"] = source.to_document() + self._metadata["source"] = source.to_document() self._write_metadata() @override def set_classes( self, classes: List[str], task: Optional[str] = None ) -> None: - if task is not None: - self.metadata["classes"][task] = classes + if task is None: + tasks = self.get_task_names() else: - raise NotImplementedError( - "Setting classes for all tasks not yet supported. " - "Set classes individually for each task" - ) + tasks = [task] + + for task in tasks: + self._metadata["classes"][task] = classes self._write_metadata() @override def get_classes(self) -> Dict[str, List[str]]: - return self.metadata["classes"] + return self._metadata["classes"] @override def set_skeletons( @@ -769,7 +802,7 @@ def set_skeletons( else: tasks = [task] for task in tasks: - self.metadata["skeletons"][task] = { + self._metadata["skeletons"][task] = { "labels": labels or [], "edges": edges or [], } @@ -781,12 +814,22 @@ def get_skeletons( ) -> Dict[str, Tuple[List[str], List[Tuple[int, int]]]]: return { task: (skel["labels"], skel["edges"]) - for task, skel in self.metadata["skeletons"].items() + for task, skel in self._metadata["skeletons"].items() } @override def get_tasks(self) -> Dict[str, List[str]]: - return self.metadata.get("tasks", {}) + return self._metadata["tasks"] + + def get_categorical_encodings( + self, + ) -> Dict[str, Dict[str, int]]: + return self._metadata["categorical_encodings"] + + def get_metadata_types( + self, + ) -> Dict[str, Literal["float", "int", "str", "Category"]]: + return self._metadata["metadata_types"] def sync_from_cloud( self, update_mode: UpdateMode = UpdateMode.IF_EMPTY @@ -952,6 +995,8 @@ def add( classes_per_task: Dict[str, OrderedSet[str]] = defaultdict( lambda: OrderedSet([]) ) + categorical_encodings = defaultdict(dict) + metadata_types = {} num_kpts_per_task: Dict[str, int] = {} annotations_path = get_dir( @@ -982,6 +1027,28 @@ def add( num_kpts_per_task[record.task] = len( ann.keypoints.keypoints ) + for name, value in ann.metadata.items(): + task = f"{record.task}/metadata/{name}" + typ = type(value).__name__ + if ( + task in metadata_types + and metadata_types[task] != typ + ): + if {typ, metadata_types[task]} == {"int", "float"}: + metadata_types[task] = "float" + else: + raise ValueError( + f"Metadata type mismatch for {task}: {metadata_types[task]} and {typ}" + ) + else: + metadata_types[task] = typ + + if not isinstance(value, Category): + continue + if value not in categorical_encodings[task]: + categorical_encodings[task][value] = len( + categorical_encodings[task] + ) data_batch.append(record) if i % batch_size == 0: @@ -1013,6 +1080,9 @@ def add( task=task, ) + self._metadata["categorical_encodings"] = dict(categorical_encodings) + self._metadata["metadata_types"] = metadata_types + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: self._write_index(index, new_index, path=tmp_file.name) @@ -1034,7 +1104,7 @@ def _save_tasks_to_metadata(self) -> None: .iter_rows() ): tasks[task_name].append(task_type) - self.metadata["tasks"] = tasks + self._metadata["tasks"] = tasks self._write_metadata() def _warn_on_duplicates(self) -> None: diff --git a/luxonis_ml/data/loaders/luxonis_loader.py b/luxonis_ml/data/loaders/luxonis_loader.py index 814dee3f..a7a24fc4 100644 --- a/luxonis_ml/data/loaders/luxonis_loader.py +++ b/luxonis_ml/data/loaders/luxonis_loader.py @@ -17,6 +17,7 @@ ) from luxonis_ml.data.datasets import ( Annotation, + Category, LuxonisDataset, UpdateMode, load_annotation, @@ -49,6 +50,7 @@ def __init__( exclude_empty_annotations: bool = False, color_space: Literal["RGB", "BGR"] = "RGB", *, + keep_categorical_as_strings: bool = False, update_mode: UpdateMode = UpdateMode.ALWAYS, ) -> None: """A loader class used for loading data from L{LuxonisDataset}. @@ -96,6 +98,11 @@ def __init__( empty annotations from the final label dictionary. Defaults to C{False} (i.e. include empty annotations). + @type keep_categorical_as_strings: bool + @param keep_categorical_as_strings: Whether to keep categorical + metadata labels as strings. + Defaults to C{False} (i.e. convert categorical labels to integers). + @type update_mode: UpdateMode @param update_mode: Enum that determines the sync mode: - UpdateMode.ALWAYS: Force a fresh download @@ -108,6 +115,7 @@ def __init__( self.dataset = dataset self.sync_mode = self.dataset.is_remote + self.keep_categorical_as_strings = keep_categorical_as_strings if self.sync_mode: self.dataset.sync_from_cloud(update_mode=update_mode) @@ -289,7 +297,7 @@ def _load_data(self, idx: int) -> Tuple[np.ndarray, Labels]: labels_by_task: Dict[str, List[Annotation]] = defaultdict(list) class_ids_by_task: Dict[str, List[int]] = defaultdict(list) instance_ids_by_task: Dict[str, List[int]] = defaultdict(list) - metadata_by_task: Dict[str, List[Union[str, int, float]]] = ( + metadata_by_task: Dict[str, List[Union[str, int, float, Category]]] = ( defaultdict(list) ) @@ -329,7 +337,10 @@ def _load_data(self, idx: int) -> Tuple[np.ndarray, Labels]: instance_ids_by_task[full_task_name].append(instance_id) labels: Labels = {} + encodings = self.dataset.get_categorical_encodings() for task, metadata in metadata_by_task.items(): + if not self.keep_categorical_as_strings and task in encodings: + metadata = [encodings[task][m] for m in metadata] # type: ignore labels[task] = np.array(metadata) for task, anns in labels_by_task.items(): diff --git a/luxonis_ml/data/utils/data_utils.py b/luxonis_ml/data/utils/data_utils.py index 884e49e0..fcf5d223 100644 --- a/luxonis_ml/data/utils/data_utils.py +++ b/luxonis_ml/data/utils/data_utils.py @@ -5,6 +5,8 @@ import numpy as np import polars as pl +from luxonis_ml.data.utils.task_utils import task_is_metadata + logger = logging.getLogger(__name__) @@ -162,14 +164,15 @@ def warn_on_duplicates(df: pl.LazyFrame) -> None: for ( file_name, - task, + task_name, task_type, annotation, count, ) in duplicate_annotation.iter_rows(): if task_type == "segmentation": annotation = "" - logger.warning( - f"File '{file_name}' has the same '{task_type}' annotation " - f"'{annotation}' ({task=}) added {count} times." - ) + if not task_is_metadata(task_type): + logger.warning( + f"File '{file_name}' has the same '{task_type}' annotation " + f"'{annotation}' ({task_name=}) added {count} times." + ) diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index 74b53f65..dd47bb76 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -9,6 +9,7 @@ from luxonis_ml.data import ( BucketStorage, + Category, LuxonisDataset, LuxonisLoader, LuxonisParser, @@ -282,9 +283,10 @@ def generator(): "annotation": { "class": "person", "metadata": { - "color": "red" if i % 2 == 0 else "blue", - "distance": 5.0, + "color": Category("red" if i % 2 == 0 else "blue"), + "distance": 5.0 if i == 0 else 5, "id": 127 + i, + "license_plate": "xyz", }, }, } @@ -297,12 +299,30 @@ def generator(): "metadata/color", "metadata/distance", "metadata/id", + "metadata/license_plate", "classification", } == set(labels.keys()) - assert labels["metadata/color"].tolist() == ["red", "blue"] * 5 + assert labels["metadata/color"].tolist() == [0, 1] * 5 assert labels["metadata/distance"].tolist() == [5.0] * 10 assert labels["metadata/id"].tolist() == list(range(127, 137)) + assert labels["metadata/license_plate"].tolist() == ["xyz"] * 10 + + loader = LuxonisLoader(dataset, keep_categorical_as_strings=True) + for _, labels in loader: + labels = {get_task_type(k): v for k, v in labels.items()} + assert labels["metadata/color"].tolist() == ["red", "blue"] * 5 + + assert dataset.get_categorical_encodings() == { + "/metadata/color": {"red": 0, "blue": 1} + } + + assert dataset.get_metadata_types() == { + "/metadata/color": "Category", + "/metadata/distance": "float", + "/metadata/id": "int", + "/metadata/license_plate": "str", + } @pytest.mark.dependency(name="test_dataset[BucketStorage.LOCAL]")