Skip to content

Commit

Permalink
Metadata improvements (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 authored Jan 24, 2025
1 parent a770f26 commit 2858ae8
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 38 deletions.
2 changes: 2 additions & 0 deletions luxonis_ml/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .datasets import (
DATASETS_REGISTRY,
BaseDataset,
Category,
DatasetIterator,
LuxonisComponent,
LuxonisDataset,
Expand Down Expand Up @@ -43,6 +44,7 @@ def load_loader_plugins() -> None: # pragma: no cover
"BucketType",
"DatasetIterator",
"DATASETS_REGISTRY",
"Category",
"LOADERS_REGISTRY",
"ImageType",
"LuxonisComponent",
Expand Down
2 changes: 2 additions & 0 deletions luxonis_ml/data/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Annotation,
ArrayAnnotation,
BBoxAnnotation,
Category,
DatasetRecord,
Detection,
KeypointAnnotation,
Expand All @@ -15,6 +16,7 @@
"BaseDataset",
"DatasetIterator",
"DatasetRecord",
"Category",
"LuxonisDataset",
"LuxonisComponent",
"LuxonisSource",
Expand Down
19 changes: 17 additions & 2 deletions luxonis_ml/data/datasets/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@
import numpy as np
import pycocotools.mask
from PIL import Image, ImageDraw
from pydantic import Field, field_serializer, field_validator, model_validator
from pydantic import (
Field,
GetCoreSchemaHandler,
field_serializer,
field_validator,
model_validator,
)
from pydantic.types import FilePath, PositiveInt
from pydantic_core import core_schema
from typeguard import check_type
from typing_extensions import Annotated, Self, TypeAlias, override

Expand All @@ -25,11 +32,19 @@
1]."""


class Category(str):
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.is_instance_schema(cls)


class Detection(BaseModelExtraForbid):
class_name: Optional[str] = Field(None, alias="class")
instance_id: int = -1

metadata: Dict[str, Union[int, float, str]] = {}
metadata: Dict[str, Union[int, float, str, Category]] = {}

boundingbox: Optional["BBoxAnnotation"] = None
keypoints: Optional["KeypointAnnotation"] = None
Expand Down
124 changes: 97 additions & 27 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from contextlib import suppress
from copy import deepcopy
from functools import cached_property
from pathlib import Path, PurePosixPath
from typing import (
Expand Down Expand Up @@ -49,7 +50,7 @@
make_progress_bar,
)

from .annotation import DatasetRecord
from .annotation import Category, DatasetRecord
from .base_dataset import BaseDataset, DatasetIterator
from .source import LuxonisSource
from .utils import find_filepath_uuid, get_dir, get_file
Expand Down Expand Up @@ -86,6 +87,8 @@ class Metadata(TypedDict):
classes: Dict[str, List[str]]
tasks: Dict[str, List[str]]
skeletons: Dict[str, Skeletons]
categorical_encodings: Dict[str, Dict[str, int]]
metadata_types: Dict[str, Literal["float", "int", "str", "Category"]]


class LuxonisDataset(BaseDataset):
Expand Down Expand Up @@ -158,33 +161,61 @@ def __init__(
with FileLock(
str(_lock_metadata)
): # DDP GCS training - multiple processes
self.metadata = cast(
self._metadata = cast(
Metadata, defaultdict(dict, self._get_metadata())
)

if self.version != LDF_VERSION:
logger.warning(
f"LDF versions do not match. The current `luxonis-ml` "
f"installation supports LDF v{LDF_VERSION}, but the "
f"`{self.identifier}` dataset is in v{self.metadata['ldf_version']}. "
f"`{self.identifier}` dataset is in v{self._metadata['ldf_version']}. "
"Internal migration will be performed. Note that some parts "
"and new features might not work correctly unless you "
"manually re-create the dataset using the latest version "
"of `luxonis-ml`."
)
self.progress = make_progress_bar()

@property
def metadata(self) -> Metadata:
"""Returns the metadata of the dataset.
The metadata is a dictionary containing the following keys:
- source: L{LuxonisSource}
- ldf_version: str
- classes: Dict[task_name, List[class_name]]
- tasks: Dict[task_name, List[task_type]]
- skeletons: Dict[task_name, Skeletons]
- Skeletons is a dictionary with keys 'labels' and 'edges'
- labels: List[str]
- edges: List[Tuple[int, int]]
- categorical_encodings: Dict[task_name, Dict[metadata_name, Dict[metadata_value, int]]]
- Encodings for string metadata values
- Example::
{
"vehicle": {
"color": {"red": 0, "green": 1, "blue": 2},
"brand": {"audi": 0, "bmw": 1, "mercedes": 2},
}
}
@type: L{Metadata}
"""
return deepcopy(self._metadata)

@cached_property
def version(self) -> Version:
return Version.parse(
self.metadata["ldf_version"], optional_minor_and_patch=True
self._metadata["ldf_version"], optional_minor_and_patch=True
)

@property
def source(self) -> LuxonisSource:
if "source" not in self.metadata:
if "source" not in self._metadata:
raise ValueError("Source not found in metadata")
return LuxonisSource.from_document(self.metadata["source"])
return LuxonisSource.from_document(self._metadata["source"])

@property
@override
Expand Down Expand Up @@ -272,11 +303,11 @@ def _save_df_offline(self, pl_df: pl.DataFrame) -> None:

def _merge_metadata_with(self, other: "LuxonisDataset") -> None:
"""Merges relevant metadata from `other` into `self`."""
for key, value in other.metadata.items():
if key not in self.metadata:
self.metadata[key] = value
for key, value in other._metadata.items():
if key not in self._metadata:
self._metadata[key] = value
else:
existing_val = self.metadata[key]
existing_val = self._metadata[key]

if isinstance(existing_val, dict) and isinstance(value, dict):
if key == "classes":
Expand All @@ -292,7 +323,7 @@ def _merge_metadata_with(self, other: "LuxonisDataset") -> None:
else:
existing_val.update(value)
else:
self.metadata[key] = value
self._metadata[key] = value
self._write_metadata()

def clone(
Expand Down Expand Up @@ -328,15 +359,15 @@ def clone(
)

new_dataset._init_paths()
new_dataset.metadata = defaultdict(dict, self._get_metadata())
new_dataset._metadata = defaultdict(dict, self._get_metadata())

new_dataset.metadata["original_dataset"] = self.dataset_name
new_dataset._metadata["original_dataset"] = self.dataset_name

if self.is_remote and push_to_cloud:
new_dataset.sync_to_cloud()

path = self.metadata_path / "metadata.json"
path.write_text(json.dumps(self.metadata, indent=4))
path.write_text(json.dumps(self._metadata, indent=4))

return new_dataset

Expand Down Expand Up @@ -645,7 +676,7 @@ def _write_index(

def _write_metadata(self) -> None:
path = self.metadata_path / "metadata.json"
path.write_text(json.dumps(self.metadata, indent=4))
path.write_text(json.dumps(self._metadata, indent=4))
with suppress(shutil.SameFileError):
self.fs.put_file(path, "metadata/metadata.json")

Expand Down Expand Up @@ -691,6 +722,8 @@ def _get_metadata(self) -> Metadata:
"classes": {},
"tasks": {},
"skeletons": {},
"categorical_encodings": {},
"metadata_types": {},
}

def _migrate_metadata(
Expand Down Expand Up @@ -734,25 +767,25 @@ def update_source(self, source: LuxonisSource) -> None:
@param source: The new C{LuxonisSource} to replace the old one.
"""

self.metadata["source"] = source.to_document()
self._metadata["source"] = source.to_document()
self._write_metadata()

@override
def set_classes(
self, classes: List[str], task: Optional[str] = None
) -> None:
if task is not None:
self.metadata["classes"][task] = classes
if task is None:
tasks = self.get_task_names()
else:
raise NotImplementedError(
"Setting classes for all tasks not yet supported. "
"Set classes individually for each task"
)
tasks = [task]

for task in tasks:
self._metadata["classes"][task] = classes
self._write_metadata()

@override
def get_classes(self) -> Dict[str, List[str]]:
return self.metadata["classes"]
return self._metadata["classes"]

@override
def set_skeletons(
Expand All @@ -769,7 +802,7 @@ def set_skeletons(
else:
tasks = [task]
for task in tasks:
self.metadata["skeletons"][task] = {
self._metadata["skeletons"][task] = {
"labels": labels or [],
"edges": edges or [],
}
Expand All @@ -781,12 +814,22 @@ def get_skeletons(
) -> Dict[str, Tuple[List[str], List[Tuple[int, int]]]]:
return {
task: (skel["labels"], skel["edges"])
for task, skel in self.metadata["skeletons"].items()
for task, skel in self._metadata["skeletons"].items()
}

@override
def get_tasks(self) -> Dict[str, List[str]]:
return self.metadata.get("tasks", {})
return self._metadata["tasks"]

def get_categorical_encodings(
self,
) -> Dict[str, Dict[str, int]]:
return self._metadata["categorical_encodings"]

def get_metadata_types(
self,
) -> Dict[str, Literal["float", "int", "str", "Category"]]:
return self._metadata["metadata_types"]

def sync_from_cloud(
self, update_mode: UpdateMode = UpdateMode.IF_EMPTY
Expand Down Expand Up @@ -952,6 +995,8 @@ def add(
classes_per_task: Dict[str, OrderedSet[str]] = defaultdict(
lambda: OrderedSet([])
)
categorical_encodings = defaultdict(dict)
metadata_types = {}
num_kpts_per_task: Dict[str, int] = {}

annotations_path = get_dir(
Expand Down Expand Up @@ -982,6 +1027,28 @@ def add(
num_kpts_per_task[record.task] = len(
ann.keypoints.keypoints
)
for name, value in ann.metadata.items():
task = f"{record.task}/metadata/{name}"
typ = type(value).__name__
if (
task in metadata_types
and metadata_types[task] != typ
):
if {typ, metadata_types[task]} == {"int", "float"}:
metadata_types[task] = "float"
else:
raise ValueError(
f"Metadata type mismatch for {task}: {metadata_types[task]} and {typ}"
)
else:
metadata_types[task] = typ

if not isinstance(value, Category):
continue
if value not in categorical_encodings[task]:
categorical_encodings[task][value] = len(
categorical_encodings[task]
)

data_batch.append(record)
if i % batch_size == 0:
Expand Down Expand Up @@ -1013,6 +1080,9 @@ def add(
task=task,
)

self._metadata["categorical_encodings"] = dict(categorical_encodings)
self._metadata["metadata_types"] = metadata_types

with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
self._write_index(index, new_index, path=tmp_file.name)

Expand All @@ -1034,7 +1104,7 @@ def _save_tasks_to_metadata(self) -> None:
.iter_rows()
):
tasks[task_name].append(task_type)
self.metadata["tasks"] = tasks
self._metadata["tasks"] = tasks
self._write_metadata()

def _warn_on_duplicates(self) -> None:
Expand Down
Loading

0 comments on commit 2858ae8

Please sign in to comment.