Skip to content

Commit

Permalink
Fix incompatibility with LDF 1.0.0 (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 authored Jan 24, 2025
1 parent a04af9a commit a770f26
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 15 deletions.
6 changes: 2 additions & 4 deletions luxonis_ml/data/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from typing_extensions import Annotated

from luxonis_ml.data import LuxonisDataset, LuxonisLoader, LuxonisParser
from luxonis_ml.data.utils.constants import LDF_VERSION
from luxonis_ml.data.utils.visualizations import visualize
from luxonis_ml.enums import DatasetType

Expand Down Expand Up @@ -76,7 +75,7 @@ def print_info(name: str) -> None:
task_table = Table(
title="Tasks", box=rich.box.ROUNDED, row_styles=["yellow", "cyan"]
)
if len(tasks) > 1 or next(iter(tasks)):
if tasks and (len(tasks) > 1 or next(iter(tasks))):
task_table.add_column(
"Task Name", header_style="magenta i", max_width=30
)
Expand Down Expand Up @@ -107,8 +106,7 @@ def get_panels():
yield ""
yield Panel.fit(get_sizes_panel(), title="Split Sizes")
yield class_table
if dataset.version == LDF_VERSION:
yield task_table
yield task_table

print(Panel.fit(get_panels(), title="Dataset Info"))

Expand Down
107 changes: 96 additions & 11 deletions luxonis_ml/data/datasets/luxonis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@

logger = logging.getLogger(__name__)

LDF_1_0_0_TASKS = {
"classification",
"segmentation",
"boundingbox",
"keypoints",
"array",
}

LDF_1_0_0_TASK_TYPES = {
"BBoxAnnotation": "boundingbox",
"ClassificationAnnotation": "classification",
"PolylineSegmentationAnnotation": "segmentation",
"RLESegmentationAnnotation": "segmentation",
"MaskSegmentationAnnotation": "segmentation",
"KeypointAnnotation": "keypoints",
"ArrayAnnotation": "array",
}


class Skeletons(TypedDict):
labels: List[str]
Expand Down Expand Up @@ -438,25 +456,38 @@ def _load_df_offline(
self,
lazy: Literal[False] = ...,
raise_when_empty: Literal[False] = ...,
attempt_migration: bool = ...,
) -> Optional[pl.DataFrame]: ...

@overload
def _load_df_offline(
self, lazy: Literal[False] = ..., raise_when_empty: Literal[True] = ...
self,
lazy: Literal[False] = ...,
raise_when_empty: Literal[True] = ...,
attempt_migration: bool = ...,
) -> pl.DataFrame: ...

@overload
def _load_df_offline(
self, lazy: Literal[True] = ..., raise_when_empty: Literal[False] = ...
self,
lazy: Literal[True] = ...,
raise_when_empty: Literal[False] = ...,
attempt_migration: bool = ...,
) -> Optional[pl.LazyFrame]: ...

@overload
def _load_df_offline(
self, lazy: Literal[True] = ..., raise_when_empty: Literal[True] = ...
self,
lazy: Literal[True] = ...,
raise_when_empty: Literal[True] = ...,
attempt_migration: bool = ...,
) -> pl.LazyFrame: ...

def _load_df_offline(
self, lazy: bool = False, raise_when_empty: bool = False
self,
lazy: bool = False,
raise_when_empty: bool = False,
attempt_migration: bool = True,
) -> Optional[Union[pl.DataFrame, pl.LazyFrame]]:
"""Loads the dataset DataFrame **always** from the local
storage."""
Expand Down Expand Up @@ -486,18 +517,40 @@ def _load_df_offline(
if df is None and raise_when_empty:
raise FileNotFoundError(f"Dataset '{self.dataset_name}' is empty.")

if self.version == LDF_VERSION or df is None:
if not attempt_migration or self.version == LDF_VERSION or df is None:
return df

return (
df.rename({"class": "class_name"})
.with_columns(
[
pl.col("task").alias("task_type"),
pl.col("task").alias("task_name"),
pl.lit("image").alias("source_name"),
]
pl.when(pl.col("task").is_in(LDF_1_0_0_TASKS))
.then(pl.lit("detection"))
.otherwise(pl.col("task"))
.alias("task_name")
)
.with_columns(
pl.when(pl.col("type") == "BBoxAnnotation")
.then(pl.lit("boundingbox"))
.when(pl.col("type") == "ClassificationAnnotation")
.then(pl.lit("classification"))
.when(
pl.col("type").is_in(
[
"PolylineSegmentationAnnotation",
"RLESegmentationAnnotation",
"MaskSegmentationAnnotation",
]
)
)
.then(pl.lit("segmentation"))
.when(pl.col("type") == "KeypointAnnotation")
.then(pl.lit("keypoints"))
.when(pl.col("type") == "ArrayAnnotation")
.then(pl.lit("array"))
.otherwise(pl.col("type"))
.alias("task_type")
)
.with_columns(pl.lit("image").alias("source_name"))
.select(
[
"file",
Expand Down Expand Up @@ -626,7 +679,11 @@ def _get_metadata(self) -> Metadata:
self.metadata_path,
default=self.metadata_path / "metadata.json",
)
return json.loads(path.read_text())
metadata = json.loads(path.read_text())
version = Version.parse(metadata.get("ldf_version", "1.0.0"))
if version != LDF_VERSION: # pragma: no cover
metadata = self._migrate_metadata(metadata)
return metadata
else:
return {
"source": LuxonisSource().to_document(),
Expand All @@ -636,6 +693,34 @@ def _get_metadata(self) -> Metadata:
"skeletons": {},
}

def _migrate_metadata(
self, metadata: Metadata
) -> Metadata: # pragma: no cover
old_classes = metadata["classes"]
if set(old_classes.keys()) <= LDF_1_0_0_TASKS:
metadata["classes"] = {
"detection": next(iter(old_classes.values()))
}
metadata["tasks"] = {"detection": list(old_classes.keys())}
else:
df = self._load_df_offline(lazy=True, attempt_migration=False)
if df is None:
raise ValueError("Cannot migrate when the dataset is empty")
tasks_df = df.select(["task", "type"]).unique().collect()
new_classes = defaultdict(list)
tasks = defaultdict(list)
for task_name, task_type in tasks_df.iter_rows():
new_task_name = task_name
if task_name in LDF_1_0_0_TASKS:
new_task_name = "detection"
tasks[new_task_name].append(LDF_1_0_0_TASK_TYPES[task_type])
new_classes[new_task_name].extend(
old_classes.get(task_name, [])
)
metadata["classes"] = dict(new_classes)
metadata["tasks"] = dict(tasks)
return metadata

@property
def is_remote(self) -> bool:
return self.bucket_storage != BucketStorage.LOCAL
Expand Down

0 comments on commit a770f26

Please sign in to comment.