Skip to content

Commit

Permalink
Fix special augmentations with no bboxes (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
kozlov721 authored Jan 22, 2025
1 parent 27657d6 commit a04af9a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
14 changes: 9 additions & 5 deletions luxonis_ml/data/augmentations/albumentations_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing_extensions import TypeAlias, override

from luxonis_ml.data.utils.task_utils import get_task_name, task_is_metadata
from luxonis_ml.typing import ConfigItem, LoaderOutput, Params, TaskType
from luxonis_ml.typing import ConfigItem, LoaderOutput, Params

from .base_engine import AugmentationEngine
from .batch_compose import BatchCompose
Expand Down Expand Up @@ -245,7 +245,7 @@ def __init__(
self,
height: int,
width: int,
targets: Dict[str, TaskType],
targets: Dict[str, str],
config: Iterable[Params],
keep_aspect_ratio: bool = True,
is_validation_pipeline: bool = False,
Expand Down Expand Up @@ -526,9 +526,13 @@ def postprocess(
task = self.target_names_to_tasks[target_name]
task_name = get_task_name(task)

bbox_ordering = bboxes_indices.get(
task_name, np.array([], dtype=int)
)
if task_name not in bboxes_indices:
if "bboxes" in self.targets.values():
bbox_ordering = np.array([], dtype=int)
else:
bbox_ordering = np.arange(array.shape[0])
else:
bbox_ordering = bboxes_indices[task_name]

if target_type == "mask":
out_labels[task] = postprocess_mask(array)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_data/test_augmentations/test_special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np

from luxonis_ml.data import AlbumentationsEngine


def test_metadata_no_boxes():
config = [
{
"name": "Defocus",
"params": {"p": 1.0},
},
{
"name": "Mosaic4",
"params": {"p": 1.0, "out_width": 640, "out_height": 640},
},
]
augmentations = AlbumentationsEngine(
256, 256, {"/metadata/id": "metadata/id"}, config
)
_, labels = augmentations.apply(
[
(np.zeros((3, 256, 256)), {"/metadata/id": np.array([i])})
for i in range(4)
]
)
assert labels["/metadata/id"].tolist() == [0, 1, 2, 3]

0 comments on commit a04af9a

Please sign in to comment.