diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 395bf59108b9..6c760b42ba65 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -1111,7 +1111,7 @@ def _to_abs_frame(rel_frame: int) -> int: ) if bulk_context: - active_validation_frame_counts = bulk_context.active_validation_frame_counts + frame_selector = bulk_context.honeypot_frame_selector else: active_validation_frame_counts = { validation_frame: 0 for validation_frame in task_active_validation_frames @@ -1121,7 +1121,8 @@ def _to_abs_frame(rel_frame: int) -> int: if real_frame in task_active_validation_frames: active_validation_frame_counts[real_frame] += 1 - frame_selector = HoneypotFrameSelector(active_validation_frame_counts) + frame_selector = HoneypotFrameSelector(active_validation_frame_counts) + requested_frames = frame_selector.select_next_frames(segment_honeypots_count) requested_frames = list(map(_to_abs_frame, requested_frames)) else: @@ -1368,7 +1369,7 @@ def __init__( honeypot_frames: list[int], all_validation_frames: list[int], active_validation_frames: list[int], - validation_frame_counts: dict[int, int] | None = None + honeypot_frame_selector: HoneypotFrameSelector | None = None ): self.updated_honeypots: dict[int, models.Image] = {} self.updated_segments: list[int] = [] @@ -1380,7 +1381,7 @@ def __init__( self.honeypot_frames = honeypot_frames self.all_validation_frames = all_validation_frames self.active_validation_frames = active_validation_frames - self.active_validation_frame_counts = validation_frame_counts + self.honeypot_frame_selector = honeypot_frame_selector class TaskValidationLayoutWriteSerializer(serializers.Serializer): disabled_frames = serializers.ListField( @@ -1495,7 +1496,9 @@ def update(self, instance: models.Task, validated_data: dict[str, Any]) -> model ) elif frame_selection_method == models.JobFrameSelectionMethod.RANDOM_UNIFORM: # Reset distribution for active validation frames - bulk_context.active_validation_frame_counts = { f: 0 for f in active_validation_frames } + active_validation_frame_counts = { f: 0 for f in active_validation_frames } + frame_selector = HoneypotFrameSelector(active_validation_frame_counts) + bulk_context.honeypot_frame_selector = frame_selector # Could be done using Django ORM, but using order_by() and filter() # would result in an extra DB request diff --git a/cvat/apps/engine/task_validation.py b/cvat/apps/engine/task_validation.py index fe76b4e99408..4734c153e8b4 100644 --- a/cvat/apps/engine/task_validation.py +++ b/cvat/apps/engine/task_validation.py @@ -2,26 +2,109 @@ # # SPDX-License-Identifier: MIT -from collections.abc import Mapping, Sequence -from typing import Generic, TypeVar +from __future__ import annotations +from typing import Callable, Generic, Iterable, Mapping, Sequence, TypeVar + +import attrs import numpy as np -_T = TypeVar("_T") +_K = TypeVar("_K") + + +@attrs.define +class _BaggedCounter(Generic[_K]): + # Stores items with count = k in a single "bag". Bags are stored in the ascending order + bags: dict[ + int, + dict[_K, None], + # dict is used instead of a set to preserve item order. It's also more performant + ] + + @staticmethod + def from_dict(item_counts: Mapping[_K, int]) -> _BaggedCounter: + return _BaggedCounter.from_counts(item_counts, item_count=item_counts.__getitem__) + + @staticmethod + def from_counts(items: Sequence[_K], item_count: Callable[[_K], int]) -> _BaggedCounter: + bags = {} + for item in items: + count = item_count(item) + bags.setdefault(count, dict())[item] = None + + return _BaggedCounter(bags=bags) + + def __attrs_post_init__(self): + self._sort_bags() + + def _sort_bags(self): + self.bags = dict(sorted(self.bags.items(), key=lambda e: e[0])) + + def shuffle(self, *, rng: np.random.Generator | None): + if not rng: + rng = np.random.default_rng() + + for count, bag in self.bags.items(): + items = list(bag.items()) + rng.shuffle(items) + self.bags[count] = dict(items) + + def use_item(self, item: _K, *, count: int | None = None, bag: dict | None = None): + if count is not None: + if bag is None: + bag = self.bags[count] + elif count is None and bag is None: + count, bag = next((c, b) for c, b in self.bags.items() if item in b) + else: + raise AssertionError("'bag' can only be used together with 'count'") + bag.pop(item) -class HoneypotFrameSelector(Generic[_T]): + if not bag: + self.bags.pop(count) + + next_bag = self.bags.get(count + 1) + if next_bag is None: + next_bag = {} + self.bags[count + 1] = next_bag + self._sort_bags() # the new bag can be added in the wrong position if there were gaps + + next_bag[item] = None + + def __iter__(self) -> Iterable[tuple[int, _K, dict]]: + for count, bag in self.bags.items(): # bags must be ordered + for item in bag: + yield (count, item, bag) + + def select_next_least_used(self, count: int) -> Sequence[_K]: + pick = [None] * count + pick_original_use_counts = [(None, None)] * count + for i, (use_count, item, bag) in zip(range(count), self): + pick[i] = item + pick_original_use_counts[i] = (use_count, bag) + + for item, (use_count, bag) in zip(pick, pick_original_use_counts): + self.use_item(item, count=use_count, bag=bag) + + return pick + + +class HoneypotFrameSelector(Generic[_K]): def __init__( - self, validation_frame_counts: Mapping[_T, int], *, rng: np.random.Generator | None = None + self, + validation_frame_counts: Mapping[_K, int], + *, + rng: np.random.Generator | None = None, ): - self.validation_frame_counts = validation_frame_counts - if not rng: rng = np.random.default_rng() self.rng = rng - def select_next_frames(self, count: int) -> Sequence[_T]: + self._counter = _BaggedCounter.from_dict(validation_frame_counts) + self._counter.shuffle(rng=rng) + + def select_next_frames(self, count: int) -> Sequence[_K]: # This approach guarantees that: # - every GT frame is used # - GT frames are used uniformly (at most min count + 1) @@ -29,20 +112,8 @@ def select_next_frames(self, count: int) -> Sequence[_T]: # - honeypot sets are different in jobs # - honeypot sets are random # if possible (if the job and GT counts allow this). - pick = [] - - for random_number in self.rng.random(count): - least_count = min(c for f, c in self.validation_frame_counts.items() if f not in pick) - least_used_frames = tuple( - f - for f, c in self.validation_frame_counts.items() - if f not in pick - if c == least_count - ) - - selected_item = int(random_number * len(least_used_frames)) - selected_frame = least_used_frames[selected_item] - pick.append(selected_frame) - self.validation_frame_counts[selected_frame] += 1 - - return pick + # Picks must be reproducible for a given rng state. + """ + Selects 'count' least used items randomly, without repetition + """ + return self._counter.select_next_least_used(count)