diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml
index bf90604cbb2f..620dc6c85d79 100644
--- a/.github/workflows/isort.yml
+++ b/.github/workflows/isort.yml
@@ -5,35 +5,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- - id: files
- uses: tj-actions/changed-files@v41.0.0
- with:
- files: |
- cvat-sdk/**/*.py
- cvat-cli/**/*.py
- tests/python/**/*.py
- cvat/apps/quality_control/**/*.py
- cvat/apps/analytics_report/**/*.py
- dir_names: true
- name: Run checks
run: |
- # If different modules use different isort configs,
- # we need to run isort for each python component group separately.
- # Otherwise, they all will use the same config.
+ pipx install $(grep "^isort" ./dev/requirements.txt)
- UPDATED_DIRS="${{steps.files.outputs.all_changed_files}}"
+ echo "isort version: $(isort --version-number)"
- if [[ ! -z $UPDATED_DIRS ]]; then
- pipx install $(grep "^isort" ./dev/requirements.txt)
-
- echo "isort version: $(isort --version-number)"
- echo "The dirs will be checked: $UPDATED_DIRS"
- EXIT_CODE=0
- for DIR in $UPDATED_DIRS; do
- isort --check $DIR || EXIT_CODE=$(($? | $EXIT_CODE)) || true
- done
- exit $EXIT_CODE
- else
- echo "No files with the \"py\" extension found"
- fi
+ isort --check --diff --resolve-all-configs .
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 31a4aae3db7f..a18a0284f814 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -16,6 +16,38 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
+
+## \[2.25.0\] - 2025-01-09
+
+### Added
+
+- \[CLI\] Added commands for working with native functions
+ ()
+
+- Ultralytics YOLO formats now support tracks
+ ()
+
+### Changed
+
+- YOLOv8 formats renamed to Ultralytics YOLO formats
+ ()
+
+- The `match_empty_frames` quality setting is changed to `empty_is_annotated`.
+ The updated option includes any empty frames in the final metrics instead of only
+ matching empty frames. This makes metrics such as Precision much more representative and useful.
+ ()
+
+### Fixed
+
+- Changing rotation after export/import in Ultralytics YOLO Oriented Boxes format
+ ()
+
+- Export to yolo formats if both Train and default dataset are present
+ ()
+
+- Issue with deleting frames
+ ()
+
## \[2.24.0\] - 2024-12-20
diff --git a/changelog.d/20241212_193004_roman_cli_agent.md b/changelog.d/20241212_193004_roman_cli_agent.md
deleted file mode 100644
index f7fd8c0a5be4..000000000000
--- a/changelog.d/20241212_193004_roman_cli_agent.md
+++ /dev/null
@@ -1,4 +0,0 @@
-### Added
-
-- \[CLI\] Added commands for working with native functions
- ()
diff --git a/changelog.d/20241224_145339_dmitrii.lavrukhin_yolov8_renaming.md b/changelog.d/20241224_145339_dmitrii.lavrukhin_yolov8_renaming.md
deleted file mode 100644
index 957ff9666951..000000000000
--- a/changelog.d/20241224_145339_dmitrii.lavrukhin_yolov8_renaming.md
+++ /dev/null
@@ -1,4 +0,0 @@
-### Changed
-
-- YOLOv8 formats renamed to Ultralytics YOLO formats
- ()
diff --git a/cvat-cli/requirements/base.txt b/cvat-cli/requirements/base.txt
index a53fd13b472e..adb737df6706 100644
--- a/cvat-cli/requirements/base.txt
+++ b/cvat-cli/requirements/base.txt
@@ -1,4 +1,4 @@
-cvat-sdk==2.24.1
+cvat-sdk==2.25.1
attrs>=24.2.0
Pillow>=10.3.0
diff --git a/cvat-cli/src/cvat_cli/version.py b/cvat-cli/src/cvat_cli/version.py
index c176a6b233ec..9cd1b45bc73a 100644
--- a/cvat-cli/src/cvat_cli/version.py
+++ b/cvat-cli/src/cvat_cli/version.py
@@ -1 +1 @@
-VERSION = "2.24.1"
+VERSION = "2.25.1"
diff --git a/cvat-core/src/quality-settings.ts b/cvat-core/src/quality-settings.ts
index 7c591e371cc4..bc553105c181 100644
--- a/cvat-core/src/quality-settings.ts
+++ b/cvat-core/src/quality-settings.ts
@@ -38,7 +38,7 @@ export default class QualitySettings {
#objectVisibilityThreshold: number;
#panopticComparison: boolean;
#compareAttributes: boolean;
- #matchEmptyFrames: boolean;
+ #emptyIsAnnotated: boolean;
#descriptions: Record;
constructor(initialData: SerializedQualitySettingsData) {
@@ -60,7 +60,7 @@ export default class QualitySettings {
this.#objectVisibilityThreshold = initialData.object_visibility_threshold;
this.#panopticComparison = initialData.panoptic_comparison;
this.#compareAttributes = initialData.compare_attributes;
- this.#matchEmptyFrames = initialData.match_empty_frames;
+ this.#emptyIsAnnotated = initialData.empty_is_annotated;
this.#descriptions = initialData.descriptions;
}
@@ -200,12 +200,12 @@ export default class QualitySettings {
this.#maxValidationsPerJob = newVal;
}
- get matchEmptyFrames(): boolean {
- return this.#matchEmptyFrames;
+ get emptyIsAnnotated(): boolean {
+ return this.#emptyIsAnnotated;
}
- set matchEmptyFrames(newVal: boolean) {
- this.#matchEmptyFrames = newVal;
+ set emptyIsAnnotated(newVal: boolean) {
+ this.#emptyIsAnnotated = newVal;
}
get descriptions(): Record {
@@ -236,7 +236,7 @@ export default class QualitySettings {
target_metric: this.#targetMetric,
target_metric_threshold: this.#targetMetricThreshold,
max_validations_per_job: this.#maxValidationsPerJob,
- match_empty_frames: this.#matchEmptyFrames,
+ empty_is_annotated: this.#emptyIsAnnotated,
};
return result;
diff --git a/cvat-core/src/server-response-types.ts b/cvat-core/src/server-response-types.ts
index ea97c0730aaa..ef635d12004e 100644
--- a/cvat-core/src/server-response-types.ts
+++ b/cvat-core/src/server-response-types.ts
@@ -258,7 +258,7 @@ export interface SerializedQualitySettingsData {
object_visibility_threshold?: number;
panoptic_comparison?: boolean;
compare_attributes?: boolean;
- match_empty_frames?: boolean;
+ empty_is_annotated?: boolean;
descriptions?: Record;
}
diff --git a/cvat-sdk/gen/generate.sh b/cvat-sdk/gen/generate.sh
index 939ac9d65b44..d37fa252e5bb 100755
--- a/cvat-sdk/gen/generate.sh
+++ b/cvat-sdk/gen/generate.sh
@@ -8,7 +8,7 @@ set -e
GENERATOR_VERSION="v6.0.1"
-VERSION="2.24.1"
+VERSION="2.25.1"
LIB_NAME="cvat_sdk"
LAYER1_LIB_NAME="${LIB_NAME}/api_client"
DST_DIR="$(cd "$(dirname -- "$0")/.." && pwd)"
diff --git a/cvat-sdk/pyproject.toml b/cvat-sdk/pyproject.toml
index ce8cba3ffba6..8d3fb7787504 100644
--- a/cvat-sdk/pyproject.toml
+++ b/cvat-sdk/pyproject.toml
@@ -7,3 +7,4 @@ profile = "black"
forced_separate = ["tests"]
line_length = 100
skip_gitignore = true # align tool behavior with Black
+known_first_party = ["cvat_sdk"]
diff --git a/cvat-ui/src/components/quality-control/quality-control-page.tsx b/cvat-ui/src/components/quality-control/quality-control-page.tsx
index cbaa26a8dd09..afa166f6f5fa 100644
--- a/cvat-ui/src/components/quality-control/quality-control-page.tsx
+++ b/cvat-ui/src/components/quality-control/quality-control-page.tsx
@@ -223,7 +223,7 @@ function QualityControlPage(): JSX.Element {
settings.lowOverlapThreshold = values.lowOverlapThreshold / 100;
settings.iouThreshold = values.iouThreshold / 100;
settings.compareAttributes = values.compareAttributes;
- settings.matchEmptyFrames = values.matchEmptyFrames;
+ settings.emptyIsAnnotated = values.emptyIsAnnotated;
settings.oksSigma = values.oksSigma / 100;
settings.pointSizeBase = values.pointSizeBase;
diff --git a/cvat-ui/src/components/quality-control/task-quality/quality-settings-form.tsx b/cvat-ui/src/components/quality-control/task-quality/quality-settings-form.tsx
index 87a727f9772b..b5218475b418 100644
--- a/cvat-ui/src/components/quality-control/task-quality/quality-settings-form.tsx
+++ b/cvat-ui/src/components/quality-control/task-quality/quality-settings-form.tsx
@@ -34,7 +34,7 @@ export default function QualitySettingsForm(props: Readonly): JSX.Element
lowOverlapThreshold: settings.lowOverlapThreshold * 100,
iouThreshold: settings.iouThreshold * 100,
compareAttributes: settings.compareAttributes,
- matchEmptyFrames: settings.matchEmptyFrames,
+ emptyIsAnnotated: settings.emptyIsAnnotated,
oksSigma: settings.oksSigma * 100,
pointSizeBase: settings.pointSizeBase,
@@ -81,7 +81,7 @@ export default function QualitySettingsForm(props: Readonly): JSX.Element
{makeTooltipFragment('Target metric', targetMetricDescription)}
{makeTooltipFragment('Target metric threshold', settings.descriptions.targetMetricThreshold)}
{makeTooltipFragment('Compare attributes', settings.descriptions.compareAttributes)}
- {makeTooltipFragment('Match empty frames', settings.descriptions.matchEmptyFrames)}
+ {makeTooltipFragment('Empty frames are annotated', settings.descriptions.emptyIsAnnotated)}
>,
);
@@ -198,12 +198,12 @@ export default function QualitySettingsForm(props: Readonly): JSX.Element
- Match empty frames
+ Empty frames are annotated
diff --git a/cvat/__init__.py b/cvat/__init__.py
index cd11fa1758cc..c28191c91a0c 100644
--- a/cvat/__init__.py
+++ b/cvat/__init__.py
@@ -4,6 +4,6 @@
from cvat.utils.version import get_version
-VERSION = (2, 24, 1, "alpha", 0)
+VERSION = (2, 25, 1, "alpha", 0)
__version__ = get_version(VERSION)
diff --git a/cvat/apps/dataset_manager/annotation.py b/cvat/apps/dataset_manager/annotation.py
index 4ea10ba9619d..943e53d003e3 100644
--- a/cvat/apps/dataset_manager/annotation.py
+++ b/cvat/apps/dataset_manager/annotation.py
@@ -3,19 +3,19 @@
#
# SPDX-License-Identifier: MIT
-from copy import copy, deepcopy
-
import math
from collections.abc import Container, Sequence
+from copy import copy, deepcopy
+from itertools import chain
from typing import Optional
+
import numpy as np
-from itertools import chain
from scipy.optimize import linear_sum_assignment
from shapely import geometry
-from cvat.apps.engine.models import ShapeType, DimensionType
-from cvat.apps.engine.serializers import LabeledDataSerializer
from cvat.apps.dataset_manager.util import faster_deepcopy
+from cvat.apps.engine.models import DimensionType, ShapeType
+from cvat.apps.engine.serializers import LabeledDataSerializer
class AnnotationIR:
diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py
index 8b759f7b6316..7fddcb198f35 100644
--- a/cvat/apps/dataset_manager/bindings.py
+++ b/cvat/apps/dataset_manager/bindings.py
@@ -16,29 +16,39 @@
from types import SimpleNamespace
from typing import Any, Callable, Literal, NamedTuple, Optional, Union
-from attrs.converters import to_bool
import datumaro as dm
import defusedxml.ElementTree as ET
import rq
from attr import attrib, attrs
+from attrs.converters import to_bool
from datumaro.components.format_detection import RejectionReason
+from django.conf import settings
from django.db.models import Prefetch, QuerySet
from django.utils import timezone
-from django.conf import settings
from cvat.apps.dataset_manager.formats.utils import get_label_color
from cvat.apps.dataset_manager.util import add_prefetch_fields
from cvat.apps.engine import models
-from cvat.apps.engine.frame_provider import TaskFrameProvider, FrameQuality, FrameOutputType
-from cvat.apps.engine.models import (AttributeSpec, AttributeType, DimensionType, Job,
- JobType, Label, LabelType, Project, SegmentType, ShapeType,
- Task)
-from cvat.apps.engine.rq_job_handler import RQJobMetaField
+from cvat.apps.engine.frame_provider import FrameOutputType, FrameQuality, TaskFrameProvider
from cvat.apps.engine.lazy_list import LazyList
+from cvat.apps.engine.models import (
+ AttributeSpec,
+ AttributeType,
+ DimensionType,
+ Job,
+ JobType,
+ Label,
+ LabelType,
+ Project,
+ SegmentType,
+ ShapeType,
+ Task,
+)
+from cvat.apps.engine.rq_job_handler import RQJobMetaField
-from .annotation import AnnotationIR, AnnotationManager, TrackManager
-from .formats.transformations import MaskConverter, EllipsesToMasks
from ..engine.log import ServerLogManager
+from .annotation import AnnotationIR, AnnotationManager, TrackManager
+from .formats.transformations import EllipsesToMasks, MaskConverter
slogger = ServerLogManager(__name__)
@@ -2175,7 +2185,11 @@ def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[ProjectDa
'coco',
'coco_instances',
'coco_person_keypoints',
- 'voc'
+ 'voc',
+ 'yolo_ultralytics_detection',
+ 'yolo_ultralytics_segmentation',
+ 'yolo_ultralytics_oriented_boxes',
+ 'yolo_ultralytics_pose',
]
label_cat = dm_dataset.categories()[dm.AnnotationType.label]
diff --git a/cvat/apps/dataset_manager/formats/camvid.py b/cvat/apps/dataset_manager/formats/camvid.py
index 75cea9e98bd4..e995c5f1075d 100644
--- a/cvat/apps/dataset_manager/formats/camvid.py
+++ b/cvat/apps/dataset_manager/formats/camvid.py
@@ -6,12 +6,11 @@
from datumaro.components.dataset import Dataset
from pyunpack import Archive
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor,
- import_dm_annotations)
+from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, import_dm_annotations
from cvat.apps.dataset_manager.util import make_zip_archive
-from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .registry import dm_env, exporter, importer
+from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .utils import make_colormap
diff --git a/cvat/apps/dataset_manager/formats/cityscapes.py b/cvat/apps/dataset_manager/formats/cityscapes.py
index ea39578ea3f3..dce977b94d1a 100644
--- a/cvat/apps/dataset_manager/formats/cityscapes.py
+++ b/cvat/apps/dataset_manager/formats/cityscapes.py
@@ -6,15 +6,18 @@
import os.path as osp
from datumaro.components.dataset import Dataset
-from datumaro.plugins.cityscapes_format import write_label_map
+from datumaro.plugins.data_formats.cityscapes import write_label_map
from pyunpack import Archive
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset,
- import_dm_annotations)
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
-from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .registry import dm_env, exporter, importer
+from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .utils import make_colormap
diff --git a/cvat/apps/dataset_manager/formats/coco.py b/cvat/apps/dataset_manager/formats/coco.py
index 1d1a8ce4d0d5..cab74bcb42e1 100644
--- a/cvat/apps/dataset_manager/formats/coco.py
+++ b/cvat/apps/dataset_manager/formats/coco.py
@@ -5,17 +5,21 @@
import zipfile
-from datumaro.components.dataset import Dataset
from datumaro.components.annotation import AnnotationType
-from datumaro.plugins.coco_format.importer import CocoImporter
+from datumaro.components.dataset import Dataset
+from datumaro.plugins.data_formats.coco.importer import CocoImporter
from cvat.apps.dataset_manager.bindings import (
- GetCVATDataExtractor, NoMediaInAnnotationFileError, import_dm_annotations, detect_dataset
+ GetCVATDataExtractor,
+ NoMediaInAnnotationFileError,
+ detect_dataset,
+ import_dm_annotations,
)
from cvat.apps.dataset_manager.util import make_zip_archive
from .registry import dm_env, exporter, importer
+
@exporter(name='COCO', ext='ZIP', version='1.0')
def _export(dst_file, temp_dir, instance_data, save_images=False):
with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor:
diff --git a/cvat/apps/dataset_manager/formats/cvat.py b/cvat/apps/dataset_manager/formats/cvat.py
index fa46b58813bf..f5c7dc18fcda 100644
--- a/cvat/apps/dataset_manager/formats/cvat.py
+++ b/cvat/apps/dataset_manager/formats/cvat.py
@@ -11,29 +11,34 @@
from io import BufferedWriter
from typing import Callable, Union
-from datumaro.components.annotation import (AnnotationType, Bbox, Label,
- LabelCategories, Points, Polygon,
- PolyLine, Skeleton)
+from datumaro.components.annotation import (
+ AnnotationType,
+ Bbox,
+ Label,
+ LabelCategories,
+ Points,
+ Polygon,
+ PolyLine,
+ Skeleton,
+)
from datumaro.components.dataset import Dataset, DatasetItem
-from datumaro.components.extractor import (DEFAULT_SUBSET_NAME, Extractor,
- Importer)
-from datumaro.plugins.cvat_format.extractor import CvatImporter as _CvatImporter
-
+from datumaro.components.extractor import DEFAULT_SUBSET_NAME, Extractor, Importer
+from datumaro.plugins.data_formats.cvat.base import CvatImporter as _CvatImporter
from datumaro.util.image import Image
from defusedxml import ElementTree
from cvat.apps.dataset_manager.bindings import (
+ JobData,
NoMediaInAnnotationFileError,
ProjectData,
TaskData,
- JobData,
detect_dataset,
get_defaulted_subset,
import_dm_annotations,
- match_dm_item
+ match_dm_item,
)
from cvat.apps.dataset_manager.util import make_zip_archive
-from cvat.apps.engine.frame_provider import FrameQuality, FrameOutputType, make_frame_provider
+from cvat.apps.engine.frame_provider import FrameOutputType, FrameQuality, make_frame_provider
from .registry import dm_env, exporter, importer
diff --git a/cvat/apps/dataset_manager/formats/datumaro.py b/cvat/apps/dataset_manager/formats/datumaro.py
index 4fc1d246dd47..81f86cb32065 100644
--- a/cvat/apps/dataset_manager/formats/datumaro.py
+++ b/cvat/apps/dataset_manager/formats/datumaro.py
@@ -4,10 +4,14 @@
# SPDX-License-Identifier: MIT
import zipfile
+
from datumaro.components.dataset import Dataset
from cvat.apps.dataset_manager.bindings import (
- GetCVATDataExtractor, import_dm_annotations, NoMediaInAnnotationFileError, detect_dataset
+ GetCVATDataExtractor,
+ NoMediaInAnnotationFileError,
+ detect_dataset,
+ import_dm_annotations,
)
from cvat.apps.dataset_manager.util import make_zip_archive
from cvat.apps.engine.models import DimensionType
diff --git a/cvat/apps/dataset_manager/formats/icdar.py b/cvat/apps/dataset_manager/formats/icdar.py
index 5d031eef82b0..c72f9708fe11 100644
--- a/cvat/apps/dataset_manager/formats/icdar.py
+++ b/cvat/apps/dataset_manager/formats/icdar.py
@@ -5,17 +5,15 @@
import zipfile
-from datumaro.components.annotation import (AnnotationType, Caption, Label,
- LabelCategories)
+from datumaro.components.annotation import AnnotationType, Caption, Label, LabelCategories
from datumaro.components.dataset import Dataset
from datumaro.components.extractor import ItemTransform
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor,
- import_dm_annotations)
+from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, import_dm_annotations
from cvat.apps.dataset_manager.util import make_zip_archive
-from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .registry import dm_env, exporter, importer
+from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
class AddLabelToAnns(ItemTransform):
diff --git a/cvat/apps/dataset_manager/formats/imagenet.py b/cvat/apps/dataset_manager/formats/imagenet.py
index fd5e9a99a176..273f47616bc1 100644
--- a/cvat/apps/dataset_manager/formats/imagenet.py
+++ b/cvat/apps/dataset_manager/formats/imagenet.py
@@ -9,8 +9,7 @@
from datumaro.components.dataset import Dataset
-from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, \
- import_dm_annotations
+from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, import_dm_annotations
from cvat.apps.dataset_manager.util import make_zip_archive
from .registry import dm_env, exporter, importer
diff --git a/cvat/apps/dataset_manager/formats/kitti.py b/cvat/apps/dataset_manager/formats/kitti.py
index 01e1cd3fc4bc..631f903f7289 100644
--- a/cvat/apps/dataset_manager/formats/kitti.py
+++ b/cvat/apps/dataset_manager/formats/kitti.py
@@ -6,15 +6,18 @@
import os.path as osp
from datumaro.components.dataset import Dataset
-from datumaro.plugins.kitti_format.format import KittiPath, write_label_map
-
+from datumaro.plugins.data_formats.kitti.format import KittiPath, write_label_map
from pyunpack import Archive
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, import_dm_annotations)
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
-from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .registry import dm_env, exporter, importer
+from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .utils import make_colormap
diff --git a/cvat/apps/dataset_manager/formats/labelme.py b/cvat/apps/dataset_manager/formats/labelme.py
index be9679f268e8..179fb320f322 100644
--- a/cvat/apps/dataset_manager/formats/labelme.py
+++ b/cvat/apps/dataset_manager/formats/labelme.py
@@ -6,8 +6,11 @@
from datumaro.components.dataset import Dataset
from pyunpack import Archive
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset,
- import_dm_annotations)
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.formats.transformations import MaskToPolygonTransformation
from cvat.apps.dataset_manager.util import make_zip_archive
diff --git a/cvat/apps/dataset_manager/formats/lfw.py b/cvat/apps/dataset_manager/formats/lfw.py
index 0af356332bb5..407240c5e0a3 100644
--- a/cvat/apps/dataset_manager/formats/lfw.py
+++ b/cvat/apps/dataset_manager/formats/lfw.py
@@ -6,8 +6,11 @@
from datumaro.components.dataset import Dataset
from pyunpack import Archive
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset,
- import_dm_annotations)
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
from .registry import dm_env, exporter, importer
diff --git a/cvat/apps/dataset_manager/formats/market1501.py b/cvat/apps/dataset_manager/formats/market1501.py
index 6be8b2fcf75f..e9d46a095bc8 100644
--- a/cvat/apps/dataset_manager/formats/market1501.py
+++ b/cvat/apps/dataset_manager/formats/market1501.py
@@ -5,17 +5,20 @@
import zipfile
-from datumaro.components.annotation import (AnnotationType, Label,
- LabelCategories)
+from datumaro.components.annotation import AnnotationType, Label, LabelCategories
from datumaro.components.dataset import Dataset
from datumaro.components.extractor import ItemTransform
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset,
- import_dm_annotations)
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
from .registry import dm_env, exporter, importer
+
class AttrToLabelAttr(ItemTransform):
def __init__(self, extractor, label):
super().__init__(extractor)
diff --git a/cvat/apps/dataset_manager/formats/mask.py b/cvat/apps/dataset_manager/formats/mask.py
index f003f68383e7..eab4238f4242 100644
--- a/cvat/apps/dataset_manager/formats/mask.py
+++ b/cvat/apps/dataset_manager/formats/mask.py
@@ -6,14 +6,18 @@
from datumaro.components.dataset import Dataset
from pyunpack import Archive
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset,
- import_dm_annotations)
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
-from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .registry import dm_env, exporter, importer
+from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .utils import make_colormap
+
@exporter(name='Segmentation mask', ext='ZIP', version='1.1')
def _export(dst_file, temp_dir, instance_data, save_images=False):
with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor:
diff --git a/cvat/apps/dataset_manager/formats/mots.py b/cvat/apps/dataset_manager/formats/mots.py
index 9ed156e6cd4e..736ccb1ce0f8 100644
--- a/cvat/apps/dataset_manager/formats/mots.py
+++ b/cvat/apps/dataset_manager/formats/mots.py
@@ -8,12 +8,16 @@
from datumaro.components.extractor import ItemTransform
from pyunpack import Archive
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset,
- find_dataset_root, match_dm_item)
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ find_dataset_root,
+ match_dm_item,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
-from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .registry import dm_env, exporter, importer
+from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
class KeepTracks(ItemTransform):
diff --git a/cvat/apps/dataset_manager/formats/openimages.py b/cvat/apps/dataset_manager/formats/openimages.py
index 51fcee29a2fb..2ae544238ee2 100644
--- a/cvat/apps/dataset_manager/formats/openimages.py
+++ b/cvat/apps/dataset_manager/formats/openimages.py
@@ -7,16 +7,21 @@
import os.path as osp
from datumaro.components.dataset import Dataset, DatasetItem
-from datumaro.plugins.open_images_format import OpenImagesPath
+from datumaro.plugins.data_formats.open_images import OpenImagesPath
from datumaro.util.image import DEFAULT_IMAGE_META_FILE_NAME
from pyunpack import Archive
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset,
- find_dataset_root, import_dm_annotations, match_dm_item)
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ find_dataset_root,
+ import_dm_annotations,
+ match_dm_item,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
-from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
from .registry import dm_env, exporter, importer
+from .transformations import MaskToPolygonTransformation, RotatedBoxesToPolygons
def find_item_ids(path):
diff --git a/cvat/apps/dataset_manager/formats/pascal_voc.py b/cvat/apps/dataset_manager/formats/pascal_voc.py
index a0d84b745d73..3b55928e1f90 100644
--- a/cvat/apps/dataset_manager/formats/pascal_voc.py
+++ b/cvat/apps/dataset_manager/formats/pascal_voc.py
@@ -11,7 +11,11 @@
from datumaro.components.dataset import Dataset
from pyunpack import Archive
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, import_dm_annotations)
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.formats.transformations import MaskToPolygonTransformation
from cvat.apps.dataset_manager.util import make_zip_archive
diff --git a/cvat/apps/dataset_manager/formats/pointcloud.py b/cvat/apps/dataset_manager/formats/pointcloud.py
index 6ddfbb495427..8743c6eb8f3c 100644
--- a/cvat/apps/dataset_manager/formats/pointcloud.py
+++ b/cvat/apps/dataset_manager/formats/pointcloud.py
@@ -7,8 +7,11 @@
from datumaro.components.dataset import Dataset
-from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset,
- import_dm_annotations)
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
from cvat.apps.engine.models import DimensionType
diff --git a/cvat/apps/dataset_manager/formats/transformations.py b/cvat/apps/dataset_manager/formats/transformations.py
index 99d754252378..496786126709 100644
--- a/cvat/apps/dataset_manager/formats/transformations.py
+++ b/cvat/apps/dataset_manager/formats/transformations.py
@@ -4,12 +4,12 @@
# SPDX-License-Identifier: MIT
import math
-import cv2
-import numpy as np
from itertools import chain
-from pycocotools import mask as mask_utils
+import cv2
import datumaro as dm
+import numpy as np
+from pycocotools import mask as mask_utils
class RotatedBoxesToPolygons(dm.ItemTransform):
@@ -37,6 +37,7 @@ def transform_item(self, item):
return item.wrap(annotations=annotations)
+
class MaskConverter:
@staticmethod
def cvat_rle_to_dm_rle(shape, img_h: int, img_w: int) -> dm.RleMask:
@@ -100,6 +101,7 @@ def rle(cls, arr: np.ndarray) -> list[int]:
return cvat_rle
+
class EllipsesToMasks:
@staticmethod
def convert_ellipse(ellipse, img_h, img_w):
@@ -115,6 +117,7 @@ def convert_ellipse(ellipse, img_h, img_w):
return dm.RleMask(rle=rle, label=ellipse.label, z_order=ellipse.z_order,
attributes=ellipse.attributes, group=ellipse.group)
+
class MaskToPolygonTransformation:
"""
Manages common logic for mask to polygons conversion in dataset import.
@@ -130,3 +133,13 @@ def convert_dataset(cls, dataset, **kwargs):
if kwargs.get('conv_mask_to_poly', True):
dataset.transform('masks_to_polygons')
return dataset
+
+
+class SetKeyframeForEveryTrackShape(dm.ItemTransform):
+ def transform_item(self, item):
+ annotations = []
+ for ann in item.annotations:
+ if "track_id" in ann.attributes:
+ ann = ann.wrap(attributes=dict(ann.attributes, keyframe=True))
+ annotations.append(ann)
+ return item.wrap(annotations=annotations)
diff --git a/cvat/apps/dataset_manager/formats/utils.py b/cvat/apps/dataset_manager/formats/utils.py
index 7811fbbfc902..f565c0aed687 100644
--- a/cvat/apps/dataset_manager/formats/utils.py
+++ b/cvat/apps/dataset_manager/formats/utils.py
@@ -2,13 +2,14 @@
#
# SPDX-License-Identifier: MIT
-import os.path as osp
-from hashlib import blake2s
import itertools
import operator
+import os.path as osp
+from hashlib import blake2s
from datumaro.util.os_util import make_file_name
+
def get_color_from_index(index):
def get_bit(number, index):
return (number >> index) & 1
diff --git a/cvat/apps/dataset_manager/formats/velodynepoint.py b/cvat/apps/dataset_manager/formats/velodynepoint.py
index 9912d0b1d67b..d6051bf6fce8 100644
--- a/cvat/apps/dataset_manager/formats/velodynepoint.py
+++ b/cvat/apps/dataset_manager/formats/velodynepoint.py
@@ -8,14 +8,16 @@
from datumaro.components.dataset import Dataset
from datumaro.components.extractor import ItemTransform
-from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, detect_dataset, \
- import_dm_annotations
-from .registry import dm_env
-
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
from cvat.apps.engine.models import DimensionType
-from .registry import exporter, importer
+from .registry import dm_env, exporter, importer
+
class RemoveTrackingInformation(ItemTransform):
def transform_item(self, item):
diff --git a/cvat/apps/dataset_manager/formats/vggface2.py b/cvat/apps/dataset_manager/formats/vggface2.py
index 642171f0f8d9..aa172f947db3 100644
--- a/cvat/apps/dataset_manager/formats/vggface2.py
+++ b/cvat/apps/dataset_manager/formats/vggface2.py
@@ -7,8 +7,12 @@
from datumaro.components.dataset import Dataset
-from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, TaskData, detect_dataset, \
- import_dm_annotations
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ TaskData,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
from .registry import dm_env, exporter, importer
diff --git a/cvat/apps/dataset_manager/formats/widerface.py b/cvat/apps/dataset_manager/formats/widerface.py
index 12a9bf0d21e5..99480bf1f8f5 100644
--- a/cvat/apps/dataset_manager/formats/widerface.py
+++ b/cvat/apps/dataset_manager/formats/widerface.py
@@ -7,8 +7,11 @@
from datumaro.components.dataset import Dataset
-from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, detect_dataset, \
- import_dm_annotations
+from cvat.apps.dataset_manager.bindings import (
+ GetCVATDataExtractor,
+ detect_dataset,
+ import_dm_annotations,
+)
from cvat.apps.dataset_manager.util import make_zip_archive
from .registry import dm_env, exporter, importer
diff --git a/cvat/apps/dataset_manager/formats/yolo.py b/cvat/apps/dataset_manager/formats/yolo.py
index 1a138557c862..2bcfdfca1325 100644
--- a/cvat/apps/dataset_manager/formats/yolo.py
+++ b/cvat/apps/dataset_manager/formats/yolo.py
@@ -4,28 +4,40 @@
# SPDX-License-Identifier: MIT
import os.path as osp
from glob import glob
+from typing import Callable, Optional
+from datumaro.components.annotation import AnnotationType
+from datumaro.components.extractor import DatasetItem
+from datumaro.components.project import Dataset
from pyunpack import Archive
from cvat.apps.dataset_manager.bindings import (
+ CommonData,
GetCVATDataExtractor,
+ ProjectData,
detect_dataset,
+ find_dataset_root,
import_dm_annotations,
match_dm_item,
- find_dataset_root,
)
from cvat.apps.dataset_manager.util import make_zip_archive
-from datumaro.components.annotation import AnnotationType
-from datumaro.components.extractor import DatasetItem
-from datumaro.components.project import Dataset
from .registry import dm_env, exporter, importer
+from .transformations import SetKeyframeForEveryTrackShape
-def _export_common(dst_file, temp_dir, instance_data, format_name, *, save_images=False):
+def _export_common(
+ dst_file: str,
+ temp_dir: str,
+ instance_data: ProjectData | CommonData,
+ format_name: str,
+ *,
+ save_images: bool = False,
+ **kwargs
+):
with GetCVATDataExtractor(instance_data, include_images=save_images) as extractor:
dataset = Dataset.from_extractors(extractor, env=dm_env)
- dataset.export(temp_dir, format_name, save_images=save_images)
+ dataset.export(temp_dir, format_name, save_images=save_images, **kwargs)
make_zip_archive(temp_dir, dst_file)
@@ -37,12 +49,12 @@ def _export_yolo(*args, **kwargs):
def _import_common(
src_file,
- temp_dir,
- instance_data,
- format_name,
+ temp_dir: str,
+ instance_data: ProjectData | CommonData,
+ format_name: str,
*,
- load_data_callback=None,
- import_kwargs=None,
+ load_data_callback: Optional[Callable] = None,
+ import_kwargs: dict | None = None,
**kwargs
):
Archive(src_file.name).extractall(temp_dir)
@@ -67,6 +79,7 @@ def _import_common(
detect_dataset(temp_dir, format_name=format_name, importer=dm_env.importers.get(format_name))
dataset = Dataset.import_from(temp_dir, format_name,
env=dm_env, image_info=image_info, **(import_kwargs or {}))
+ dataset = dataset.transform(SetKeyframeForEveryTrackShape)
if load_data_callback is not None:
load_data_callback(dataset, instance_data)
import_dm_annotations(dataset, instance_data)
@@ -82,6 +95,11 @@ def _export_yolo_ultralytics_detection(*args, **kwargs):
_export_common(*args, format_name='yolo_ultralytics_detection', **kwargs)
+@exporter(name='Ultralytics YOLO Detection Track', ext='ZIP', version='1.0')
+def _export_yolo_ultralytics_detection_track(*args, **kwargs):
+ _export_common(*args, format_name='yolo_ultralytics_detection', write_track_id=True, **kwargs)
+
+
@exporter(name='Ultralytics YOLO Oriented Bounding Boxes', ext='ZIP', version='1.0')
def _export_yolo_ultralytics_oriented_boxes(*args, **kwargs):
_export_common(*args, format_name='yolo_ultralytics_oriented_boxes', **kwargs)
diff --git a/cvat/apps/dataset_manager/project.py b/cvat/apps/dataset_manager/project.py
index 93ac651cf477..ad51370b04e1 100644
--- a/cvat/apps/dataset_manager/project.py
+++ b/cvat/apps/dataset_manager/project.py
@@ -6,22 +6,22 @@
import os
from collections.abc import Mapping
from tempfile import TemporaryDirectory
-import rq
from typing import Any, Callable
-from datumaro.components.errors import DatasetError, DatasetImportError, DatasetNotFoundError
-from django.db import transaction
+import rq
+from datumaro.components.errors import DatasetError, DatasetImportError, DatasetNotFoundError
from django.conf import settings
+from django.db import transaction
+from cvat.apps.dataset_manager.task import TaskAnnotation
from cvat.apps.engine import models
from cvat.apps.engine.log import DatasetLogManager
+from cvat.apps.engine.rq_job_handler import RQJobMetaField
from cvat.apps.engine.serializers import DataSerializer, TaskWriteSerializer
from cvat.apps.engine.task import _create_thread as create_task
-from cvat.apps.engine.rq_job_handler import RQJobMetaField
-from cvat.apps.dataset_manager.task import TaskAnnotation
from .annotation import AnnotationIR
-from .bindings import CvatDatasetNotFoundError, ProjectData, load_dataset_data, CvatImportError
+from .bindings import CvatDatasetNotFoundError, CvatImportError, ProjectData, load_dataset_data
from .formats.registry import make_exporter, make_importer
dlogger = DatasetLogManager()
diff --git a/cvat/apps/dataset_manager/task.py b/cvat/apps/dataset_manager/task.py
index 83886d7e9cf1..74f035d40787 100644
--- a/cvat/apps/dataset_manager/task.py
+++ b/cvat/apps/dataset_manager/task.py
@@ -10,27 +10,34 @@
from enum import Enum
from tempfile import TemporaryDirectory
from typing import Optional, Union
-from datumaro.components.errors import DatasetError, DatasetImportError, DatasetNotFoundError
+from datumaro.components.errors import DatasetError, DatasetImportError, DatasetNotFoundError
+from django.conf import settings
from django.db import transaction
from django.db.models.query import Prefetch, QuerySet
-from django.conf import settings
from rest_framework.exceptions import ValidationError
+from cvat.apps.dataset_manager.annotation import AnnotationIR, AnnotationManager
+from cvat.apps.dataset_manager.bindings import (
+ CvatDatasetNotFoundError,
+ CvatImportError,
+ JobData,
+ TaskData,
+)
+from cvat.apps.dataset_manager.formats.registry import make_exporter, make_importer
+from cvat.apps.dataset_manager.util import (
+ add_prefetch_fields,
+ bulk_create,
+ faster_deepcopy,
+ get_cached,
+)
from cvat.apps.engine import models, serializers
-from cvat.apps.engine.plugins import plugin_decorator
from cvat.apps.engine.log import DatasetLogManager
+from cvat.apps.engine.plugins import plugin_decorator
from cvat.apps.engine.utils import take_by
from cvat.apps.events.handlers import handle_annotations_change
from cvat.apps.profiler import silk_profile
-from cvat.apps.dataset_manager.annotation import AnnotationIR, AnnotationManager
-from cvat.apps.dataset_manager.bindings import TaskData, JobData, CvatImportError, CvatDatasetNotFoundError
-from cvat.apps.dataset_manager.formats.registry import make_exporter, make_importer
-from cvat.apps.dataset_manager.util import (
- add_prefetch_fields, bulk_create, get_cached, faster_deepcopy
-)
-
dlogger = DatasetLogManager()
class dotdict(OrderedDict):
diff --git a/cvat/apps/dataset_manager/tests/assets/annotations.json b/cvat/apps/dataset_manager/tests/assets/annotations.json
index a0c9e8ff96d5..2a1d7f70696c 100644
--- a/cvat/apps/dataset_manager/tests/assets/annotations.json
+++ b/cvat/apps/dataset_manager/tests/assets/annotations.json
@@ -1008,6 +1008,54 @@
],
"tracks": []
},
+ "Ultralytics YOLO Detection Track 1.0": {
+ "version": 0,
+ "tags": [],
+ "shapes": [
+ {
+ "type": "rectangle",
+ "occluded": false,
+ "z_order": 0,
+ "points": [0.3, 0.1, 0.2, 0.8],
+ "frame": 0,
+ "label_id": null,
+ "group": 0,
+ "source": "manual",
+ "attributes": []
+ }
+ ],
+ "tracks": [
+ {
+ "frame": 0,
+ "label_id": null,
+ "group": 0,
+ "source": "manual",
+ "shapes": [
+ {
+ "type": "rectangle",
+ "occluded": false,
+ "z_order": 0,
+ "points": [0.2, 0.1, 0.2, 0.8],
+ "frame": 0,
+ "outside": false,
+ "attributes": [],
+ "keyframe": true
+ },
+ {
+ "type": "rectangle",
+ "occluded": false,
+ "z_order": 0,
+ "points": [0.4, 0.1, 0.2, 0.8],
+ "frame": 1,
+ "outside": true,
+ "attributes": [],
+ "keyframe": true
+ }
+ ],
+ "attributes": []
+ }
+ ]
+ },
"Ultralytics YOLO Oriented Bounding Boxes 1.0": {
"version": 0,
"tags": [],
diff --git a/cvat/apps/dataset_manager/tests/test_formats.py b/cvat/apps/dataset_manager/tests/test_formats.py
index e6ba111f29f9..097884092de0 100644
--- a/cvat/apps/dataset_manager/tests/test_formats.py
+++ b/cvat/apps/dataset_manager/tests/test_formats.py
@@ -4,28 +4,33 @@
#
# SPDX-License-Identifier: MIT
-import numpy as np
import os.path as osp
import tempfile
import zipfile
from io import BytesIO
import datumaro
-from datumaro.components.dataset import Dataset, DatasetItem
+import numpy as np
from datumaro.components.annotation import Mask
+from datumaro.components.dataset import Dataset, DatasetItem
from django.contrib.auth.models import Group, User
-
from rest_framework import status
import cvat.apps.dataset_manager as dm
from cvat.apps.dataset_manager.annotation import AnnotationIR
-from cvat.apps.dataset_manager.bindings import (CvatTaskOrJobDataExtractor,
- TaskData, find_dataset_root)
+from cvat.apps.dataset_manager.bindings import (
+ CvatTaskOrJobDataExtractor,
+ TaskData,
+ find_dataset_root,
+)
from cvat.apps.dataset_manager.task import TaskAnnotation
from cvat.apps.dataset_manager.util import make_zip_archive
from cvat.apps.engine.models import Task
from cvat.apps.engine.tests.utils import (
- get_paginated_collection, ForceLogin, generate_image_file, ApiTestBase
+ ApiTestBase,
+ ForceLogin,
+ generate_image_file,
+ get_paginated_collection,
)
@@ -295,6 +300,7 @@ def test_export_formats_query(self):
'Ultralytics YOLO Classification 1.0',
'Ultralytics YOLO Oriented Bounding Boxes 1.0',
'Ultralytics YOLO Detection 1.0',
+ 'Ultralytics YOLO Detection Track 1.0',
'Ultralytics YOLO Pose 1.0',
'Ultralytics YOLO Segmentation 1.0',
})
diff --git a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py
index f3640b835bcb..fe1addd2cbc5 100644
--- a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py
+++ b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py
@@ -6,11 +6,9 @@
import copy
import itertools
import json
-import os.path as osp
-import os
import multiprocessing
-import av
-import numpy as np
+import os
+import os.path as osp
import random
import shutil
import xml.etree.ElementTree as ET
@@ -22,8 +20,11 @@
from tempfile import TemporaryDirectory
from time import sleep
from typing import Any, Callable, ClassVar, Optional, overload
-from unittest.mock import MagicMock, patch, DEFAULT as MOCK_DEFAULT
+from unittest.mock import DEFAULT as MOCK_DEFAULT
+from unittest.mock import MagicMock, patch
+import av
+import numpy as np
from attr import define, field
from datumaro.components.dataset import Dataset
from datumaro.components.operations import ExactComparator
@@ -38,7 +39,7 @@
from cvat.apps.dataset_manager.util import get_export_cache_lock
from cvat.apps.dataset_manager.views import clear_export_cache, export, parse_export_file_path
from cvat.apps.engine.models import Task
-from cvat.apps.engine.tests.utils import get_paginated_collection, ApiTestBase, ForceLogin
+from cvat.apps.engine.tests.utils import ApiTestBase, ForceLogin, get_paginated_collection
projects_path = osp.join(osp.dirname(__file__), 'assets', 'projects.json')
with open(projects_path) as file:
@@ -58,6 +59,7 @@
"Ultralytics YOLO Classification 1.0",
"YOLO 1.1",
"Ultralytics YOLO Detection 1.0",
+ "Ultralytics YOLO Detection Track 1.0",
"Ultralytics YOLO Segmentation 1.0",
"Ultralytics YOLO Oriented Bounding Boxes 1.0",
"Ultralytics YOLO Pose 1.0",
@@ -979,6 +981,8 @@ def test_api_v2_rewriting_annotations(self):
if dump_format_name == "CVAT for images 1.1" or dump_format_name == "CVAT for video 1.1":
dump_format_name = "CVAT 1.1"
+ elif dump_format_name == "Ultralytics YOLO Detection Track 1.0":
+ dump_format_name = "Ultralytics YOLO Detection 1.0"
url = self._generate_url_upload_tasks_annotations(task_id, dump_format_name)
with open(file_zip_name, 'rb') as binary_file:
@@ -1092,6 +1096,8 @@ def test_api_v2_tasks_annotations_dump_and_upload_with_datumaro(self):
# upload annotations
if dump_format_name in ["CVAT for images 1.1", "CVAT for video 1.1"]:
upload_format_name = "CVAT 1.1"
+ elif dump_format_name in ['Ultralytics YOLO Detection Track 1.0']:
+ upload_format_name = 'Ultralytics YOLO Detection 1.0'
else:
upload_format_name = dump_format_name
url = self._generate_url_upload_tasks_annotations(task_id, upload_format_name)
@@ -1451,8 +1457,8 @@ def _export(*_, task_id: int):
import sys
from os import replace as original_replace
from os.path import exists as original_exists
- from cvat.apps.dataset_manager.task import export_task as original_export_task
+ from cvat.apps.dataset_manager.task import export_task as original_export_task
from cvat.apps.dataset_manager.views import log_exception as original_log_exception
def patched_log_exception(logger=None, exc_info=True):
diff --git a/cvat/apps/dataset_manager/views.py b/cvat/apps/dataset_manager/views.py
index 52bc9cd15f7a..4dcd8304e43d 100644
--- a/cvat/apps/dataset_manager/views.py
+++ b/cvat/apps/dataset_manager/views.py
@@ -8,10 +8,10 @@
import os.path as osp
import tempfile
from datetime import timedelta
+from os.path import exists as osp_exists
import django_rq
import rq
-from os.path import exists as osp_exists
from django.conf import settings
from django.utils import timezone
from rq_scheduler import Scheduler
@@ -20,18 +20,20 @@
import cvat.apps.dataset_manager.task as task
from cvat.apps.engine.log import ServerLogManager
from cvat.apps.engine.models import Job, Project, Task
-from cvat.apps.engine.utils import get_rq_lock_by_user
from cvat.apps.engine.rq_job_handler import RQMeta
+from cvat.apps.engine.utils import get_rq_lock_by_user
from .formats.registry import EXPORT_FORMATS, IMPORT_FORMATS
+from .util import EXPORT_CACHE_DIR_NAME # pylint: disable=unused-import
from .util import (
LockNotAvailableError,
- current_function_name, get_export_cache_lock,
- get_export_cache_dir, make_export_filename,
- parse_export_file_path, extend_export_file_lifetime
+ current_function_name,
+ extend_export_file_lifetime,
+ get_export_cache_dir,
+ get_export_cache_lock,
+ make_export_filename,
+ parse_export_file_path,
)
-from .util import EXPORT_CACHE_DIR_NAME # pylint: disable=unused-import
-
slogger = ServerLogManager(__name__)
diff --git a/cvat/apps/dataset_repo/migrations/0001_initial.py b/cvat/apps/dataset_repo/migrations/0001_initial.py
index 2ecf9c17c9b9..fa02f8c54b5d 100644
--- a/cvat/apps/dataset_repo/migrations/0001_initial.py
+++ b/cvat/apps/dataset_repo/migrations/0001_initial.py
@@ -1,7 +1,7 @@
# Generated by Django 2.1.3 on 2018-12-05 13:24
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
@@ -9,23 +9,31 @@ class Migration(migrations.Migration):
initial = True
dependencies = [
- ('engine', '0014_job_max_shape_id'),
+ ("engine", "0014_job_max_shape_id"),
]
- replaces = [('git', '0001_initial')]
+ replaces = [("git", "0001_initial")]
operations = [
migrations.CreateModel(
- name='GitData',
+ name="GitData",
fields=[
- ('task', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, to='engine.Task')),
- ('url', models.URLField(max_length=2000)),
- ('path', models.CharField(max_length=256)),
- ('sync_date', models.DateTimeField(auto_now_add=True)),
- ('status', models.CharField(default='!sync', max_length=20)),
+ (
+ "task",
+ models.OneToOneField(
+ on_delete=django.db.models.deletion.CASCADE,
+ primary_key=True,
+ serialize=False,
+ to="engine.Task",
+ ),
+ ),
+ ("url", models.URLField(max_length=2000)),
+ ("path", models.CharField(max_length=256)),
+ ("sync_date", models.DateTimeField(auto_now_add=True)),
+ ("status", models.CharField(default="!sync", max_length=20)),
],
options={
- 'db_table': 'git_gitdata',
+ "db_table": "git_gitdata",
},
),
]
diff --git a/cvat/apps/dataset_repo/migrations/0002_auto_20190123_1305.py b/cvat/apps/dataset_repo/migrations/0002_auto_20190123_1305.py
index 13fb92b8658e..ce0be5cbbc39 100644
--- a/cvat/apps/dataset_repo/migrations/0002_auto_20190123_1305.py
+++ b/cvat/apps/dataset_repo/migrations/0002_auto_20190123_1305.py
@@ -6,15 +6,15 @@
class Migration(migrations.Migration):
dependencies = [
- ('dataset_repo', '0001_initial'),
+ ("dataset_repo", "0001_initial"),
]
- replaces = [('git', '0002_auto_20190123_1305')]
+ replaces = [("git", "0002_auto_20190123_1305")]
operations = [
migrations.AlterField(
- model_name='gitdata',
- name='status',
- field=models.CharField(default='!sync', max_length=20),
+ model_name="gitdata",
+ name="status",
+ field=models.CharField(default="!sync", max_length=20),
),
]
diff --git a/cvat/apps/dataset_repo/migrations/0003_gitdata_lfs.py b/cvat/apps/dataset_repo/migrations/0003_gitdata_lfs.py
index b42ebd30db29..1e845e48a108 100644
--- a/cvat/apps/dataset_repo/migrations/0003_gitdata_lfs.py
+++ b/cvat/apps/dataset_repo/migrations/0003_gitdata_lfs.py
@@ -6,15 +6,15 @@
class Migration(migrations.Migration):
dependencies = [
- ('dataset_repo', '0002_auto_20190123_1305'),
+ ("dataset_repo", "0002_auto_20190123_1305"),
]
- replaces = [('git', '0003_gitdata_lfs')]
+ replaces = [("git", "0003_gitdata_lfs")]
operations = [
migrations.AddField(
- model_name='gitdata',
- name='lfs',
+ model_name="gitdata",
+ name="lfs",
field=models.BooleanField(default=True),
),
]
diff --git a/cvat/apps/dataset_repo/migrations/0004_rename.py b/cvat/apps/dataset_repo/migrations/0004_rename.py
index 9629165722d1..94b820dcaa56 100644
--- a/cvat/apps/dataset_repo/migrations/0004_rename.py
+++ b/cvat/apps/dataset_repo/migrations/0004_rename.py
@@ -1,16 +1,18 @@
from django.db import migrations
+
def update_contenttypes_table(apps, schema_editor):
- content_type_model = apps.get_model('contenttypes', 'ContentType')
- content_type_model.objects.filter(app_label='git').update(app_label='dataset_repo')
+ content_type_model = apps.get_model("contenttypes", "ContentType")
+ content_type_model.objects.filter(app_label="git").update(app_label="dataset_repo")
+
class Migration(migrations.Migration):
dependencies = [
- ('dataset_repo', '0003_gitdata_lfs'),
+ ("dataset_repo", "0003_gitdata_lfs"),
]
operations = [
- migrations.AlterModelTable('gitdata', 'dataset_repo_gitdata'),
+ migrations.AlterModelTable("gitdata", "dataset_repo_gitdata"),
migrations.RunPython(update_contenttypes_table),
]
diff --git a/cvat/apps/dataset_repo/migrations/0005_auto_20201019_1100.py b/cvat/apps/dataset_repo/migrations/0005_auto_20201019_1100.py
index f26c280b7f84..8c07d05d29f3 100644
--- a/cvat/apps/dataset_repo/migrations/0005_auto_20201019_1100.py
+++ b/cvat/apps/dataset_repo/migrations/0005_auto_20201019_1100.py
@@ -6,12 +6,12 @@
class Migration(migrations.Migration):
dependencies = [
- ('dataset_repo', '0004_rename'),
+ ("dataset_repo", "0004_rename"),
]
operations = [
migrations.AlterModelTable(
- name='gitdata',
+ name="gitdata",
table=None,
),
]
diff --git a/cvat/apps/dataset_repo/migrations/0006_gitdata_format.py b/cvat/apps/dataset_repo/migrations/0006_gitdata_format.py
index 641d246743eb..1b42f2d3caea 100644
--- a/cvat/apps/dataset_repo/migrations/0006_gitdata_format.py
+++ b/cvat/apps/dataset_repo/migrations/0006_gitdata_format.py
@@ -4,21 +4,27 @@
def update_default_format_field(apps, schema_editor):
- GitData = apps.get_model('dataset_repo', 'GitData')
+ GitData = apps.get_model("dataset_repo", "GitData")
for git_data in GitData.objects.all():
if not git_data.format:
- git_data.format = 'CVAT for images 1.1' if git_data.task.mode == 'annotation' else 'CVAT for video 1.1'
+ git_data.format = (
+ "CVAT for images 1.1"
+ if git_data.task.mode == "annotation"
+ else "CVAT for video 1.1"
+ )
git_data.save()
+
+
class Migration(migrations.Migration):
dependencies = [
- ('dataset_repo', '0005_auto_20201019_1100'),
+ ("dataset_repo", "0005_auto_20201019_1100"),
]
operations = [
migrations.AddField(
- model_name='gitdata',
- name='format',
+ model_name="gitdata",
+ name="format",
field=models.CharField(blank=True, max_length=256),
),
migrations.RunPython(update_default_format_field),
diff --git a/cvat/apps/engine/admin.py b/cvat/apps/engine/admin.py
index 05e4b40a0f9b..712e67fa5582 100644
--- a/cvat/apps/engine/admin.py
+++ b/cvat/apps/engine/admin.py
@@ -4,8 +4,21 @@
# SPDX-License-Identifier: MIT
from django.contrib import admin
-from .models import Task, Segment, Job, Label, AttributeSpec, Project, \
- CloudStorage, Storage, Data, AnnotationGuide, Asset
+
+from .models import (
+ AnnotationGuide,
+ Asset,
+ AttributeSpec,
+ CloudStorage,
+ Data,
+ Job,
+ Label,
+ Project,
+ Segment,
+ Storage,
+ Task,
+)
+
class JobInline(admin.TabularInline):
model = Job
diff --git a/cvat/apps/engine/apps.py b/cvat/apps/engine/apps.py
index bcad84510f5d..1cea639842c8 100644
--- a/cvat/apps/engine/apps.py
+++ b/cvat/apps/engine/apps.py
@@ -20,6 +20,7 @@ def ready(self):
# Required to define signals in application
import cvat.apps.engine.signals
+
# Required in order to silent "unused-import" in pyflake
assert cvat.apps.engine.signals
diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py
index 3c8ba5678c24..f3790427f5ba 100644
--- a/cvat/apps/engine/backup.py
+++ b/cvat/apps/engine/backup.py
@@ -21,37 +21,57 @@
from django.conf import settings
from django.db import transaction
from django.utils import timezone
-
from rest_framework import serializers, status
+from rest_framework.exceptions import ValidationError
from rest_framework.parsers import JSONParser
from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response
-from rest_framework.exceptions import ValidationError
import cvat.apps.dataset_manager as dm
+from cvat.apps.dataset_manager.bindings import CvatImportError
+from cvat.apps.dataset_manager.views import get_export_cache_dir, log_exception
from cvat.apps.engine import models
+from cvat.apps.engine.cloud_provider import import_resource_from_cloud_storage
+from cvat.apps.engine.location import StorageType, get_location_configuration
from cvat.apps.engine.log import ServerLogManager
-from cvat.apps.engine.serializers import (AttributeSerializer, DataSerializer, JobWriteSerializer,
- LabelSerializer, AnnotationGuideWriteSerializer, AssetWriteSerializer,
- LabeledDataSerializer, SegmentSerializer, SimpleJobSerializer, TaskReadSerializer,
- ProjectReadSerializer, ProjectFileSerializer, TaskFileSerializer, RqIdSerializer,
- ValidationParamsSerializer)
-from cvat.apps.engine.utils import (
- av_scan_paths, process_failed_job,
- get_rq_job_meta, import_resource_with_clean_up_after,
- define_dependent_job, get_rq_lock_by_user,
+from cvat.apps.engine.models import (
+ DataChoice,
+ Location,
+ Project,
+ RequestAction,
+ RequestSubresource,
+ RequestTarget,
+ StorageChoice,
+ StorageMethodChoice,
)
+from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export
from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField
-from cvat.apps.engine.models import (
- StorageChoice, StorageMethodChoice, DataChoice, Project, Location,
- RequestAction, RequestTarget, RequestSubresource,
+from cvat.apps.engine.serializers import (
+ AnnotationGuideWriteSerializer,
+ AssetWriteSerializer,
+ AttributeSerializer,
+ DataSerializer,
+ JobWriteSerializer,
+ LabeledDataSerializer,
+ LabelSerializer,
+ ProjectFileSerializer,
+ ProjectReadSerializer,
+ RqIdSerializer,
+ SegmentSerializer,
+ SimpleJobSerializer,
+ TaskFileSerializer,
+ TaskReadSerializer,
+ ValidationParamsSerializer,
)
from cvat.apps.engine.task import JobFileMapping, _create_thread
-from cvat.apps.engine.cloud_provider import import_resource_from_cloud_storage
-from cvat.apps.engine.location import StorageType, get_location_configuration
-from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export
-from cvat.apps.dataset_manager.views import get_export_cache_dir, log_exception
-from cvat.apps.dataset_manager.bindings import CvatImportError
+from cvat.apps.engine.utils import (
+ av_scan_paths,
+ define_dependent_job,
+ get_rq_job_meta,
+ get_rq_lock_by_user,
+ import_resource_with_clean_up_after,
+ process_failed_job,
+)
slogger = ServerLogManager(__name__)
diff --git a/cvat/apps/engine/cache.py b/cvat/apps/engine/cache.py
index 43c2be7bc57e..ffe8fe0cb920 100644
--- a/cvat/apps/engine/cache.py
+++ b/cvat/apps/engine/cache.py
@@ -218,17 +218,19 @@ def _create_and_set_cache_item(
item_data = create_callback()
item_data_bytes = item_data[0].getvalue()
item = (item_data[0], item_data[1], cls._get_checksum(item_data_bytes), timestamp)
- if item_data_bytes:
- cache = cls._cache()
- with get_rq_lock_for_job(
- cls._get_queue(),
- key,
- ):
- cached_item = cache.get(key)
- if cached_item is not None and timestamp <= cached_item[3]:
- item = cached_item
- else:
- cache.set(key, item, timeout=cache_item_ttl or cache.default_timeout)
+
+ # allow empty data to be set in cache to prevent
+ # future rq jobs from being enqueued to prepare the item
+ cache = cls._cache()
+ with get_rq_lock_for_job(
+ cls._get_queue(),
+ key,
+ ):
+ cached_item = cache.get(key)
+ if cached_item is not None and timestamp <= cached_item[3]:
+ item = cached_item
+ else:
+ cache.set(key, item, timeout=cache_item_ttl or cache.default_timeout)
return item
@@ -353,11 +355,18 @@ def _make_frame_context_images_chunk_key(self, db_data: models.Data, frame_numbe
def _to_data_with_mime(self, cache_item: _CacheItem) -> DataWithMime: ...
@overload
- def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataWithMime]: ...
+ def _to_data_with_mime(
+ self, cache_item: Optional[_CacheItem], *, allow_none: bool = False
+ ) -> Optional[DataWithMime]: ...
- def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataWithMime]:
+ def _to_data_with_mime(
+ self, cache_item: Optional[_CacheItem], *, allow_none: bool = False
+ ) -> Optional[DataWithMime]:
if not cache_item:
- return None
+ if allow_none:
+ return None
+
+ raise ValueError("A cache item is not allowed to be None")
return cache_item[:2]
@@ -385,7 +394,8 @@ def get_task_chunk(
return self._to_data_with_mime(
self._get_cache_item(
key=self._make_chunk_key(db_task, chunk_number, quality=quality),
- )
+ ),
+ allow_none=True,
)
def get_or_set_task_chunk(
@@ -413,7 +423,8 @@ def get_segment_task_chunk(
return self._to_data_with_mime(
self._get_cache_item(
key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality),
- )
+ ),
+ allow_none=True,
)
def get_or_set_segment_task_chunk(
@@ -510,7 +521,9 @@ def remove_context_images_chunks(self, params: Sequence[dict[str, Any]]) -> None
self._bulk_delete_cache_items(keys_to_remove)
def get_cloud_preview(self, db_storage: models.CloudStorage) -> Optional[DataWithMime]:
- return self._to_data_with_mime(self._get_cache_item(self._make_preview_key(db_storage)))
+ return self._to_data_with_mime(
+ self._get_cache_item(self._make_preview_key(db_storage)), allow_none=True
+ )
def get_or_set_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithMime:
return self._to_data_with_mime(
diff --git a/cvat/apps/engine/cloud_provider.py b/cvat/apps/engine/cloud_provider.py
index b810304d73f9..06b2496ce16b 100644
--- a/cvat/apps/engine/cloud_provider.py
+++ b/cvat/apps/engine/cloud_provider.py
@@ -5,14 +5,14 @@
import functools
import json
-import os
import math
+import os
from abc import ABC, abstractmethod
from collections.abc import Iterator
-from concurrent.futures import ThreadPoolExecutor, wait, FIRST_EXCEPTION
+from concurrent.futures import FIRST_EXCEPTION, ThreadPoolExecutor, wait
from enum import Enum
from io import BytesIO
-from typing import Optional, Any, Callable, TypeVar
+from typing import Any, Callable, Optional, TypeVar
import boto3
from azure.core.exceptions import HttpResponseError, ResourceExistsError
@@ -27,14 +27,14 @@
from google.cloud.exceptions import Forbidden as GoogleCloudForbidden
from google.cloud.exceptions import NotFound as GoogleCloudNotFound
from PIL import Image, ImageFile
-from rest_framework.exceptions import (NotFound, PermissionDenied,
- ValidationError)
+from rest_framework.exceptions import NotFound, PermissionDenied, ValidationError
from cvat.apps.engine.log import ServerLogManager
from cvat.apps.engine.models import CloudProviderChoice, CredentialsTypeChoice
from cvat.apps.engine.utils import get_cpu_number, take_by
from cvat.utils.http import PROXIES_FOR_UNTRUSTED_URLS
+
class NamedBytesIO(BytesIO):
@property
def filename(self) -> Optional[str]:
diff --git a/cvat/apps/engine/filters.py b/cvat/apps/engine/filters.py
index 32355629d06d..6a80e94ad6cc 100644
--- a/cvat/apps/engine/filters.py
+++ b/cvat/apps/engine/filters.py
@@ -3,25 +3,25 @@
#
# SPDX-License-Identifier: MIT
-from collections.abc import Iterator, Iterable
+import json
+import operator
+from collections.abc import Iterable, Iterator
from functools import reduce
+from textwrap import dedent
from typing import Any, Optional
-import operator
-import json
+from django.db.models import Q
+from django.db.models.query import QuerySet
+from django.utils.encoding import force_str
+from django.utils.translation import gettext_lazy as _
from django_filters import FilterSet
from django_filters import filters as djf
from django_filters.filterset import BaseFilterSet
from django_filters.rest_framework import DjangoFilterBackend
-from django.db.models import Q
-from django.db.models.query import QuerySet
-from django.utils.translation import gettext_lazy as _
-from django.utils.encoding import force_str
-from rest_framework.request import Request
from rest_framework import filters
from rest_framework.compat import coreapi, coreschema
from rest_framework.exceptions import ValidationError
-from textwrap import dedent
+from rest_framework.request import Request
DEFAULT_FILTER_FIELDS_ATTR = 'filter_fields'
DEFAULT_LOOKUP_MAP_ATTR = 'lookup_fields'
diff --git a/cvat/apps/engine/frame_provider.py b/cvat/apps/engine/frame_provider.py
index 6b756543c7f3..1a5fd1f40ebd 100644
--- a/cvat/apps/engine/frame_provider.py
+++ b/cvat/apps/engine/frame_provider.py
@@ -15,15 +15,7 @@
from dataclasses import dataclass
from enum import Enum, auto
from io import BytesIO
-from typing import (
- Any,
- Callable,
- Generic,
- Optional,
- TypeVar,
- Union,
- overload,
-)
+from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload
import av
import cv2
diff --git a/cvat/apps/engine/handlers.py b/cvat/apps/engine/handlers.py
index d686bbf0ba5c..0a831a44827b 100644
--- a/cvat/apps/engine/handlers.py
+++ b/cvat/apps/engine/handlers.py
@@ -4,7 +4,9 @@
from pathlib import Path
from time import time
+
from django.conf import settings
+
from cvat.apps.engine.log import ServerLogManager
slogger = ServerLogManager(__name__)
diff --git a/cvat/apps/engine/location.py b/cvat/apps/engine/location.py
index c9e216e24627..deea541f09d3 100644
--- a/cvat/apps/engine/location.py
+++ b/cvat/apps/engine/location.py
@@ -3,9 +3,10 @@
# SPDX-License-Identifier: MIT
from enum import Enum
-from typing import Any, Union, Optional
+from typing import Any, Optional, Union
+
+from cvat.apps.engine.models import Job, Location, Project, Task
-from cvat.apps.engine.models import Location, Project, Task, Job
class StorageType(str, Enum):
TARGET = 'target_storage'
diff --git a/cvat/apps/engine/log.py b/cvat/apps/engine/log.py
index 6f1740e74fd4..3cc2cecff37b 100644
--- a/cvat/apps/engine/log.py
+++ b/cvat/apps/engine/log.py
@@ -4,12 +4,15 @@
# SPDX-License-Identifier: MIT
import logging
-import sys
import os.path as osp
+import sys
from contextlib import contextmanager
-from cvat.apps.engine.utils import directory_tree
+
from django.conf import settings
+from cvat.apps.engine.utils import directory_tree
+
+
class _LoggerAdapter(logging.LoggerAdapter):
def process(self, msg: str, kwargs):
if msg_prefix := self.extra.get("msg_prefix"):
diff --git a/cvat/apps/engine/management/commands/runperiodicjob.py b/cvat/apps/engine/management/commands/runperiodicjob.py
new file mode 100644
index 000000000000..765f16541cfd
--- /dev/null
+++ b/cvat/apps/engine/management/commands/runperiodicjob.py
@@ -0,0 +1,23 @@
+from argparse import ArgumentParser
+
+from django.conf import settings
+from django.core.management.base import BaseCommand, CommandError
+from django.utils.module_loading import import_string
+
+
+class Command(BaseCommand):
+ help = "Run a configured periodic job immediately"
+
+ def add_arguments(self, parser: ArgumentParser) -> None:
+ parser.add_argument("job_id", help="ID of the job to run")
+
+ def handle(self, *args, **options):
+ job_id = options["job_id"]
+
+ for job_definition in settings.PERIODIC_RQ_JOBS:
+ if job_definition["id"] == job_id:
+ job_func = import_string(job_definition["func"])
+ job_func()
+ return
+
+ raise CommandError(f"Job with ID {job_id} not found")
diff --git a/cvat/apps/engine/management/commands/syncperiodicjobs.py b/cvat/apps/engine/management/commands/syncperiodicjobs.py
index 097f468b337f..d78d3f247179 100644
--- a/cvat/apps/engine/management/commands/syncperiodicjobs.py
+++ b/cvat/apps/engine/management/commands/syncperiodicjobs.py
@@ -5,10 +5,10 @@
from argparse import ArgumentParser
from collections import defaultdict
-from django.core.management.base import BaseCommand
+import django_rq
from django.conf import settings
+from django.core.management.base import BaseCommand
-import django_rq
class Command(BaseCommand):
help = "Synchronize periodic jobs in Redis with the project configuration"
diff --git a/cvat/apps/engine/media_extractors.py b/cvat/apps/engine/media_extractors.py
index ae1c7b9f7da8..09c2ce2876de 100644
--- a/cvat/apps/engine/media_extractors.py
+++ b/cvat/apps/engine/media_extractors.py
@@ -5,20 +5,21 @@
from __future__ import annotations
+import io
+import itertools
import os
+import shutil
+import struct
import sysconfig
import tempfile
-import shutil
import zipfile
-import io
-import itertools
-import struct
from abc import ABC, abstractmethod
from bisect import bisect
from collections.abc import Generator, Iterable, Iterator, Sequence
from contextlib import AbstractContextManager, ExitStack, closing, contextmanager
from dataclasses import dataclass
from enum import IntEnum
+from random import shuffle
from typing import Any, Callable, Optional, Protocol, TypeVar, Union
import av
@@ -27,19 +28,19 @@
import av.video.stream
import numpy as np
from natsort import os_sorted
-from pyunpack import Archive
from PIL import Image, ImageFile, ImageOps
-from random import shuffle
-from cvat.apps.engine.utils import rotate_image
-from cvat.apps.engine.models import DimensionType, SortingMethod
+from pyunpack import Archive
from rest_framework.exceptions import ValidationError
+from cvat.apps.engine.models import DimensionType, SortingMethod
+from cvat.apps.engine.utils import rotate_image
+
# fixes: "OSError:broken data stream" when executing line 72 while loading images downloaded from the web
# see: https://stackoverflow.com/questions/42462431/oserror-broken-data-stream-when-reading-image-file
ImageFile.LOAD_TRUNCATED_IMAGES = True
from cvat.apps.engine.mime_types import mimetypes
-from utils.dataset_manifest import VideoManifestManager, ImageManifestManager
+from utils.dataset_manifest import ImageManifestManager, VideoManifestManager
ORIENTATION_EXIF_TAG = 274
diff --git a/cvat/apps/engine/middleware.py b/cvat/apps/engine/middleware.py
index f2b990a14b50..2e8f116f4ecd 100644
--- a/cvat/apps/engine/middleware.py
+++ b/cvat/apps/engine/middleware.py
@@ -4,6 +4,7 @@
from uuid import uuid4
+
class RequestTrackingMiddleware:
def __init__(self, get_response):
self.get_response = get_response
diff --git a/cvat/apps/engine/migrations/0001_release_v0_1_0.py b/cvat/apps/engine/migrations/0001_release_v0_1_0.py
index 64d030cc81c6..59edc03104f4 100644
--- a/cvat/apps/engine/migrations/0001_release_v0_1_0.py
+++ b/cvat/apps/engine/migrations/0001_release_v0_1_0.py
@@ -5,9 +5,9 @@
# Generated by Django 2.0.3 on 2018-05-23 11:51
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0002_labeledpoints_labeledpointsattributeval_labeledpolygon_labeledpolygonattributeval_labeledpolyline_la.py b/cvat/apps/engine/migrations/0002_labeledpoints_labeledpointsattributeval_labeledpolygon_labeledpolygonattributeval_labeledpolyline_la.py
index 0e7820999c38..fa3e6fe79b94 100644
--- a/cvat/apps/engine/migrations/0002_labeledpoints_labeledpointsattributeval_labeledpolygon_labeledpolygonattributeval_labeledpolyline_la.py
+++ b/cvat/apps/engine/migrations/0002_labeledpoints_labeledpointsattributeval_labeledpolygon_labeledpolygonattributeval_labeledpolyline_la.py
@@ -5,8 +5,8 @@
# Generated by Django 2.0.3 on 2018-05-30 09:53
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0008_auto_20180917_1424.py b/cvat/apps/engine/migrations/0008_auto_20180917_1424.py
index cf6b45500d90..a32051d585e4 100644
--- a/cvat/apps/engine/migrations/0008_auto_20180917_1424.py
+++ b/cvat/apps/engine/migrations/0008_auto_20180917_1424.py
@@ -1,8 +1,8 @@
# Generated by Django 2.0.3 on 2018-09-17 11:24
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0011_add_task_source_and_safecharfield.py b/cvat/apps/engine/migrations/0011_add_task_source_and_safecharfield.py
index bb96c1b588dd..4b168322d486 100644
--- a/cvat/apps/engine/migrations/0011_add_task_source_and_safecharfield.py
+++ b/cvat/apps/engine/migrations/0011_add_task_source_and_safecharfield.py
@@ -1,8 +1,9 @@
# Generated by Django 2.0.9 on 2018-10-24 10:50
-import cvat.apps.engine.models
from django.db import migrations
+import cvat.apps.engine.models
+
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0013_auth_no_default_permissions.py b/cvat/apps/engine/migrations/0013_auth_no_default_permissions.py
index bc735269eed6..2dabe07fe9a0 100644
--- a/cvat/apps/engine/migrations/0013_auth_no_default_permissions.py
+++ b/cvat/apps/engine/migrations/0013_auth_no_default_permissions.py
@@ -1,8 +1,8 @@
# Generated by Django 2.0.9 on 2018-11-07 12:25
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0015_db_redesign_20190217.py b/cvat/apps/engine/migrations/0015_db_redesign_20190217.py
index db9589d8b807..accac35b8187 100644
--- a/cvat/apps/engine/migrations/0015_db_redesign_20190217.py
+++ b/cvat/apps/engine/migrations/0015_db_redesign_20190217.py
@@ -1,11 +1,13 @@
# Generated by Django 2.1.5 on 2019-02-17 19:32
-from django.conf import settings
-from django.db import migrations, models
import django.db.migrations.operations.special
import django.db.models.deletion
+from django.conf import settings
+from django.db import migrations, models
+
import cvat.apps.engine.models
+
def set_segment_size(apps, schema_editor):
Task = apps.get_model('engine', 'Task')
for task in Task.objects.all():
diff --git a/cvat/apps/engine/migrations/0016_attribute_spec_20190217.py b/cvat/apps/engine/migrations/0016_attribute_spec_20190217.py
index 27d273af2790..ac060ad69326 100644
--- a/cvat/apps/engine/migrations/0016_attribute_spec_20190217.py
+++ b/cvat/apps/engine/migrations/0016_attribute_spec_20190217.py
@@ -1,12 +1,15 @@
+import csv
import os
import re
-import csv
from io import StringIO
-from PIL import Image
-from django.db import migrations
+
from django.conf import settings
+from django.db import migrations
+from PIL import Image
+
from cvat.apps.engine.media_extractors import get_mime
+
def parse_attribute(value):
match = re.match(r'^([~@])(\w+)=(\w+):(.+)?$', value)
if match:
diff --git a/cvat/apps/engine/migrations/0017_db_redesign_20190221.py b/cvat/apps/engine/migrations/0017_db_redesign_20190221.py
index 22b7e5d28881..d30d5fa0a73a 100644
--- a/cvat/apps/engine/migrations/0017_db_redesign_20190221.py
+++ b/cvat/apps/engine/migrations/0017_db_redesign_20190221.py
@@ -1,11 +1,13 @@
# Generated by Django 2.1.5 on 2019-02-21 12:25
-import cvat.apps.engine.models
-from django.db import migrations, models
import django.db.models.deletion
from django.conf import settings
+from django.db import migrations, models
+
+import cvat.apps.engine.models
from cvat.apps.dataset_manager.task import merge_table_rows as _merge_table_rows
+
# some modified functions to transfer annotation
def _bulk_create(db_model, db_alias, objects, flt_param):
if objects:
diff --git a/cvat/apps/engine/migrations/0018_jobcommit.py b/cvat/apps/engine/migrations/0018_jobcommit.py
index c526cb896435..b25187c50f60 100644
--- a/cvat/apps/engine/migrations/0018_jobcommit.py
+++ b/cvat/apps/engine/migrations/0018_jobcommit.py
@@ -1,8 +1,8 @@
# Generated by Django 2.1.7 on 2019-04-17 09:25
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0020_remove_task_flipped.py b/cvat/apps/engine/migrations/0020_remove_task_flipped.py
index 7ca57e880417..7744def2b302 100644
--- a/cvat/apps/engine/migrations/0020_remove_task_flipped.py
+++ b/cvat/apps/engine/migrations/0020_remove_task_flipped.py
@@ -1,14 +1,15 @@
# Generated by Django 2.1.7 on 2019-06-18 11:08
-from django.db import migrations
+import os
+from ast import literal_eval
+
from django.conf import settings
+from django.db import migrations
+from PIL import Image
-from cvat.apps.engine.models import Job, ShapeType
from cvat.apps.engine.media_extractors import get_mime
+from cvat.apps.engine.models import Job, ShapeType
-from PIL import Image
-from ast import literal_eval
-import os
def make_image_meta_cache(db_task):
with open(db_task.get_image_meta_cache_path(), 'w') as meta_file:
diff --git a/cvat/apps/engine/migrations/0022_auto_20191004_0817.py b/cvat/apps/engine/migrations/0022_auto_20191004_0817.py
index b48a24f583db..6fd0ca45d8c3 100644
--- a/cvat/apps/engine/migrations/0022_auto_20191004_0817.py
+++ b/cvat/apps/engine/migrations/0022_auto_20191004_0817.py
@@ -1,9 +1,10 @@
# Generated by Django 2.2.3 on 2019-10-04 08:17
-import cvat.apps.engine.models
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
+
+import cvat.apps.engine.models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0023_auto_20200113_1323.py b/cvat/apps/engine/migrations/0023_auto_20200113_1323.py
index 4089eb1a1a66..33c586398323 100644
--- a/cvat/apps/engine/migrations/0023_auto_20200113_1323.py
+++ b/cvat/apps/engine/migrations/0023_auto_20200113_1323.py
@@ -1,8 +1,9 @@
# Generated by Django 2.2.8 on 2020-01-13 13:23
-import cvat.apps.engine.models
from django.db import migrations
+import cvat.apps.engine.models
+
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0024_auto_20191023_1025.py b/cvat/apps/engine/migrations/0024_auto_20191023_1025.py
index 1946e08e47e2..945879ef7552 100644
--- a/cvat/apps/engine/migrations/0024_auto_20191023_1025.py
+++ b/cvat/apps/engine/migrations/0024_auto_20191023_1025.py
@@ -1,24 +1,32 @@
# Generated by Django 2.2.4 on 2019-10-23 10:25
+import glob
+import itertools
+import multiprocessing
import os
import re
import shutil
-import glob
import sys
-import traceback
-import itertools
-import multiprocessing
import time
+import traceback
-from django.db import migrations, models
import django.db.models.deletion
from django.conf import settings
+from django.db import migrations, models
-from cvat.apps.engine.media_extractors import (VideoReader, ArchiveReader, ZipReader,
- PdfReader , ImageListReader, Mpeg4ChunkWriter,
- ZipChunkWriter, ZipCompressedChunkWriter, get_mime)
-from cvat.apps.engine.models import DataChoice
from cvat.apps.engine.log import get_migration_logger
+from cvat.apps.engine.media_extractors import (
+ ArchiveReader,
+ ImageListReader,
+ Mpeg4ChunkWriter,
+ PdfReader,
+ VideoReader,
+ ZipChunkWriter,
+ ZipCompressedChunkWriter,
+ ZipReader,
+ get_mime,
+)
+from cvat.apps.engine.models import DataChoice
MIGRATION_THREAD_COUNT = 2
diff --git a/cvat/apps/engine/migrations/0028_labelcolor.py b/cvat/apps/engine/migrations/0028_labelcolor.py
index af30fbabd8d2..eda6215ecdd6 100644
--- a/cvat/apps/engine/migrations/0028_labelcolor.py
+++ b/cvat/apps/engine/migrations/0028_labelcolor.py
@@ -1,7 +1,9 @@
# Generated by Django 2.2.13 on 2020-08-11 11:26
from django.db import migrations, models
+
from cvat.apps.dataset_manager.formats.utils import get_label_color
+
def alter_label_colors(apps, schema_editor):
Label = apps.get_model('engine', 'Label')
Task = apps.get_model('engine', 'Task')
diff --git a/cvat/apps/engine/migrations/0029_data_storage_method.py b/cvat/apps/engine/migrations/0029_data_storage_method.py
index 1c1aa814e4cd..e5ee36f33f06 100644
--- a/cvat/apps/engine/migrations/0029_data_storage_method.py
+++ b/cvat/apps/engine/migrations/0029_data_storage_method.py
@@ -1,12 +1,15 @@
# Generated by Django 2.2.13 on 2020-08-13 05:49
-from cvat.apps.engine.media_extractors import _is_archive, _is_zip
-import cvat.apps.engine.models
+import os
+
from django.conf import settings
from django.db import migrations, models
-import os
from pyunpack import Archive
+import cvat.apps.engine.models
+from cvat.apps.engine.media_extractors import _is_archive, _is_zip
+
+
def unzip(apps, schema_editor):
Data = apps.get_model("engine", "Data")
data_q_set = Data.objects.all()
diff --git a/cvat/apps/engine/migrations/0033_projects_adjastment.py b/cvat/apps/engine/migrations/0033_projects_adjastment.py
index e57bd0e6c568..8af73e6d1da5 100644
--- a/cvat/apps/engine/migrations/0033_projects_adjastment.py
+++ b/cvat/apps/engine/migrations/0033_projects_adjastment.py
@@ -1,7 +1,7 @@
# Generated by Django 3.1.1 on 2020-09-24 12:44
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0034_auto_20201125_1426.py b/cvat/apps/engine/migrations/0034_auto_20201125_1426.py
index 311b21655b9d..d02582342893 100644
--- a/cvat/apps/engine/migrations/0034_auto_20201125_1426.py
+++ b/cvat/apps/engine/migrations/0034_auto_20201125_1426.py
@@ -1,9 +1,11 @@
# Generated by Django 3.1.1 on 2020-11-25 14:26
-import cvat.apps.engine.models
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
+
+import cvat.apps.engine.models
+
def create_profile(apps, schema_editor):
User = apps.get_model('auth', 'User')
diff --git a/cvat/apps/engine/migrations/0035_data_storage.py b/cvat/apps/engine/migrations/0035_data_storage.py
index 5a8a9903784f..075d7ce38015 100644
--- a/cvat/apps/engine/migrations/0035_data_storage.py
+++ b/cvat/apps/engine/migrations/0035_data_storage.py
@@ -1,8 +1,9 @@
# Generated by Django 3.1.1 on 2020-12-02 06:47
-import cvat.apps.engine.models
from django.db import migrations, models
+import cvat.apps.engine.models
+
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0036_auto_20201216_0943.py b/cvat/apps/engine/migrations/0036_auto_20201216_0943.py
index 6f2fde01250f..52cb5faca2a5 100644
--- a/cvat/apps/engine/migrations/0036_auto_20201216_0943.py
+++ b/cvat/apps/engine/migrations/0036_auto_20201216_0943.py
@@ -1,8 +1,9 @@
# Generated by Django 3.1.1 on 2020-12-16 09:43
-import cvat.apps.engine.models
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
+
+import cvat.apps.engine.models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0038_manifest.py b/cvat/apps/engine/migrations/0038_manifest.py
index 002a0326c2dc..33208ad4cf19 100644
--- a/cvat/apps/engine/migrations/0038_manifest.py
+++ b/cvat/apps/engine/migrations/0038_manifest.py
@@ -9,9 +9,8 @@
from django.db import migrations
from cvat.apps.engine.log import get_logger
-from cvat.apps.engine.models import (DimensionType, StorageChoice,
- StorageMethodChoice)
from cvat.apps.engine.media_extractors import get_mime
+from cvat.apps.engine.models import DimensionType, StorageChoice, StorageMethodChoice
from utils.dataset_manifest import ImageManifestManager, VideoManifestManager
MIGRATION_NAME = os.path.splitext(os.path.basename(__file__))[0]
diff --git a/cvat/apps/engine/migrations/0039_auto_training.py b/cvat/apps/engine/migrations/0039_auto_training.py
index a9f22ea7a03a..4594942d801e 100644
--- a/cvat/apps/engine/migrations/0039_auto_training.py
+++ b/cvat/apps/engine/migrations/0039_auto_training.py
@@ -1,7 +1,7 @@
# Generated by Django 3.1.7 on 2021-04-02 13:17
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0040_cloud_storage.py b/cvat/apps/engine/migrations/0040_cloud_storage.py
index c73609fd9fef..f7ecac010d19 100644
--- a/cvat/apps/engine/migrations/0040_cloud_storage.py
+++ b/cvat/apps/engine/migrations/0040_cloud_storage.py
@@ -1,9 +1,10 @@
# Generated by Django 3.1.8 on 2021-05-07 06:42
-import cvat.apps.engine.models
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
+
+import cvat.apps.engine.models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0042_auto_20210830_1056.py b/cvat/apps/engine/migrations/0042_auto_20210830_1056.py
index 7b5a496af97c..69866f2c788a 100644
--- a/cvat/apps/engine/migrations/0042_auto_20210830_1056.py
+++ b/cvat/apps/engine/migrations/0042_auto_20210830_1056.py
@@ -1,7 +1,7 @@
# Generated by Django 3.1.13 on 2021-08-30 10:56
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0046_data_sorting_method.py b/cvat/apps/engine/migrations/0046_data_sorting_method.py
index f3880482fc33..cb58bce9ed69 100644
--- a/cvat/apps/engine/migrations/0046_data_sorting_method.py
+++ b/cvat/apps/engine/migrations/0046_data_sorting_method.py
@@ -1,8 +1,9 @@
# Generated by Django 3.1.13 on 2021-12-03 08:06
-import cvat.apps.engine.models
from django.db import migrations, models
+import cvat.apps.engine.models
+
class Migration(migrations.Migration):
replaces = [('engine', '0045_data_sorting_method')]
diff --git a/cvat/apps/engine/migrations/0047_auto_20211110_1938.py b/cvat/apps/engine/migrations/0047_auto_20211110_1938.py
index 69434115f269..0113b1816c67 100644
--- a/cvat/apps/engine/migrations/0047_auto_20211110_1938.py
+++ b/cvat/apps/engine/migrations/0047_auto_20211110_1938.py
@@ -1,8 +1,9 @@
# Generated by Django 3.2.8 on 2021-11-10 19:38
-import cvat.apps.engine.models
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
+
+import cvat.apps.engine.models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0048_auto_20211112_1918.py b/cvat/apps/engine/migrations/0048_auto_20211112_1918.py
index e1c54ab1206b..6c2106624397 100644
--- a/cvat/apps/engine/migrations/0048_auto_20211112_1918.py
+++ b/cvat/apps/engine/migrations/0048_auto_20211112_1918.py
@@ -1,8 +1,8 @@
# Generated by Django 3.2.8 on 2021-11-12 19:18
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0053_data_deleted_frames.py b/cvat/apps/engine/migrations/0053_data_deleted_frames.py
index 8bbf49792f49..e1421a0a2c1f 100644
--- a/cvat/apps/engine/migrations/0053_data_deleted_frames.py
+++ b/cvat/apps/engine/migrations/0053_data_deleted_frames.py
@@ -1,8 +1,9 @@
# Generated by Django 3.2.12 on 2022-05-20 09:21
-import cvat.apps.engine.models
from django.db import migrations
+import cvat.apps.engine.models
+
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0054_auto_20220610_1829.py b/cvat/apps/engine/migrations/0054_auto_20220610_1829.py
index 1c7ae1a802ec..25ed5b9c0617 100644
--- a/cvat/apps/engine/migrations/0054_auto_20220610_1829.py
+++ b/cvat/apps/engine/migrations/0054_auto_20220610_1829.py
@@ -1,8 +1,9 @@
# Generated by Django 3.2.12 on 2022-06-10 18:29
-import cvat.apps.engine.models
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
+
+import cvat.apps.engine.models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0055_jobs_directories.py b/cvat/apps/engine/migrations/0055_jobs_directories.py
index ec97f2c8d3d5..89d7cd300b24 100644
--- a/cvat/apps/engine/migrations/0055_jobs_directories.py
+++ b/cvat/apps/engine/migrations/0055_jobs_directories.py
@@ -3,8 +3,9 @@
import os
import shutil
-from django.db import migrations
from django.conf import settings
+from django.db import migrations
+
from cvat.apps.engine.log import get_logger
MIGRATION_NAME = os.path.splitext(os.path.basename(__file__))[0]
diff --git a/cvat/apps/engine/migrations/0056_jobs_previews.py b/cvat/apps/engine/migrations/0056_jobs_previews.py
index b8722018f92b..f3e6235fc780 100644
--- a/cvat/apps/engine/migrations/0056_jobs_previews.py
+++ b/cvat/apps/engine/migrations/0056_jobs_previews.py
@@ -2,8 +2,10 @@
import os
import shutil
-from django.db import migrations
+
from django.conf import settings
+from django.db import migrations
+
from cvat.apps.engine.log import get_logger
MIGRATION_NAME = os.path.splitext(os.path.basename(__file__))[0]
diff --git a/cvat/apps/engine/migrations/0057_auto_20220726_0926.py b/cvat/apps/engine/migrations/0057_auto_20220726_0926.py
index 459dbad6e783..22cd9f15e70b 100644
--- a/cvat/apps/engine/migrations/0057_auto_20220726_0926.py
+++ b/cvat/apps/engine/migrations/0057_auto_20220726_0926.py
@@ -1,8 +1,9 @@
# Generated by Django 3.2.14 on 2022-07-26 09:26
-import cvat.apps.engine.models
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
+
+import cvat.apps.engine.models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0058_auto_20220809_1236.py b/cvat/apps/engine/migrations/0058_auto_20220809_1236.py
index 8a7eb002d0af..aafb9a3bfab0 100644
--- a/cvat/apps/engine/migrations/0058_auto_20220809_1236.py
+++ b/cvat/apps/engine/migrations/0058_auto_20220809_1236.py
@@ -1,7 +1,7 @@
# Generated by Django 3.2.15 on 2022-08-09 12:36
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0060_alter_label_parent.py b/cvat/apps/engine/migrations/0060_alter_label_parent.py
index 5eb698343413..a5e8a8df31f3 100644
--- a/cvat/apps/engine/migrations/0060_alter_label_parent.py
+++ b/cvat/apps/engine/migrations/0060_alter_label_parent.py
@@ -1,7 +1,7 @@
# Generated by Django 3.2.15 on 2022-09-09 09:00
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0062_delete_previews.py b/cvat/apps/engine/migrations/0062_delete_previews.py
index da986be097fb..ccf5e8f9f176 100644
--- a/cvat/apps/engine/migrations/0062_delete_previews.py
+++ b/cvat/apps/engine/migrations/0062_delete_previews.py
@@ -2,10 +2,12 @@
import sys
import traceback
-from django.db import migrations
from django.conf import settings
+from django.db import migrations
+
from cvat.apps.engine.log import get_migration_logger
+
def delete_previews(apps, schema_editor):
migration_name = os.path.splitext(os.path.basename(__file__))[0]
with get_migration_logger(migration_name) as log:
diff --git a/cvat/apps/engine/migrations/0064_delete_or_rename_wrong_labels.py b/cvat/apps/engine/migrations/0064_delete_or_rename_wrong_labels.py
index 63c167381529..97cad2c4f565 100644
--- a/cvat/apps/engine/migrations/0064_delete_or_rename_wrong_labels.py
+++ b/cvat/apps/engine/migrations/0064_delete_or_rename_wrong_labels.py
@@ -1,8 +1,10 @@
import os
from django.db import migrations
+
from cvat.apps.engine.log import get_migration_logger
+
def delete_or_rename_wrong_labels(apps, schema_editor):
migration_name = os.path.splitext(os.path.basename(__file__))[0]
with get_migration_logger(migration_name) as log:
diff --git a/cvat/apps/engine/migrations/0070_add_job_type_created_date.py b/cvat/apps/engine/migrations/0070_add_job_type_created_date.py
index 034a6b275ae9..62d0293245cf 100644
--- a/cvat/apps/engine/migrations/0070_add_job_type_created_date.py
+++ b/cvat/apps/engine/migrations/0070_add_job_type_created_date.py
@@ -1,6 +1,7 @@
-import cvat.apps.engine.models
-from django.db import migrations, models
import django.utils.timezone
+from django.db import migrations, models
+
+import cvat.apps.engine.models
def add_created_date_to_existing_jobs(apps, schema_editor):
diff --git a/cvat/apps/engine/migrations/0071_annotationguide_asset.py b/cvat/apps/engine/migrations/0071_annotationguide_asset.py
index 1060c4576aba..a6b50c50861b 100644
--- a/cvat/apps/engine/migrations/0071_annotationguide_asset.py
+++ b/cvat/apps/engine/migrations/0071_annotationguide_asset.py
@@ -1,9 +1,10 @@
# Generated by Django 3.2.18 on 2023-06-13 13:14
+import uuid
+
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
-import uuid
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0072_alter_issue_updated_date.py b/cvat/apps/engine/migrations/0072_alter_issue_updated_date.py
index 4c549be10aa5..344036d12f65 100644
--- a/cvat/apps/engine/migrations/0072_alter_issue_updated_date.py
+++ b/cvat/apps/engine/migrations/0072_alter_issue_updated_date.py
@@ -2,6 +2,7 @@
from django.db import migrations, models
+
def forwards_func(apps, schema_editor):
Issue = apps.get_model("engine", "Issue")
diff --git a/cvat/apps/engine/migrations/0076_remove_storages_that_refer_to_deleted_cloud_storages.py b/cvat/apps/engine/migrations/0076_remove_storages_that_refer_to_deleted_cloud_storages.py
index 50c1461319a7..41c902bb2500 100644
--- a/cvat/apps/engine/migrations/0076_remove_storages_that_refer_to_deleted_cloud_storages.py
+++ b/cvat/apps/engine/migrations/0076_remove_storages_that_refer_to_deleted_cloud_storages.py
@@ -1,6 +1,7 @@
# Generated by Django 4.2.6 on 2023-11-17 10:10
from django.db import migrations, models
+
from cvat.apps.engine.models import Location
diff --git a/cvat/apps/engine/migrations/0077_auto_20231121_1952.py b/cvat/apps/engine/migrations/0077_auto_20231121_1952.py
index 8b5c3648e068..831e83c8712a 100644
--- a/cvat/apps/engine/migrations/0077_auto_20231121_1952.py
+++ b/cvat/apps/engine/migrations/0077_auto_20231121_1952.py
@@ -1,7 +1,7 @@
# Generated by Django 4.2.6 on 2023-11-21 19:52
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0079_alter_labeledimageattributeval_image_and_more.py b/cvat/apps/engine/migrations/0079_alter_labeledimageattributeval_image_and_more.py
index ccafa6086b5e..58921bc97c92 100644
--- a/cvat/apps/engine/migrations/0079_alter_labeledimageattributeval_image_and_more.py
+++ b/cvat/apps/engine/migrations/0079_alter_labeledimageattributeval_image_and_more.py
@@ -1,7 +1,7 @@
# Generated by Django 4.2.13 on 2024-07-09 11:08
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0080_alter_trackedshape_track.py b/cvat/apps/engine/migrations/0080_alter_trackedshape_track.py
index d5997d15ff91..8266dbf4ba38 100644
--- a/cvat/apps/engine/migrations/0080_alter_trackedshape_track.py
+++ b/cvat/apps/engine/migrations/0080_alter_trackedshape_track.py
@@ -1,7 +1,7 @@
# Generated by Django 4.2.13 on 2024-07-12 19:01
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0082_alter_labeledimage_job_and_more.py b/cvat/apps/engine/migrations/0082_alter_labeledimage_job_and_more.py
index 50b91829b213..ecbc9d76f60d 100644
--- a/cvat/apps/engine/migrations/0082_alter_labeledimage_job_and_more.py
+++ b/cvat/apps/engine/migrations/0082_alter_labeledimage_job_and_more.py
@@ -1,7 +1,7 @@
# Generated by Django 4.2.14 on 2024-07-22 07:27
-from django.db import migrations, models
import django.db.models.deletion
+from django.db import migrations, models
class Migration(migrations.Migration):
diff --git a/cvat/apps/engine/migrations/0085_segment_chunks_updated_date.py b/cvat/apps/engine/migrations/0085_segment_chunks_updated_date.py
index 52342d7db774..6fed44b22a6a 100644
--- a/cvat/apps/engine/migrations/0085_segment_chunks_updated_date.py
+++ b/cvat/apps/engine/migrations/0085_segment_chunks_updated_date.py
@@ -1,6 +1,7 @@
# Generated by Django 4.2.15 on 2024-09-25 13:52
from datetime import datetime
+
from django.db import migrations, models
diff --git a/cvat/apps/engine/mime_types.py b/cvat/apps/engine/mime_types.py
index 8e70c5cc4193..fad18ba6b6f8 100644
--- a/cvat/apps/engine/mime_types.py
+++ b/cvat/apps/engine/mime_types.py
@@ -2,9 +2,8 @@
#
# SPDX-License-Identifier: MIT
-import os
import mimetypes
-
+import os
_SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
MEDIA_MIMETYPES_FILES = [
diff --git a/cvat/apps/engine/mixins.py b/cvat/apps/engine/mixins.py
index 39f50ed31db4..9e69ffdd5ccb 100644
--- a/cvat/apps/engine/mixins.py
+++ b/cvat/apps/engine/mixins.py
@@ -12,9 +12,9 @@
from dataclasses import asdict, dataclass
from pathlib import Path
from tempfile import NamedTemporaryFile
-from unittest import mock
from textwrap import dedent
-from typing import Optional, Callable, Any
+from typing import Any, Callable, Optional
+from unittest import mock
from urllib.parse import urljoin
import django_rq
@@ -22,20 +22,18 @@
from django.conf import settings
from django.http import HttpRequest
from drf_spectacular.types import OpenApiTypes
-from drf_spectacular.utils import (OpenApiParameter, OpenApiResponse,
- extend_schema)
+from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework import mixins, status
-from rest_framework.decorators import action
from rest_framework.authentication import SessionAuthentication
+from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.views import APIView
-from cvat.apps.engine.background import (BackupExportManager,
- DatasetExportManager)
+from cvat.apps.engine.background import BackupExportManager, DatasetExportManager
from cvat.apps.engine.handlers import clear_import_cache
from cvat.apps.engine.location import StorageType, get_location_configuration
from cvat.apps.engine.log import ServerLogManager
-from cvat.apps.engine.models import Location, RequestAction, RequestTarget, RequestSubresource
+from cvat.apps.engine.models import Location, RequestAction, RequestSubresource, RequestTarget
from cvat.apps.engine.rq_job_handler import RQId
from cvat.apps.engine.serializers import DataSerializer, RqIdSerializer
from cvat.apps.engine.utils import is_dataset_export
diff --git a/cvat/apps/engine/pagination.py b/cvat/apps/engine/pagination.py
index 2bb417f5c0d1..6a1dd499b893 100644
--- a/cvat/apps/engine/pagination.py
+++ b/cvat/apps/engine/pagination.py
@@ -3,8 +3,10 @@
# SPDX-License-Identifier: MIT
import sys
+
from rest_framework.pagination import PageNumberPagination
+
class CustomPagination(PageNumberPagination):
page_size_query_param = "page_size"
diff --git a/cvat/apps/engine/parsers.py b/cvat/apps/engine/parsers.py
index d0cecc4b02d0..03b4ebd45da8 100644
--- a/cvat/apps/engine/parsers.py
+++ b/cvat/apps/engine/parsers.py
@@ -4,6 +4,7 @@
from rest_framework.parsers import BaseParser
+
class TusUploadParser(BaseParser):
# The media type is sent by TUS protocol (tus.io) for uploading files
media_type = 'application/offset+octet-stream'
diff --git a/cvat/apps/engine/permissions.py b/cvat/apps/engine/permissions.py
index c5ddd4799c4c..a180410142cd 100644
--- a/cvat/apps/engine/permissions.py
+++ b/cvat/apps/engine/permissions.py
@@ -7,20 +7,23 @@
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
-from django.shortcuts import get_object_or_404
from django.conf import settings
-
-from rest_framework.exceptions import ValidationError, PermissionDenied
+from django.shortcuts import get_object_or_404
+from rest_framework.exceptions import PermissionDenied, ValidationError
from rq.job import Job as RQJob
from cvat.apps.engine.rq_job_handler import is_rq_job_owner
+from cvat.apps.engine.utils import is_dataset_export
from cvat.apps.iam.permissions import (
- OpenPolicyAgentPermission, StrEnum, get_iam_context, get_membership
+ OpenPolicyAgentPermission,
+ StrEnum,
+ get_iam_context,
+ get_membership,
)
from cvat.apps.organizations.models import Organization
from .models import AnnotationGuide, CloudStorage, Issue, Job, Label, Project, Task
-from cvat.apps.engine.utils import is_dataset_export
+
def _get_key(d: dict[str, Any], key_path: Union[str, Sequence[str]]) -> Optional[Any]:
"""
diff --git a/cvat/apps/engine/renderers.py b/cvat/apps/engine/renderers.py
index f56eb4d39808..542a322048ed 100644
--- a/cvat/apps/engine/renderers.py
+++ b/cvat/apps/engine/renderers.py
@@ -4,5 +4,6 @@
from rest_framework.renderers import JSONRenderer
+
class CVATAPIRenderer(JSONRenderer):
media_type = 'application/vnd.cvat+json'
diff --git a/cvat/apps/engine/rq_job_handler.py b/cvat/apps/engine/rq_job_handler.py
index c5b31336ecdc..b4f146197afc 100644
--- a/cvat/apps/engine/rq_job_handler.py
+++ b/cvat/apps/engine/rq_job_handler.py
@@ -4,13 +4,14 @@
from __future__ import annotations
-import attrs
-
-from typing import Optional, Union, Any
+from typing import Any, Optional, Union
from uuid import UUID
+
+import attrs
from rq.job import Job as RQJob
-from .models import RequestAction, RequestTarget, RequestSubresource
+from .models import RequestAction, RequestSubresource, RequestTarget
+
class RQMeta:
@staticmethod
diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py
index 9f772cd24e6d..6c760b42ba65 100644
--- a/cvat/apps/engine/serializers.py
+++ b/cvat/apps/engine/serializers.py
@@ -5,48 +5,55 @@
from __future__ import annotations
+import os
+import re
+import shutil
+import string
+import textwrap
+import warnings
from collections import OrderedDict
from collections.abc import Iterable, Sequence
from contextlib import closing
-import warnings
from copy import copy
from datetime import timedelta
from decimal import Decimal
from inspect import isclass
-import os
-import re
-import shutil
-import string
from tempfile import NamedTemporaryFile
-import textwrap
from typing import Any, Optional, Union
import django_rq
+import rq.defaults as rq_defaults
from django.conf import settings
-from django.contrib.auth.models import User, Group
+from django.contrib.auth.models import Group, User
from django.db import transaction
-from django.db.models import prefetch_related_objects, Prefetch
+from django.db.models import Prefetch, prefetch_related_objects
from django.utils import timezone
+from drf_spectacular.utils import OpenApiExample, extend_schema_field, extend_schema_serializer
from numpy import random
-from rest_framework import serializers, exceptions
-import rq.defaults as rq_defaults
-from rq.job import Job as RQJob, JobStatus as RQJobStatus
+from rest_framework import exceptions, serializers
+from rq.job import Job as RQJob
+from rq.job import JobStatus as RQJobStatus
from cvat.apps.dataset_manager.formats.utils import get_label_color
from cvat.apps.engine import field_validation, models
-from cvat.apps.engine.frame_provider import TaskFrameProvider, FrameQuality
-from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials, Status
+from cvat.apps.engine.cloud_provider import Credentials, Status, get_cloud_storage_instance
+from cvat.apps.engine.frame_provider import FrameQuality, TaskFrameProvider
from cvat.apps.engine.log import ServerLogManager
from cvat.apps.engine.permissions import TaskPermission
+from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField
from cvat.apps.engine.task_validation import HoneypotFrameSelector
-from cvat.apps.engine.rq_job_handler import RQJobMetaField, RQId
from cvat.apps.engine.utils import (
- format_list, grouped, parse_exception_message, CvatChunkTimestampMismatchError,
- parse_specific_attributes, build_field_filter_params, get_list_view_name, reverse, take_by
+ CvatChunkTimestampMismatchError,
+ build_field_filter_params,
+ format_list,
+ get_list_view_name,
+ grouped,
+ parse_exception_message,
+ parse_specific_attributes,
+ reverse,
+ take_by,
)
-from drf_spectacular.utils import OpenApiExample, extend_schema_field, extend_schema_serializer
-
slogger = ServerLogManager(__name__)
class WriteOnceMixin:
@@ -996,7 +1003,10 @@ def validate(self, attrs):
@transaction.atomic
def update(self, instance: models.Job, validated_data: dict[str, Any]) -> models.Job:
from cvat.apps.engine.cache import (
- MediaCache, Callback, enqueue_create_chunk_job, wait_for_rq_job
+ Callback,
+ MediaCache,
+ enqueue_create_chunk_job,
+ wait_for_rq_job,
)
from cvat.apps.engine.frame_provider import JobFrameProvider
@@ -1101,7 +1111,7 @@ def _to_abs_frame(rel_frame: int) -> int:
)
if bulk_context:
- active_validation_frame_counts = bulk_context.active_validation_frame_counts
+ frame_selector = bulk_context.honeypot_frame_selector
else:
active_validation_frame_counts = {
validation_frame: 0 for validation_frame in task_active_validation_frames
@@ -1111,7 +1121,8 @@ def _to_abs_frame(rel_frame: int) -> int:
if real_frame in task_active_validation_frames:
active_validation_frame_counts[real_frame] += 1
- frame_selector = HoneypotFrameSelector(active_validation_frame_counts)
+ frame_selector = HoneypotFrameSelector(active_validation_frame_counts)
+
requested_frames = frame_selector.select_next_frames(segment_honeypots_count)
requested_frames = list(map(_to_abs_frame, requested_frames))
else:
@@ -1358,7 +1369,7 @@ def __init__(
honeypot_frames: list[int],
all_validation_frames: list[int],
active_validation_frames: list[int],
- validation_frame_counts: dict[int, int] | None = None
+ honeypot_frame_selector: HoneypotFrameSelector | None = None
):
self.updated_honeypots: dict[int, models.Image] = {}
self.updated_segments: list[int] = []
@@ -1370,7 +1381,7 @@ def __init__(
self.honeypot_frames = honeypot_frames
self.all_validation_frames = all_validation_frames
self.active_validation_frames = active_validation_frames
- self.active_validation_frame_counts = validation_frame_counts
+ self.honeypot_frame_selector = honeypot_frame_selector
class TaskValidationLayoutWriteSerializer(serializers.Serializer):
disabled_frames = serializers.ListField(
@@ -1485,7 +1496,9 @@ def update(self, instance: models.Task, validated_data: dict[str, Any]) -> model
)
elif frame_selection_method == models.JobFrameSelectionMethod.RANDOM_UNIFORM:
# Reset distribution for active validation frames
- bulk_context.active_validation_frame_counts = { f: 0 for f in active_validation_frames }
+ active_validation_frame_counts = { f: 0 for f in active_validation_frames }
+ frame_selector = HoneypotFrameSelector(active_validation_frame_counts)
+ bulk_context.honeypot_frame_selector = frame_selector
# Could be done using Django ORM, but using order_by() and filter()
# would result in an extra DB request
diff --git a/cvat/apps/engine/signals.py b/cvat/apps/engine/signals.py
index 3a964d90c2cc..456c6f228081 100644
--- a/cvat/apps/engine/signals.py
+++ b/cvat/apps/engine/signals.py
@@ -11,8 +11,7 @@
from django.db.models.signals import m2m_changed, post_delete, post_save
from django.dispatch import receiver
-from .models import CloudStorage, Data, Job, Profile, Project, StatusChoice, Task, Asset
-
+from .models import Asset, CloudStorage, Data, Job, Profile, Project, StatusChoice, Task
# TODO: need to log any problems reported by shutil.rmtree when the new
# analytics feature is available. Now the log system can write information
diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py
index 0f36674299fc..7aa92acba2fd 100644
--- a/cvat/apps/engine/task.py
+++ b/cvat/apps/engine/task.py
@@ -4,24 +4,24 @@
# SPDX-License-Identifier: MIT
import concurrent.futures
-import itertools
import fnmatch
+import itertools
import os
import re
-import rq
import shutil
from collections.abc import Iterator, Sequence
-from copy import deepcopy
from contextlib import closing
+from copy import deepcopy
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, NamedTuple, Optional, Union
from urllib import parse as urlparse
from urllib import request as urlrequest
-import av
import attrs
+import av
import django_rq
+import rq
from django.conf import settings
from django.db import transaction
from django.forms.models import model_to_dict
@@ -29,25 +29,39 @@
from rest_framework.serializers import ValidationError
from cvat.apps.engine import models
-from cvat.apps.engine.log import ServerLogManager
from cvat.apps.engine.frame_provider import TaskFrameProvider
+from cvat.apps.engine.log import ServerLogManager
from cvat.apps.engine.media_extractors import (
- MEDIA_TYPES, CachingMediaIterator, IMediaReader, ImageListReader,
- Mpeg4ChunkWriter, Mpeg4CompressedChunkWriter, RandomAccessIterator,
- ValidateDimension, ZipChunkWriter, ZipCompressedChunkWriter, get_mime, sort,
+ MEDIA_TYPES,
+ CachingMediaIterator,
+ ImageListReader,
+ IMediaReader,
+ Mpeg4ChunkWriter,
+ Mpeg4CompressedChunkWriter,
+ RandomAccessIterator,
+ ValidateDimension,
+ ZipChunkWriter,
+ ZipCompressedChunkWriter,
+ get_mime,
load_image,
+ sort,
)
from cvat.apps.engine.models import RequestAction, RequestTarget
-from cvat.apps.engine.utils import (
- av_scan_paths, format_list, get_rq_job_meta,
- define_dependent_job, get_rq_lock_by_user, take_by
-)
from cvat.apps.engine.rq_job_handler import RQId
from cvat.apps.engine.task_validation import HoneypotFrameSelector
-from cvat.utils.http import make_requests_session, PROXIES_FOR_UNTRUSTED_URLS
+from cvat.apps.engine.utils import (
+ av_scan_paths,
+ define_dependent_job,
+ format_list,
+ get_rq_job_meta,
+ get_rq_lock_by_user,
+ take_by,
+)
+from cvat.utils.http import PROXIES_FOR_UNTRUSTED_URLS, make_requests_session
from utils.dataset_manifest import ImageManifestManager, VideoManifestManager, is_manifest
from utils.dataset_manifest.core import VideoManifestValidator, is_dataset_manifest
from utils.dataset_manifest.utils import detect_related_images
+
from .cloud_provider import db_storage_to_storage_instance
slogger = ServerLogManager(__name__)
diff --git a/cvat/apps/engine/task_validation.py b/cvat/apps/engine/task_validation.py
index fe76b4e99408..4734c153e8b4 100644
--- a/cvat/apps/engine/task_validation.py
+++ b/cvat/apps/engine/task_validation.py
@@ -2,26 +2,109 @@
#
# SPDX-License-Identifier: MIT
-from collections.abc import Mapping, Sequence
-from typing import Generic, TypeVar
+from __future__ import annotations
+from typing import Callable, Generic, Iterable, Mapping, Sequence, TypeVar
+
+import attrs
import numpy as np
-_T = TypeVar("_T")
+_K = TypeVar("_K")
+
+
+@attrs.define
+class _BaggedCounter(Generic[_K]):
+ # Stores items with count = k in a single "bag". Bags are stored in the ascending order
+ bags: dict[
+ int,
+ dict[_K, None],
+ # dict is used instead of a set to preserve item order. It's also more performant
+ ]
+
+ @staticmethod
+ def from_dict(item_counts: Mapping[_K, int]) -> _BaggedCounter:
+ return _BaggedCounter.from_counts(item_counts, item_count=item_counts.__getitem__)
+
+ @staticmethod
+ def from_counts(items: Sequence[_K], item_count: Callable[[_K], int]) -> _BaggedCounter:
+ bags = {}
+ for item in items:
+ count = item_count(item)
+ bags.setdefault(count, dict())[item] = None
+
+ return _BaggedCounter(bags=bags)
+
+ def __attrs_post_init__(self):
+ self._sort_bags()
+
+ def _sort_bags(self):
+ self.bags = dict(sorted(self.bags.items(), key=lambda e: e[0]))
+
+ def shuffle(self, *, rng: np.random.Generator | None):
+ if not rng:
+ rng = np.random.default_rng()
+
+ for count, bag in self.bags.items():
+ items = list(bag.items())
+ rng.shuffle(items)
+ self.bags[count] = dict(items)
+
+ def use_item(self, item: _K, *, count: int | None = None, bag: dict | None = None):
+ if count is not None:
+ if bag is None:
+ bag = self.bags[count]
+ elif count is None and bag is None:
+ count, bag = next((c, b) for c, b in self.bags.items() if item in b)
+ else:
+ raise AssertionError("'bag' can only be used together with 'count'")
+ bag.pop(item)
-class HoneypotFrameSelector(Generic[_T]):
+ if not bag:
+ self.bags.pop(count)
+
+ next_bag = self.bags.get(count + 1)
+ if next_bag is None:
+ next_bag = {}
+ self.bags[count + 1] = next_bag
+ self._sort_bags() # the new bag can be added in the wrong position if there were gaps
+
+ next_bag[item] = None
+
+ def __iter__(self) -> Iterable[tuple[int, _K, dict]]:
+ for count, bag in self.bags.items(): # bags must be ordered
+ for item in bag:
+ yield (count, item, bag)
+
+ def select_next_least_used(self, count: int) -> Sequence[_K]:
+ pick = [None] * count
+ pick_original_use_counts = [(None, None)] * count
+ for i, (use_count, item, bag) in zip(range(count), self):
+ pick[i] = item
+ pick_original_use_counts[i] = (use_count, bag)
+
+ for item, (use_count, bag) in zip(pick, pick_original_use_counts):
+ self.use_item(item, count=use_count, bag=bag)
+
+ return pick
+
+
+class HoneypotFrameSelector(Generic[_K]):
def __init__(
- self, validation_frame_counts: Mapping[_T, int], *, rng: np.random.Generator | None = None
+ self,
+ validation_frame_counts: Mapping[_K, int],
+ *,
+ rng: np.random.Generator | None = None,
):
- self.validation_frame_counts = validation_frame_counts
-
if not rng:
rng = np.random.default_rng()
self.rng = rng
- def select_next_frames(self, count: int) -> Sequence[_T]:
+ self._counter = _BaggedCounter.from_dict(validation_frame_counts)
+ self._counter.shuffle(rng=rng)
+
+ def select_next_frames(self, count: int) -> Sequence[_K]:
# This approach guarantees that:
# - every GT frame is used
# - GT frames are used uniformly (at most min count + 1)
@@ -29,20 +112,8 @@ def select_next_frames(self, count: int) -> Sequence[_T]:
# - honeypot sets are different in jobs
# - honeypot sets are random
# if possible (if the job and GT counts allow this).
- pick = []
-
- for random_number in self.rng.random(count):
- least_count = min(c for f, c in self.validation_frame_counts.items() if f not in pick)
- least_used_frames = tuple(
- f
- for f, c in self.validation_frame_counts.items()
- if f not in pick
- if c == least_count
- )
-
- selected_item = int(random_number * len(least_used_frames))
- selected_frame = least_used_frames[selected_item]
- pick.append(selected_frame)
- self.validation_frame_counts[selected_frame] += 1
-
- return pick
+ # Picks must be reproducible for a given rng state.
+ """
+ Selects 'count' least used items randomly, without repetition
+ """
+ return self._counter.select_next_least_used(count)
diff --git a/cvat/apps/engine/tests/test_lazy_list.py b/cvat/apps/engine/tests/test_lazy_list.py
index 6ba4b07dd38f..2a021f89b94a 100644
--- a/cvat/apps/engine/tests/test_lazy_list.py
+++ b/cvat/apps/engine/tests/test_lazy_list.py
@@ -1,9 +1,9 @@
-import unittest
import copy
import pickle
+import unittest
from typing import TypeVar
-from cvat.apps.engine.lazy_list import LazyList
+from cvat.apps.engine.lazy_list import LazyList
T = TypeVar('T')
diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py
index b0c5500eda4c..d59c310e5a3c 100644
--- a/cvat/apps/engine/tests/test_rest_api.py
+++ b/cvat/apps/engine/tests/test_rest_api.py
@@ -3,10 +3,10 @@
#
# SPDX-License-Identifier: MIT
-from contextlib import ExitStack
-from datetime import timedelta
+import copy
import io
-from itertools import product
+import json
+import logging
import os
import random
import shutil
@@ -15,40 +15,56 @@
import xml.etree.ElementTree as ET
import zipfile
from collections import defaultdict
+from contextlib import ExitStack
+from datetime import timedelta
from enum import Enum
from glob import glob
from io import BytesIO, IOBase
-from unittest import mock
+from itertools import product
from time import sleep
-import logging
-import copy
-import json
+from unittest import mock
import av
import django_rq
import numpy as np
-from pdf2image import convert_from_bytes
-from pyunpack import Archive
from django.conf import settings
from django.contrib.auth.models import Group, User
from django.http import HttpResponse
+from pdf2image import convert_from_bytes
from PIL import Image
from pycocotools import coco as coco_loader
+from pyunpack import Archive
from rest_framework import status
from rest_framework.test import APIClient
from cvat.apps.dataset_manager.tests.utils import TestDir
from cvat.apps.dataset_manager.util import current_function_name
-from cvat.apps.engine.models import (AttributeSpec, AttributeType, Data, Job,
- Project, Segment, StageChoice, StatusChoice, Task, Label, StorageMethodChoice,
- StorageChoice, DimensionType, SortingMethod)
from cvat.apps.engine.media_extractors import ValidateDimension, sort
-from cvat.apps.engine.tests.utils import get_paginated_collection
+from cvat.apps.engine.models import (
+ AttributeSpec,
+ AttributeType,
+ Data,
+ DimensionType,
+ Job,
+ Label,
+ Project,
+ Segment,
+ SortingMethod,
+ StageChoice,
+ StatusChoice,
+ StorageChoice,
+ StorageMethodChoice,
+ Task,
+)
+from cvat.apps.engine.tests.utils import (
+ ApiTestBase,
+ ForceLogin,
+ generate_image_file,
+ generate_video_file,
+ get_paginated_collection,
+)
from utils.dataset_manifest import ImageManifestManager, VideoManifestManager
-from cvat.apps.engine.tests.utils import (ApiTestBase, ForceLogin,
- generate_image_file, generate_video_file)
-
#suppress av warnings
logging.getLogger('libav').setLevel(logging.ERROR)
@@ -6382,6 +6398,9 @@ def _get_initial_annotation(annotation_format):
formats['CVAT for video 1.1'] = 'CVAT 1.1'
if 'CVAT for images 1.1' in export_formats:
formats['CVAT for images 1.1'] = 'CVAT 1.1'
+ if 'Ultralytics YOLO Detection 1.0' in import_formats:
+ if 'Ultralytics YOLO Detection Track 1.0' in export_formats:
+ formats['Ultralytics YOLO Detection Track 1.0'] = 'Ultralytics YOLO Detection 1.0'
if set(import_formats) ^ set(export_formats):
# NOTE: this may not be an error, so we should not fail
print("The following import formats have no pair:",
diff --git a/cvat/apps/engine/tests/test_rest_api_3D.py b/cvat/apps/engine/tests/test_rest_api_3D.py
index 67791c3c113c..087448c90dd2 100644
--- a/cvat/apps/engine/tests/test_rest_api_3D.py
+++ b/cvat/apps/engine/tests/test_rest_api_3D.py
@@ -4,7 +4,9 @@
# SPDX-License-Identifier: MIT
+import copy
import io
+import itertools
import os
import os.path as osp
import tempfile
@@ -13,18 +15,15 @@
from collections import defaultdict
from glob import glob
from io import BytesIO
-import copy
from shutil import copyfile
-import itertools
from django.contrib.auth.models import Group, User
from rest_framework import status
+from cvat.apps.dataset_manager.task import TaskAnnotation
from cvat.apps.dataset_manager.tests.utils import TestDir
from cvat.apps.engine.media_extractors import ValidateDimension
-from cvat.apps.dataset_manager.task import TaskAnnotation
-
-from cvat.apps.engine.tests.utils import get_paginated_collection, ApiTestBase, ForceLogin
+from cvat.apps.engine.tests.utils import ApiTestBase, ForceLogin, get_paginated_collection
CREATE_ACTION = "create"
UPDATE_ACTION = "update"
diff --git a/cvat/apps/engine/tests/utils.py b/cvat/apps/engine/tests/utils.py
index 910323cac1f7..09fd850b2c19 100644
--- a/cvat/apps/engine/tests/utils.py
+++ b/cvat/apps/engine/tests/utils.py
@@ -2,22 +2,22 @@
#
# SPDX-License-Identifier: MIT
+import itertools
+import logging
+import os
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from io import BytesIO
from typing import Any, Callable, TypeVar
-import itertools
-import logging
-import os
+import av
+import django_rq
+import numpy as np
from django.conf import settings
from django.core.cache import caches
from django.http.response import HttpResponse
from PIL import Image
from rest_framework.test import APITestCase
-import av
-import django_rq
-import numpy as np
T = TypeVar('T')
diff --git a/cvat/apps/engine/urls.py b/cvat/apps/engine/urls.py
index 1755197ebcdf..1380ae5f7961 100644
--- a/cvat/apps/engine/urls.py
+++ b/cvat/apps/engine/urls.py
@@ -3,14 +3,13 @@
#
# SPDX-License-Identifier: MIT
-from django.urls import path, include
-from . import views
-from rest_framework import routers
-
-from django.views.generic import RedirectView
from django.conf import settings
-
+from django.urls import include, path
+from django.views.generic import RedirectView
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
+from rest_framework import routers
+
+from . import views
router = routers.DefaultRouter(trailing_slash=False)
router.register('projects', views.ProjectViewSet)
diff --git a/cvat/apps/engine/utils.py b/cvat/apps/engine/utils.py
index dd4533538f5a..b3e3d48f69d6 100644
--- a/cvat/apps/engine/utils.py
+++ b/cvat/apps/engine/utils.py
@@ -4,42 +4,39 @@
# SPDX-License-Identifier: MIT
import ast
-from itertools import islice
-import cv2 as cv
-from collections import namedtuple
-from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
import hashlib
import importlib
+import logging
+import os
+import platform
+import re
+import subprocess
import sys
import traceback
-from contextlib import suppress, nullcontext
-from typing import Any, Callable, Optional, TypeVar, Union
-import subprocess
-import os
import urllib.parse
-import re
-import logging
-import platform
+from collections import namedtuple
+from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
+from contextlib import nullcontext, suppress
+from itertools import islice
+from multiprocessing import cpu_count
+from pathlib import Path
+from typing import Any, Callable, Optional, TypeVar, Union
+import cv2 as cv
from attr.converters import to_bool
+from av import VideoFrame
from datumaro.util.os_util import walk
-from rq.job import Job, Dependency
-from django_rq.queues import DjangoRQ
-from pathlib import Path
-
+from django.conf import settings
+from django.core.exceptions import ValidationError
from django.http.request import HttpRequest
from django.utils import timezone
from django.utils.http import urlencode
-from rest_framework.reverse import reverse as _reverse
-
-from av import VideoFrame
-from PIL import Image
-from multiprocessing import cpu_count
-
-from django.core.exceptions import ValidationError
+from django_rq.queues import DjangoRQ
from django_sendfile import sendfile as _sendfile
-from django.conf import settings
+from PIL import Image
from redis.lock import Lock
+from rest_framework.reverse import reverse as _reverse
+from rq.job import Dependency, Job
Import = namedtuple("Import", ["module", "name", "alias"])
@@ -230,8 +227,8 @@ def get_rq_job_meta(
result_url: Optional[str] = None,
):
# to prevent circular import
- from cvat.apps.webhooks.signals import project_id, organization_id
- from cvat.apps.events.handlers import task_id, job_id, organization_slug
+ from cvat.apps.events.handlers import job_id, organization_slug, task_id
+ from cvat.apps.webhooks.signals import organization_id, project_id
oid = organization_id(db_obj)
oslug = organization_slug(db_obj)
diff --git a/cvat/apps/engine/view_utils.py b/cvat/apps/engine/view_utils.py
index 6f5dc298a7b6..dbac90720b43 100644
--- a/cvat/apps/engine/view_utils.py
+++ b/cvat/apps/engine/view_utils.py
@@ -9,11 +9,11 @@
from django.db.models.query import QuerySet
from django.http.request import HttpRequest
from django.http.response import HttpResponse
+from drf_spectacular.utils import extend_schema
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.serializers import Serializer
from rest_framework.viewsets import GenericViewSet
-from drf_spectacular.utils import extend_schema
from cvat.apps.engine.mixins import UploadMixin
from cvat.apps.engine.parsers import TusUploadParser
diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py
index eb39f6732c18..6b70836d53a1 100644
--- a/cvat/apps/engine/views.py
+++ b/cvat/apps/engine/views.py
@@ -13,105 +13,173 @@
import traceback
import zlib
from abc import ABCMeta, abstractmethod
-from contextlib import suppress
-from PIL import Image
-from types import SimpleNamespace
-from typing import Optional, Any, Union, cast, Callable
from collections import namedtuple
-from collections.abc import Mapping, Iterable
+from collections.abc import Iterable, Mapping
+from contextlib import suppress
from copy import copy
from datetime import datetime
-from redis.exceptions import ConnectionError as RedisConnectionError
+from pathlib import Path
from tempfile import NamedTemporaryFile
+from types import SimpleNamespace
+from typing import Any, Callable, Optional, Union, cast
import django_rq
from attr.converters import to_bool
from django.conf import settings
from django.contrib.auth.models import User
-from django.db import IntegrityError, transaction
+from django.db import IntegrityError
from django.db import models as django_models
+from django.db import transaction
from django.db.models.query import Prefetch
-from django.http import HttpResponse, HttpRequest, HttpResponseNotFound, HttpResponseBadRequest
+from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest, HttpResponseNotFound
from django.utils import timezone
from django.utils.decorators import method_decorator
from django.views.decorators.cache import never_cache
from django_rq.queues import DjangoRQ
-
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import (
- OpenApiExample, OpenApiParameter, OpenApiResponse, PolymorphicProxySerializer,
- extend_schema_view, extend_schema
+ OpenApiExample,
+ OpenApiParameter,
+ OpenApiResponse,
+ PolymorphicProxySerializer,
+ extend_schema,
+ extend_schema_view,
)
-
-from pathlib import Path
+from PIL import Image
+from redis.exceptions import ConnectionError as RedisConnectionError
from rest_framework import mixins, serializers, status, viewsets
from rest_framework.decorators import action
-from rest_framework.exceptions import APIException, NotFound, ValidationError, PermissionDenied
+from rest_framework.exceptions import APIException, NotFound, PermissionDenied, ValidationError
from rest_framework.parsers import MultiPartParser
from rest_framework.permissions import SAFE_METHODS
from rest_framework.response import Response
from rest_framework.settings import api_settings
-
-from rq.job import Job as RQJob, JobStatus as RQJobStatus
+from rq.job import Job as RQJob
+from rq.job import JobStatus as RQJobStatus
import cvat.apps.dataset_manager as dm
import cvat.apps.dataset_manager.views # pylint: disable=unused-import
-from cvat.apps.engine.cloud_provider import db_storage_to_storage_instance, import_resource_from_cloud_storage
-from cvat.apps.events.handlers import handle_dataset_import
from cvat.apps.dataset_manager.bindings import CvatImportError
from cvat.apps.dataset_manager.serializers import DatasetFormatsSerializer
+from cvat.apps.engine import backup
+from cvat.apps.engine.cache import CvatChunkTimestampMismatchError, LockError, MediaCache
+from cvat.apps.engine.cloud_provider import (
+ db_storage_to_storage_instance,
+ import_resource_from_cloud_storage,
+)
+from cvat.apps.engine.filters import (
+ NonModelJsonLogicFilter,
+ NonModelOrderingFilter,
+ NonModelSimpleFilter,
+)
from cvat.apps.engine.frame_provider import (
- DataWithMeta, IFrameProvider, TaskFrameProvider, JobFrameProvider, FrameQuality
+ DataWithMeta,
+ FrameQuality,
+ IFrameProvider,
+ JobFrameProvider,
+ TaskFrameProvider,
)
-from cvat.apps.engine.filters import NonModelSimpleFilter, NonModelOrderingFilter, NonModelJsonLogicFilter
+from cvat.apps.engine.location import StorageType, get_location_configuration
from cvat.apps.engine.media_extractors import get_mime
-from cvat.apps.engine.permissions import AnnotationGuidePermission, get_iam_context
+from cvat.apps.engine.mixins import (
+ BackupMixin,
+ CsrfWorkaroundMixin,
+ DatasetMixin,
+ PartialUpdateModelMixin,
+ UploadMixin,
+)
+from cvat.apps.engine.models import AnnotationGuide, Asset, ClientFile, CloudProviderChoice
+from cvat.apps.engine.models import CloudStorage as CloudStorageModel
from cvat.apps.engine.models import (
- ClientFile, Job, JobType, Label, Task, Project, Issue, Data,
- Comment, StorageMethodChoice, StorageChoice,
- CloudProviderChoice, Location, CloudStorage as CloudStorageModel,
- Asset, AnnotationGuide, RequestStatus, RequestAction, RequestTarget, RequestSubresource
+ Comment,
+ Data,
+ Issue,
+ Job,
+ JobType,
+ Label,
+ Location,
+ Project,
+ RequestAction,
+ RequestStatus,
+ RequestSubresource,
+ RequestTarget,
+ StorageChoice,
+ StorageMethodChoice,
+ Task,
+)
+from cvat.apps.engine.permissions import (
+ AnnotationGuidePermission,
+ CloudStoragePermission,
+ CommentPermission,
+ IssuePermission,
+ JobPermission,
+ LabelPermission,
+ ProjectPermission,
+ TaskPermission,
+ UserPermission,
+ get_cloud_storage_for_import_or_export,
+ get_iam_context,
)
+from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField, is_rq_job_owner
from cvat.apps.engine.serializers import (
- AboutSerializer, AnnotationFileSerializer, BasicUserSerializer,
- DataMetaReadSerializer, DataMetaWriteSerializer, DataSerializer, FileInfoSerializer,
- JobDataMetaWriteSerializer, JobReadSerializer, JobWriteSerializer,
- JobValidationLayoutReadSerializer, JobValidationLayoutWriteSerializer,
- LabelSerializer, LabeledDataSerializer,
- ProjectReadSerializer, ProjectWriteSerializer,
- RqStatusSerializer, TaskReadSerializer, TaskValidationLayoutReadSerializer, TaskValidationLayoutWriteSerializer, TaskWriteSerializer,
- UserSerializer, PluginsSerializer, IssueReadSerializer,
- AnnotationGuideReadSerializer, AnnotationGuideWriteSerializer,
- AssetReadSerializer, AssetWriteSerializer,
- IssueWriteSerializer, CommentReadSerializer, CommentWriteSerializer, CloudStorageWriteSerializer,
- CloudStorageReadSerializer, DatasetFileSerializer,
- ProjectFileSerializer, TaskFileSerializer, RqIdSerializer, CloudStorageContentSerializer,
+ AboutSerializer,
+ AnnotationFileSerializer,
+ AnnotationGuideReadSerializer,
+ AnnotationGuideWriteSerializer,
+ AssetReadSerializer,
+ AssetWriteSerializer,
+ BasicUserSerializer,
+ CloudStorageContentSerializer,
+ CloudStorageReadSerializer,
+ CloudStorageWriteSerializer,
+ CommentReadSerializer,
+ CommentWriteSerializer,
+ DataMetaReadSerializer,
+ DataMetaWriteSerializer,
+ DataSerializer,
+ DatasetFileSerializer,
+ FileInfoSerializer,
+ IssueReadSerializer,
+ IssueWriteSerializer,
+ JobDataMetaWriteSerializer,
+ JobReadSerializer,
+ JobValidationLayoutReadSerializer,
+ JobValidationLayoutWriteSerializer,
+ JobWriteSerializer,
+ LabeledDataSerializer,
+ LabelSerializer,
+ PluginsSerializer,
+ ProjectFileSerializer,
+ ProjectReadSerializer,
+ ProjectWriteSerializer,
RequestSerializer,
+ RqIdSerializer,
+ RqStatusSerializer,
+ TaskFileSerializer,
+ TaskReadSerializer,
+ TaskValidationLayoutReadSerializer,
+ TaskValidationLayoutWriteSerializer,
+ TaskWriteSerializer,
+ UserSerializer,
)
-from cvat.apps.engine.permissions import get_cloud_storage_for_import_or_export
-
-from utils.dataset_manifest import ImageManifestManager
from cvat.apps.engine.utils import (
- av_scan_paths, process_failed_job,
- parse_exception_message, get_rq_job_meta,
- import_resource_with_clean_up_after, sendfile, define_dependent_job, get_rq_lock_by_user,
-)
-from cvat.apps.engine.rq_job_handler import RQId, is_rq_job_owner, RQJobMetaField
-from cvat.apps.engine import backup
-from cvat.apps.engine.mixins import (
- PartialUpdateModelMixin, UploadMixin, DatasetMixin, BackupMixin, CsrfWorkaroundMixin
+ av_scan_paths,
+ define_dependent_job,
+ get_rq_job_meta,
+ get_rq_lock_by_user,
+ import_resource_with_clean_up_after,
+ parse_exception_message,
+ process_failed_job,
+ sendfile,
)
-from cvat.apps.engine.location import get_location_configuration, StorageType
+from cvat.apps.engine.view_utils import tus_chunk_action
+from cvat.apps.events.handlers import handle_dataset_import
+from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
+from cvat.apps.iam.permissions import IsAuthenticatedOrReadPublicResource, PolicyEnforcer
+from utils.dataset_manifest import ImageManifestManager
from . import models, task
from .log import ServerLogManager
-from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
-from cvat.apps.iam.permissions import PolicyEnforcer, IsAuthenticatedOrReadPublicResource
-from cvat.apps.engine.cache import MediaCache, CvatChunkTimestampMismatchError, LockError
-from cvat.apps.engine.permissions import (CloudStoragePermission,
- CommentPermission, IssuePermission, JobPermission, LabelPermission, ProjectPermission,
- TaskPermission, UserPermission)
-from cvat.apps.engine.view_utils import tus_chunk_action
slogger = ServerLogManager(__name__)
diff --git a/cvat/apps/events/apps.py b/cvat/apps/events/apps.py
index f700758ad204..c4a7b0a3d9b4 100644
--- a/cvat/apps/events/apps.py
+++ b/cvat/apps/events/apps.py
@@ -6,10 +6,11 @@
class EventsConfig(AppConfig):
- name = 'cvat.apps.events'
+ name = "cvat.apps.events"
def ready(self):
- from . import signals # pylint: disable=unused-import
-
from cvat.apps.iam.permissions import load_app_permissions
+
load_app_permissions(self)
+
+ from . import signals # pylint: disable=unused-import
diff --git a/cvat/apps/events/cache.py b/cvat/apps/events/cache.py
index 30d1e67b8fc1..d17a8e703bc1 100644
--- a/cvat/apps/events/cache.py
+++ b/cvat/apps/events/cache.py
@@ -4,36 +4,42 @@
_caches = {}
-class DeleteCache():
+
+class DeleteCache:
def __init__(self, cache_id):
- from cvat.apps.engine.models import Task, Job, Issue, Comment
- self._cache = _caches.setdefault(cache_id, {
- Task: {},
- Job: {},
- Issue: {},
- Comment: {},
- })
+ from cvat.apps.engine.models import Comment, Issue, Job, Task
+
+ self._cache = _caches.setdefault(
+ cache_id,
+ {
+ Task: {},
+ Job: {},
+ Issue: {},
+ Comment: {},
+ },
+ )
def set(self, instance_class, instance_id, value):
self._cache[instance_class][instance_id] = value
def pop(self, instance_class, instance_id, default=None):
- if instance_class in self._cache and \
- instance_id in self._cache[instance_class]:
+ if instance_class in self._cache and instance_id in self._cache[instance_class]:
return self._cache[instance_class].pop(instance_id, default)
def has_key(self, instance_class, instance_id):
- if instance_class in self._cache and \
- instance_id in self._cache[instance_class]:
+ if instance_class in self._cache and instance_id in self._cache[instance_class]:
return True
return False
def clear(self):
self._cache.clear()
+
def get_cache():
from .handlers import request_id
+
return DeleteCache(request_id())
+
def clear_cache():
get_cache().clear()
diff --git a/cvat/apps/events/const.py b/cvat/apps/events/const.py
index 35df51c7adb0..9291d9397be3 100644
--- a/cvat/apps/events/const.py
+++ b/cvat/apps/events/const.py
@@ -6,5 +6,5 @@
MAX_EVENT_DURATION = datetime.timedelta(seconds=100)
WORKING_TIME_RESOLUTION = datetime.timedelta(milliseconds=1)
-WORKING_TIME_SCOPE = 'send:working_time'
+WORKING_TIME_SCOPE = "send:working_time"
COMPRESSED_EVENT_SCOPES = frozenset(("change:frame",))
diff --git a/cvat/apps/events/event.py b/cvat/apps/events/event.py
index a4afff968549..5368367b70d7 100644
--- a/cvat/apps/events/event.py
+++ b/cvat/apps/events/event.py
@@ -2,14 +2,15 @@
#
# SPDX-License-Identifier: MIT
-from rest_framework.renderers import JSONRenderer
from datetime import datetime, timezone
from typing import Optional
from django.db import transaction
+from rest_framework.renderers import JSONRenderer
from cvat.apps.engine.log import vlogger
+
def event_scope(action, resource):
return f"{action}:{resource}"
@@ -41,6 +42,7 @@ def select(cls, resources):
for action in cls.RESOURCES.get(resource, [])
]
+
def record_server_event(
*,
scope: str,
@@ -63,11 +65,11 @@ def record_server_event(
"scope": scope,
"timestamp": str(datetime.now(timezone.utc).timestamp()),
"source": "server",
- "payload": JSONRenderer().render(payload_with_request_id).decode('UTF-8'),
+ "payload": JSONRenderer().render(payload_with_request_id).decode("UTF-8"),
**kwargs,
}
- rendered_data = JSONRenderer().render(data).decode('UTF-8')
+ rendered_data = JSONRenderer().render(data).decode("UTF-8")
if on_commit:
transaction.on_commit(lambda: vlogger.info(rendered_data), robust=True)
@@ -80,6 +82,7 @@ class EventScopeChoice:
def choices(cls):
return sorted((val, val.upper()) for val in AllEvents.events)
+
class AllEvents:
events = list(
event_scope(action, resource)
diff --git a/cvat/apps/events/export.py b/cvat/apps/events/export.py
index 9225f1141162..770f84dda054 100644
--- a/cvat/apps/events/export.py
+++ b/cvat/apps/events/export.py
@@ -2,50 +2,49 @@
#
# SPDX-License-Identifier: MIT
-from logging import Logger
-import os
import csv
-from datetime import datetime, timedelta, timezone
-from dateutil import parser
+import os
import uuid
+from datetime import datetime, timedelta, timezone
+from logging import Logger
+import clickhouse_connect
import django_rq
+from dateutil import parser
from django.conf import settings
-import clickhouse_connect
-
-
from rest_framework import serializers, status
from rest_framework.response import Response
from cvat.apps.dataset_manager.views import log_exception
from cvat.apps.engine.log import ServerLogManager
-from cvat.apps.engine.utils import sendfile
from cvat.apps.engine.rq_job_handler import RQJobMetaField
+from cvat.apps.engine.utils import sendfile
slogger = ServerLogManager(__name__)
DEFAULT_CACHE_TTL = timedelta(hours=1)
+
def _create_csv(query_params, output_filename, cache_ttl):
try:
- clickhouse_settings = settings.CLICKHOUSE['events']
+ clickhouse_settings = settings.CLICKHOUSE["events"]
time_filter = {
- 'from': query_params.pop('from'),
- 'to': query_params.pop('to'),
+ "from": query_params.pop("from"),
+ "to": query_params.pop("to"),
}
query = "SELECT * FROM events"
conditions = []
parameters = {}
- if time_filter['from']:
+ if time_filter["from"]:
conditions.append(f"timestamp >= {{from:DateTime64}}")
- parameters['from'] = time_filter['from']
+ parameters["from"] = time_filter["from"]
- if time_filter['to']:
+ if time_filter["to"]:
conditions.append(f"timestamp <= {{to:DateTime64}}")
- parameters['to'] = time_filter['to']
+ parameters["to"] = time_filter["to"]
for param, value in query_params.items():
if value:
@@ -58,22 +57,23 @@ def _create_csv(query_params, output_filename, cache_ttl):
query += " ORDER BY timestamp ASC"
with clickhouse_connect.get_client(
- host=clickhouse_settings['HOST'],
- database=clickhouse_settings['NAME'],
- port=clickhouse_settings['PORT'],
- username=clickhouse_settings['USER'],
- password=clickhouse_settings['PASSWORD'],
+ host=clickhouse_settings["HOST"],
+ database=clickhouse_settings["NAME"],
+ port=clickhouse_settings["PORT"],
+ username=clickhouse_settings["USER"],
+ password=clickhouse_settings["PASSWORD"],
) as client:
result = client.query(query, parameters=parameters)
- with open(output_filename, 'w', encoding='UTF8') as f:
+ with open(output_filename, "w", encoding="UTF8") as f:
writer = csv.writer(f)
writer.writerow(result.column_names)
writer.writerows(result.result_rows)
archive_ctime = os.path.getctime(output_filename)
scheduler = django_rq.get_scheduler(settings.CVAT_QUEUES.EXPORT_DATA.value)
- cleaning_job = scheduler.enqueue_in(time_delta=cache_ttl,
+ cleaning_job = scheduler.enqueue_in(
+ time_delta=cache_ttl,
func=_clear_export_cache,
file_path=output_filename,
file_ctime=archive_ctime,
@@ -89,36 +89,37 @@ def _create_csv(query_params, output_filename, cache_ttl):
log_exception(slogger.glob)
raise
+
def export(request, filter_query, queue_name):
- action = request.query_params.get('action', None)
- filename = request.query_params.get('filename', None)
+ action = request.query_params.get("action", None)
+ filename = request.query_params.get("filename", None)
query_params = {
- 'org_id': filter_query.get('org_id', None),
- 'project_id': filter_query.get('project_id', None),
- 'task_id': filter_query.get('task_id', None),
- 'job_id': filter_query.get('job_id', None),
- 'user_id': filter_query.get('user_id', None),
- 'from': filter_query.get('from', None),
- 'to': filter_query.get('to', None),
+ "org_id": filter_query.get("org_id", None),
+ "project_id": filter_query.get("project_id", None),
+ "task_id": filter_query.get("task_id", None),
+ "job_id": filter_query.get("job_id", None),
+ "user_id": filter_query.get("user_id", None),
+ "from": filter_query.get("from", None),
+ "to": filter_query.get("to", None),
}
try:
- if query_params['from']:
- query_params['from'] = parser.parse(query_params['from']).timestamp()
+ if query_params["from"]:
+ query_params["from"] = parser.parse(query_params["from"]).timestamp()
except parser.ParserError:
raise serializers.ValidationError(
f"Cannot parse 'from' datetime parameter: {query_params['from']}"
)
try:
- if query_params['to']:
- query_params['to'] = parser.parse(query_params['to']).timestamp()
+ if query_params["to"]:
+ query_params["to"] = parser.parse(query_params["to"]).timestamp()
except parser.ParserError:
raise serializers.ValidationError(
f"Cannot parse 'to' datetime parameter: {query_params['to']}"
)
- if query_params['from'] and query_params['to'] and query_params['from'] > query_params['to']:
+ if query_params["from"] and query_params["to"] and query_params["from"] > query_params["to"]:
raise serializers.ValidationError("'from' must be before than 'to'")
# Set the default time interval to last 30 days
@@ -126,14 +127,13 @@ def export(request, filter_query, queue_name):
query_params["to"] = datetime.now(timezone.utc)
query_params["from"] = query_params["to"] - timedelta(days=30)
- if action not in (None, 'download'):
- raise serializers.ValidationError(
- "Unexpected action specified for the request")
+ if action not in (None, "download"):
+ raise serializers.ValidationError("Unexpected action specified for the request")
- query_id = request.query_params.get('query_id', None) or uuid.uuid4()
+ query_id = request.query_params.get("query_id", None) or uuid.uuid4()
rq_id = f"export:csv-logs-{query_id}-by-{request.user}"
response_data = {
- 'query_id': query_id,
+ "query_id": query_id,
}
queue = django_rq.get_queue(queue_name)
@@ -147,16 +147,14 @@ def export(request, filter_query, queue_name):
timestamp = datetime.strftime(datetime.now(), "%Y_%m_%d_%H_%M_%S")
filename = filename or f"logs_{timestamp}.csv"
- return sendfile(request, file_path, attachment=True,
- attachment_filename=filename)
+ return sendfile(request, file_path, attachment=True, attachment_filename=filename)
else:
if os.path.exists(file_path):
return Response(status=status.HTTP_201_CREATED)
elif rq_job.is_failed:
exc_info = rq_job.meta.get(RQJobMetaField.FORMATTED_EXCEPTION, str(rq_job.exc_info))
rq_job.delete()
- return Response(exc_info,
- status=status.HTTP_500_INTERNAL_SERVER_ERROR)
+ return Response(exc_info, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
else:
return Response(data=response_data, status=status.HTTP_202_ACCEPTED)
@@ -167,18 +165,19 @@ def export(request, filter_query, queue_name):
args=(query_params, output_filename, DEFAULT_CACHE_TTL),
job_id=rq_id,
meta={},
- result_ttl=ttl, failure_ttl=ttl)
+ result_ttl=ttl,
+ failure_ttl=ttl,
+ )
return Response(data=response_data, status=status.HTTP_202_ACCEPTED)
+
def _clear_export_cache(file_path: str, file_ctime: float, logger: Logger) -> None:
try:
if os.path.exists(file_path) and os.path.getctime(file_path) == file_ctime:
os.remove(file_path)
- logger.info(
- "Export cache file '{}' successfully removed" \
- .format(file_path))
+ logger.info("Export cache file '{}' successfully removed".format(file_path))
except Exception:
log_exception(logger)
raise
diff --git a/cvat/apps/events/handlers.py b/cvat/apps/events/handlers.py
index 753eb84dd0da..69dd4b11cdd8 100644
--- a/cvat/apps/events/handlers.py
+++ b/cvat/apps/events/handlers.py
@@ -11,26 +11,40 @@
from rest_framework.exceptions import NotAuthenticated
from rest_framework.views import exception_handler
-from cvat.apps.engine.models import (CloudStorage, Comment, Issue, Job, Label,
- Project, ShapeType, Task, User)
-from cvat.apps.engine.serializers import (BasicUserSerializer,
- CloudStorageReadSerializer,
- CommentReadSerializer,
- IssueReadSerializer,
- JobReadSerializer, LabelSerializer,
- ProjectReadSerializer,
- TaskReadSerializer)
-from cvat.apps.organizations.models import Invitation, Membership, Organization
-from cvat.apps.organizations.serializers import (InvitationReadSerializer,
- MembershipReadSerializer,
- OrganizationReadSerializer)
+from cvat.apps.engine.models import (
+ CloudStorage,
+ Comment,
+ Issue,
+ Job,
+ Label,
+ Project,
+ ShapeType,
+ Task,
+ User,
+)
from cvat.apps.engine.rq_job_handler import RQJobMetaField
+from cvat.apps.engine.serializers import (
+ BasicUserSerializer,
+ CloudStorageReadSerializer,
+ CommentReadSerializer,
+ IssueReadSerializer,
+ JobReadSerializer,
+ LabelSerializer,
+ ProjectReadSerializer,
+ TaskReadSerializer,
+)
+from cvat.apps.organizations.models import Invitation, Membership, Organization
+from cvat.apps.organizations.serializers import (
+ InvitationReadSerializer,
+ MembershipReadSerializer,
+ OrganizationReadSerializer,
+)
from cvat.apps.webhooks.models import Webhook
from cvat.apps.webhooks.serializers import WebhookReadSerializer
from .cache import get_cache
-from .event import event_scope, record_server_event
from .const import WORKING_TIME_RESOLUTION, WORKING_TIME_SCOPE
+from .event import event_scope, record_server_event
from .utils import compute_working_time_per_ids
@@ -161,9 +175,7 @@ def organization_slug(instance):
def get_instance_diff(old_data, data):
- ignore_related_fields = (
- "labels",
- )
+ ignore_related_fields = ("labels",)
diff = {}
for prop, value in data.items():
if prop in ignore_related_fields:
@@ -179,7 +191,7 @@ def get_instance_diff(old_data, data):
def _cleanup_fields(obj: dict[str, Any]) -> dict[str, Any]:
- fields=(
+ fields = (
"slug",
"id",
"name",
@@ -199,9 +211,7 @@ def _cleanup_fields(obj: dict[str, Any]) -> dict[str, Any]:
"attributes",
"key",
)
- subfields=(
- "url",
- )
+ subfields = ("url",)
data = {}
for k, v in obj.items():
@@ -215,11 +225,13 @@ def _cleanup_fields(obj: dict[str, Any]) -> dict[str, Any]:
def _get_object_name(instance):
- if isinstance(instance, Organization) or \
- isinstance(instance, Project) or \
- isinstance(instance, Task) or \
- isinstance(instance, Job) or \
- isinstance(instance, Label):
+ if (
+ isinstance(instance, Organization)
+ or isinstance(instance, Project)
+ or isinstance(instance, Task)
+ or isinstance(instance, Job)
+ or isinstance(instance, Label)
+ ):
return getattr(instance, "name", None)
if isinstance(instance, User):
@@ -251,9 +263,7 @@ def _get_object_name(instance):
def get_serializer(instance):
- context = {
- "request": get_current_request()
- }
+ context = {"request": get_current_request()}
serializer = None
for model, serializer_class in SERIALIZERS:
@@ -262,6 +272,7 @@ def get_serializer(instance):
return serializer
+
def get_serializer_without_url(instance):
serializer = get_serializer(instance)
if serializer:
@@ -290,7 +301,7 @@ def handle_create(scope, instance, **kwargs):
scope=scope,
request_id=request_id(),
on_commit=True,
- obj_id=getattr(instance, 'id', None),
+ obj_id=getattr(instance, "id", None),
obj_name=_get_object_name(instance),
org_id=oid,
org_slug=oslug,
@@ -325,7 +336,7 @@ def handle_update(scope, instance, old_instance, **kwargs):
request_id=request_id(),
on_commit=True,
obj_name=prop,
- obj_id=getattr(instance, f'{prop}_id', None),
+ obj_id=getattr(instance, f"{prop}_id", None),
obj_val=str(change["new_value"]),
org_id=oid,
org_slug=oslug,
@@ -480,6 +491,7 @@ def filter_track(track):
payload={"tracks": tracks},
)
+
def handle_dataset_io(
instance: Union[Project, Task, Job],
action: str,
@@ -488,7 +500,7 @@ def handle_dataset_io(
cloud_storage_id: Optional[int],
**payload_fields,
) -> None:
- payload={"format": format_name, **payload_fields}
+ payload = {"format": format_name, **payload_fields}
if cloud_storage_id:
payload["cloud_storage"] = {"id": cloud_storage_id}
@@ -507,6 +519,7 @@ def handle_dataset_io(
payload=payload,
)
+
def handle_dataset_export(
instance: Union[Project, Task, Job],
*,
@@ -514,8 +527,14 @@ def handle_dataset_export(
cloud_storage_id: Optional[int],
save_images: bool,
) -> None:
- handle_dataset_io(instance, "export",
- format_name=format_name, cloud_storage_id=cloud_storage_id, save_images=save_images)
+ handle_dataset_io(
+ instance,
+ "export",
+ format_name=format_name,
+ cloud_storage_id=cloud_storage_id,
+ save_images=save_images,
+ )
+
def handle_dataset_import(
instance: Union[Project, Task, Job],
@@ -523,7 +542,10 @@ def handle_dataset_import(
format_name: str,
cloud_storage_id: Optional[int],
) -> None:
- handle_dataset_io(instance, "import", format_name=format_name, cloud_storage_id=cloud_storage_id)
+ handle_dataset_io(
+ instance, "import", format_name=format_name, cloud_storage_id=cloud_storage_id
+ )
+
def handle_function_call(
function_id: str,
@@ -545,6 +567,7 @@ def handle_function_call(
},
)
+
def handle_rq_exception(rq_job, exc_type, exc_value, tb):
oid = rq_job.meta.get(RQJobMetaField.ORG_ID, None)
oslug = rq_job.meta.get(RQJobMetaField.ORG_SLUG, None)
@@ -558,7 +581,7 @@ def handle_rq_exception(rq_job, exc_type, exc_value, tb):
payload = {
"message": tb_strings[-1].rstrip("\n"),
- "stack": ''.join(tb_strings),
+ "stack": "".join(tb_strings),
}
record_server_event(
@@ -578,10 +601,11 @@ def handle_rq_exception(rq_job, exc_type, exc_value, tb):
return False
+
def handle_viewset_exception(exc, context):
response = exception_handler(exc, context)
- IGNORED_EXCEPTION_CLASSES = (NotAuthenticated, )
+ IGNORED_EXCEPTION_CLASSES = (NotAuthenticated,)
if isinstance(exc, IGNORED_EXCEPTION_CLASSES):
return response
# the standard DRF exception handler only handle APIException, Http404 and PermissionDenied
@@ -604,7 +628,7 @@ def handle_viewset_exception(exc, context):
"method": request.method,
},
"message": tb_strings[-1].rstrip("\n"),
- "stack": ''.join(tb_strings),
+ "stack": "".join(tb_strings),
"status_code": status_code,
}
diff --git a/cvat/apps/events/permissions.py b/cvat/apps/events/permissions.py
index a1b049cbbd4b..18d30f63ff65 100644
--- a/cvat/apps/events/permissions.py
+++ b/cvat/apps/events/permissions.py
@@ -4,21 +4,21 @@
# SPDX-License-Identifier: MIT
from django.conf import settings
-
from rest_framework.exceptions import PermissionDenied
from cvat.apps.iam.permissions import OpenPolicyAgentPermission, StrEnum
from cvat.utils.http import make_requests_session
+
class EventsPermission(OpenPolicyAgentPermission):
class Scopes(StrEnum):
- SEND_EVENTS = 'send:events'
- DUMP_EVENTS = 'dump:events'
+ SEND_EVENTS = "send:events"
+ DUMP_EVENTS = "dump:events"
@classmethod
def create(cls, request, view, obj, iam_context):
permissions = []
- if view.basename == 'events':
+ if view.basename == "events":
for scope in cls.get_scopes(request, view, obj):
self = cls.create_base_perm(request, view, scope, iam_context, obj)
permissions.append(self)
@@ -27,19 +27,21 @@ def create(cls, request, view, obj, iam_context):
def __init__(self, **kwargs):
super().__init__(**kwargs)
- self.url = settings.IAM_OPA_DATA_URL + '/events/allow'
+ self.url = settings.IAM_OPA_DATA_URL + "/events/allow"
def filter(self, query_params):
- url = self.url.replace('/allow', '/filter')
+ url = self.url.replace("/allow", "/filter")
with make_requests_session() as session:
- r = session.post(url, json=self.payload).json()['result']
+ r = session.post(url, json=self.payload).json()["result"]
filter_params = query_params.copy()
for query in r:
for attr, value in query.items():
if filter_params.get(attr, value) != value:
- raise PermissionDenied(f"You don't have permission to view events with {attr}={filter_params.get(attr)}")
+ raise PermissionDenied(
+ f"You don't have permission to view events with {attr}={filter_params.get(attr)}"
+ )
else:
filter_params[attr] = value
return filter_params
@@ -47,10 +49,12 @@ def filter(self, query_params):
@staticmethod
def get_scopes(request, view, obj):
Scopes = __class__.Scopes
- return [{
- ('create', 'POST'): Scopes.SEND_EVENTS,
- ('list', 'GET'): Scopes.DUMP_EVENTS,
- }[(view.action, request.method)]]
+ return [
+ {
+ ("create", "POST"): Scopes.SEND_EVENTS,
+ ("list", "GET"): Scopes.DUMP_EVENTS,
+ }[(view.action, request.method)]
+ ]
def get_resource(self):
return None
diff --git a/cvat/apps/events/rules/tests/generators/events_test.gen.rego.py b/cvat/apps/events/rules/tests/generators/events_test.gen.rego.py
index dee2d4a68963..a345c8369f9e 100644
--- a/cvat/apps/events/rules/tests/generators/events_test.gen.rego.py
+++ b/cvat/apps/events/rules/tests/generators/events_test.gen.rego.py
@@ -83,13 +83,15 @@ def get_data(scope, context, ownership, privilege, membership, resource, same_or
"scope": scope,
"auth": {
"user": {"id": random.randrange(0, 100), "privilege": privilege},
- "organization": {
- "id": random.randrange(100, 200),
- "owner": {"id": random.randrange(200, 300)},
- "user": {"role": membership},
- }
- if context == "organization"
- else None,
+ "organization": (
+ {
+ "id": random.randrange(100, 200),
+ "owner": {"id": random.randrange(200, 300)},
+ "user": {"role": membership},
+ }
+ if context == "organization"
+ else None
+ ),
},
"resource": resource,
}
diff --git a/cvat/apps/events/serializers.py b/cvat/apps/events/serializers.py
index 9b70f17429c9..f634fef20b87 100644
--- a/cvat/apps/events/serializers.py
+++ b/cvat/apps/events/serializers.py
@@ -30,19 +30,40 @@ class EventSerializer(serializers.Serializer):
class ClientEventsSerializer(serializers.Serializer):
ALLOWED_SCOPES = {
- 'client': frozenset((
- 'load:cvat', 'load:job', 'save:job','load:workspace',
- 'upload:annotations', # TODO: remove in next releases
- 'lock:object', # TODO: remove in next releases
- 'change:attribute', # TODO: remove in next releases
- 'change:label', # TODO: remove in next releases
- 'send:exception', 'join:objects', 'change:frame',
- 'draw:object', 'paste:object', 'copy:object', 'propagate:object',
- 'drag:object', 'resize:object', 'delete:object',
- 'merge:objects', 'split:objects', 'group:objects', 'slice:object',
- 'zoom:image', 'fit:image', 'rotate:image', 'action:undo', 'action:redo',
- 'debug:info', 'run:annotations_action', 'click:element',
- )),
+ "client": frozenset(
+ (
+ "load:cvat",
+ "load:job",
+ "save:job",
+ "load:workspace",
+ "upload:annotations", # TODO: remove in next releases
+ "lock:object", # TODO: remove in next releases
+ "change:attribute", # TODO: remove in next releases
+ "change:label", # TODO: remove in next releases
+ "send:exception",
+ "join:objects",
+ "change:frame",
+ "draw:object",
+ "paste:object",
+ "copy:object",
+ "propagate:object",
+ "drag:object",
+ "resize:object",
+ "delete:object",
+ "merge:objects",
+ "split:objects",
+ "group:objects",
+ "slice:object",
+ "zoom:image",
+ "fit:image",
+ "rotate:image",
+ "action:undo",
+ "action:redo",
+ "debug:info",
+ "run:annotations_action",
+ "click:element",
+ )
+ ),
}
events = EventSerializer(many=True, default=[])
@@ -72,18 +93,24 @@ def to_internal_value(self, data):
scope = event["scope"]
source = event.get("source", "client")
if scope not in ClientEventsSerializer.ALLOWED_SCOPES.get(source, []):
- raise serializers.ValidationError({"scope": f"Event scope **{scope}** is not allowed from {source}"})
+ raise serializers.ValidationError(
+ {"scope": f"Event scope **{scope}** is not allowed from {source}"}
+ )
try:
payload = json.loads(event.get("payload", "{}"))
except json.JSONDecodeError:
- raise serializers.ValidationError({ "payload": "JSON payload is not valid in passed event" })
+ raise serializers.ValidationError(
+ {"payload": "JSON payload is not valid in passed event"}
+ )
- event.update({
- "timestamp": event["timestamp"] + time_correction,
- "source": source,
- "payload": json.dumps(payload),
- **(user_and_org_data if source == 'client' else {})
- })
+ event.update(
+ {
+ "timestamp": event["timestamp"] + time_correction,
+ "source": source,
+ "payload": json.dumps(payload),
+ **(user_and_org_data if source == "client" else {}),
+ }
+ )
return data
diff --git a/cvat/apps/events/tests/test_events.py b/cvat/apps/events/tests/test_events.py
index 71dcc1ea89b0..3b9c4a6c832c 100644
--- a/cvat/apps/events/tests/test_events.py
+++ b/cvat/apps/events/tests/test_events.py
@@ -9,10 +9,11 @@
from django.contrib.auth import get_user_model
from django.test import RequestFactory
-from cvat.apps.events.serializers import ClientEventsSerializer
-from cvat.apps.organizations.models import Organization
from cvat.apps.events.const import MAX_EVENT_DURATION, WORKING_TIME_RESOLUTION
+from cvat.apps.events.serializers import ClientEventsSerializer
from cvat.apps.events.utils import compute_working_time_per_ids, is_contained
+from cvat.apps.organizations.models import Organization
+
class WorkingTimeTestCase(unittest.TestCase):
_START_TIMESTAMP = datetime(2024, 1, 1, 12)
@@ -37,22 +38,20 @@ def _compressed_event(timestamp: datetime, duration: timedelta) -> dict:
"duration": duration // WORKING_TIME_RESOLUTION,
}
-
@staticmethod
def _get_actual_working_times(data: dict) -> list[int]:
data_copy = data.copy()
working_times = []
- for event in data['events']:
- data_copy['events'] = [event]
+ for event in data["events"]:
+ data_copy["events"] = [event]
event_working_time = compute_working_time_per_ids(data_copy)
for working_time in event_working_time.values():
- working_times.append((working_time['value'] // WORKING_TIME_RESOLUTION))
- if data_copy['previous_event'] and is_contained(event, data_copy['previous_event']):
+ working_times.append((working_time["value"] // WORKING_TIME_RESOLUTION))
+ if data_copy["previous_event"] and is_contained(event, data_copy["previous_event"]):
continue
- data_copy['previous_event'] = event
+ data_copy["previous_event"] = event
return working_times
-
@staticmethod
def _deserialize(events: list[dict], previous_event: Optional[dict] = None) -> dict:
request = RequestFactory().post("/api/events")
@@ -65,7 +64,7 @@ def _deserialize(events: list[dict], previous_event: Optional[dict] = None) -> d
data={
"events": events,
"previous_event": previous_event,
- "timestamp": datetime.now(timezone.utc)
+ "timestamp": datetime.now(timezone.utc),
},
context={"request": request},
)
@@ -75,103 +74,118 @@ def _deserialize(events: list[dict], previous_event: Optional[dict] = None) -> d
return s.validated_data
def test_instant(self):
- data = self._deserialize([
- self._instant_event(self._START_TIMESTAMP),
- ])
+ data = self._deserialize(
+ [
+ self._instant_event(self._START_TIMESTAMP),
+ ]
+ )
event_times = self._get_actual_working_times(data)
self.assertEqual(event_times[0], 0)
def test_compressed(self):
- data = self._deserialize([
- self._compressed_event(self._START_TIMESTAMP, self._LONG_GAP),
- ])
+ data = self._deserialize(
+ [
+ self._compressed_event(self._START_TIMESTAMP, self._LONG_GAP),
+ ]
+ )
event_times = self._get_actual_working_times(data)
self.assertEqual(event_times[0], self._LONG_GAP_INT)
def test_instants_with_short_gap(self):
- data = self._deserialize([
- self._instant_event(self._START_TIMESTAMP),
- self._instant_event(self._START_TIMESTAMP + self._SHORT_GAP),
- ])
+ data = self._deserialize(
+ [
+ self._instant_event(self._START_TIMESTAMP),
+ self._instant_event(self._START_TIMESTAMP + self._SHORT_GAP),
+ ]
+ )
event_times = self._get_actual_working_times(data)
self.assertEqual(event_times[0], 0)
self.assertEqual(event_times[1], self._SHORT_GAP_INT)
def test_instants_with_long_gap(self):
- data = self._deserialize([
- self._instant_event(self._START_TIMESTAMP),
- self._instant_event(self._START_TIMESTAMP + self._LONG_GAP),
- ])
+ data = self._deserialize(
+ [
+ self._instant_event(self._START_TIMESTAMP),
+ self._instant_event(self._START_TIMESTAMP + self._LONG_GAP),
+ ]
+ )
event_times = self._get_actual_working_times(data)
self.assertEqual(event_times[0], 0)
self.assertEqual(event_times[1], 0)
def test_compressed_with_short_gap(self):
- data = self._deserialize([
- self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=1)),
- self._compressed_event(
- self._START_TIMESTAMP + timedelta(seconds=1) + self._SHORT_GAP,
- timedelta(seconds=5)
- ),
- ])
+ data = self._deserialize(
+ [
+ self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=1)),
+ self._compressed_event(
+ self._START_TIMESTAMP + timedelta(seconds=1) + self._SHORT_GAP,
+ timedelta(seconds=5),
+ ),
+ ]
+ )
event_times = self._get_actual_working_times(data)
self.assertEqual(event_times[0], 1000)
self.assertEqual(event_times[1], self._SHORT_GAP_INT + 5000)
def test_compressed_with_long_gap(self):
- data = self._deserialize([
- self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=1)),
- self._compressed_event(
- self._START_TIMESTAMP + timedelta(seconds=1) + self._LONG_GAP,
- timedelta(seconds=5)
- ),
- ])
+ data = self._deserialize(
+ [
+ self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=1)),
+ self._compressed_event(
+ self._START_TIMESTAMP + timedelta(seconds=1) + self._LONG_GAP,
+ timedelta(seconds=5),
+ ),
+ ]
+ )
event_times = self._get_actual_working_times(data)
self.assertEqual(event_times[0], 1000)
self.assertEqual(event_times[1], 5000)
def test_compressed_contained(self):
- data = self._deserialize([
- self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)),
- self._compressed_event(
- self._START_TIMESTAMP + timedelta(seconds=3),
- timedelta(seconds=1)
- ),
- ])
+ data = self._deserialize(
+ [
+ self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)),
+ self._compressed_event(
+ self._START_TIMESTAMP + timedelta(seconds=3), timedelta(seconds=1)
+ ),
+ ]
+ )
event_times = self._get_actual_working_times(data)
self.assertEqual(event_times[0], 5000)
self.assertEqual(event_times[1], 0)
def test_compressed_overlapping(self):
- data = self._deserialize([
- self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)),
- self._compressed_event(
- self._START_TIMESTAMP + timedelta(seconds=3),
- timedelta(seconds=6)
- ),
- ])
+ data = self._deserialize(
+ [
+ self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)),
+ self._compressed_event(
+ self._START_TIMESTAMP + timedelta(seconds=3), timedelta(seconds=6)
+ ),
+ ]
+ )
event_times = self._get_actual_working_times(data)
self.assertEqual(event_times[0], 5000)
self.assertEqual(event_times[1], 4000)
def test_instant_inside_compressed(self):
- data = self._deserialize([
- self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)),
- self._instant_event(self._START_TIMESTAMP + timedelta(seconds=3)),
- self._instant_event(self._START_TIMESTAMP + timedelta(seconds=6)),
- ])
+ data = self._deserialize(
+ [
+ self._compressed_event(self._START_TIMESTAMP, timedelta(seconds=5)),
+ self._instant_event(self._START_TIMESTAMP + timedelta(seconds=3)),
+ self._instant_event(self._START_TIMESTAMP + timedelta(seconds=6)),
+ ]
+ )
event_times = self._get_actual_working_times(data)
self.assertEqual(event_times[0], 5000)
self.assertEqual(event_times[1], 0)
self.assertEqual(event_times[2], 1000)
-
def test_previous_instant_short_gap(self):
data = self._deserialize(
[self._instant_event(self._START_TIMESTAMP + self._SHORT_GAP)],
diff --git a/cvat/apps/events/urls.py b/cvat/apps/events/urls.py
index 832c86ac396b..cdb0d2032e68 100644
--- a/cvat/apps/events/urls.py
+++ b/cvat/apps/events/urls.py
@@ -1,4 +1,3 @@
-
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
@@ -8,6 +7,6 @@
from . import views
router = routers.DefaultRouter(trailing_slash=False)
-router.register('events', views.EventsViewSet, basename='events')
+router.register("events", views.EventsViewSet, basename="events")
urlpatterns = router.urls
diff --git a/cvat/apps/events/utils.py b/cvat/apps/events/utils.py
index 745bb8fde316..31c7f83c1791 100644
--- a/cvat/apps/events/utils.py
+++ b/cvat/apps/events/utils.py
@@ -4,16 +4,15 @@
import datetime
-
-from .const import MAX_EVENT_DURATION, COMPRESSED_EVENT_SCOPES
from .cache import clear_cache
+from .const import COMPRESSED_EVENT_SCOPES, MAX_EVENT_DURATION
def _prepare_objects_to_delete(object_to_delete):
- from cvat.apps.engine.models import Project, Task, Segment, Job, Issue, Comment
+ from cvat.apps.engine.models import Comment, Issue, Job, Project, Segment, Task
relation_chain = (Project, Task, Segment, Job, Issue, Comment)
- related_field_names = ('task_set', 'segment_set', 'job_set', 'issues', 'comments')
+ related_field_names = ("task_set", "segment_set", "job_set", "issues", "comments")
field_names = tuple(m._meta.model_name for m in relation_chain)
# Find object Model
@@ -26,25 +25,21 @@ def _prepare_objects_to_delete(object_to_delete):
# Fill filter param
filter_params = {
- f'{object_to_delete.__class__._meta.model_name}_id': object_to_delete.id,
+ f"{object_to_delete.__class__._meta.model_name}_id": object_to_delete.id,
}
# Fill prefetch
prefetch = []
if index < len(relation_chain) - 1:
- forward_prefetch = '__'.join(related_field_names[index:])
+ forward_prefetch = "__".join(related_field_names[index:])
prefetch.append(forward_prefetch)
if index > 0:
- backward_prefetch = '__'.join(reversed(field_names[:index]))
+ backward_prefetch = "__".join(reversed(field_names[:index]))
prefetch.append(backward_prefetch)
# make queryset
- objects = relation_chain[index].objects.filter(
- **filter_params
- ).prefetch_related(
- *prefetch
- )
+ objects = relation_chain[index].objects.filter(**filter_params).prefetch_related(*prefetch)
# list of objects which will be deleted with current object
objects_to_delete = list(objects)
@@ -56,9 +51,11 @@ def _prepare_objects_to_delete(object_to_delete):
return objects_to_delete
+
def cache_deleted(method):
def wrap(self, *args, **kwargs):
from .signals import resource_delete
+
objects = _prepare_objects_to_delete(self)
try:
for obj in objects:
@@ -67,6 +64,7 @@ def wrap(self, *args, **kwargs):
method(self, *args, **kwargs)
finally:
clear_cache()
+
return wrap
@@ -75,8 +73,10 @@ def get_end_timestamp(event: dict) -> datetime.datetime:
return event["timestamp"] + datetime.timedelta(milliseconds=event["duration"])
return event["timestamp"]
+
def is_contained(event1: dict, event2: dict) -> bool:
- return event1['timestamp'] < get_end_timestamp(event2)
+ return event1["timestamp"] < get_end_timestamp(event2)
+
def compute_working_time_per_ids(data: dict) -> dict:
def read_ids(event: dict) -> tuple[int | None, int | None, int | None]:
diff --git a/cvat/apps/events/views.py b/cvat/apps/events/views.py
index 31914a829c3b..e910dabdc3be 100644
--- a/cvat/apps/events/views.py
+++ b/cvat/apps/events/views.py
@@ -4,8 +4,7 @@
from django.conf import settings
from drf_spectacular.types import OpenApiTypes
-from drf_spectacular.utils import (OpenApiParameter, OpenApiResponse,
- extend_schema)
+from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from rest_framework import status, viewsets
from rest_framework.renderers import JSONRenderer
from rest_framework.response import Response
@@ -22,59 +21,114 @@
class EventsViewSet(viewsets.ViewSet):
serializer_class = None
- @extend_schema(summary='Log client events',
- methods=['POST'],
- description='Sends logs to the Clickhouse if it is connected',
+ @extend_schema(
+ summary="Log client events",
+ methods=["POST"],
+ description="Sends logs to the Clickhouse if it is connected",
parameters=ORGANIZATION_OPEN_API_PARAMETERS,
request=ClientEventsSerializer(),
responses={
- '201': ClientEventsSerializer(),
- })
+ "201": ClientEventsSerializer(),
+ },
+ )
def create(self, request):
serializer = ClientEventsSerializer(data=request.data, context={"request": request})
serializer.is_valid(raise_exception=True)
handle_client_events_push(request, serializer.validated_data)
for event in serializer.validated_data["events"]:
- message = JSONRenderer().render({
- **event,
- 'timestamp': str(event["timestamp"].timestamp())
- }).decode('UTF-8')
+ message = (
+ JSONRenderer()
+ .render({**event, "timestamp": str(event["timestamp"].timestamp())})
+ .decode("UTF-8")
+ )
vlogger.info(message)
return Response(serializer.validated_data, status=status.HTTP_201_CREATED)
- @extend_schema(summary='Get an event log',
- methods=['GET'],
- description='The log is returned in the CSV format.',
+ @extend_schema(
+ summary="Get an event log",
+ methods=["GET"],
+ description="The log is returned in the CSV format.",
parameters=[
- OpenApiParameter('org_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False,
- description="Filter events by organization ID"),
- OpenApiParameter('project_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False,
- description="Filter events by project ID"),
- OpenApiParameter('task_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False,
- description="Filter events by task ID"),
- OpenApiParameter('job_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False,
- description="Filter events by job ID"),
- OpenApiParameter('user_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, required=False,
- description="Filter events by user ID"),
- OpenApiParameter('from', location=OpenApiParameter.QUERY, type=OpenApiTypes.DATETIME, required=False,
- description="Filter events after the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set."),
- OpenApiParameter('to', location=OpenApiParameter.QUERY, type=OpenApiTypes.DATETIME, required=False,
- description="Filter events before the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set."),
- OpenApiParameter('filename', description='Desired output file name',
- location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False),
- OpenApiParameter('action', location=OpenApiParameter.QUERY,
- description='Used to start downloading process after annotation file had been created',
- type=OpenApiTypes.STR, required=False, enum=['download']),
- OpenApiParameter('query_id', location=OpenApiParameter.QUERY, type=OpenApiTypes.STR, required=False,
- description="ID of query request that need to check or download"),
+ OpenApiParameter(
+ "org_id",
+ location=OpenApiParameter.QUERY,
+ type=OpenApiTypes.INT,
+ required=False,
+ description="Filter events by organization ID",
+ ),
+ OpenApiParameter(
+ "project_id",
+ location=OpenApiParameter.QUERY,
+ type=OpenApiTypes.INT,
+ required=False,
+ description="Filter events by project ID",
+ ),
+ OpenApiParameter(
+ "task_id",
+ location=OpenApiParameter.QUERY,
+ type=OpenApiTypes.INT,
+ required=False,
+ description="Filter events by task ID",
+ ),
+ OpenApiParameter(
+ "job_id",
+ location=OpenApiParameter.QUERY,
+ type=OpenApiTypes.INT,
+ required=False,
+ description="Filter events by job ID",
+ ),
+ OpenApiParameter(
+ "user_id",
+ location=OpenApiParameter.QUERY,
+ type=OpenApiTypes.INT,
+ required=False,
+ description="Filter events by user ID",
+ ),
+ OpenApiParameter(
+ "from",
+ location=OpenApiParameter.QUERY,
+ type=OpenApiTypes.DATETIME,
+ required=False,
+ description="Filter events after the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set.",
+ ),
+ OpenApiParameter(
+ "to",
+ location=OpenApiParameter.QUERY,
+ type=OpenApiTypes.DATETIME,
+ required=False,
+ description="Filter events before the datetime. If no 'from' or 'to' parameters are passed, the last 30 days will be set.",
+ ),
+ OpenApiParameter(
+ "filename",
+ description="Desired output file name",
+ location=OpenApiParameter.QUERY,
+ type=OpenApiTypes.STR,
+ required=False,
+ ),
+ OpenApiParameter(
+ "action",
+ location=OpenApiParameter.QUERY,
+ description="Used to start downloading process after annotation file had been created",
+ type=OpenApiTypes.STR,
+ required=False,
+ enum=["download"],
+ ),
+ OpenApiParameter(
+ "query_id",
+ location=OpenApiParameter.QUERY,
+ type=OpenApiTypes.STR,
+ required=False,
+ description="ID of query request that need to check or download",
+ ),
],
responses={
- '200': OpenApiResponse(description='Download of file started'),
- '201': OpenApiResponse(description='CSV log file is ready for downloading'),
- '202': OpenApiResponse(description='Creating a CSV log file has been started'),
- })
+ "200": OpenApiResponse(description="Download of file started"),
+ "201": OpenApiResponse(description="CSV log file is ready for downloading"),
+ "202": OpenApiResponse(description="Creating a CSV log file has been started"),
+ },
+ )
def list(self, request):
perm = EventsPermission.create_scope_list(request)
filter_query = perm.filter(request.query_params)
diff --git a/cvat/apps/health/apps.py b/cvat/apps/health/apps.py
index a457048b87c9..ae38010ff7b2 100644
--- a/cvat/apps/health/apps.py
+++ b/cvat/apps/health/apps.py
@@ -3,12 +3,13 @@
# SPDX-License-Identifier: MIT
from django.apps import AppConfig
-
from health_check.plugins import plugin_dir
+
class HealthConfig(AppConfig):
- name = 'cvat.apps.health'
+ name = "cvat.apps.health"
def ready(self):
from .backends import OPAHealthCheck
+
plugin_dir.register(OPAHealthCheck)
diff --git a/cvat/apps/health/backends.py b/cvat/apps/health/backends.py
index 2f361117173a..0ba37cb23195 100644
--- a/cvat/apps/health/backends.py
+++ b/cvat/apps/health/backends.py
@@ -3,19 +3,18 @@
# SPDX-License-Identifier: MIT
import requests
-
+from django.conf import settings
from health_check.backends import BaseHealthCheckBackend
from health_check.exceptions import HealthCheckException
-from django.conf import settings
-
from cvat.utils.http import make_requests_session
+
class OPAHealthCheck(BaseHealthCheckBackend):
critical_service = True
def check_status(self):
- opa_health_url = f'{settings.IAM_OPA_HOST}/health?bundles'
+ opa_health_url = f"{settings.IAM_OPA_HOST}/health?bundles"
try:
with make_requests_session() as session:
response = session.get(opa_health_url)
diff --git a/cvat/apps/health/management/commands/workerprobe.py b/cvat/apps/health/management/commands/workerprobe.py
index fc8b6cf7077a..af9d663a1a29 100644
--- a/cvat/apps/health/management/commands/workerprobe.py
+++ b/cvat/apps/health/management/commands/workerprobe.py
@@ -1,10 +1,11 @@
import os
import platform
from datetime import datetime, timedelta
-from django.core.management.base import BaseCommand, CommandError
+
+import django_rq
from django.conf import settings
+from django.core.management.base import BaseCommand, CommandError
from rq.worker import Worker
-import django_rq
class Command(BaseCommand):
@@ -20,13 +21,21 @@ def handle(self, *args, **options):
raise CommandError(f"Queue {queue_name} is not defined")
connection = django_rq.get_connection(queue_name)
- workers = [w for w in Worker.all(connection) if queue_name in w.queue_names() and w.hostname == hostname]
+ workers = [
+ w
+ for w in Worker.all(connection)
+ if queue_name in w.queue_names() and w.hostname == hostname
+ ]
expected_workers = int(os.getenv("NUMPROCS", 1))
if len(workers) != expected_workers:
- raise CommandError("Number of registered workers does not match the expected number, " \
- f"actual: {len(workers)}, expected: {expected_workers}")
+ raise CommandError(
+ "Number of registered workers does not match the expected number, "
+ f"actual: {len(workers)}, expected: {expected_workers}"
+ )
for worker in workers:
if datetime.now() - worker.last_heartbeat > timedelta(seconds=worker.worker_ttl):
- raise CommandError(f"It seems that worker {worker.name}, pid: {worker.pid} is dead")
+ raise CommandError(
+ f"It seems that worker {worker.name}, pid: {worker.pid} is dead"
+ )
diff --git a/cvat/apps/iam/adapters.py b/cvat/apps/iam/adapters.py
index 703bec48743f..50ff2812c3a5 100644
--- a/cvat/apps/iam/adapters.py
+++ b/cvat/apps/iam/adapters.py
@@ -2,10 +2,10 @@
#
# SPDX-License-Identifier: MIT
-from django.http import HttpResponseRedirect
+from allauth.account.adapter import DefaultAccountAdapter
from django.conf import settings
+from django.http import HttpResponseRedirect
-from allauth.account.adapter import DefaultAccountAdapter
class DefaultAccountAdapterEx(DefaultAccountAdapter):
def respond_email_verification_sent(self, request, user):
diff --git a/cvat/apps/iam/admin.py b/cvat/apps/iam/admin.py
index 648e15dc2da4..bf6efafe9a34 100644
--- a/cvat/apps/iam/admin.py
+++ b/cvat/apps/iam/admin.py
@@ -4,8 +4,8 @@
# SPDX-License-Identifier: MIT
from django.contrib import admin
-from django.contrib.auth.models import Group, User
from django.contrib.auth.admin import GroupAdmin, UserAdmin
+from django.contrib.auth.models import Group, User
from django.utils.translation import gettext_lazy as _
from cvat.apps.engine.models import Profile
@@ -14,20 +14,27 @@
class ProfileInline(admin.StackedInline):
model = Profile
- fieldsets = (
- (None, {'fields': ('has_analytics_access', )}),
- )
+ fieldsets = ((None, {"fields": ("has_analytics_access",)}),)
class CustomUserAdmin(UserAdmin):
inlines = (ProfileInline,)
list_display = ("username", "email", "first_name", "last_name", "is_active", "is_staff")
fieldsets = (
- (None, {'fields': ('username', 'password')}),
- (_('Personal info'), {'fields': ('first_name', 'last_name', 'email')}),
- (_('Permissions'), {'fields': ('is_active', 'is_staff', 'is_superuser',
- 'groups',)}),
- (_('Important dates'), {'fields': ('last_login', 'date_joined')}),
+ (None, {"fields": ("username", "password")}),
+ (_("Personal info"), {"fields": ("first_name", "last_name", "email")}),
+ (
+ _("Permissions"),
+ {
+ "fields": (
+ "is_active",
+ "is_staff",
+ "is_superuser",
+ "groups",
+ )
+ },
+ ),
+ (_("Important dates"), {"fields": ("last_login", "date_joined")}),
)
add_fieldsets = (
(
@@ -40,21 +47,17 @@ class CustomUserAdmin(UserAdmin):
)
actions = ["user_activate", "user_deactivate"]
- @admin.action(
- permissions=["change"], description=_("Mark selected users as active")
- )
+ @admin.action(permissions=["change"], description=_("Mark selected users as active"))
def user_activate(self, request, queryset):
queryset.update(is_active=True)
- @admin.action(
- permissions=["change"], description=_("Mark selected users as not active")
- )
+ @admin.action(permissions=["change"], description=_("Mark selected users as not active"))
def user_deactivate(self, request, queryset):
queryset.update(is_active=False)
class CustomGroupAdmin(GroupAdmin):
- fieldsets = ((None, {'fields': ('name',)}),)
+ fieldsets = ((None, {"fields": ("name",)}),)
admin.site.unregister(User)
diff --git a/cvat/apps/iam/apps.py b/cvat/apps/iam/apps.py
index 97bdc3ca05fd..00f051a75c08 100644
--- a/cvat/apps/iam/apps.py
+++ b/cvat/apps/iam/apps.py
@@ -5,9 +5,11 @@
from django.apps import AppConfig
+
class IAMConfig(AppConfig):
- name = 'cvat.apps.iam'
+ name = "cvat.apps.iam"
def ready(self):
from .signals import register_signals
+
register_signals(self)
diff --git a/cvat/apps/iam/authentication.py b/cvat/apps/iam/authentication.py
index 412806380389..74ec6f5424b7 100644
--- a/cvat/apps/iam/authentication.py
+++ b/cvat/apps/iam/authentication.py
@@ -2,21 +2,23 @@
#
# SPDX-License-Identifier: MIT
+import hashlib
+
+from django.contrib.auth import get_user_model
from django.core import signing
+from furl import furl
from rest_framework import exceptions
from rest_framework.authentication import BaseAuthentication
-from django.contrib.auth import get_user_model
-from furl import furl
-import hashlib
+
# Got implementation ideas in https://github.com/marcgibbons/drf_signed_auth
class Signer:
- QUERY_PARAM = 'sign'
+ QUERY_PARAM = "sign"
MAX_AGE = 30
@classmethod
def get_salt(cls, url):
- normalized_url = furl(url).remove(cls.QUERY_PARAM).url.encode('utf-8')
+ normalized_url = furl(url).remove(cls.QUERY_PARAM).url.encode("utf-8")
salt = hashlib.sha256(normalized_url).hexdigest()
return salt
@@ -24,10 +26,7 @@ def sign(self, user, url):
"""
Create a signature for a user object.
"""
- data = {
- 'user_id': user.pk,
- 'username': user.get_username()
- }
+ data = {"user_id": user.pk, "username": user.get_username()}
return signing.dumps(data, salt=self.get_salt(url))
@@ -36,24 +35,24 @@ def unsign(self, signature, url):
Return a user object for a valid signature.
"""
User = get_user_model()
- data = signing.loads(signature, salt=self.get_salt(url),
- max_age=self.MAX_AGE)
+ data = signing.loads(signature, salt=self.get_salt(url), max_age=self.MAX_AGE)
if not isinstance(data, dict):
raise signing.BadSignature()
try:
- return User.objects.get(**{
- 'pk': data.get('user_id'),
- User.USERNAME_FIELD: data.get('username')
- })
+ return User.objects.get(
+ **{"pk": data.get("user_id"), User.USERNAME_FIELD: data.get("username")}
+ )
except User.DoesNotExist:
raise signing.BadSignature()
+
class SignatureAuthentication(BaseAuthentication):
"""
Authentication backend for signed URLs.
"""
+
def authenticate(self, request):
"""
Returns authenticated user if URL signature is valid.
@@ -66,10 +65,10 @@ def authenticate(self, request):
try:
user = signer.unsign(sign, request.build_absolute_uri())
except signing.SignatureExpired:
- raise exceptions.AuthenticationFailed('This URL has expired.')
+ raise exceptions.AuthenticationFailed("This URL has expired.")
except signing.BadSignature:
- raise exceptions.AuthenticationFailed('Invalid signature.')
+ raise exceptions.AuthenticationFailed("Invalid signature.")
if not user.is_active:
- raise exceptions.AuthenticationFailed('User inactive or deleted.')
+ raise exceptions.AuthenticationFailed("User inactive or deleted.")
return (user, None)
diff --git a/cvat/apps/iam/filters.py b/cvat/apps/iam/filters.py
index 6fd62d8d05e5..c99da171bac7 100644
--- a/cvat/apps/iam/filters.py
+++ b/cvat/apps/iam/filters.py
@@ -2,29 +2,29 @@
#
# SPDX-License-Identifier: MIT
-from rest_framework.filters import BaseFilterBackend
-from django.db.models import Q
from collections.abc import Iterable
+from django.db.models import Q
from drf_spectacular.utils import OpenApiParameter
+from rest_framework.filters import BaseFilterBackend
ORGANIZATION_OPEN_API_PARAMETERS = [
OpenApiParameter(
- name='org',
+ name="org",
type=str,
required=False,
location=OpenApiParameter.QUERY,
description="Organization unique slug",
),
OpenApiParameter(
- name='org_id',
+ name="org_id",
type=int,
required=False,
location=OpenApiParameter.QUERY,
description="Organization identifier",
),
OpenApiParameter(
- name='X-Organization',
+ name="X-Organization",
type=str,
required=False,
location=OpenApiParameter.HEADER,
@@ -32,13 +32,14 @@
),
]
+
class OrganizationFilterBackend(BaseFilterBackend):
def _parameter_is_provided(self, request):
for parameter in ORGANIZATION_OPEN_API_PARAMETERS:
- if parameter.location == 'header' and parameter.name in request.headers:
+ if parameter.location == "header" and parameter.name in request.headers:
return True
- elif parameter.location == 'query' and parameter.name in request.query_params:
+ elif parameter.location == "query" and parameter.name in request.query_params:
return True
return False
@@ -62,34 +63,35 @@ def _construct_filter_query(self, organization_fields, org_id):
return Q()
-
def filter_queryset(self, request, queryset, view):
# Filter works only for "list" requests and allows to return
# only non-organization objects if org isn't specified
if (
- view.detail or not view.iam_organization_field or
+ view.detail
+ or not view.iam_organization_field
+ or
# FIXME: It should be handled in another way. For example, if we try to get information for a specific job
# and org isn't specified, we need to return the full list of labels, issues, comments.
# Allow crowdsourcing users to get labels/issues/comments related to specific job.
# Crowdsourcing user always has worker group and isn't a member of an organization.
(
- view.__class__.__name__ in ('LabelViewSet', 'IssueViewSet', 'CommentViewSet') and
- request.query_params.get('job_id') and
- request.iam_context.get('organization') is None and
- request.iam_context.get('privilege') == 'worker'
+ view.__class__.__name__ in ("LabelViewSet", "IssueViewSet", "CommentViewSet")
+ and request.query_params.get("job_id")
+ and request.iam_context.get("organization") is None
+ and request.iam_context.get("privilege") == "worker"
)
):
return queryset
visibility = None
- org = request.iam_context['organization']
+ org = request.iam_context["organization"]
if org:
- visibility = {'organization': org.id}
+ visibility = {"organization": org.id}
elif not org and self._parameter_is_provided(request):
- visibility = {'organization': None}
+ visibility = {"organization": None}
if visibility:
org_id = visibility.pop("organization")
@@ -108,15 +110,17 @@ def get_schema_operation_parameters(self, view):
parameter_type = None
if parameter.type == int:
- parameter_type = 'integer'
+ parameter_type = "integer"
elif parameter.type == str:
- parameter_type = 'string'
-
- parameters.append({
- 'name': parameter.name,
- 'in': parameter.location,
- 'description': parameter.description,
- 'schema': {'type': parameter_type}
- })
+ parameter_type = "string"
+
+ parameters.append(
+ {
+ "name": parameter.name,
+ "in": parameter.location,
+ "description": parameter.description,
+ "schema": {"type": parameter_type},
+ }
+ )
return parameters
diff --git a/cvat/apps/iam/forms.py b/cvat/apps/iam/forms.py
index c1668b924387..af619a563f38 100644
--- a/cvat/apps/iam/forms.py
+++ b/cvat/apps/iam/forms.py
@@ -2,22 +2,27 @@
#
# SPDX-License-Identifier: MIT
-from django.contrib.sites.shortcuts import get_current_site
-from django.contrib.auth import get_user_model
-
+from allauth.account.adapter import get_adapter
from allauth.account.forms import default_token_generator
from allauth.account.utils import user_pk_to_url_str
-from allauth.account.adapter import get_adapter
from dj_rest_auth.forms import AllAuthPasswordResetForm
+from django.contrib.auth import get_user_model
+from django.contrib.sites.shortcuts import get_current_site
UserModel = get_user_model()
-class ResetPasswordFormEx(AllAuthPasswordResetForm):
- def save(self, request=None, domain_override=None,
- email_template_prefix='authentication/password_reset_key',
- use_https=False, token_generator=default_token_generator,
- extra_email_context=None, **kwargs):
+class ResetPasswordFormEx(AllAuthPasswordResetForm):
+ def save(
+ self,
+ request=None,
+ domain_override=None,
+ email_template_prefix="authentication/password_reset_key",
+ use_https=False,
+ token_generator=default_token_generator,
+ extra_email_context=None,
+ **kwargs,
+ ):
"""
Generate a one-use only link for resetting password and send it to the
user.
@@ -33,16 +38,16 @@ def save(self, request=None, domain_override=None,
for user in self.users:
user_email = getattr(user, email_field_name)
context = {
- 'email': user_email,
- 'domain': domain,
- 'site_name': site_name,
- 'uid': user_pk_to_url_str(user),
- 'user': user,
- 'token': token_generator.make_token(user),
- 'protocol': 'https' if use_https else 'http',
+ "email": user_email,
+ "domain": domain,
+ "site_name": site_name,
+ "uid": user_pk_to_url_str(user),
+ "user": user,
+ "token": token_generator.make_token(user),
+ "protocol": "https" if use_https else "http",
**(extra_email_context or {}),
}
get_adapter(request).send_mail(email_template_prefix, email, context)
- return self.cleaned_data['email']
+ return self.cleaned_data["email"]
diff --git a/cvat/apps/iam/middleware.py b/cvat/apps/iam/middleware.py
index f2f1a4bae2e0..c09c5eeb96b6 100644
--- a/cvat/apps/iam/middleware.py
+++ b/cvat/apps/iam/middleware.py
@@ -5,10 +5,10 @@
from datetime import timedelta
from typing import Callable
-from django.utils.functional import SimpleLazyObject
-from rest_framework.exceptions import ValidationError, NotFound
from django.conf import settings
from django.http import HttpRequest, HttpResponse
+from django.utils.functional import SimpleLazyObject
+from rest_framework.exceptions import NotFound, ValidationError
def get_organization(request):
@@ -22,31 +22,32 @@ def get_organization(request):
organization = None
try:
- org_slug = request.GET.get('org')
- org_id = request.GET.get('org_id')
- org_header = request.headers.get('X-Organization')
+ org_slug = request.GET.get("org")
+ org_id = request.GET.get("org_id")
+ org_header = request.headers.get("X-Organization")
if org_id is not None and (org_slug is not None or org_header is not None):
- raise ValidationError('You cannot specify "org_id" query parameter with '
- '"org" query parameter or "X-Organization" HTTP header at the same time.')
+ raise ValidationError(
+ 'You cannot specify "org_id" query parameter with '
+ '"org" query parameter or "X-Organization" HTTP header at the same time.'
+ )
if org_slug is not None and org_header is not None and org_slug != org_header:
- raise ValidationError('You cannot specify "org" query parameter and '
- '"X-Organization" HTTP header with different values.')
+ raise ValidationError(
+ 'You cannot specify "org" query parameter and '
+ '"X-Organization" HTTP header with different values.'
+ )
org_slug = org_slug if org_slug is not None else org_header
if org_slug:
- organization = Organization.objects.select_related('owner').get(slug=org_slug)
+ organization = Organization.objects.select_related("owner").get(slug=org_slug)
elif org_id:
- organization = Organization.objects.select_related('owner').get(id=int(org_id))
+ organization = Organization.objects.select_related("owner").get(id=int(org_id))
except Organization.DoesNotExist:
- raise NotFound(f'{org_slug or org_id} organization does not exist.')
+ raise NotFound(f"{org_slug or org_id} organization does not exist.")
- context = {
- "organization": organization,
- "privilege": getattr(privilege, 'name', None)
- }
+ context = {"organization": organization, "privilege": getattr(privilege, "name", None)}
return context
@@ -62,6 +63,7 @@ def __call__(self, request):
return self.get_response(request)
+
class SessionRefreshMiddleware:
"""
Implements behavior similar to SESSION_SAVE_EVERY_REQUEST=True, but instead of
diff --git a/cvat/apps/iam/migrations/0001_remove_business_group.py b/cvat/apps/iam/migrations/0001_remove_business_group.py
index 2bf1a56b4065..aa64d4a56d6d 100644
--- a/cvat/apps/iam/migrations/0001_remove_business_group.py
+++ b/cvat/apps/iam/migrations/0001_remove_business_group.py
@@ -2,13 +2,12 @@
from django.conf import settings
from django.db import migrations
-
BUSINESS_GROUP_NAME = "business"
USER_GROUP_NAME = "user"
def delete_business_group(apps, schema_editor):
- Group = apps.get_model('auth', 'Group')
+ Group = apps.get_model("auth", "Group")
User = apps.get_model(settings.AUTH_USER_MODEL)
if user_group := Group.objects.filter(name=USER_GROUP_NAME).first():
diff --git a/cvat/apps/iam/models.py b/cvat/apps/iam/models.py
index b1220197cf2a..f7c3408e3d12 100644
--- a/cvat/apps/iam/models.py
+++ b/cvat/apps/iam/models.py
@@ -1,4 +1,3 @@
# Copyright (C) 2021-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
-
diff --git a/cvat/apps/iam/permissions.py b/cvat/apps/iam/permissions.py
index d4925426724a..f13d6be377ce 100644
--- a/cvat/apps/iam/permissions.py
+++ b/cvat/apps/iam/permissions.py
@@ -44,21 +44,21 @@ def get_organization(request, obj):
if obj:
try:
- organization_id = getattr(obj, 'organization_id')
+ organization_id = getattr(obj, "organization_id")
except AttributeError as exc:
# Skip initialization of organization for those objects that don't related with organization
- view = request.parser_context.get('view')
+ view = request.parser_context.get("view")
if view and view.basename in settings.OBJECTS_NOT_RELATED_WITH_ORG:
- return request.iam_context['organization']
+ return request.iam_context["organization"]
raise exc
try:
- return Organization.objects.select_related('owner').get(id=organization_id)
+ return Organization.objects.select_related("owner").get(id=organization_id)
except Organization.DoesNotExist:
return None
- return request.iam_context['organization']
+ return request.iam_context["organization"]
def get_membership(request, organization):
@@ -66,21 +66,20 @@ def get_membership(request, organization):
return None
return Membership.objects.filter(
- organization=organization,
- user=request.user,
- is_active=True
+ organization=organization, user=request.user, is_active=True
).first()
-def build_iam_context(request, organization: Optional[Organization], membership: Optional[Membership]):
+def build_iam_context(
+ request, organization: Optional[Organization], membership: Optional[Membership]
+):
return {
- 'user_id': request.user.id,
- 'group_name': request.iam_context['privilege'],
- 'org_id': getattr(organization, 'id', None),
- 'org_slug': getattr(organization, 'slug', None),
- 'org_owner_id': getattr(organization.owner, 'id', None)
- if organization else None,
- 'org_role': getattr(membership, 'role', None),
+ "user_id": request.user.id,
+ "group_name": request.iam_context["privilege"],
+ "org_id": getattr(organization, "id", None),
+ "org_slug": getattr(organization, "slug", None),
+ "org_owner_id": getattr(organization.owner, "id", None) if organization else None,
+ "org_role": getattr(membership, "role", None),
}
@@ -103,23 +102,19 @@ class OpenPolicyAgentPermission(metaclass=ABCMeta):
@classmethod
@abstractmethod
- def create(cls, request, view, obj, iam_context) -> Sequence[OpenPolicyAgentPermission]:
- ...
+ def create(cls, request, view, obj, iam_context) -> Sequence[OpenPolicyAgentPermission]: ...
@classmethod
def create_base_perm(cls, request, view, scope, iam_context, obj=None, **kwargs):
if not iam_context and request:
iam_context = get_iam_context(request, obj)
- return cls(
- scope=scope,
- obj=obj,
- **iam_context, **kwargs)
+ return cls(scope=scope, obj=obj, **iam_context, **kwargs)
@classmethod
def create_scope_list(cls, request, iam_context=None):
if not iam_context and request:
iam_context = get_iam_context(request, None)
- return cls(**iam_context, scope='list')
+ return cls(**iam_context, scope="list")
def __init__(self, **kwargs):
self.obj = None
@@ -127,27 +122,31 @@ def __init__(self, **kwargs):
setattr(self, name, val)
self.payload = {
- 'input': {
- 'scope': self.scope,
- 'auth': {
- 'user': {
- 'id': self.user_id,
- 'privilege': self.group_name,
+ "input": {
+ "scope": self.scope,
+ "auth": {
+ "user": {
+ "id": self.user_id,
+ "privilege": self.group_name,
},
- 'organization': {
- 'id': self.org_id,
- 'owner': {
- 'id': self.org_owner_id,
- },
- 'user': {
- 'role': self.org_role,
- },
- } if self.org_id is not None else None
- }
+ "organization": (
+ {
+ "id": self.org_id,
+ "owner": {
+ "id": self.org_owner_id,
+ },
+ "user": {
+ "role": self.org_role,
+ },
+ }
+ if self.org_id is not None
+ else None
+ ),
+ },
}
}
- self.payload['input']['resource'] = self.get_resource()
+ self.payload["input"]["resource"] = self.get_resource()
@abstractmethod
def get_resource(self):
@@ -156,13 +155,13 @@ def get_resource(self):
def check_access(self) -> PermissionResult:
with make_requests_session() as session:
response = session.post(self.url, json=self.payload)
- output = response.json()['result']
+ output = response.json()["result"]
allow = False
reasons = []
if isinstance(output, dict):
- allow = output['allow']
- reasons = output.get('reasons', [])
+ allow = output["allow"]
+ reasons = output.get("reasons", [])
elif isinstance(output, bool):
allow = output
else:
@@ -171,21 +170,21 @@ def check_access(self) -> PermissionResult:
return PermissionResult(allow=allow, reasons=reasons)
def filter(self, queryset):
- url = self.url.replace('/allow', '/filter')
+ url = self.url.replace("/allow", "/filter")
with make_requests_session() as session:
- r = session.post(url, json=self.payload).json()['result']
+ r = session.post(url, json=self.payload).json()["result"]
q_objects = []
ops_dict = {
- '|': operator.or_,
- '&': operator.and_,
- '~': operator.not_,
+ "|": operator.or_,
+ "&": operator.and_,
+ "~": operator.not_,
}
for item in r:
if isinstance(item, str):
val1 = q_objects.pop()
- if item == '~':
+ if item == "~":
q_objects.append(ops_dict[item](val1))
else:
val2 = q_objects.pop()
@@ -211,7 +210,7 @@ def get_per_field_update_scopes(cls, request, scopes_per_field):
request body fields are associated with different scopes.
"""
- assert request.method == 'PATCH'
+ assert request.method == "PATCH"
# Even if no fields are modified, a PATCH request typically returns the
# new state of the object, so we need to make sure the user has permissions
@@ -226,7 +225,7 @@ def get_per_field_update_scopes(cls, request, scopes_per_field):
return scopes
-T = TypeVar('T', bound=Model)
+T = TypeVar("T", bound=Model)
def is_public_obj(obj: T) -> bool:
@@ -257,22 +256,23 @@ def has_permission(self, request, view):
if not view.detail:
return self.check_permission(request, view, None)
else:
- return True # has_object_permission will be called later
+ return True # has_object_permission will be called later
def has_object_permission(self, request, view, obj):
return self.check_permission(request, view, obj)
@staticmethod
def is_metadata_request(request, view):
- return request.method == 'OPTIONS' \
- or (request.method == 'POST' and view.action == 'metadata' and len(request.data) == 0)
+ return request.method == "OPTIONS" or (
+ request.method == "POST" and view.action == "metadata" and len(request.data) == 0
+ )
class IsAuthenticatedOrReadPublicResource(BasePermission):
def has_object_permission(self, request, view, obj) -> bool:
return bool(
- (request.user and request.user.is_authenticated) or
- (request.method == 'GET' and is_public_obj(obj))
+ (request.user and request.user.is_authenticated)
+ or (request.method == "GET" and is_public_obj(obj))
)
diff --git a/cvat/apps/iam/rules/tests/generate_tests.py b/cvat/apps/iam/rules/tests/generate_tests.py
index 729de6732eb2..92b4a0e699a9 100755
--- a/cvat/apps/iam/rules/tests/generate_tests.py
+++ b/cvat/apps/iam/rules/tests/generate_tests.py
@@ -10,11 +10,12 @@
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from functools import partial
-from typing import Optional
from pathlib import Path
+from typing import Optional
REPO_ROOT = Path(__file__).resolve().parents[5]
+
def create_arg_parser() -> ArgumentParser:
parser = ArgumentParser(add_help=True)
parser.add_argument(
@@ -36,7 +37,7 @@ def parse_args(args: Optional[Sequence[str]] = None) -> Namespace:
def call_generator(generator_path: Path, gen_params: Namespace) -> None:
rules_dir = generator_path.parents[2]
subprocess.check_call(
- [sys.executable, generator_path.relative_to(rules_dir), 'tests/configs'], cwd=rules_dir
+ [sys.executable, generator_path.relative_to(rules_dir), "tests/configs"], cwd=rules_dir
)
@@ -53,7 +54,7 @@ def main(args: Optional[Sequence[str]] = None) -> int:
partial(call_generator, gen_params=args),
generator_paths,
):
- pass # consume all results in order to propagate exceptions
+ pass # consume all results in order to propagate exceptions
if __name__ == "__main__":
diff --git a/cvat/apps/iam/schema.py b/cvat/apps/iam/schema.py
index 46f9e31052c1..7f54c5597a95 100644
--- a/cvat/apps/iam/schema.py
+++ b/cvat/apps/iam/schema.py
@@ -9,7 +9,6 @@
from drf_spectacular.authentication import SessionScheme, TokenScheme
from drf_spectacular.extensions import OpenApiAuthenticationExtension
from drf_spectacular.openapi import AutoSchema
-
from rest_framework import serializers
@@ -18,35 +17,37 @@ class SignatureAuthenticationScheme(OpenApiAuthenticationExtension):
Adds the signature auth method to schema
"""
- target_class = 'cvat.apps.iam.authentication.SignatureAuthentication'
- name = 'signatureAuth' # name used in the schema
+ target_class = "cvat.apps.iam.authentication.SignatureAuthentication"
+ name = "signatureAuth" # name used in the schema
def get_security_definition(self, auto_schema):
return {
- 'type': 'apiKey',
- 'in': 'query',
- 'name': 'sign',
- 'description': 'Can be used to share URLs to private links',
+ "type": "apiKey",
+ "in": "query",
+ "name": "sign",
+ "description": "Can be used to share URLs to private links",
}
+
class TokenAuthenticationScheme(TokenScheme):
"""
Adds the token auth method to schema. The description includes extra info
comparing to what is generated by default.
"""
- name = 'tokenAuth'
+ name = "tokenAuth"
priority = 0
match_subclasses = True
def get_security_requirement(self, auto_schema):
# These schemes must be used together
- return {'sessionAuth': [], 'csrfAuth': [], self.name: []}
+ return {"sessionAuth": [], "csrfAuth": [], self.name: []}
def get_security_definition(self, auto_schema):
schema = super().get_security_definition(auto_schema)
- schema['x-token-prefix'] = self.target.keyword
- schema['description'] = textwrap.dedent(f"""
+ schema["x-token-prefix"] = self.target.keyword
+ schema["description"] = textwrap.dedent(
+ f"""
To authenticate using a token (or API key), you need to have 3 components in a request:
- the 'sessionid' cookie
- the 'csrftoken' cookie or 'X-CSRFTOKEN' header
@@ -54,16 +55,18 @@ def get_security_definition(self, auto_schema):
You can obtain an API key (the token) from the server response on
the basic auth request.
- """)
+ """
+ )
return schema
+
class CookieAuthenticationScheme(SessionScheme):
"""
This class adds csrftoken cookie into security sections. It must be used together with
the 'sessionid' cookie.
"""
- name = ['sessionAuth', 'csrfAuth']
+ name = ["sessionAuth", "csrfAuth"]
priority = 0
def get_security_requirement(self, auto_schema):
@@ -73,13 +76,14 @@ def get_security_requirement(self, auto_schema):
def get_security_definition(self, auto_schema):
sessionid_schema = super().get_security_definition(auto_schema)
csrftoken_schema = {
- 'type': 'apiKey',
- 'in': 'cookie',
- 'name': 'csrftoken',
- 'description': 'Can be sent as a cookie or as the X-CSRFTOKEN header'
+ "type": "apiKey",
+ "in": "cookie",
+ "name": "csrftoken",
+ "description": "Can be sent as a cookie or as the X-CSRFTOKEN header",
}
return [sessionid_schema, csrftoken_schema]
+
class CustomAutoSchema(AutoSchema):
def get_operation_id(self):
# Change style of operation ids to [viewset _ action _ object]
@@ -87,20 +91,20 @@ def get_operation_id(self):
tokenized_path = self._tokenize_path()
# replace dashes as they can be problematic later in code generation
- tokenized_path = [t.replace('-', '_') for t in tokenized_path]
+ tokenized_path = [t.replace("-", "_") for t in tokenized_path]
- if self.method == 'GET' and self._is_list_view():
- action = 'list'
+ if self.method == "GET" and self._is_list_view():
+ action = "list"
else:
action = self.method_mapping[self.method.lower()]
if not tokenized_path:
- tokenized_path.append('root')
+ tokenized_path.append("root")
- if re.search(r'', self.path_regex):
- tokenized_path.append('formatted')
+ if re.search(r"", self.path_regex):
+ tokenized_path.append("formatted")
- return '_'.join([tokenized_path[0]] + [action] + tokenized_path[1:])
+ return "_".join([tokenized_path[0]] + [action] + tokenized_path[1:])
def _get_request_for_media_type(self, serializer, *args, **kwargs):
# Enables support for required=False serializers in request body specification
diff --git a/cvat/apps/iam/serializers.py b/cvat/apps/iam/serializers.py
index 967b696a4f21..7de9919e3ab3 100644
--- a/cvat/apps/iam/serializers.py
+++ b/cvat/apps/iam/serializers.py
@@ -3,23 +3,21 @@
#
# SPDX-License-Identifier: MIT
-from dj_rest_auth.registration.serializers import RegisterSerializer
-from dj_rest_auth.serializers import PasswordResetSerializer, LoginSerializer
-from django.core.exceptions import ValidationError as DjangoValidationError
-from rest_framework.exceptions import ValidationError
-from rest_framework import serializers
+from typing import Optional, Union
+
from allauth.account import app_settings as allauth_settings
-from allauth.account.utils import filter_users_by_email
from allauth.account.adapter import get_adapter
-from allauth.account.utils import setup_user_email
from allauth.account.models import EmailAddress
-
+from allauth.account.utils import filter_users_by_email, setup_user_email
+from dj_rest_auth.registration.serializers import RegisterSerializer
+from dj_rest_auth.serializers import LoginSerializer, PasswordResetSerializer
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import User
-
+from django.core.exceptions import ValidationError as DjangoValidationError
from drf_spectacular.utils import extend_schema_field
-from typing import Optional, Union
+from rest_framework import serializers
+from rest_framework.exceptions import ValidationError
from cvat.apps.iam.forms import ResetPasswordFormEx
from cvat.apps.iam.utils import get_dummy_user
@@ -33,22 +31,30 @@ class RegisterSerializerEx(RegisterSerializer):
@extend_schema_field(serializers.BooleanField)
def get_email_verification_required(self, obj: Union[dict, User]) -> bool:
- return allauth_settings.EMAIL_VERIFICATION == allauth_settings.EmailVerificationMethod.MANDATORY
+ return (
+ allauth_settings.EMAIL_VERIFICATION
+ == allauth_settings.EmailVerificationMethod.MANDATORY
+ )
@extend_schema_field(serializers.CharField(allow_null=True))
def get_key(self, obj: Union[dict, User]) -> Optional[str]:
key = None
- if isinstance(obj, User) and allauth_settings.EMAIL_VERIFICATION != \
- allauth_settings.EmailVerificationMethod.MANDATORY:
+ if (
+ isinstance(obj, User)
+ and allauth_settings.EMAIL_VERIFICATION
+ != allauth_settings.EmailVerificationMethod.MANDATORY
+ ):
key = obj.auth_token.key
return key
def get_cleaned_data(self):
data = super().get_cleaned_data()
- data.update({
- 'first_name': self.validated_data.get('first_name', ''),
- 'last_name': self.validated_data.get('last_name', ''),
- })
+ data.update(
+ {
+ "first_name": self.validated_data.get("first_name", ""),
+ "last_name": self.validated_data.get("last_name", ""),
+ }
+ )
return data
@@ -57,7 +63,7 @@ def email_address_exists(email) -> bool:
if EmailAddress.objects.filter(email__iexact=email).exists():
return True
- if (email_field := allauth_settings.USER_MODEL_EMAIL_FIELD):
+ if email_field := allauth_settings.USER_MODEL_EMAIL_FIELD:
users = get_user_model().objects
return users.filter(**{email_field + "__iexact": email}).exists()
return False
@@ -68,7 +74,7 @@ def email_address_exists(email) -> bool:
user = get_dummy_user(email)
if not user:
raise serializers.ValidationError(
- ('A user is already registered with this e-mail address.'),
+ ("A user is already registered with this e-mail address."),
)
return email
@@ -84,11 +90,9 @@ def save(self, request):
user = adapter.save_user(request, user, self, commit=False)
if "password1" in self.cleaned_data:
try:
- adapter.clean_password(self.cleaned_data['password1'], user=user)
+ adapter.clean_password(self.cleaned_data["password1"], user=user)
except DjangoValidationError as exc:
- raise serializers.ValidationError(
- detail=serializers.as_serializer_error(exc)
- )
+ raise serializers.ValidationError(detail=serializers.as_serializer_error(exc))
user.save()
self.custom_signup(request, user)
@@ -104,35 +108,42 @@ def password_reset_form_class(self):
def get_email_options(self):
domain = None
- if hasattr(settings, 'UI_HOST') and settings.UI_HOST:
+ if hasattr(settings, "UI_HOST") and settings.UI_HOST:
domain = settings.UI_HOST
- if hasattr(settings, 'UI_PORT') and settings.UI_PORT:
- domain += ':{}'.format(settings.UI_PORT)
- return {
- 'domain_override': domain
- }
+ if hasattr(settings, "UI_PORT") and settings.UI_PORT:
+ domain += ":{}".format(settings.UI_PORT)
+ return {"domain_override": domain}
+
class LoginSerializerEx(LoginSerializer):
def get_auth_user_using_allauth(self, username, email, password):
def is_email_authentication():
- return settings.ACCOUNT_AUTHENTICATION_METHOD == allauth_settings.AuthenticationMethod.EMAIL
+ return (
+ settings.ACCOUNT_AUTHENTICATION_METHOD
+ == allauth_settings.AuthenticationMethod.EMAIL
+ )
def is_username_authentication():
- return settings.ACCOUNT_AUTHENTICATION_METHOD == allauth_settings.AuthenticationMethod.USERNAME
+ return (
+ settings.ACCOUNT_AUTHENTICATION_METHOD
+ == allauth_settings.AuthenticationMethod.USERNAME
+ )
# check that the server settings match the request
if is_username_authentication() and not username and email:
raise ValidationError(
- 'Attempt to authenticate with email/password. '
- 'But username/password are used for authentication on the server. '
- 'Please check your server configuration ACCOUNT_AUTHENTICATION_METHOD.')
+ "Attempt to authenticate with email/password. "
+ "But username/password are used for authentication on the server. "
+ "Please check your server configuration ACCOUNT_AUTHENTICATION_METHOD."
+ )
if is_email_authentication() and not email and username:
raise ValidationError(
- 'Attempt to authenticate with username/password. '
- 'But email/password are used for authentication on the server. '
- 'Please check your server configuration ACCOUNT_AUTHENTICATION_METHOD.')
+ "Attempt to authenticate with username/password. "
+ "But email/password are used for authentication on the server. "
+ "Please check your server configuration ACCOUNT_AUTHENTICATION_METHOD."
+ )
# Authentication through email
if settings.ACCOUNT_AUTHENTICATION_METHOD == allauth_settings.AuthenticationMethod.EMAIL:
@@ -146,6 +157,6 @@ def is_username_authentication():
if email:
users = filter_users_by_email(email)
if not users or len(users) > 1:
- raise ValidationError('Unable to login with provided credentials')
+ raise ValidationError("Unable to login with provided credentials")
return self._validate_username_email(username, email, password)
diff --git a/cvat/apps/iam/signals.py b/cvat/apps/iam/signals.py
index 73f919a1a4a4..b8bbf643dab8 100644
--- a/cvat/apps/iam/signals.py
+++ b/cvat/apps/iam/signals.py
@@ -3,8 +3,8 @@
# SPDX-License-Identifier: MIT
from django.conf import settings
-from django.contrib.auth.models import User, Group
-from django.db.models.signals import post_save, post_migrate
+from django.contrib.auth.models import Group, User
+from django.db.models.signals import post_migrate, post_save
def register_groups(sender, **kwargs):
@@ -12,7 +12,9 @@ def register_groups(sender, **kwargs):
for role in settings.IAM_ROLES:
Group.objects.get_or_create(name=role)
-if settings.IAM_TYPE == 'BASIC':
+
+if settings.IAM_TYPE == "BASIC":
+
def create_user(sender, instance, created, **kwargs):
from allauth.account import app_settings as allauth_settings
from allauth.account.models import EmailAddress
@@ -23,14 +25,16 @@ def create_user(sender, instance, created, **kwargs):
# create and verify EmailAddress for superuser accounts
if allauth_settings.EMAIL_REQUIRED:
- EmailAddress.objects.get_or_create(user=instance,
- email=instance.email, primary=True, verified=True)
- else: # don't need to add default groups for superuser
- if created and not getattr(instance, 'skip_group_assigning', None):
+ EmailAddress.objects.get_or_create(
+ user=instance, email=instance.email, primary=True, verified=True
+ )
+ else: # don't need to add default groups for superuser
+ if created and not getattr(instance, "skip_group_assigning", None):
db_group = Group.objects.get(name=settings.IAM_DEFAULT_ROLE)
instance.groups.add(db_group)
-elif settings.IAM_TYPE == 'LDAP':
+elif settings.IAM_TYPE == "LDAP":
+
def create_user(sender, user=None, ldap_user=None, **kwargs):
user_groups = []
for role in settings.IAM_ROLES:
@@ -56,11 +60,12 @@ def create_user(sender, user=None, ldap_user=None, **kwargs):
def register_signals(app_config):
post_migrate.connect(register_groups, app_config)
- if settings.IAM_TYPE == 'BASIC':
+ if settings.IAM_TYPE == "BASIC":
# Add default groups and add admin rights to super users.
post_save.connect(create_user, sender=User)
- elif settings.IAM_TYPE == 'LDAP':
+ elif settings.IAM_TYPE == "LDAP":
import django_auth_ldap.backend
+
# Map groups from LDAP to roles, convert a user to super user if he/she
# has an admin group.
django_auth_ldap.backend.populate_user.connect(create_user)
diff --git a/cvat/apps/iam/tests/test_rest_api.py b/cvat/apps/iam/tests/test_rest_api.py
index d3de9fd6f1df..db0745e999d3 100644
--- a/cvat/apps/iam/tests/test_rest_api.py
+++ b/cvat/apps/iam/tests/test_rest_api.py
@@ -3,25 +3,30 @@
#
# SPDX-License-Identifier: MIT
-from django.urls import reverse
+from allauth.account.views import EmailVerificationSentView
+from django.test import override_settings
+from django.urls import path, re_path, reverse
from rest_framework import status
-from rest_framework.test import APITestCase, APIClient
from rest_framework.authtoken.models import Token
-from django.test import override_settings
-from django.urls import path, re_path
-from allauth.account.views import EmailVerificationSentView
+from rest_framework.test import APIClient, APITestCase
from cvat.apps.iam.urls import urlpatterns as iam_url_patterns
from cvat.apps.iam.views import ConfirmEmailViewEx
-
urlpatterns = iam_url_patterns + [
- re_path(r'^account-confirm-email/(?P[-:\w]+)/$', ConfirmEmailViewEx.as_view(),
- name='account_confirm_email'),
- path('register/account-email-verification-sent', EmailVerificationSentView.as_view(),
- name='account_email_verification_sent'),
+ re_path(
+ r"^account-confirm-email/(?P[-:\w]+)/$",
+ ConfirmEmailViewEx.as_view(),
+ name="account_confirm_email",
+ ),
+ path(
+ "register/account-email-verification-sent",
+ EmailVerificationSentView.as_view(),
+ name="account_email_verification_sent",
+ ),
]
+
class ForceLogin:
def __init__(self, user, client):
self.user = user
@@ -29,7 +34,7 @@ def __init__(self, user, client):
def __enter__(self):
if self.user:
- self.client.force_login(self.user, backend='django.contrib.auth.backends.ModelBackend')
+ self.client.force_login(self.user, backend="django.contrib.auth.backends.ModelBackend")
return self
@@ -37,57 +42,91 @@ def __exit__(self, exception_type, exception_value, traceback):
if self.user:
self.client.logout()
+
class UserRegisterAPITestCase(APITestCase):
- user_data = {'first_name': 'test_first', 'last_name': 'test_last', 'username': 'test_username',
- 'email': 'test_email@test.com', 'password1': '$Test357Test%', 'password2': '$Test357Test%',
- 'confirmations': []}
+ user_data = {
+ "first_name": "test_first",
+ "last_name": "test_last",
+ "username": "test_username",
+ "email": "test_email@test.com",
+ "password1": "$Test357Test%",
+ "password2": "$Test357Test%",
+ "confirmations": [],
+ }
def setUp(self):
self.client = APIClient()
def _run_api_v2_user_register(self, data):
- url = reverse('rest_register')
- response = self.client.post(url, data, format='json')
+ url = reverse("rest_register")
+ response = self.client.post(url, data, format="json")
return response
def _check_response(self, response, data):
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, data)
- @override_settings(ACCOUNT_EMAIL_VERIFICATION='none')
+ @override_settings(ACCOUNT_EMAIL_VERIFICATION="none")
def test_api_v2_user_register_with_email_verification_none(self):
"""
Ensure we can register a user and get auth token key when email verification is none
"""
response = self._run_api_v2_user_register(self.user_data)
- user_token = Token.objects.get(user__username=response.data['username'])
- self._check_response(response, {'first_name': 'test_first', 'last_name': 'test_last',
- 'username': 'test_username', 'email': 'test_email@test.com',
- 'email_verification_required': False, 'key': user_token.key})
+ user_token = Token.objects.get(user__username=response.data["username"])
+ self._check_response(
+ response,
+ {
+ "first_name": "test_first",
+ "last_name": "test_last",
+ "username": "test_username",
+ "email": "test_email@test.com",
+ "email_verification_required": False,
+ "key": user_token.key,
+ },
+ )
# Since URLConf is executed before running the tests, so we have to manually configure the url patterns for
# the tests and pass it using ROOT_URLCONF in the override settings decorator
- @override_settings(ACCOUNT_EMAIL_VERIFICATION='optional', ROOT_URLCONF=__name__)
+ @override_settings(ACCOUNT_EMAIL_VERIFICATION="optional", ROOT_URLCONF=__name__)
def test_api_v2_user_register_with_email_verification_optional(self):
"""
Ensure we can register a user and get auth token key when email verification is optional
"""
response = self._run_api_v2_user_register(self.user_data)
- user_token = Token.objects.get(user__username=response.data['username'])
- self._check_response(response, {'first_name': 'test_first', 'last_name': 'test_last',
- 'username': 'test_username', 'email': 'test_email@test.com',
- 'email_verification_required': False, 'key': user_token.key})
-
- @override_settings(ACCOUNT_EMAIL_REQUIRED=True, ACCOUNT_EMAIL_VERIFICATION='mandatory',
- EMAIL_BACKEND='django.core.mail.backends.console.EmailBackend', ROOT_URLCONF=__name__)
+ user_token = Token.objects.get(user__username=response.data["username"])
+ self._check_response(
+ response,
+ {
+ "first_name": "test_first",
+ "last_name": "test_last",
+ "username": "test_username",
+ "email": "test_email@test.com",
+ "email_verification_required": False,
+ "key": user_token.key,
+ },
+ )
+
+ @override_settings(
+ ACCOUNT_EMAIL_REQUIRED=True,
+ ACCOUNT_EMAIL_VERIFICATION="mandatory",
+ EMAIL_BACKEND="django.core.mail.backends.console.EmailBackend",
+ ROOT_URLCONF=__name__,
+ )
def test_register_account_with_email_verification_mandatory(self):
"""
Ensure we can register a user and it does not return auth token key when email verification is mandatory
"""
response = self._run_api_v2_user_register(self.user_data)
- self._check_response(response, {'first_name': 'test_first', 'last_name': 'test_last',
- 'username': 'test_username', 'email': 'test_email@test.com',
- 'email_verification_required': True, 'key': None})
-
+ self._check_response(
+ response,
+ {
+ "first_name": "test_first",
+ "last_name": "test_last",
+ "username": "test_username",
+ "email": "test_email@test.com",
+ "email_verification_required": True,
+ "key": None,
+ },
+ )
diff --git a/cvat/apps/iam/urls.py b/cvat/apps/iam/urls.py
index 8b8135fc2d9a..8f66f48f22b1 100644
--- a/cvat/apps/iam/urls.py
+++ b/cvat/apps/iam/urls.py
@@ -3,46 +3,55 @@
#
# SPDX-License-Identifier: MIT
-from django.urls import path, re_path
+from allauth.account import app_settings as allauth_settings
+from dj_rest_auth.views import (
+ LogoutView,
+ PasswordChangeView,
+ PasswordResetConfirmView,
+ PasswordResetView,
+)
from django.conf import settings
+from django.urls import path, re_path
from django.urls.conf import include
-from dj_rest_auth.views import (
- LogoutView, PasswordChangeView,
- PasswordResetView, PasswordResetConfirmView)
-from allauth.account import app_settings as allauth_settings
from cvat.apps.iam.views import (
- SigningView, RegisterViewEx, RulesView,
- ConfirmEmailViewEx, LoginViewEx
+ ConfirmEmailViewEx,
+ LoginViewEx,
+ RegisterViewEx,
+ RulesView,
+ SigningView,
)
-BASIC_LOGIN_PATH_NAME = 'rest_login'
-BASIC_REGISTER_PATH_NAME = 'rest_register'
+BASIC_LOGIN_PATH_NAME = "rest_login"
+BASIC_REGISTER_PATH_NAME = "rest_register"
urlpatterns = [
- path('login', LoginViewEx.as_view(), name=BASIC_LOGIN_PATH_NAME),
- path('logout', LogoutView.as_view(), name='rest_logout'),
- path('signing', SigningView.as_view(), name='signing'),
- path('rules', RulesView.as_view(), name='rules'),
+ path("login", LoginViewEx.as_view(), name=BASIC_LOGIN_PATH_NAME),
+ path("logout", LogoutView.as_view(), name="rest_logout"),
+ path("signing", SigningView.as_view(), name="signing"),
+ path("rules", RulesView.as_view(), name="rules"),
]
-if settings.IAM_TYPE == 'BASIC':
+if settings.IAM_TYPE == "BASIC":
urlpatterns += [
- path('register', RegisterViewEx.as_view(), name=BASIC_REGISTER_PATH_NAME),
+ path("register", RegisterViewEx.as_view(), name=BASIC_REGISTER_PATH_NAME),
# password
- path('password/reset', PasswordResetView.as_view(),
- name='rest_password_reset'),
- path('password/reset/confirm', PasswordResetConfirmView.as_view(),
- name='rest_password_reset_confirm'),
- path('password/change', PasswordChangeView.as_view(),
- name='rest_password_change'),
+ path("password/reset", PasswordResetView.as_view(), name="rest_password_reset"),
+ path(
+ "password/reset/confirm",
+ PasswordResetConfirmView.as_view(),
+ name="rest_password_reset_confirm",
+ ),
+ path("password/change", PasswordChangeView.as_view(), name="rest_password_change"),
]
- if allauth_settings.EMAIL_VERIFICATION != \
- allauth_settings.EmailVerificationMethod.NONE:
+ if allauth_settings.EMAIL_VERIFICATION != allauth_settings.EmailVerificationMethod.NONE:
# emails
urlpatterns += [
- re_path(r'^account-confirm-email/(?P[-:\w]+)/$', ConfirmEmailViewEx.as_view(),
- name='account_confirm_email'),
+ re_path(
+ r"^account-confirm-email/(?P[-:\w]+)/$",
+ ConfirmEmailViewEx.as_view(),
+ name="account_confirm_email",
+ ),
]
-urlpatterns = [path('auth/', include(urlpatterns))]
+urlpatterns = [path("auth/", include(urlpatterns))]
diff --git a/cvat/apps/iam/utils.py b/cvat/apps/iam/utils.py
index 8095902769f3..9b911e48ea7c 100644
--- a/cvat/apps/iam/utils.py
+++ b/cvat/apps/iam/utils.py
@@ -1,37 +1,40 @@
-from pathlib import Path
import functools
import hashlib
import importlib
import io
import tarfile
+from pathlib import Path
from django.conf import settings
from django.contrib.sessions.backends.base import SessionBase
_OPA_RULES_PATHS = {
- Path(__file__).parent / 'rules',
+ Path(__file__).parent / "rules",
}
+
@functools.lru_cache(maxsize=None)
def get_opa_bundle() -> tuple[bytes, str]:
bundle_file = io.BytesIO()
- with tarfile.open(fileobj=bundle_file, mode='w:gz') as tar:
+ with tarfile.open(fileobj=bundle_file, mode="w:gz") as tar:
for p in _OPA_RULES_PATHS:
- for f in p.glob('*[!.gen].rego'):
+ for f in p.glob("*[!.gen].rego"):
tar.add(name=f, arcname=f.relative_to(p.parent))
bundle = bundle_file.getvalue()
etag = hashlib.blake2b(bundle).hexdigest()
return bundle, etag
+
def add_opa_rules_path(path: Path) -> None:
_OPA_RULES_PATHS.add(path)
get_opa_bundle.cache_clear()
+
def get_dummy_user(email):
- from allauth.account.models import EmailAddress
from allauth.account import app_settings
+ from allauth.account.models import EmailAddress
from allauth.account.utils import filter_users_by_email
users = filter_users_by_email(email)
@@ -40,13 +43,13 @@ def get_dummy_user(email):
user = users[0]
if user.has_usable_password():
return None
- if app_settings.EMAIL_VERIFICATION == \
- app_settings.EmailVerificationMethod.MANDATORY:
+ if app_settings.EMAIL_VERIFICATION == app_settings.EmailVerificationMethod.MANDATORY:
email = EmailAddress.objects.get_for_user(user, email)
if email.verified:
return None
return user
+
def clean_up_sessions() -> None:
SessionStore: type[SessionBase] = importlib.import_module(settings.SESSION_ENGINE).SessionStore
SessionStore.clear_expired()
diff --git a/cvat/apps/iam/views.py b/cvat/apps/iam/views.py
index 928d170c3bc4..d9bf960e426c 100644
--- a/cvat/apps/iam/views.py
+++ b/cvat/apps/iam/views.py
@@ -5,49 +5,55 @@
import functools
-from django.http import Http404, HttpResponseBadRequest, HttpResponseRedirect
-from rest_framework import views, serializers
-from rest_framework.exceptions import ValidationError
-from rest_framework.permissions import AllowAny
-from django.conf import settings
-from django.http import HttpResponse
-from django.views.decorators.http import etag as django_etag
-from rest_framework.response import Response
+from allauth.account import app_settings as allauth_settings
+from allauth.account.utils import complete_signup, has_verified_email, send_email_confirmation
+from allauth.account.views import ConfirmEmailView
from dj_rest_auth.app_settings import api_settings as dj_rest_auth_settings
from dj_rest_auth.registration.views import RegisterView
from dj_rest_auth.utils import jwt_encode
from dj_rest_auth.views import LoginView
-from allauth.account import app_settings as allauth_settings
-from allauth.account.views import ConfirmEmailView
-from allauth.account.utils import complete_signup, has_verified_email, send_email_confirmation
-
-from furl import furl
-
-from drf_spectacular.types import OpenApiTypes
-from drf_spectacular.utils import OpenApiResponse, extend_schema, inline_serializer, extend_schema_view
+from django.conf import settings
+from django.http import Http404, HttpResponse, HttpResponseBadRequest, HttpResponseRedirect
+from django.views.decorators.http import etag as django_etag
from drf_spectacular.contrib.rest_auth import get_token_serializer_class
+from drf_spectacular.types import OpenApiTypes
+from drf_spectacular.utils import (
+ OpenApiResponse,
+ extend_schema,
+ extend_schema_view,
+ inline_serializer,
+)
+from furl import furl
+from rest_framework import serializers, views
+from rest_framework.exceptions import ValidationError
+from rest_framework.permissions import AllowAny
+from rest_framework.response import Response
from .authentication import Signer
from .utils import get_opa_bundle
-@extend_schema(tags=['auth'])
-@extend_schema_view(post=extend_schema(
- summary='This method signs URL for access to the server',
- description='Signed URL contains a token which authenticates a user on the server.'
- 'Signed URL is valid during 30 seconds since signing.',
- request=inline_serializer(
- name='Signing',
- fields={
- 'url': serializers.CharField(),
- }
- ),
- responses={'200': OpenApiResponse(response=OpenApiTypes.STR, description='text URL')}))
+
+@extend_schema(tags=["auth"])
+@extend_schema_view(
+ post=extend_schema(
+ summary="This method signs URL for access to the server",
+ description="Signed URL contains a token which authenticates a user on the server."
+ "Signed URL is valid during 30 seconds since signing.",
+ request=inline_serializer(
+ name="Signing",
+ fields={
+ "url": serializers.CharField(),
+ },
+ ),
+ responses={"200": OpenApiResponse(response=OpenApiTypes.STR, description="text URL")},
+ )
+)
class SigningView(views.APIView):
def post(self, request):
- url = request.data.get('url')
+ url = request.data.get("url")
if not url:
- raise ValidationError('Please provide `url` parameter')
+ raise ValidationError("Please provide `url` parameter")
signer = Signer()
url = self.request.build_absolute_uri(url)
@@ -56,6 +62,7 @@ def post(self, request):
url = furl(url).add({Signer.QUERY_PARAM: sign}).url
return Response(url)
+
class LoginViewEx(LoginView):
"""
Check the credentials and return the REST Token
@@ -68,6 +75,7 @@ class LoginViewEx(LoginView):
Accept the following POST parameters: username, email, password
Return the REST Framework Token Object's key.
"""
+
@extend_schema(responses=get_token_serializer_class())
def post(self, request, *args, **kwargs):
self.request = request
@@ -76,9 +84,9 @@ def post(self, request, *args, **kwargs):
self.serializer.is_valid(raise_exception=True)
except ValidationError:
user = self.serializer.get_auth_user(
- self.serializer.data.get('username'),
- self.serializer.data.get('email'),
- self.serializer.data.get('password')
+ self.serializer.data.get("username"),
+ self.serializer.data.get("email"),
+ self.serializer.data.get("password"),
)
if not user:
raise
@@ -90,13 +98,14 @@ def post(self, request, *args, **kwargs):
# we cannot use redirect to ACCOUNT_EMAIL_VERIFICATION_SENT_REDIRECT_URL here
# because redirect will make a POST request and we'll get a 404 code
# (although in the browser request method will be displayed like GET)
- return HttpResponseBadRequest('Unverified email')
- except Exception: # nosec
+ return HttpResponseBadRequest("Unverified email")
+ except Exception: # nosec
pass
self.login()
return self.get_response()
+
class RegisterViewEx(RegisterView):
def get_response_data(self, user):
serializer = self.get_serializer(user)
@@ -117,20 +126,24 @@ def get_response_data(self, user):
# Link to the issue: https://github.com/iMerica/dj-rest-auth/issues/604
def perform_create(self, serializer):
user = serializer.save(self.request)
- if allauth_settings.EMAIL_VERIFICATION != \
- allauth_settings.EmailVerificationMethod.MANDATORY:
+ if (
+ allauth_settings.EMAIL_VERIFICATION
+ != allauth_settings.EmailVerificationMethod.MANDATORY
+ ):
if dj_rest_auth_settings.USE_JWT:
self.access_token, self.refresh_token = jwt_encode(user)
elif self.token_model:
dj_rest_auth_settings.TOKEN_CREATOR(self.token_model, user, serializer)
complete_signup(
- self.request._request, user,
+ self.request._request,
+ user,
allauth_settings.EMAIL_VERIFICATION,
None,
)
return user
+
def _etag(etag_func):
"""
Decorator to support conditional retrieval (or change)
@@ -138,6 +151,7 @@ def _etag(etag_func):
It calls Django's original decorator but pass correct request object to it.
Django's original decorator doesn't work with DRF request object.
"""
+
def decorator(func):
@functools.wraps(func)
def wrapper(obj_self, request, *args, **kwargs):
@@ -150,9 +164,12 @@ def patched_viewset_method(*_args, **_kwargs):
return func(obj_self, drf_request, *args, **kwargs)
return patched_viewset_method(wsgi_request, *args, **kwargs)
+
return wrapper
+
return decorator
+
class RulesView(views.APIView):
serializer_class = None
permission_classes = [AllowAny]
@@ -161,10 +178,11 @@ class RulesView(views.APIView):
@_etag(lambda request: get_opa_bundle()[1])
def get(self, request):
- return HttpResponse(get_opa_bundle()[0], content_type='application/x-tar')
+ return HttpResponse(get_opa_bundle()[0], content_type="application/x-tar")
+
class ConfirmEmailViewEx(ConfirmEmailView):
- template_name = 'account/email/email_confirmation_signup_message.html'
+ template_name = "account/email/email_confirmation_signup_message.html"
def get(self, *args, **kwargs):
try:
diff --git a/cvat/apps/lambda_manager/apps.py b/cvat/apps/lambda_manager/apps.py
index 1bbc515522ad..974e32dc74a4 100644
--- a/cvat/apps/lambda_manager/apps.py
+++ b/cvat/apps/lambda_manager/apps.py
@@ -7,8 +7,9 @@
class LambdaManagerConfig(AppConfig):
- name = 'cvat.apps.lambda_manager'
+ name = "cvat.apps.lambda_manager"
def ready(self) -> None:
from cvat.apps.iam.permissions import load_app_permissions
+
load_app_permissions(self)
diff --git a/cvat/apps/lambda_manager/models.py b/cvat/apps/lambda_manager/models.py
index 47d732c41dd1..f6e684a1cc0f 100644
--- a/cvat/apps/lambda_manager/models.py
+++ b/cvat/apps/lambda_manager/models.py
@@ -5,6 +5,7 @@
import django.db.models as models
+
class FunctionKind(models.TextChoices):
DETECTOR = "detector"
INTERACTOR = "interactor"
diff --git a/cvat/apps/lambda_manager/permissions.py b/cvat/apps/lambda_manager/permissions.py
index 94800f0edd5d..a2192cdd4914 100644
--- a/cvat/apps/lambda_manager/permissions.py
+++ b/cvat/apps/lambda_manager/permissions.py
@@ -8,27 +8,28 @@
from cvat.apps.engine.permissions import JobPermission, TaskPermission
from cvat.apps.iam.permissions import OpenPolicyAgentPermission, StrEnum
+
class LambdaPermission(OpenPolicyAgentPermission):
class Scopes(StrEnum):
- LIST = 'list'
- VIEW = 'view'
- CALL_ONLINE = 'call:online'
- CALL_OFFLINE = 'call:offline'
- LIST_OFFLINE = 'list:offline'
+ LIST = "list"
+ VIEW = "view"
+ CALL_ONLINE = "call:online"
+ CALL_OFFLINE = "call:offline"
+ LIST_OFFLINE = "list:offline"
@classmethod
def create(cls, request, view, obj, iam_context):
permissions = []
- if view.basename == 'lambda_function' or view.basename == 'lambda_request':
+ if view.basename == "lambda_function" or view.basename == "lambda_request":
scopes = cls.get_scopes(request, view, obj)
for scope in scopes:
self = cls.create_base_perm(request, view, scope, iam_context, obj)
permissions.append(self)
- if job_id := request.data.get('job'):
+ if job_id := request.data.get("job"):
perm = JobPermission.create_scope_view_data(iam_context, job_id)
permissions.append(perm)
- elif task_id := request.data.get('task'):
+ elif task_id := request.data.get("task"):
perm = TaskPermission.create_scope_view_data(iam_context, task_id)
permissions.append(perm)
@@ -36,20 +37,22 @@ def create(cls, request, view, obj, iam_context):
def __init__(self, **kwargs):
super().__init__(**kwargs)
- self.url = settings.IAM_OPA_DATA_URL + '/lambda/allow'
+ self.url = settings.IAM_OPA_DATA_URL + "/lambda/allow"
@staticmethod
def get_scopes(request, view, obj):
Scopes = __class__.Scopes
- return [{
- ('lambda_function', 'list'): Scopes.LIST,
- ('lambda_function', 'retrieve'): Scopes.VIEW,
- ('lambda_function', 'call'): Scopes.CALL_ONLINE,
- ('lambda_request', 'create'): Scopes.CALL_OFFLINE,
- ('lambda_request', 'list'): Scopes.LIST_OFFLINE,
- ('lambda_request', 'retrieve'): Scopes.CALL_OFFLINE,
- ('lambda_request', 'destroy'): Scopes.CALL_OFFLINE,
- }[(view.basename, view.action)]]
+ return [
+ {
+ ("lambda_function", "list"): Scopes.LIST,
+ ("lambda_function", "retrieve"): Scopes.VIEW,
+ ("lambda_function", "call"): Scopes.CALL_ONLINE,
+ ("lambda_request", "create"): Scopes.CALL_OFFLINE,
+ ("lambda_request", "list"): Scopes.LIST_OFFLINE,
+ ("lambda_request", "retrieve"): Scopes.CALL_OFFLINE,
+ ("lambda_request", "destroy"): Scopes.CALL_OFFLINE,
+ }[(view.basename, view.action)]
+ ]
def get_resource(self):
return None
diff --git a/cvat/apps/lambda_manager/rules/tests/generators/lambda_test.gen.rego.py b/cvat/apps/lambda_manager/rules/tests/generators/lambda_test.gen.rego.py
index 94f694988a38..f506fda56a07 100644
--- a/cvat/apps/lambda_manager/rules/tests/generators/lambda_test.gen.rego.py
+++ b/cvat/apps/lambda_manager/rules/tests/generators/lambda_test.gen.rego.py
@@ -77,13 +77,15 @@ def get_data(scope, context, ownership, privilege, membership, resource):
"scope": scope,
"auth": {
"user": {"id": random.randrange(0, 100), "privilege": privilege},
- "organization": {
- "id": random.randrange(100, 200),
- "owner": {"id": random.randrange(200, 300)},
- "user": {"role": membership},
- }
- if context == "organization"
- else None,
+ "organization": (
+ {
+ "id": random.randrange(100, 200),
+ "owner": {"id": random.randrange(200, 300)},
+ "user": {"role": membership},
+ }
+ if context == "organization"
+ else None
+ ),
},
"resource": resource,
}
diff --git a/cvat/apps/lambda_manager/serializers.py b/cvat/apps/lambda_manager/serializers.py
index ab8809bd7cc8..8daf3a53642b 100644
--- a/cvat/apps/lambda_manager/serializers.py
+++ b/cvat/apps/lambda_manager/serializers.py
@@ -5,20 +5,25 @@
from drf_spectacular.utils import extend_schema_serializer
from rest_framework import serializers
+
class SublabelMappingEntrySerializer(serializers.Serializer):
name = serializers.CharField()
attributes = serializers.DictField(child=serializers.CharField(), required=False)
+
class LabelMappingEntrySerializer(serializers.Serializer):
name = serializers.CharField()
attributes = serializers.DictField(child=serializers.CharField(), required=False)
- sublabels = serializers.DictField(child=SublabelMappingEntrySerializer(), required=False,
- help_text="Label mapping for from the model to the task sublabels within a parent label"
+ sublabels = serializers.DictField(
+ child=SublabelMappingEntrySerializer(),
+ required=False,
+ help_text="Label mapping for from the model to the task sublabels within a parent label",
)
+
@extend_schema_serializer(
# The "Request" suffix is added by drf-spectacular automatically
- component_name='FunctionCall'
+ component_name="FunctionCall"
)
class FunctionCallRequestSerializer(serializers.Serializer):
function = serializers.CharField(help_text="The name of the function to execute")
@@ -26,13 +31,25 @@ class FunctionCallRequestSerializer(serializers.Serializer):
job = serializers.IntegerField(required=False, help_text="The id of the job to be annotated")
max_distance = serializers.IntegerField(required=False)
threshold = serializers.FloatField(required=False)
- cleanup = serializers.BooleanField(help_text="Whether existing annotations should be removed", default=False)
- convMaskToPoly = serializers.BooleanField(required=False, source="conv_mask_to_poly", write_only=True, help_text="Deprecated; use conv_mask_to_poly instead")
- conv_mask_to_poly = serializers.BooleanField(required=False, help_text="Convert mask shapes to polygons")
- mapping = serializers.DictField(child=LabelMappingEntrySerializer(), required=False,
- help_text="Label mapping from the model to the task labels"
+ cleanup = serializers.BooleanField(
+ help_text="Whether existing annotations should be removed", default=False
+ )
+ convMaskToPoly = serializers.BooleanField(
+ required=False,
+ source="conv_mask_to_poly",
+ write_only=True,
+ help_text="Deprecated; use conv_mask_to_poly instead",
+ )
+ conv_mask_to_poly = serializers.BooleanField(
+ required=False, help_text="Convert mask shapes to polygons"
+ )
+ mapping = serializers.DictField(
+ child=LabelMappingEntrySerializer(),
+ required=False,
+ help_text="Label mapping from the model to the task labels",
)
+
class FunctionCallParamsSerializer(serializers.Serializer):
id = serializers.CharField(allow_null=True, help_text="The name of the function")
@@ -41,6 +58,7 @@ class FunctionCallParamsSerializer(serializers.Serializer):
threshold = serializers.FloatField(allow_null=True)
+
class FunctionCallSerializer(serializers.Serializer):
id = serializers.CharField(help_text="Request id")
diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py
index f9292b278b45..38e812b25ff0 100644
--- a/cvat/apps/lambda_manager/tests/test_lambda.py
+++ b/cvat/apps/lambda_manager/tests/test_lambda.py
@@ -3,12 +3,12 @@
#
# SPDX-License-Identifier: MIT
+import json
+import os
from collections import Counter, OrderedDict
from itertools import groupby
from typing import Optional
from unittest import mock, skip
-import json
-import os
import requests
from django.contrib.auth.models import Group, User
@@ -16,16 +16,22 @@
from rest_framework import status
from cvat.apps.engine.tests.utils import (
- ApiTestBase, filter_dict, ForceLogin, generate_image_file, get_paginated_collection
+ ApiTestBase,
+ ForceLogin,
+ filter_dict,
+ generate_image_file,
+ get_paginated_collection,
)
-LAMBDA_ROOT_PATH = '/api/lambda'
-LAMBDA_FUNCTIONS_PATH = f'{LAMBDA_ROOT_PATH}/functions'
-LAMBDA_REQUESTS_PATH = f'{LAMBDA_ROOT_PATH}/requests'
+LAMBDA_ROOT_PATH = "/api/lambda"
+LAMBDA_FUNCTIONS_PATH = f"{LAMBDA_ROOT_PATH}/functions"
+LAMBDA_REQUESTS_PATH = f"{LAMBDA_ROOT_PATH}/requests"
id_function_detector = "test-openvino-omz-public-yolo-v3-tf"
id_function_reid_with_response_data = "test-openvino-omz-intel-person-reidentification-retail-0300"
-id_function_reid_with_no_response_data = "test-openvino-omz-intel-person-reidentification-retail-1234"
+id_function_reid_with_no_response_data = (
+ "test-openvino-omz-intel-person-reidentification-retail-1234"
+)
id_function_interactor = "test-openvino-dextr"
id_function_tracker = "test-pth-foolwood-siammask"
id_function_non_type = "test-model-has-non-type"
@@ -36,29 +42,47 @@
id_function_state_error = "test-model-has-state-error"
expected_keys_in_response_all_functions = ["id", "kind", "labels_v2", "description", "name"]
-expected_keys_in_response_function_interactor = expected_keys_in_response_all_functions + ["min_pos_points", "startswith_box"]
-expected_keys_in_response_requests = ["id", "function", "status", "progress", "enqueued", "started", "ended", "exc_info"]
-
-path = os.path.join(os.path.dirname(__file__), 'assets', 'tasks.json')
+expected_keys_in_response_function_interactor = expected_keys_in_response_all_functions + [
+ "min_pos_points",
+ "startswith_box",
+]
+expected_keys_in_response_requests = [
+ "id",
+ "function",
+ "status",
+ "progress",
+ "enqueued",
+ "started",
+ "ended",
+ "exc_info",
+]
+
+path = os.path.join(os.path.dirname(__file__), "assets", "tasks.json")
with open(path) as f:
tasks = json.load(f)
# removed unnecessary data
-path = os.path.join(os.path.dirname(__file__), 'assets', 'functions.json')
+path = os.path.join(os.path.dirname(__file__), "assets", "functions.json")
with open(path) as f:
functions = json.load(f)
+
class _LambdaTestCaseBase(ApiTestBase):
def setUp(self):
super().setUp()
self.client = self.client_class(raise_request_exception=False)
- http_patcher = mock.patch('cvat.apps.lambda_manager.views.LambdaGateway._http', side_effect = self._get_data_from_lambda_manager_http)
+ http_patcher = mock.patch(
+ "cvat.apps.lambda_manager.views.LambdaGateway._http",
+ side_effect=self._get_data_from_lambda_manager_http,
+ )
self.addCleanup(http_patcher.stop)
http_patcher.start()
- invoke_patcher = mock.patch('cvat.apps.lambda_manager.views.LambdaGateway.invoke', side_effect = self._invoke_function)
+ invoke_patcher = mock.patch(
+ "cvat.apps.lambda_manager.views.LambdaGateway.invoke", side_effect=self._invoke_function
+ )
self.addCleanup(invoke_patcher.stop)
invoke_patcher.start()
@@ -72,13 +96,13 @@ def _get_data_from_lambda_manager_http(self, **kwargs):
if func_id in [id_function_state_building, id_function_state_error]:
r = requests.RequestException()
r.response = HttpResponseServerError()
- raise r # raise 500 Internal_Server error
+ raise r # raise 500 Internal_Server error
return functions["positive"][func_id]
else:
r = requests.HTTPError()
r.response = HttpResponseNotFound()
- raise r # raise 404 Not Found error
+ raise r # raise 404 Not Found error
def _invoke_function(self, func, payload):
data = []
@@ -135,27 +159,32 @@ def _create_db_users(cls):
(group_admin, _) = Group.objects.get_or_create(name="admin")
(group_user, _) = Group.objects.get_or_create(name="user")
- user_admin = User.objects.create_superuser(username="admin", email="",
- password="admin")
+ user_admin = User.objects.create_superuser(username="admin", email="", password="admin")
user_admin.groups.add(group_admin)
- user_dummy = User.objects.create_user(username="user", password="user",
- email="user@example.com")
+ user_dummy = User.objects.create_user(
+ username="user", password="user", email="user@example.com"
+ )
user_dummy.groups.add(group_user)
cls.admin = user_admin
cls.user = user_dummy
-
def _create_task(self, task_spec, data, *, owner=None, org_id=None):
with ForceLogin(owner or self.admin, self.client):
- response = self.client.post('/api/tasks', data=task_spec, format="json",
- QUERY_STRING=f'org_id={org_id}' if org_id is not None else None)
+ response = self.client.post(
+ "/api/tasks",
+ data=task_spec,
+ format="json",
+ QUERY_STRING=f"org_id={org_id}" if org_id is not None else None,
+ )
assert response.status_code == status.HTTP_201_CREATED, response.status_code
tid = response.data["id"]
- response = self.client.post("/api/tasks/%s/data" % tid,
+ response = self.client.post(
+ "/api/tasks/%s/data" % tid,
data=data,
- QUERY_STRING=f'org_id={org_id}' if org_id is not None else None)
+ QUERY_STRING=f"org_id={org_id}" if org_id is not None else None,
+ )
assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code
rq_id = response.json()["rq_id"]
@@ -163,65 +192,72 @@ def _create_task(self, task_spec, data, *, owner=None, org_id=None):
assert response.status_code == status.HTTP_200_OK, response.status_code
assert response.json()["status"] == "finished", response.json().get("status")
- response = self.client.get("/api/tasks/%s" % tid,
- QUERY_STRING=f'org_id={org_id}' if org_id is not None else None)
+ response = self.client.get(
+ "/api/tasks/%s" % tid,
+ QUERY_STRING=f"org_id={org_id}" if org_id is not None else None,
+ )
task = response.data
return task
-
- def _generate_task_images(self, count): # pylint: disable=no-self-use
+ def _generate_task_images(self, count): # pylint: disable=no-self-use
images = {
- "client_files[%d]" % i: generate_image_file("image_%d.jpg" % i)
- for i in range(count)
+ "client_files[%d]" % i: generate_image_file("image_%d.jpg" % i) for i in range(count)
}
images["image_quality"] = 75
return images
-
@classmethod
def setUpTestData(cls):
cls._create_db_users()
-
def _get_request(self, path, user, *, org_id=None):
with ForceLogin(user, self.client):
- response = self.client.get(path,
- QUERY_STRING=f'org_id={org_id}' if org_id is not None else '')
+ response = self.client.get(
+ path, QUERY_STRING=f"org_id={org_id}" if org_id is not None else ""
+ )
return response
-
def _delete_request(self, path, user, *, org_id=None):
with ForceLogin(user, self.client):
- response = self.client.delete(path,
- QUERY_STRING=f'org_id={org_id}' if org_id is not None else '')
+ response = self.client.delete(
+ path, QUERY_STRING=f"org_id={org_id}" if org_id is not None else ""
+ )
return response
-
def _post_request(self, path, user, data, *, org_id=None):
data = json.dumps(data)
with ForceLogin(user, self.client):
- response = self.client.post(path, data=data, content_type='application/json',
- QUERY_STRING=f'org_id={org_id}' if org_id is not None else '')
+ response = self.client.post(
+ path,
+ data=data,
+ content_type="application/json",
+ QUERY_STRING=f"org_id={org_id}" if org_id is not None else "",
+ )
return response
-
def _patch_request(self, path, user, data, *, org_id=None):
data = json.dumps(data)
with ForceLogin(user, self.client):
- response = self.client.patch(path, data=data, content_type='application/json',
- QUERY_STRING=f'org_id={org_id}' if org_id is not None else '')
+ response = self.client.patch(
+ path,
+ data=data,
+ content_type="application/json",
+ QUERY_STRING=f"org_id={org_id}" if org_id is not None else "",
+ )
return response
-
def _put_request(self, path, user, data, *, org_id=None):
data = json.dumps(data)
with ForceLogin(user, self.client):
- response = self.client.put(path, data=data, content_type='application/json',
- QUERY_STRING=f'org_id={org_id}' if org_id is not None else '')
+ response = self.client.put(
+ path,
+ data=data,
+ content_type="application/json",
+ QUERY_STRING=f"org_id={org_id}" if org_id is not None else "",
+ )
return response
-
def _check_expected_keys_in_response_function(self, data):
kind = data["kind"]
if kind == "interactor":
@@ -232,7 +268,7 @@ def _check_expected_keys_in_response_function(self, data):
self.assertIn(key, data)
def _delete_lambda_request(self, request_id: str, user: Optional[User] = None) -> None:
- response = self._delete_request(f'{LAMBDA_REQUESTS_PATH}/{request_id}', user or self.admin)
+ response = self._delete_request(f"{LAMBDA_REQUESTS_PATH}/{request_id}", user or self.admin)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
@@ -261,8 +297,7 @@ def test_api_v2_lambda_functions_list(self):
response = self._get_request(LAMBDA_FUNCTIONS_PATH, None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
- @mock.patch('cvat.apps.lambda_manager.views.LambdaGateway._http', return_value = {})
+ @mock.patch("cvat.apps.lambda_manager.views.LambdaGateway._http", return_value={})
def test_api_v2_lambda_functions_list_empty(self, mock_http):
response = self._get_request(LAMBDA_FUNCTIONS_PATH, self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -275,10 +310,12 @@ def test_api_v2_lambda_functions_list_empty(self, mock_http):
response = self._get_request(LAMBDA_FUNCTIONS_PATH, None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
@mock.patch(
- 'cvat.apps.lambda_manager.views.LambdaGateway._http',
- return_value={**functions["negative"], id_function_detector: functions["positive"][id_function_detector]}
+ "cvat.apps.lambda_manager.views.LambdaGateway._http",
+ return_value={
+ **functions["negative"],
+ id_function_detector: functions["positive"][id_function_detector],
+ },
)
def test_api_v2_lambda_functions_list_negative(self, mock_http):
response = self._get_request(LAMBDA_FUNCTIONS_PATH, self.admin)
@@ -289,11 +326,15 @@ def test_api_v2_lambda_functions_list_negative(self, mock_http):
self.assertEqual(visible_ids, {id_function_detector})
def test_api_v2_lambda_functions_read(self):
- ids_functions = [id_function_detector, id_function_interactor,\
- id_function_tracker, id_function_reid_with_response_data]
+ ids_functions = [
+ id_function_detector,
+ id_function_interactor,
+ id_function_tracker,
+ id_function_reid_with_response_data,
+ ]
for id_func in ids_functions:
- path = f'{LAMBDA_FUNCTIONS_PATH}/{id_func}'
+ path = f"{LAMBDA_FUNCTIONS_PATH}/{id_func}"
response = self._get_request(path, self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -306,32 +347,31 @@ def test_api_v2_lambda_functions_read(self):
response = self._get_request(path, None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
def test_api_v2_lambda_functions_read_wrong_id(self):
id_wrong_function = "test-functions-wrong-id"
- response = self._get_request(f'{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}', self.admin)
+ response = self._get_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}", self.admin)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- response = self._get_request(f'{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}', self.user)
+ response = self._get_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}", self.user)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- response = self._get_request(f'{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}', None)
+ response = self._get_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_wrong_function}", None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
def test_api_v2_lambda_functions_read_negative(self):
for id_func in [
- id_function_non_type, id_function_wrong_type, id_function_unknown_type,
+ id_function_non_type,
+ id_function_wrong_type,
+ id_function_unknown_type,
id_function_non_unique_labels,
]:
with mock.patch(
- 'cvat.apps.lambda_manager.views.LambdaGateway._http',
- return_value=functions["negative"][id_func]
+ "cvat.apps.lambda_manager.views.LambdaGateway._http",
+ return_value=functions["negative"][id_func],
):
- response = self._get_request(f'{LAMBDA_FUNCTIONS_PATH}/{id_func}', self.admin)
+ response = self._get_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_func}", self.admin)
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
-
@skip("Fail: add mock")
def test_api_v2_lambda_requests_list(self):
response = self._get_request(LAMBDA_REQUESTS_PATH, self.admin)
@@ -347,7 +387,6 @@ def test_api_v2_lambda_requests_list(self):
response = self._get_request(LAMBDA_REQUESTS_PATH, None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
def test_api_v2_lambda_requests_list_empty(self):
response = self._get_request(LAMBDA_REQUESTS_PATH, self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -360,7 +399,6 @@ def test_api_v2_lambda_requests_list_empty(self):
response = self._get_request(LAMBDA_REQUESTS_PATH, None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
def test_api_v2_lambda_requests_read(self):
# create request
data_main_task = {
@@ -369,76 +407,78 @@ def test_api_v2_lambda_requests_read(self):
"cleanup": True,
"threshold": 55,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data_main_task)
self.assertEqual(response.status_code, status.HTTP_200_OK)
id_request = response.data["id"]
- response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.admin)
+ response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
for key in expected_keys_in_response_requests:
self.assertIn(key, response.data)
- response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.user)
+ response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.user)
self.assertEqual(response.status_code, status.HTTP_200_OK)
for key in expected_keys_in_response_requests:
self.assertIn(key, response.data)
- response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', None)
+ response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
def test_api_v2_lambda_requests_read_wrong_id(self):
id_request = "cf343b95-afeb-475e-ab53-8d7e64991d30-wrong-id"
- response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.admin)
+ response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.admin)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.user)
+ response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.user)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', None)
+ response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
def test_api_v2_lambda_requests_delete_finished_request(self):
data = {
"function": id_function_detector,
"task": self.main_task["id"],
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
id_request = response.data["id"]
- response = self._delete_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', None)
+ response = self._delete_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", None)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
- response = self._delete_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.admin)
+ response = self._delete_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.admin)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
- response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.admin)
+ response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.admin)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
id_request = response.data["id"]
- response = self._delete_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.user)
+ response = self._delete_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.user)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
- response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.user)
+ response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.user)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
@skip("Fail: add mock")
def test_api_v2_lambda_requests_delete_not_finished_request(self):
pass
-
def test_api_v2_lambda_requests_create(self):
- ids_functions = [id_function_detector, id_function_interactor, id_function_tracker, \
- id_function_reid_with_response_data, id_function_detector, id_function_reid_with_no_response_data]
+ ids_functions = [
+ id_function_detector,
+ id_function_interactor,
+ id_function_tracker,
+ id_function_reid_with_response_data,
+ id_function_detector,
+ id_function_reid_with_no_response_data,
+ ]
for id_func in ids_functions:
data_main_task = {
@@ -447,7 +487,7 @@ def test_api_v2_lambda_requests_create(self):
"cleanup": True,
"threshold": 55,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
data_assigneed_to_user_task = {
@@ -456,7 +496,7 @@ def test_api_v2_lambda_requests_create(self):
"cleanup": False,
"max_distance": 70,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
@@ -467,7 +507,9 @@ def test_api_v2_lambda_requests_create(self):
self._delete_lambda_request(response.data["id"])
- response = self._post_request(LAMBDA_REQUESTS_PATH, self.user, data_assigneed_to_user_task)
+ response = self._post_request(
+ LAMBDA_REQUESTS_PATH, self.user, data_assigneed_to_user_task
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
for key in expected_keys_in_response_requests:
self.assertIn(key, response.data)
@@ -480,10 +522,11 @@ def test_api_v2_lambda_requests_create(self):
response = self._post_request(LAMBDA_REQUESTS_PATH, None, data_main_task)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
def test_api_v2_lambda_requests_create_negative(self):
for id_func in [
- id_function_non_type, id_function_wrong_type, id_function_unknown_type,
+ id_function_non_type,
+ id_function_wrong_type,
+ id_function_unknown_type,
id_function_non_unique_labels,
]:
data = {
@@ -491,49 +534,45 @@ def test_api_v2_lambda_requests_create_negative(self):
"task": self.main_task["id"],
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
with mock.patch(
- 'cvat.apps.lambda_manager.views.LambdaGateway._http',
+ "cvat.apps.lambda_manager.views.LambdaGateway._http",
return_value=functions["negative"][id_func],
):
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
-
def test_api_v2_lambda_requests_create_empty_data(self):
data = {}
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
def test_api_v2_lambda_requests_create_without_function(self):
data = {
"task": self.main_task["id"],
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
def test_api_v2_lambda_requests_create_wrong_id_function(self):
data = {
"function": "test-requests-wrong-id",
"task": self.main_task["id"],
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
@skip("Fail: add mock")
def test_api_v2_lambda_requests_create_two_requests(self):
data = {
@@ -541,10 +580,10 @@ def test_api_v2_lambda_requests_create_two_requests(self):
"task": self.main_task["id"],
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
- request_id = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data).data['id']
+ request_id = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data).data["id"]
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
@@ -569,7 +608,7 @@ def test_api_v2_lambda_requests_create_without_cleanup(self):
"function": id_function_detector,
"task": self.main_task["id"],
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
@@ -597,26 +636,24 @@ def test_api_v2_lambda_requests_create_without_task(self):
"function": id_function_detector,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
def test_api_v2_lambda_requests_create_wrong_id_task(self):
data = {
"function": id_function_detector,
"task": 12345,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
def test_api_v2_lambda_requests_create_is_not_ready(self):
ids_functions = [id_function_state_building, id_function_state_error]
@@ -626,14 +663,13 @@ def test_api_v2_lambda_requests_create_is_not_ready(self):
"task": self.main_task["id"],
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data)
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
-
def test_api_v2_lambda_functions_create_detector(self):
data_main_task = {
"task": self.main_task["id"],
@@ -641,7 +677,7 @@ def test_api_v2_lambda_functions_create_detector(self):
"cleanup": True,
"threshold": 0.55,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
data_assigneed_to_user_task = {
@@ -649,122 +685,199 @@ def test_api_v2_lambda_functions_create_detector(self):
"frame": 0,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data_main_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data_main_task
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.user, data_assigneed_to_user_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}",
+ self.user,
+ data_assigneed_to_user_task,
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", None, data_main_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", None, data_main_task
+ )
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
- @skip("Fail: expected result != actual result") # TODO move test to test_api_v2_lambda_functions_create
+ @skip(
+ "Fail: expected result != actual result"
+ ) # TODO move test to test_api_v2_lambda_functions_create
def test_api_v2_lambda_functions_create_user_assigned_to_no_user(self):
data = {
"task": self.main_task["id"],
"frame": 0,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.user, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.user, data
+ )
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
-
def test_api_v2_lambda_functions_create_interactor(self):
data_main_task = {
"task": self.main_task["id"],
"frame": 0,
"pos_points": [
- [3.45, 6.78],
- [12.1, 12.1],
+ [3.45, 6.78],
+ [12.1, 12.1],
[34.1, 41.0],
- [43.01, 43.99],
+ [43.01, 43.99],
],
"neg_points": [
- [3.25, 6.58],
- [11.1, 11.0],
- [35.5, 44.44],
- [45.01, 45.99],
- ],
+ [3.25, 6.58],
+ [11.1, 11.0],
+ [35.5, 44.44],
+ [45.01, 45.99],
+ ],
}
data_assigneed_to_user_task = {
"task": self.assigneed_to_user_task["id"],
"frame": 0,
"threshold": 0.1,
"pos_points": [
- [3.45, 6.78],
- [12.1, 12.1],
+ [3.45, 6.78],
+ [12.1, 12.1],
[34.1, 41.0],
- [43.01, 43.99],
+ [43.01, 43.99],
],
"neg_points": [
- [3.25, 6.58],
- [11.1, 11.0],
- [35.5, 44.44],
- [45.01, 45.99],
- ],
+ [3.25, 6.58],
+ [11.1, 11.0],
+ [35.5, 44.44],
+ [45.01, 45.99],
+ ],
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", self.admin, data_main_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", self.admin, data_main_task
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", self.user, data_assigneed_to_user_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}",
+ self.user,
+ data_assigneed_to_user_task,
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", None, data_main_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_interactor}", None, data_main_task
+ )
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
def test_api_v2_lambda_functions_create_tracker(self):
data_main_task = {
"task": self.main_task["id"],
"frame": 0,
"shape": [
- 12.12,
- 34.45,
- 54.0,
- 76.12,
- ],
+ 12.12,
+ 34.45,
+ 54.0,
+ 76.12,
+ ],
}
data_assigneed_to_user_task = {
"task": self.assigneed_to_user_task["id"],
"frame": 0,
"shape": [
- 12.12,
- 34.45,
- 54.0,
- 76.12,
- ],
+ 12.12,
+ 34.45,
+ 54.0,
+ 76.12,
+ ],
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", self.admin, data_main_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", self.admin, data_main_task
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", self.user, data_assigneed_to_user_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", self.user, data_assigneed_to_user_task
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", None, data_main_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_tracker}", None, data_main_task
+ )
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
def test_api_v2_lambda_functions_create_reid(self):
data_main_task = {
"task": self.main_task["id"],
"frame0": 0,
"frame1": 1,
"boxes0": [
- OrderedDict([('attributes', []), ('frame', 0), ('group', None), ('id', 11258), ('label_id', 8), ('occluded', False), ('path_id', 0), ('points', [137.0, 129.0, 457.0, 676.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]),
- OrderedDict([('attributes', []), ('frame', 0), ('group', None), ('id', 11259), ('label_id', 8), ('occluded', False), ('path_id', 1), ('points', [1511.0, 224.0, 1537.0, 437.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]),
+ OrderedDict(
+ [
+ ("attributes", []),
+ ("frame", 0),
+ ("group", None),
+ ("id", 11258),
+ ("label_id", 8),
+ ("occluded", False),
+ ("path_id", 0),
+ ("points", [137.0, 129.0, 457.0, 676.0]),
+ ("source", "auto"),
+ ("type", "rectangle"),
+ ("z_order", 0),
+ ]
+ ),
+ OrderedDict(
+ [
+ ("attributes", []),
+ ("frame", 0),
+ ("group", None),
+ ("id", 11259),
+ ("label_id", 8),
+ ("occluded", False),
+ ("path_id", 1),
+ ("points", [1511.0, 224.0, 1537.0, 437.0]),
+ ("source", "auto"),
+ ("type", "rectangle"),
+ ("z_order", 0),
+ ]
+ ),
],
"boxes1": [
- OrderedDict([('attributes', []), ('frame', 1), ('group', None), ('id', 11260), ('label_id', 8), ('occluded', False), ('points', [1076.0, 199.0, 1218.0, 593.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]),
- OrderedDict([('attributes', []), ('frame', 1), ('group', None), ('id', 11261), ('label_id', 8), ('occluded', False), ('points', [924.0, 177.0, 1090.0, 615.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]),
+ OrderedDict(
+ [
+ ("attributes", []),
+ ("frame", 1),
+ ("group", None),
+ ("id", 11260),
+ ("label_id", 8),
+ ("occluded", False),
+ ("points", [1076.0, 199.0, 1218.0, 593.0]),
+ ("source", "auto"),
+ ("type", "rectangle"),
+ ("z_order", 0),
+ ]
+ ),
+ OrderedDict(
+ [
+ ("attributes", []),
+ ("frame", 1),
+ ("group", None),
+ ("id", 11261),
+ ("label_id", 8),
+ ("occluded", False),
+ ("points", [924.0, 177.0, 1090.0, 615.0]),
+ ("source", "auto"),
+ ("type", "rectangle"),
+ ("z_order", 0),
+ ]
+ ),
],
"threshold": 0.5,
"max_distance": 55,
@@ -774,63 +887,154 @@ def test_api_v2_lambda_functions_create_reid(self):
"frame0": 0,
"frame1": 1,
"boxes0": [
- OrderedDict([('attributes', []), ('frame', 0), ('group', None), ('id', 11258), ('label_id', 8), ('occluded', False), ('path_id', 0), ('points', [137.0, 129.0, 457.0, 676.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]),
- OrderedDict([('attributes', []), ('frame', 0), ('group', None), ('id', 11259), ('label_id', 8), ('occluded', False), ('path_id', 1), ('points', [1511.0, 224.0, 1537.0, 437.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]),
+ OrderedDict(
+ [
+ ("attributes", []),
+ ("frame", 0),
+ ("group", None),
+ ("id", 11258),
+ ("label_id", 8),
+ ("occluded", False),
+ ("path_id", 0),
+ ("points", [137.0, 129.0, 457.0, 676.0]),
+ ("source", "auto"),
+ ("type", "rectangle"),
+ ("z_order", 0),
+ ]
+ ),
+ OrderedDict(
+ [
+ ("attributes", []),
+ ("frame", 0),
+ ("group", None),
+ ("id", 11259),
+ ("label_id", 8),
+ ("occluded", False),
+ ("path_id", 1),
+ ("points", [1511.0, 224.0, 1537.0, 437.0]),
+ ("source", "auto"),
+ ("type", "rectangle"),
+ ("z_order", 0),
+ ]
+ ),
],
"boxes1": [
- OrderedDict([('attributes', []), ('frame', 1), ('group', None), ('id', 11260), ('label_id', 8), ('occluded', False), ('points', [1076.0, 199.0, 1218.0, 593.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]),
- OrderedDict([('attributes', []), ('frame', 1), ('group', 0), ('id', 11398), ('label_id', 8), ('occluded', False), ('points', [184.3935546875, 211.5048828125, 331.64968722073354, 97.27792672028772, 445.87667560321825, 126.17873100983161, 454.13404825737416, 691.8087578194827, 180.26452189455085]), ('source', 'manual'), ('type', 'polygon'), ('z_order', 0)]),
+ OrderedDict(
+ [
+ ("attributes", []),
+ ("frame", 1),
+ ("group", None),
+ ("id", 11260),
+ ("label_id", 8),
+ ("occluded", False),
+ ("points", [1076.0, 199.0, 1218.0, 593.0]),
+ ("source", "auto"),
+ ("type", "rectangle"),
+ ("z_order", 0),
+ ]
+ ),
+ OrderedDict(
+ [
+ ("attributes", []),
+ ("frame", 1),
+ ("group", 0),
+ ("id", 11398),
+ ("label_id", 8),
+ ("occluded", False),
+ (
+ "points",
+ [
+ 184.3935546875,
+ 211.5048828125,
+ 331.64968722073354,
+ 97.27792672028772,
+ 445.87667560321825,
+ 126.17873100983161,
+ 454.13404825737416,
+ 691.8087578194827,
+ 180.26452189455085,
+ ],
+ ),
+ ("source", "manual"),
+ ("type", "polygon"),
+ ("z_order", 0),
+ ]
+ ),
],
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", self.admin, data_main_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}",
+ self.admin,
+ data_main_task,
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", self.user, data_assigneed_to_user_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}",
+ self.user,
+ data_assigneed_to_user_task,
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", None, data_main_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", None, data_main_task
+ )
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}", self.admin, data_main_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}",
+ self.admin,
+ data_main_task,
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}", self.user, data_assigneed_to_user_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}",
+ self.user,
+ data_assigneed_to_user_task,
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}", None, data_main_task)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_no_response_data}",
+ None,
+ data_main_task,
+ )
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
-
def test_api_v2_lambda_functions_create_negative(self):
data = {
"task": self.main_task["id"],
"frame": 0,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
for id_func in [
- id_function_non_type, id_function_wrong_type, id_function_unknown_type,
+ id_function_non_type,
+ id_function_wrong_type,
+ id_function_unknown_type,
id_function_non_unique_labels,
]:
with mock.patch(
- 'cvat.apps.lambda_manager.views.LambdaGateway._http',
- return_value=functions["negative"][id_func]
+ "cvat.apps.lambda_manager.views.LambdaGateway._http",
+ return_value=functions["negative"][id_func],
):
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_func}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_func}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
-
def test_api_v2_lambda_functions_convert_mask_to_rle(self):
data_main_task = {
"function": id_function_detector,
"task": self.main_task["id"],
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
response = self._post_request(LAMBDA_REQUESTS_PATH, self.admin, data_main_task)
@@ -839,7 +1043,7 @@ def test_api_v2_lambda_functions_convert_mask_to_rle(self):
request_status = "started"
while request_status != "finished" and request_status != "failed":
- response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{id_request}', self.admin)
+ response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{id_request}", self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
request_status = response.json().get("status")
self.assertEqual(request_status, "finished")
@@ -854,13 +1058,13 @@ def test_api_v2_lambda_functions_convert_mask_to_rle(self):
# [1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0] -> [0, 2, 2, 2, 2, 2, 2]
self.assertEqual(masks[0].get("points"), [0, 2, 2, 2, 2, 2, 2, 0, 0, 2, 3])
-
def test_api_v2_lambda_functions_create_empty_data(self):
data = {}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
def test_api_v2_lambda_functions_create_detector_empty_mapping(self):
data = {
"task": self.main_task["id"],
@@ -868,82 +1072,89 @@ def test_api_v2_lambda_functions_create_detector_empty_mapping(self):
"cleanup": True,
"mapping": {},
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
-
def test_api_v2_lambda_functions_create_detector_without_cleanup(self):
data = {
"task": self.main_task["id"],
"frame": 0,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
-
def test_api_v2_lambda_functions_create_detector_without_mapping(self):
data = {
"task": self.main_task["id"],
"frame": 0,
"cleanup": True,
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
-
def test_api_v2_lambda_functions_create_detector_without_task(self):
data = {
"frame": 0,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
def test_api_v2_lambda_functions_create_detector_without_id_frame(self):
data = {
"task": self.main_task["id"],
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
def test_api_v2_lambda_functions_create_wrong_id_function(self):
data = {
"task": self.main_task["id"],
"frame": 0,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/test-functions-wrong-id", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/test-functions-wrong-id", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
-
def test_api_v2_lambda_functions_create_wrong_id_task(self):
data = {
"task": 12345,
"frame": 0,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
@skip("Fail: expected result != actual result, issue #2770")
def test_api_v2_lambda_functions_create_detector_wrong_id_frame(self):
data = {
@@ -951,13 +1162,14 @@ def test_api_v2_lambda_functions_create_detector_wrong_id_frame(self):
"frame": 12345,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
-
@skip("Fail: add mock and expected result != actual result")
def test_api_v2_lambda_functions_create_two_functions(self):
data = {
@@ -965,27 +1177,32 @@ def test_api_v2_lambda_functions_create_two_functions(self):
"frame": 0,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_409_CONFLICT)
-
def test_api_v2_lambda_functions_create_function_is_not_ready(self):
data = {
"task": self.main_task["id"],
"frame": 0,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_building}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_building}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
- response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_error}", self.admin, data)
+ response = self._post_request(
+ f"{LAMBDA_FUNCTIONS_PATH}/{id_function_state_error}", self.admin, data
+ )
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
@@ -1038,29 +1255,27 @@ def setUp(self):
self.task = self._create_task(
task_spec={
- 'name': 'test_task',
- 'labels': [{'name': 'car'}],
- 'segment_size': segment_size
+ "name": "test_task",
+ "labels": [{"name": "car"}],
+ "segment_size": segment_size,
},
data=data,
- owner=self.user
+ owner=self.user,
)
self.task_rel_frame_range = range(len(range(start_frame, stop_frame, frame_step)))
self.start_frame = start_frame
self.frame_step = frame_step
self.segment_size = segment_size
- self.labels = get_paginated_collection(lambda page:
- self._get_request(
- f"/api/labels?task_id={self.task['id']}&page={page}&sort=id",
- self.admin
+ self.labels = get_paginated_collection(
+ lambda page: self._get_request(
+ f"/api/labels?task_id={self.task['id']}&page={page}&sort=id", self.admin
)
)
- self.jobs = get_paginated_collection(lambda page:
- self._get_request(
- f"/api/jobs?task_id={self.task['id']}&page={page}",
- self.admin
+ self.jobs = get_paginated_collection(
+ lambda page: self._get_request(
+ f"/api/jobs?task_id={self.task['id']}&page={page}", self.admin
)
)
@@ -1068,7 +1283,7 @@ def setUp(self):
self.reid_function_id = id_function_reid_with_response_data
self.common_request_data = {
- "task": self.task['id'],
+ "task": self.task["id"],
"cleanup": True,
}
@@ -1085,14 +1300,14 @@ def _run_offline_function(self, function_id, data, user):
def _wait_request(self, request_id: str) -> str:
request_status = "started"
while request_status != "finished" and request_status != "failed":
- response = self._get_request(f'{LAMBDA_REQUESTS_PATH}/{request_id}', self.admin)
+ response = self._get_request(f"{LAMBDA_REQUESTS_PATH}/{request_id}", self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
request_status = response.json().get("status")
return request_status
def _run_online_function(self, function_id, data, user):
- response = self._post_request(f'{LAMBDA_FUNCTIONS_PATH}/{function_id}', user, data)
+ response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{function_id}", user, data)
return response
def test_can_run_offline_detector_function_on_whole_task(self):
@@ -1108,40 +1323,39 @@ def test_can_run_offline_detector_function_on_whole_task(self):
requested_frame_range = self.task_rel_frame_range
self.assertEqual(
- {
- frame: 1 for frame in requested_frame_range
- },
+ {frame: 1 for frame in requested_frame_range},
{
frame: len(list(group))
for frame, group in groupby(annotations["shapes"], key=lambda a: a["frame"])
- }
+ },
)
def test_can_run_offline_reid_function_on_whole_task(self):
# Add starting shapes to be tracked on following frames
requested_frame_range = self.task_rel_frame_range
shape_template = {
- 'attributes': [],
- 'group': None,
- 'label_id': self.labels[0]["id"],
- 'occluded': False,
- 'points': [0, 5, 5, 0],
- 'source': 'manual',
- 'type': 'rectangle',
- 'z_order': 0,
+ "attributes": [],
+ "group": None,
+ "label_id": self.labels[0]["id"],
+ "occluded": False,
+ "points": [0, 5, 5, 0],
+ "source": "manual",
+ "type": "rectangle",
+ "z_order": 0,
}
- response = self._put_request(f'/api/tasks/{self.task["id"]}/annotations', self.admin, data={
- 'tags': [],
- 'shapes': [
- { 'frame': frame, **shape_template }
- for frame in requested_frame_range
- ],
- 'tracks': []
- })
+ response = self._put_request(
+ f'/api/tasks/{self.task["id"]}/annotations',
+ self.admin,
+ data={
+ "tags": [],
+ "shapes": [{"frame": frame, **shape_template} for frame in requested_frame_range],
+ "tracks": [],
+ },
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = self.common_request_data.copy()
- data["cleanup"] = False # cleanup is not compatible with reid
+ data["cleanup"] = False # cleanup is not compatible with reid
self._run_offline_function(self.reid_function_id, data, self.user)
response = self._get_request(f'/api/tasks/{self.task["id"]}/annotations', self.admin)
@@ -1154,25 +1368,24 @@ def test_can_run_offline_reid_function_on_whole_task(self):
[
# The single track will be split by job segments
{
- 'frame': job["start_frame"],
- 'shapes': [
- { 'frame': frame, 'outside': frame > job["stop_frame"] }
+ "frame": job["start_frame"],
+ "shapes": [
+ {"frame": frame, "outside": frame > job["stop_frame"]}
for frame in requested_frame_range
if frame in range(job["start_frame"], job["stop_frame"] + self.segment_size)
- ]
+ ],
}
for job in sorted(self.jobs, key=lambda j: j["start_frame"])
],
[
{
- 'frame': track['frame'],
- 'shapes': [
- filter_dict(shape, keep=['frame', 'outside'])
- for shape in track["shapes"]
- ]
+ "frame": track["frame"],
+ "shapes": [
+ filter_dict(shape, keep=["frame", "outside"]) for shape in track["shapes"]
+ ],
}
- for track in annotations['tracks']
- ]
+ for track in annotations["tracks"]
+ ],
)
def test_can_run_offline_detector_function_on_whole_job(self):
@@ -1190,13 +1403,11 @@ def test_can_run_offline_detector_function_on_whole_job(self):
requested_frame_range = range(job["start_frame"], job["stop_frame"] + 1)
self.assertEqual(
- {
- frame: 1 for frame in requested_frame_range
- },
+ {frame: 1 for frame in requested_frame_range},
{
frame: len(list(group))
for frame, group in groupby(annotations["shapes"], key=lambda a: a["frame"])
- }
+ },
)
def test_can_run_offline_reid_function_on_whole_job(self):
@@ -1205,27 +1416,28 @@ def test_can_run_offline_reid_function_on_whole_job(self):
# Add starting shapes to be tracked on following frames
shape_template = {
- 'attributes': [],
- 'group': None,
- 'label_id': self.labels[0]["id"],
- 'occluded': False,
- 'points': [0, 5, 5, 0],
- 'source': 'manual',
- 'type': 'rectangle',
- 'z_order': 0,
+ "attributes": [],
+ "group": None,
+ "label_id": self.labels[0]["id"],
+ "occluded": False,
+ "points": [0, 5, 5, 0],
+ "source": "manual",
+ "type": "rectangle",
+ "z_order": 0,
}
- response = self._put_request(f'/api/jobs/{job["id"]}/annotations', self.admin, data={
- 'tags': [],
- 'shapes': [
- { 'frame': frame, **shape_template }
- for frame in requested_frame_range
- ],
- 'tracks': []
- })
+ response = self._put_request(
+ f'/api/jobs/{job["id"]}/annotations',
+ self.admin,
+ data={
+ "tags": [],
+ "shapes": [{"frame": frame, **shape_template} for frame in requested_frame_range],
+ "tracks": [],
+ },
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = self.common_request_data.copy()
- data["cleanup"] = False # cleanup is not compatible with reid
+ data["cleanup"] = False # cleanup is not compatible with reid
data["job"] = job["id"]
self._run_offline_function(self.reid_function_id, data, self.user)
@@ -1238,34 +1450,37 @@ def test_can_run_offline_reid_function_on_whole_job(self):
self.assertEqual(
[
{
- 'frame': job["start_frame"],
- 'shapes': [
- { 'frame': frame, 'outside': frame > job["stop_frame"] }
+ "frame": job["start_frame"],
+ "shapes": [
+ {"frame": frame, "outside": frame > job["stop_frame"]}
for frame in requested_frame_range
if frame in range(job["start_frame"], job["stop_frame"] + self.segment_size)
- ]
+ ],
}
],
[
{
- 'frame': track['frame'],
- 'shapes': [
- filter_dict(shape, keep=['frame', 'outside'])
- for shape in track["shapes"]
- ]
+ "frame": track["frame"],
+ "shapes": [
+ filter_dict(shape, keep=["frame", "outside"]) for shape in track["shapes"]
+ ],
}
- for track in annotations['tracks']
- ]
+ for track in annotations["tracks"]
+ ],
)
def test_can_run_offline_detector_function_on_whole_gt_job(self):
requested_frame_range = self.task_rel_frame_range[::3]
- response = self._post_request("/api/jobs", self.admin, data={
- "type": "ground_truth",
- "task_id": self.task["id"],
- "frame_selection_method": "manual",
- "frames": list(requested_frame_range),
- })
+ response = self._post_request(
+ "/api/jobs",
+ self.admin,
+ data={
+ "type": "ground_truth",
+ "task_id": self.task["id"],
+ "frame_selection_method": "manual",
+ "frames": list(requested_frame_range),
+ },
+ )
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
job = response.json()
@@ -1281,49 +1496,54 @@ def test_can_run_offline_detector_function_on_whole_gt_job(self):
self.assertEqual(len(annotations["tracks"]), 0)
self.assertEqual(
- { frame: 1 for frame in requested_frame_range },
- Counter(a["frame"] for a in annotations["shapes"])
+ {frame: 1 for frame in requested_frame_range},
+ Counter(a["frame"] for a in annotations["shapes"]),
)
response = self._get_request(f'/api/tasks/{self.task["id"]}/annotations', self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
annotations = response.json()
- self.assertEqual(annotations, {'version': 0, 'tags': [], 'shapes': [], 'tracks': []})
+ self.assertEqual(annotations, {"version": 0, "tags": [], "shapes": [], "tracks": []})
def test_can_run_offline_reid_function_on_whole_gt_job(self):
requested_frame_range = self.task_rel_frame_range[::3]
- response = self._post_request("/api/jobs", self.admin, data={
- "type": "ground_truth",
- "task_id": self.task["id"],
- "frame_selection_method": "manual",
- "frames": list(requested_frame_range),
- })
+ response = self._post_request(
+ "/api/jobs",
+ self.admin,
+ data={
+ "type": "ground_truth",
+ "task_id": self.task["id"],
+ "frame_selection_method": "manual",
+ "frames": list(requested_frame_range),
+ },
+ )
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
job = response.json()
# Add starting shapes to be tracked on following frames
shape_template = {
- 'attributes': [],
- 'group': None,
- 'label_id': self.labels[0]["id"],
- 'occluded': False,
- 'points': [0, 5, 5, 0],
- 'source': 'manual',
- 'type': 'rectangle',
- 'z_order': 0,
+ "attributes": [],
+ "group": None,
+ "label_id": self.labels[0]["id"],
+ "occluded": False,
+ "points": [0, 5, 5, 0],
+ "source": "manual",
+ "type": "rectangle",
+ "z_order": 0,
}
- response = self._put_request(f'/api/jobs/{job["id"]}/annotations', self.admin, data={
- 'tags': [],
- 'shapes': [
- { 'frame': frame, **shape_template }
- for frame in requested_frame_range
- ],
- 'tracks': []
- })
+ response = self._put_request(
+ f'/api/jobs/{job["id"]}/annotations',
+ self.admin,
+ data={
+ "tags": [],
+ "shapes": [{"frame": frame, **shape_template} for frame in requested_frame_range],
+ "tracks": [],
+ },
+ )
self.assertEqual(response.status_code, status.HTTP_200_OK)
data = self.common_request_data.copy()
- data["cleanup"] = False # cleanup is not compatible with reid
+ data["cleanup"] = False # cleanup is not compatible with reid
data["job"] = job["id"]
self._run_offline_function(self.reid_function_id, data, self.user)
@@ -1336,38 +1556,41 @@ def test_can_run_offline_reid_function_on_whole_gt_job(self):
self.assertEqual(
[
{
- 'frame': job["start_frame"],
- 'shapes': [
- { 'frame': frame, 'outside': frame > job["stop_frame"] }
+ "frame": job["start_frame"],
+ "shapes": [
+ {"frame": frame, "outside": frame > job["stop_frame"]}
for frame in requested_frame_range
if frame in range(job["start_frame"], job["stop_frame"] + self.segment_size)
- ]
+ ],
}
],
[
{
- 'frame': track['frame'],
- 'shapes': [
- filter_dict(shape, keep=['frame', 'outside'])
- for shape in track["shapes"]
- ]
+ "frame": track["frame"],
+ "shapes": [
+ filter_dict(shape, keep=["frame", "outside"]) for shape in track["shapes"]
+ ],
}
- for track in annotations['tracks']
- ]
+ for track in annotations["tracks"]
+ ],
)
response = self._get_request(f'/api/tasks/{self.task["id"]}/annotations', self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
annotations = response.json()
- self.assertEqual(annotations, {'version': 0, 'tags': [], 'shapes': [], 'tracks': []})
+ self.assertEqual(annotations, {"version": 0, "tags": [], "shapes": [], "tracks": []})
def test_offline_function_run_on_task_does_not_affect_gt_job(self):
- response = self._post_request("/api/jobs", self.admin, data={
- "type": "ground_truth",
- "task_id": self.task["id"],
- "frame_selection_method": "manual",
- "frames": list(self.task_rel_frame_range[::3]),
- })
+ response = self._post_request(
+ "/api/jobs",
+ self.admin,
+ data={
+ "type": "ground_truth",
+ "task_id": self.task["id"],
+ "frame_selection_method": "manual",
+ "frames": list(self.task_rel_frame_range[::3]),
+ },
+ )
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
job = response.json()
@@ -1383,14 +1606,14 @@ def test_offline_function_run_on_task_does_not_affect_gt_job(self):
requested_frame_range = self.task_rel_frame_range
self.assertEqual(
- { frame: 1 for frame in requested_frame_range },
- Counter(a["frame"] for a in annotations["shapes"])
+ {frame: 1 for frame in requested_frame_range},
+ Counter(a["frame"] for a in annotations["shapes"]),
)
response = self._get_request(f'/api/jobs/{job["id"]}/annotations', self.admin)
self.assertEqual(response.status_code, status.HTTP_200_OK)
annotations = response.json()
- self.assertEqual(annotations, {'version': 0, 'tags': [], 'shapes': [], 'tracks': []})
+ self.assertEqual(annotations, {"version": 0, "tags": [], "shapes": [], "tracks": []})
def test_can_run_online_function_on_valid_task_frame(self):
data = self.common_request_data.copy()
@@ -1441,70 +1664,87 @@ class Issue4996_Cases(_LambdaTestCaseBase):
# This requires to pass the job id in the call request.
def _create_org(self, *, owner: int, members: dict[int, str] = None) -> dict:
- org = self._post_request('/api/organizations', user=owner, data={
- "slug": "testorg",
- "name": "test Org",
- })
+ org = self._post_request(
+ "/api/organizations",
+ user=owner,
+ data={
+ "slug": "testorg",
+ "name": "test Org",
+ },
+ )
assert org.status_code == status.HTTP_201_CREATED
org = org.json()
for uid, role in members.items():
- user = self._get_request('/api/users/self', user=uid)
+ user = self._get_request("/api/users/self", user=uid)
assert user.status_code == status.HTTP_200_OK
user = user.json()
- invitation = self._post_request('/api/invitations', user=owner, data={
- 'email': user['email'],
- 'role': role,
- }, org_id=org['id'])
+ invitation = self._post_request(
+ "/api/invitations",
+ user=owner,
+ data={
+ "email": user["email"],
+ "role": role,
+ },
+ org_id=org["id"],
+ )
assert invitation.status_code == status.HTTP_201_CREATED
return org
- def _set_task_assignee(self, task: int, assignee: Optional[int], *,
- org_id: Optional[int] = None):
- response = self._patch_request(f'/api/tasks/{task}', user=self.admin, data={
- 'assignee_id': assignee,
- }, org_id=org_id)
+ def _set_task_assignee(
+ self, task: int, assignee: Optional[int], *, org_id: Optional[int] = None
+ ):
+ response = self._patch_request(
+ f"/api/tasks/{task}",
+ user=self.admin,
+ data={
+ "assignee_id": assignee,
+ },
+ org_id=org_id,
+ )
assert response.status_code == status.HTTP_200_OK
- def _set_job_assignee(self, job: int, assignee: Optional[int], *,
- org_id: Optional[int] = None):
- response = self._patch_request(f'/api/jobs/{job}', user=self.admin, data={
- 'assignee': assignee,
- }, org_id=org_id)
+ def _set_job_assignee(self, job: int, assignee: Optional[int], *, org_id: Optional[int] = None):
+ response = self._patch_request(
+ f"/api/jobs/{job}",
+ user=self.admin,
+ data={
+ "assignee": assignee,
+ },
+ org_id=org_id,
+ )
assert response.status_code == status.HTTP_200_OK
def setUp(self):
super().setUp()
- self.org = self._create_org(owner=self.admin, members={self.user: 'worker'})
+ self.org = self._create_org(owner=self.admin, members={self.user: "worker"})
- task = self._create_task(task_spec={
- 'name': 'test_task',
- 'labels': [{'name': 'car'}],
- 'segment_size': 2
- },
+ task = self._create_task(
+ task_spec={"name": "test_task", "labels": [{"name": "car"}], "segment_size": 2},
data=self._generate_task_images(6),
owner=self.admin,
- org_id=self.org['id'],
+ org_id=self.org["id"],
)
self.task = task
- jobs = get_paginated_collection(lambda page:
- self._get_request(
+ jobs = get_paginated_collection(
+ lambda page: self._get_request(
f"/api/jobs?task_id={self.task['id']}&page={page}",
- self.admin, org_id=self.org['id']
+ self.admin,
+ org_id=self.org["id"],
)
)
self.job = jobs[1]
self.common_request_data = {
- "task": self.task['id'],
+ "task": self.task["id"],
"frame": 0,
"cleanup": True,
"mapping": {
- "car": { "name": "car" },
+ "car": {"name": "car"},
},
}
@@ -1512,75 +1752,70 @@ def setUp(self):
def _get_valid_job_request_data(self):
data = self.common_request_data.copy()
- data.update({
- "job": self.job['id'],
- "frame": 2
- })
+ data.update({"job": self.job["id"], "frame": 2})
return data
def _get_invalid_job_request_data(self):
data = self.common_request_data.copy()
- data.update({
- "job": self.job['id'],
- "frame": 0
- })
+ data.update({"job": self.job["id"], "frame": 0})
return data
- def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_task_request(self):
+ def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_task_request(
+ self,
+ ):
data = self.common_request_data.copy()
with self.subTest(job=None, assignee=None):
- response = self._post_request(self.function_url, self.user, data,
- org_id=self.org['id'])
+ response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"])
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_can_call_function_for_job_worker_in_org__deny_unassigned_worker_with_job_request(self):
data = self._get_valid_job_request_data()
- with self.subTest(job='defined', assignee=None):
- response = self._post_request(self.function_url, self.user, data,
- org_id=self.org['id'])
+ with self.subTest(job="defined", assignee=None):
+ response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"])
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- def test_can_call_function_for_job_worker_in_org__allow_task_assigned_worker_with_task_request(self):
- self._set_task_assignee(self.task['id'], self.user.id, org_id=self.org['id'])
+ def test_can_call_function_for_job_worker_in_org__allow_task_assigned_worker_with_task_request(
+ self,
+ ):
+ self._set_task_assignee(self.task["id"], self.user.id, org_id=self.org["id"])
data = self.common_request_data.copy()
- with self.subTest(job=None, assignee='task'):
- response = self._post_request(self.function_url, self.user, data,
- org_id=self.org['id'])
+ with self.subTest(job=None, assignee="task"):
+ response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"])
self.assertEqual(response.status_code, status.HTTP_200_OK)
- def test_can_call_function_for_job_worker_in_org__deny_job_assigned_worker_with_task_request(self):
- self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id'])
+ def test_can_call_function_for_job_worker_in_org__deny_job_assigned_worker_with_task_request(
+ self,
+ ):
+ self._set_job_assignee(self.job["id"], self.user.id, org_id=self.org["id"])
data = self.common_request_data.copy()
- with self.subTest(job=None, assignee='job'):
- response = self._post_request(self.function_url, self.user, data,
- org_id=self.org['id'])
+ with self.subTest(job=None, assignee="job"):
+ response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"])
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- def test_can_call_function_for_job_worker_in_org__allow_job_assigned_worker_with_job_request(self):
- self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id'])
+ def test_can_call_function_for_job_worker_in_org__allow_job_assigned_worker_with_job_request(
+ self,
+ ):
+ self._set_job_assignee(self.job["id"], self.user.id, org_id=self.org["id"])
data = self._get_valid_job_request_data()
- with self.subTest(job='defined', assignee='job'):
- response = self._post_request(self.function_url, self.user, data,
- org_id=self.org['id'])
+ with self.subTest(job="defined", assignee="job"):
+ response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"])
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_can_check_job_boundaries_in_function_call__fail_for_frame_outside_job(self):
- self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id'])
+ self._set_job_assignee(self.job["id"], self.user.id, org_id=self.org["id"])
data = self._get_invalid_job_request_data()
- with self.subTest(job='defined', frame='outside'):
- response = self._post_request(self.function_url, self.user, data,
- org_id=self.org['id'])
+ with self.subTest(job="defined", frame="outside"):
+ response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"])
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
def test_can_check_job_boundaries_in_function_call__ok_for_frame_inside_job(self):
- self._set_job_assignee(self.job['id'], self.user.id, org_id=self.org['id'])
+ self._set_job_assignee(self.job["id"], self.user.id, org_id=self.org["id"])
data = self._get_valid_job_request_data()
- with self.subTest(job='defined', frame='inside'):
- response = self._post_request(self.function_url, self.user, data,
- org_id=self.org['id'])
+ with self.subTest(job="defined", frame="inside"):
+ response = self._post_request(self.function_url, self.user, data, org_id=self.org["id"])
self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/cvat/apps/lambda_manager/urls.py b/cvat/apps/lambda_manager/urls.py
index 6dae0edaca76..261592a9f469 100644
--- a/cvat/apps/lambda_manager/urls.py
+++ b/cvat/apps/lambda_manager/urls.py
@@ -12,9 +12,9 @@
# I want to "call" my functions. To do that need to map my call method to
# POST (like get HTTP method is mapped to list(...)). One way is to implement
# own CustomRouter. But it is simpler just patch the router instance here.
-router.routes[2].mapping.update({'post': 'call'})
-router.register('functions', views.FunctionViewSet, basename='lambda_function')
-router.register('requests', views.RequestViewSet, basename='lambda_request')
+router.routes[2].mapping.update({"post": "call"})
+router.register("functions", views.FunctionViewSet, basename="lambda_function")
+router.register("requests", views.RequestViewSet, basename="lambda_request")
# GET /api/lambda/functions - get list of functions
# GET /api/lambda/functions/ - get information about the function
@@ -24,6 +24,4 @@
# GET /api/lambda/requests - get list of requests
# GET /api/lambda/requests/ - get status of the request
# DEL /api/lambda/requests/ - cancel a request (don't delete)
-urlpatterns = [
- path('api/lambda/', include(router.urls))
-]
+urlpatterns = [path("api/lambda/", include(router.urls))]
diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py
index 559ef29813b5..465414e243a5 100644
--- a/cvat/apps/lambda_manager/views.py
+++ b/cvat/apps/lambda_manager/views.py
@@ -19,53 +19,75 @@
import numpy as np
import requests
import rq
-from cvat.apps.events.handlers import handle_function_call
-from cvat.apps.lambda_manager.signals import interactive_function_call_signal
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist, ValidationError
from drf_spectacular.types import OpenApiTypes
-from drf_spectacular.utils import (OpenApiParameter, OpenApiResponse,
- extend_schema, extend_schema_view,
- inline_serializer)
+from drf_spectacular.utils import (
+ OpenApiParameter,
+ OpenApiResponse,
+ extend_schema,
+ extend_schema_view,
+ inline_serializer,
+)
from rest_framework import serializers, status, viewsets
-from rest_framework.response import Response
from rest_framework.request import Request
+from rest_framework.response import Response
import cvat.apps.dataset_manager as dm
from cvat.apps.engine.frame_provider import TaskFrameProvider
+from cvat.apps.engine.log import ServerLogManager
from cvat.apps.engine.models import (
- Job, ShapeType, SourceType, Task, Label, RequestAction, RequestTarget
+ Job,
+ Label,
+ RequestAction,
+ RequestTarget,
+ ShapeType,
+ SourceType,
+ Task,
)
from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField
from cvat.apps.engine.serializers import LabeledDataSerializer
+from cvat.apps.engine.utils import define_dependent_job, get_rq_job_meta, get_rq_lock_by_user
+from cvat.apps.events.handlers import handle_function_call
+from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
from cvat.apps.lambda_manager.models import FunctionKind
from cvat.apps.lambda_manager.permissions import LambdaPermission
from cvat.apps.lambda_manager.serializers import (
- FunctionCallRequestSerializer, FunctionCallSerializer
+ FunctionCallRequestSerializer,
+ FunctionCallSerializer,
)
-from cvat.apps.engine.log import ServerLogManager
-from cvat.apps.engine.utils import define_dependent_job, get_rq_job_meta, get_rq_lock_by_user
+from cvat.apps.lambda_manager.signals import interactive_function_call_signal
from cvat.utils.http import make_requests_session
-from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
slogger = ServerLogManager(__name__)
+
class LambdaGateway:
- NUCLIO_ROOT_URL = '/api/functions'
-
- def _http(self, method="get", scheme=None, host=None, port=None,
- function_namespace=None, url=None, headers=None, data=None):
- NUCLIO_GATEWAY = '{}://{}:{}'.format(
- scheme or settings.NUCLIO['SCHEME'],
- host or settings.NUCLIO['HOST'],
- port or settings.NUCLIO['PORT'])
- NUCLIO_FUNCTION_NAMESPACE = function_namespace or settings.NUCLIO['FUNCTION_NAMESPACE']
- NUCLIO_TIMEOUT = settings.NUCLIO['DEFAULT_TIMEOUT']
+ NUCLIO_ROOT_URL = "/api/functions"
+
+ def _http(
+ self,
+ method="get",
+ scheme=None,
+ host=None,
+ port=None,
+ function_namespace=None,
+ url=None,
+ headers=None,
+ data=None,
+ ):
+ NUCLIO_GATEWAY = "{}://{}:{}".format(
+ scheme or settings.NUCLIO["SCHEME"],
+ host or settings.NUCLIO["HOST"],
+ port or settings.NUCLIO["PORT"],
+ )
+ NUCLIO_FUNCTION_NAMESPACE = function_namespace or settings.NUCLIO["FUNCTION_NAMESPACE"]
+ NUCLIO_TIMEOUT = settings.NUCLIO["DEFAULT_TIMEOUT"]
extra_headers = {
- 'x-nuclio-project-name': 'cvat',
- 'x-nuclio-function-namespace': NUCLIO_FUNCTION_NAMESPACE,
- 'x-nuclio-invoke-via': 'domain-name',
- 'X-Nuclio-Invoke-Timeout': f"{NUCLIO_TIMEOUT}s",
+ "x-nuclio-project-name": "cvat",
+ "x-nuclio-function-namespace": NUCLIO_FUNCTION_NAMESPACE,
+ "x-nuclio-invoke-via": "domain-name",
+ "X-Nuclio-Invoke-Timeout": f"{NUCLIO_TIMEOUT}s",
}
if headers:
extra_headers.update(headers)
@@ -76,8 +98,9 @@ def _http(self, method="get", scheme=None, host=None, port=None,
url = NUCLIO_GATEWAY
with make_requests_session() as session:
- reply = session.request(method, url, headers=extra_headers,
- timeout=NUCLIO_TIMEOUT, json=data)
+ reply = session.request(
+ method, url, headers=extra_headers, timeout=NUCLIO_TIMEOUT, json=data
+ )
reply.raise_for_status()
response = reply.json()
@@ -92,32 +115,33 @@ def list(self):
slogger.glob.error("Failed to parse lambda function metadata", exc_info=True)
def get(self, func_id):
- data = self._http(url=self.NUCLIO_ROOT_URL + '/' + func_id)
+ data = self._http(url=self.NUCLIO_ROOT_URL + "/" + func_id)
response = LambdaFunction(self, data)
return response
def invoke(self, func, payload):
invoke_method = {
- 'dashboard': self._invoke_via_dashboard,
- 'direct': self._invoke_directly,
+ "dashboard": self._invoke_via_dashboard,
+ "direct": self._invoke_directly,
}
- return invoke_method[settings.NUCLIO['INVOKE_METHOD']](func, payload)
+ return invoke_method[settings.NUCLIO["INVOKE_METHOD"]](func, payload)
def _invoke_via_dashboard(self, func, payload):
- return self._http(method="post", url='/api/function_invocations',
- data=payload, headers={
- 'x-nuclio-function-name': func.id,
- 'x-nuclio-path': '/'
- })
+ return self._http(
+ method="post",
+ url="/api/function_invocations",
+ data=payload,
+ headers={"x-nuclio-function-name": func.id, "x-nuclio-path": "/"},
+ )
def _invoke_directly(self, func, payload):
# host.docker.internal for Linux will work only with Docker 20.10+
- NUCLIO_TIMEOUT = settings.NUCLIO['DEFAULT_TIMEOUT']
- if os.path.exists('/.dockerenv'): # inside a docker container
- url = f'http://host.docker.internal:{func.port}'
+ NUCLIO_TIMEOUT = settings.NUCLIO["DEFAULT_TIMEOUT"]
+ if os.path.exists("/.dockerenv"): # inside a docker container
+ url = f"http://host.docker.internal:{func.port}"
else:
- url = f'http://localhost:{func.port}'
+ url = f"http://localhost:{func.port}"
with make_requests_session() as session:
reply = session.post(url, timeout=NUCLIO_TIMEOUT, json=payload)
@@ -126,105 +150,119 @@ def _invoke_directly(self, func, payload):
return response
+
class InvalidFunctionMetadataError(Exception):
pass
+
class LambdaFunction:
FRAME_PARAMETERS = (
- ('frame', 'frame'),
- ('frame0', 'start frame'),
- ('frame1', 'end frame'),
+ ("frame", "frame"),
+ ("frame0", "start frame"),
+ ("frame1", "end frame"),
)
def __init__(self, gateway, data):
# ID of the function (e.g. omz.public.yolo-v3)
- self.id = data['metadata']['name']
+ self.id = data["metadata"]["name"]
# type of the function (e.g. detector, interactor)
- meta_anno = data['metadata']['annotations']
- kind = meta_anno.get('type')
+ meta_anno = data["metadata"]["annotations"]
+ kind = meta_anno.get("type")
try:
self.kind = FunctionKind(kind)
except ValueError as e:
raise InvalidFunctionMetadataError(
- f"{self.id} lambda function has unknown type: {kind!r}") from e
+ f"{self.id} lambda function has unknown type: {kind!r}"
+ ) from e
# dictionary of labels for the function (e.g. car, person)
- spec = json.loads(meta_anno.get('spec') or '[]')
+ spec = json.loads(meta_anno.get("spec") or "[]")
def parse_labels(spec):
def parse_attributes(attrs_spec):
- parsed_attributes = [{
- 'name': attr['name'],
- 'input_type': attr['input_type'],
- 'values': attr['values'],
- } for attr in attrs_spec]
-
- if len(parsed_attributes) != len({attr['name'] for attr in attrs_spec}):
+ parsed_attributes = [
+ {
+ "name": attr["name"],
+ "input_type": attr["input_type"],
+ "values": attr["values"],
+ }
+ for attr in attrs_spec
+ ]
+
+ if len(parsed_attributes) != len({attr["name"] for attr in attrs_spec}):
raise InvalidFunctionMetadataError(
- f"{self.id} lambda function has non-unique attributes")
+ f"{self.id} lambda function has non-unique attributes"
+ )
return parsed_attributes
parsed_labels = []
for label in spec:
parsed_label = {
- 'name': label['name'],
- 'type': label.get('type', 'unknown'),
- 'attributes': parse_attributes(label.get('attributes', []))
+ "name": label["name"],
+ "type": label.get("type", "unknown"),
+ "attributes": parse_attributes(label.get("attributes", [])),
}
- if parsed_label['type'] == 'skeleton':
- parsed_label.update({
- 'sublabels': parse_labels(label['sublabels']),
- 'svg': label['svg']
- })
+ if parsed_label["type"] == "skeleton":
+ parsed_label.update(
+ {"sublabels": parse_labels(label["sublabels"]), "svg": label["svg"]}
+ )
parsed_labels.append(parsed_label)
- if len(parsed_labels) != len({label['name'] for label in spec}):
+ if len(parsed_labels) != len({label["name"] for label in spec}):
raise InvalidFunctionMetadataError(
- f"{self.id} lambda function has non-unique labels")
+ f"{self.id} lambda function has non-unique labels"
+ )
return parsed_labels
self.labels = parse_labels(spec)
# mapping of labels and corresponding supported attributes
- self.func_attributes = {item['name']: item.get('attributes', []) for item in spec}
+ self.func_attributes = {item["name"]: item.get("attributes", []) for item in spec}
for label, attributes in self.func_attributes.items():
- if len([attr['name'] for attr in attributes]) != len(set([attr['name'] for attr in attributes])):
+ if len([attr["name"] for attr in attributes]) != len(
+ set([attr["name"] for attr in attributes])
+ ):
raise InvalidFunctionMetadataError(
- "`{}` lambda function has non-unique attributes for label {}".format(self.id, label))
+ "`{}` lambda function has non-unique attributes for label {}".format(
+ self.id, label
+ )
+ )
# description of the function
- self.description = data['spec']['description']
+ self.description = data["spec"]["description"]
# http port to access the serverless function
self.port = data["status"].get("httpPort")
# display name for the function
- self.name = meta_anno.get('name', self.id)
- self.min_pos_points = int(meta_anno.get('min_pos_points', 1))
- self.min_neg_points = int(meta_anno.get('min_neg_points', -1))
- self.startswith_box = bool(meta_anno.get('startswith_box', False))
- self.startswith_box_optional = bool(meta_anno.get('startswith_box_optional', False))
- self.animated_gif = meta_anno.get('animated_gif', '')
- self.version = int(meta_anno.get('version', '1'))
- self.help_message = meta_anno.get('help_message', '')
+ self.name = meta_anno.get("name", self.id)
+ self.min_pos_points = int(meta_anno.get("min_pos_points", 1))
+ self.min_neg_points = int(meta_anno.get("min_neg_points", -1))
+ self.startswith_box = bool(meta_anno.get("startswith_box", False))
+ self.startswith_box_optional = bool(meta_anno.get("startswith_box_optional", False))
+ self.animated_gif = meta_anno.get("animated_gif", "")
+ self.version = int(meta_anno.get("version", "1"))
+ self.help_message = meta_anno.get("help_message", "")
self.gateway = gateway
def to_dict(self):
response = {
- 'id': self.id,
- 'kind': str(self.kind),
- 'labels_v2': self.labels,
- 'description': self.description,
- 'name': self.name,
- 'version': self.version
+ "id": self.id,
+ "kind": str(self.kind),
+ "labels_v2": self.labels,
+ "description": self.description,
+ "name": self.name,
+ "version": self.version,
}
if self.kind is FunctionKind.INTERACTOR:
- response.update({
- 'min_pos_points': self.min_pos_points,
- 'min_neg_points': self.min_neg_points,
- 'startswith_box': self.startswith_box,
- 'startswith_box_optional': self.startswith_box_optional,
- 'help_message': self.help_message,
- 'animated_gif': self.animated_gif
- })
+ response.update(
+ {
+ "min_pos_points": self.min_pos_points,
+ "min_neg_points": self.min_neg_points,
+ "startswith_box": self.startswith_box,
+ "startswith_box_optional": self.startswith_box_optional,
+ "help_message": self.help_message,
+ "animated_gif": self.animated_gif,
+ }
+ )
return response
@@ -235,62 +273,75 @@ def invoke(
*,
db_job: Optional[Job] = None,
is_interactive: Optional[bool] = False,
- request: Optional[Request] = None
+ request: Optional[Request] = None,
):
if db_job is not None and db_job.get_task_id() != db_task.id:
- raise ValidationError("Job task id does not match task id",
- code=status.HTTP_400_BAD_REQUEST
+ raise ValidationError(
+ "Job task id does not match task id", code=status.HTTP_400_BAD_REQUEST
)
payload = {}
- data = {k: v for k,v in data.items() if v is not None}
+ data = {k: v for k, v in data.items() if v is not None}
def mandatory_arg(name: str) -> Any:
try:
return data[name]
except KeyError:
raise ValidationError(
- "`{}` lambda function was called without mandatory argument: {}"
- .format(self.id, name),
- code=status.HTTP_400_BAD_REQUEST)
+ "`{}` lambda function was called without mandatory argument: {}".format(
+ self.id, name
+ ),
+ code=status.HTTP_400_BAD_REQUEST,
+ )
threshold = data.get("threshold")
if threshold:
- payload.update({ "threshold": threshold })
+ payload.update({"threshold": threshold})
mapping = data.get("mapping", {})
model_labels = self.labels
task_labels = db_task.get_labels(prefetch=True)
def labels_compatible(model_label: dict, task_label: Label) -> bool:
- model_type = model_label['type']
+ model_type = model_label["type"]
db_type = task_label.type
compatible_types = [[ShapeType.MASK, ShapeType.POLYGON]]
- return model_type == db_type or \
- (db_type == 'any' and model_type != 'skeleton') or \
- (model_type == 'unknown' and db_type != 'skeleton') or \
- any([model_type in compatible and db_type in compatible for compatible in compatible_types])
+ return (
+ model_type == db_type
+ or (db_type == "any" and model_type != "skeleton")
+ or (model_type == "unknown" and db_type != "skeleton")
+ or any(
+ [
+ model_type in compatible and db_type in compatible
+ for compatible in compatible_types
+ ]
+ )
+ )
def make_default_mapping(model_labels, task_labels):
mapping_by_default = {}
for model_label in model_labels:
for task_label in task_labels:
- if task_label.name == model_label['name'] and labels_compatible(model_label, task_label):
+ if task_label.name == model_label["name"] and labels_compatible(
+ model_label, task_label
+ ):
attributes_default_mapping = {}
- for model_attr in model_label.get('attributes', {}):
+ for model_attr in model_label.get("attributes", {}):
for db_attr in task_label.attributespec_set.all():
- if db_attr.name == model_attr['name']:
- attributes_default_mapping[model_attr['name']] = db_attr.name
+ if db_attr.name == model_attr["name"]:
+ attributes_default_mapping[model_attr["name"]] = db_attr.name
- mapping_by_default[model_label['name']] = {
- 'name': task_label.name,
- 'attributes': attributes_default_mapping,
+ mapping_by_default[model_label["name"]] = {
+ "name": task_label.name,
+ "attributes": attributes_default_mapping,
}
- if model_label['type'] == 'skeleton' and task_label.type == 'skeleton':
- mapping_by_default[model_label['name']]['sublabels'] = make_default_mapping(
- model_label['sublabels'],
- task_label.sublabels.all(),
+ if model_label["type"] == "skeleton" and task_label.type == "skeleton":
+ mapping_by_default[model_label["name"]]["sublabels"] = (
+ make_default_mapping(
+ model_label["sublabels"],
+ task_label.sublabels.all(),
+ )
)
return mapping_by_default
@@ -298,39 +349,43 @@ def make_default_mapping(model_labels, task_labels):
def update_mapping(_mapping, _model_labels, _db_labels):
copy = deepcopy(_mapping)
for model_label_name, mapping_item in copy.items():
- md_label = next(filter(lambda x: x['name'] == model_label_name, _model_labels))
- db_label = next(filter(lambda x: x.name == mapping_item['name'], _db_labels))
- mapping_item.setdefault('attributes', {})
- mapping_item['md_label'] = md_label
- mapping_item['db_label'] = db_label
- if md_label['type'] == 'skeleton' and db_label.type == 'skeleton':
- mapping_item['sublabels'] = update_mapping(
- mapping_item['sublabels'],
- md_label['sublabels'],
- db_label.sublabels.all()
+ md_label = next(filter(lambda x: x["name"] == model_label_name, _model_labels))
+ db_label = next(filter(lambda x: x.name == mapping_item["name"], _db_labels))
+ mapping_item.setdefault("attributes", {})
+ mapping_item["md_label"] = md_label
+ mapping_item["db_label"] = db_label
+ if md_label["type"] == "skeleton" and db_label.type == "skeleton":
+ mapping_item["sublabels"] = update_mapping(
+ mapping_item["sublabels"], md_label["sublabels"], db_label.sublabels.all()
)
return copy
def validate_labels_mapping(_mapping, _model_labels, _db_labels):
def validate_attributes_mapping(attributes_mapping, model_attributes, db_attributes):
db_attr_names = [attr.name for attr in db_attributes]
- model_attr_names = [attr['name'] for attr in model_attributes]
+ model_attr_names = [attr["name"] for attr in model_attributes]
for model_attr in attributes_mapping:
task_attr = attributes_mapping[model_attr]
if model_attr not in model_attr_names:
- raise ValidationError(f'Invalid mapping. Unknown model attribute "{model_attr}"')
+ raise ValidationError(
+ f'Invalid mapping. Unknown model attribute "{model_attr}"'
+ )
if task_attr not in db_attr_names:
- raise ValidationError(f'Invalid mapping. Unknown db attribute "{task_attr}"')
+ raise ValidationError(
+ f'Invalid mapping. Unknown db attribute "{task_attr}"'
+ )
for model_label_name, mapping_item in _mapping.items():
- db_label_name = mapping_item['name']
+ db_label_name = mapping_item["name"]
md_label = None
db_label = None
try:
- md_label = next(x for x in _model_labels if x['name'] == model_label_name)
+ md_label = next(x for x in _model_labels if x["name"] == model_label_name)
except StopIteration:
- raise ValidationError(f'Invalid mapping. Unknown model label "{model_label_name}"')
+ raise ValidationError(
+ f'Invalid mapping. Unknown model label "{model_label_name}"'
+ )
try:
db_label = next(x for x in _db_labels if x.name == db_label_name)
@@ -339,26 +394,24 @@ def validate_attributes_mapping(attributes_mapping, model_attributes, db_attribu
if not labels_compatible(md_label, db_label):
raise ValidationError(
- f'Invalid mapping. Model label "{model_label_name}" and' + \
- f' database label "{db_label_name}" are not compatible'
+ f'Invalid mapping. Model label "{model_label_name}" and'
+ + f' database label "{db_label_name}" are not compatible'
)
validate_attributes_mapping(
- mapping_item.get('attributes', {}),
- md_label['attributes'],
- db_label.attributespec_set.all()
+ mapping_item.get("attributes", {}),
+ md_label["attributes"],
+ db_label.attributespec_set.all(),
)
- if md_label['type'] == 'skeleton' and db_label.type == 'skeleton':
- if 'sublabels' not in mapping_item:
+ if md_label["type"] == "skeleton" and db_label.type == "skeleton":
+ if "sublabels" not in mapping_item:
raise ValidationError(
f'Mapping for elements was not specified in skeleton "{model_label_name}" '
)
validate_labels_mapping(
- mapping_item['sublabels'],
- md_label['sublabels'],
- db_label.sublabels.all()
+ mapping_item["sublabels"], md_label["sublabels"], db_label.sublabels.all()
)
if not mapping:
@@ -380,44 +433,46 @@ def validate_attributes_mapping(attributes_mapping, model_attributes, db_attribu
abs_frame_id = data_start_frame + data[key] * step
if not db_job.segment.contains_frame(abs_frame_id):
- raise ValidationError(f"The {desc} is outside the job range",
- code=status.HTTP_400_BAD_REQUEST)
-
+ raise ValidationError(
+ f"The {desc} is outside the job range", code=status.HTTP_400_BAD_REQUEST
+ )
if self.kind == FunctionKind.DETECTOR:
- payload.update({
- "image": self._get_image(db_task, mandatory_arg("frame"))
- })
+ payload.update({"image": self._get_image(db_task, mandatory_arg("frame"))})
elif self.kind == FunctionKind.INTERACTOR:
- payload.update({
- "image": self._get_image(db_task, mandatory_arg("frame")),
- "pos_points": mandatory_arg("pos_points"),
- "neg_points": mandatory_arg("neg_points"),
- "obj_bbox": data.get("obj_bbox", None)
- })
+ payload.update(
+ {
+ "image": self._get_image(db_task, mandatory_arg("frame")),
+ "pos_points": mandatory_arg("pos_points"),
+ "neg_points": mandatory_arg("neg_points"),
+ "obj_bbox": data.get("obj_bbox", None),
+ }
+ )
elif self.kind == FunctionKind.REID:
- payload.update({
- "image0": self._get_image(db_task, mandatory_arg("frame0")),
- "image1": self._get_image(db_task, mandatory_arg("frame1")),
- "boxes0": mandatory_arg("boxes0"),
- "boxes1": mandatory_arg("boxes1")
- })
+ payload.update(
+ {
+ "image0": self._get_image(db_task, mandatory_arg("frame0")),
+ "image1": self._get_image(db_task, mandatory_arg("frame1")),
+ "boxes0": mandatory_arg("boxes0"),
+ "boxes1": mandatory_arg("boxes1"),
+ }
+ )
max_distance = data.get("max_distance")
if max_distance:
- payload.update({
- "max_distance": max_distance
- })
+ payload.update({"max_distance": max_distance})
elif self.kind == FunctionKind.TRACKER:
- payload.update({
- "image": self._get_image(db_task, mandatory_arg("frame")),
- "shapes": data.get("shapes", []),
- "states": data.get("states", [])
- })
+ payload.update(
+ {
+ "image": self._get_image(db_task, mandatory_arg("frame")),
+ "shapes": data.get("shapes", []),
+ "states": data.get("states", []),
+ }
+ )
else:
raise ValidationError(
- '`{}` lambda function has incorrect type: {}'
- .format(self.id, self.kind),
- code=status.HTTP_500_INTERNAL_SERVER_ERROR)
+ "`{}` lambda function has incorrect type: {}".format(self.id, self.kind),
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
if is_interactive and request:
interactive_function_call_signal.send(sender=self, request=request)
@@ -445,41 +500,38 @@ def check_attr_value(value, db_attr):
def transform_attributes(input_attributes, attr_mapping, db_attributes):
attributes = []
for attr in input_attributes:
- if attr['name'] not in attr_mapping:
+ if attr["name"] not in attr_mapping:
continue
- db_attr_name = attr_mapping[attr['name']]
- db_attr = next(filter(lambda x: x['name'] == db_attr_name, db_attributes), None)
- if db_attr is not None and check_attr_value(attr['value'], db_attr):
- attributes.append({
- 'name': db_attr['name'],
- 'value': attr['value']
- })
+ db_attr_name = attr_mapping[attr["name"]]
+ db_attr = next(filter(lambda x: x["name"] == db_attr_name, db_attributes), None)
+ if db_attr is not None and check_attr_value(attr["value"], db_attr):
+ attributes.append({"name": db_attr["name"], "value": attr["value"]})
return attributes
if self.kind == FunctionKind.DETECTOR:
for item in response:
- item_label = item['label']
+ item_label = item["label"]
if item_label not in mapping:
continue
- db_label = mapping[item_label]['db_label']
- item['label'] = db_label.name
- item['attributes'] = transform_attributes(
- item.get('attributes', {}),
- mapping[item_label]['attributes'],
- db_label.attributespec_set.values()
+ db_label = mapping[item_label]["db_label"]
+ item["label"] = db_label.name
+ item["attributes"] = transform_attributes(
+ item.get("attributes", {}),
+ mapping[item_label]["attributes"],
+ db_label.attributespec_set.values(),
)
- if 'elements' in item:
- sublabels = mapping[item_label]['sublabels']
- item['elements'] = [x for x in item['elements'] if x['label'] in sublabels]
- for element in item['elements']:
- element_label = element['label']
- db_label = sublabels[element_label]['db_label']
- element['label'] = db_label.name
- element['attributes'] = transform_attributes(
- element.get('attributes', {}),
- sublabels[element_label]['attributes'],
- db_label.attributespec_set.values()
+ if "elements" in item:
+ sublabels = mapping[item_label]["sublabels"]
+ item["elements"] = [x for x in item["elements"] if x["label"] in sublabels]
+ for element in item["elements"]:
+ element_label = element["label"]
+ db_label = sublabels[element_label]["db_label"]
+ element["label"] = db_label.name
+ element["attributes"] = transform_attributes(
+ element.get("attributes", {}),
+ sublabels[element_label]["attributes"],
+ db_label.attributespec_set.values(),
)
response_filtered.append(item)
response = response_filtered
@@ -490,7 +542,8 @@ def _get_image(self, db_task, frame):
frame_provider = TaskFrameProvider(db_task)
image = frame_provider.get_frame(frame)
- return base64.b64encode(image.data.getvalue()).decode('utf-8')
+ return base64.b64encode(image.data.getvalue()).decode("utf-8")
+
class LambdaQueue:
RESULT_TTL = timedelta(minutes=30)
@@ -502,19 +555,29 @@ def _get_queue(self):
def get_jobs(self):
queue = self._get_queue()
# Only failed jobs are not included in the list below.
- job_ids = set(queue.get_job_ids() +
- queue.started_job_registry.get_job_ids() +
- queue.finished_job_registry.get_job_ids() +
- queue.scheduled_job_registry.get_job_ids() +
- queue.deferred_job_registry.get_job_ids())
+ job_ids = set(
+ queue.get_job_ids()
+ + queue.started_job_registry.get_job_ids()
+ + queue.finished_job_registry.get_job_ids()
+ + queue.scheduled_job_registry.get_job_ids()
+ + queue.deferred_job_registry.get_job_ids()
+ )
jobs = queue.job_class.fetch_many(job_ids, queue.connection)
return [LambdaJob(job) for job in jobs if job and job.meta.get("lambda")]
- def enqueue(self,
- lambda_func, threshold, task, mapping, cleanup, conv_mask_to_poly, max_distance, request,
+ def enqueue(
+ self,
+ lambda_func,
+ threshold,
+ task,
+ mapping,
+ cleanup,
+ conv_mask_to_poly,
+ max_distance,
+ request,
*,
- job: Optional[int] = None
+ job: Optional[int] = None,
) -> LambdaJob:
queue = self._get_queue()
rq_id = RQId(RequestAction.AUTOANNOTATE, RequestTarget.TASK, task).render()
@@ -524,8 +587,10 @@ def enqueue(self,
# protection.
rq_job = queue.fetch_job(rq_id)
- have_conflict = rq_job and \
- rq_job.get_status(refresh=False) not in {rq.job.JobStatus.FAILED, rq.job.JobStatus.FINISHED}
+ have_conflict = rq_job and rq_job.get_status(refresh=False) not in {
+ rq.job.JobStatus.FAILED,
+ rq.job.JobStatus.FINISHED,
+ }
# There could be some jobs left over from before the current naming convention was adopted.
# TODO: remove this check after a few releases.
@@ -536,7 +601,8 @@ def enqueue(self,
if have_conflict or have_legacy_conflict:
raise ValidationError(
"Only one running request is allowed for the same task #{}".format(task),
- code=status.HTTP_409_CONFLICT)
+ code=status.HTTP_409_CONFLICT,
+ )
if rq_job:
rq_job.delete()
@@ -548,14 +614,13 @@ def enqueue(self,
user_id = request.user.id
with get_rq_lock_by_user(queue, user_id):
- rq_job = queue.create_job(LambdaJob(None),
+ rq_job = queue.create_job(
+ LambdaJob(None),
job_id=rq_id,
meta={
**get_rq_job_meta(
request,
- db_obj=(
- Job.objects.get(pk=job) if job else Task.objects.get(pk=task)
- ),
+ db_obj=(Job.objects.get(pk=job) if job else Task.objects.get(pk=task)),
),
RQJobMetaField.FUNCTION_ID: lambda_func.id,
"lambda": True,
@@ -568,7 +633,7 @@ def enqueue(self,
"cleanup": cleanup,
"conv_mask_to_poly": conv_mask_to_poly,
"mapping": mapping,
- "max_distance": max_distance
+ "max_distance": max_distance,
},
depends_on=define_dependent_job(queue, user_id),
result_ttl=self.RESULT_TTL.total_seconds(),
@@ -583,36 +648,42 @@ def fetch_job(self, pk):
queue = self._get_queue()
rq_job = queue.fetch_job(pk)
if rq_job is None or not rq_job.meta.get("lambda"):
- raise ValidationError("{} lambda job is not found".format(pk),
- code=status.HTTP_404_NOT_FOUND)
+ raise ValidationError(
+ "{} lambda job is not found".format(pk), code=status.HTTP_404_NOT_FOUND
+ )
return LambdaJob(rq_job)
+
class LambdaJob:
def __init__(self, job):
self.job = job
def to_dict(self):
lambda_func = self.job.kwargs.get("function")
- dict_ = {
+ dict_ = {
"id": self.job.id,
"function": {
"id": lambda_func.id if lambda_func else None,
"threshold": self.job.kwargs.get("threshold"),
"task": self.job.kwargs.get("task"),
- **({
- "job": self.job.kwargs["job"],
- } if self.job.kwargs.get("job") else {})
+ **(
+ {
+ "job": self.job.kwargs["job"],
+ }
+ if self.job.kwargs.get("job")
+ else {}
+ ),
},
"status": self.job.get_status(),
- "progress": self.job.meta.get('progress', 0),
+ "progress": self.job.meta.get("progress", 0),
"enqueued": self.job.enqueued_at,
"started": self.job.started_at,
"ended": self.job.ended_at,
- "exc_info": self.job.exc_info
+ "exc_info": self.job.exc_info,
}
- if dict_['status'] == rq.job.JobStatus.DEFERRED:
- dict_['status'] = rq.job.JobStatus.QUEUED.value
+ if dict_["status"] == rq.job.JobStatus.DEFERRED:
+ dict_["status"] = rq.job.JobStatus.QUEUED.value
return dict_
@@ -659,7 +730,7 @@ def _call_detector(
mapping: Optional[dict[str, str]],
conv_mask_to_poly: bool,
*,
- db_job: Optional[Job] = None
+ db_job: Optional[Job] = None,
):
class Results:
def __init__(self, task_id, job_id: Optional[int] = None):
@@ -700,15 +771,16 @@ def parse_anno(anno, labels):
# Invalid label provided
return None
- attrs = [{
- 'spec_id': label['attributes'][attr['name']],
- 'value': attr['value']
- } for attr in anno.get('attributes', []) if attr['name'] in label['attributes']]
+ attrs = [
+ {"spec_id": label["attributes"][attr["name"]], "value": attr["value"]}
+ for attr in anno.get("attributes", [])
+ if attr["name"] in label["attributes"]
+ ]
if anno["type"].lower() == "tag":
return {
"frame": frame,
- "label_id": label['id'],
+ "label_id": label["id"],
"source": "auto",
"attributes": attrs,
"group": None,
@@ -716,14 +788,16 @@ def parse_anno(anno, labels):
else:
shape = {
"frame": frame,
- "label_id": label['id'],
+ "label_id": label["id"],
"source": "auto",
"attributes": attrs,
"group": anno["group_id"] if "group_id" in anno else None,
"type": anno["type"],
"occluded": False,
"outside": anno.get("outside", False),
- "points": anno.get("mask", []) if anno["type"] == "mask" else anno.get("points", []),
+ "points": (
+ anno.get("mask", []) if anno["type"] == "mask" else anno.get("points", [])
+ ),
"z_order": 0,
}
@@ -741,7 +815,7 @@ def parse_anno(anno, labels):
shape["points"] = rle
if shape["type"] == "skeleton":
- parsed_elements = [parse_anno(x, label['sublabels']) for x in anno["elements"]]
+ parsed_elements = [parse_anno(x, label["sublabels"]) for x in anno["elements"]]
# find a center to set position of missing points
center = [0, 0]
@@ -753,25 +827,26 @@ def parse_anno(anno, labels):
def _map(sublabel_body):
try:
- return next(filter(
- lambda x: x['label_id'] == sublabel_body['id'],
- parsed_elements)
+ return next(
+ filter(
+ lambda x: x["label_id"] == sublabel_body["id"], parsed_elements
+ )
)
except StopIteration:
return {
"frame": frame,
- "label_id": sublabel_body['id'],
+ "label_id": sublabel_body["id"],
"source": "auto",
"attributes": [],
"group": None,
- "type": sublabel_body['type'],
+ "type": sublabel_body["type"],
"occluded": False,
"points": center,
"outside": True,
"z_order": 0,
}
- shape["elements"] = list(map(_map, label['sublabels'].values()))
+ shape["elements"] = list(map(_map, label["sublabels"].values()))
if all(element["outside"] for element in shape["elements"]):
return None
@@ -785,10 +860,11 @@ def _map(sublabel_body):
if frame in db_task.data.deleted_frames:
continue
- annotations = function.invoke(db_task, db_job=db_job, data={
- "frame": frame, "mapping": mapping,
- "threshold": threshold
- })
+ annotations = function.invoke(
+ db_task,
+ db_job=db_job,
+ data={"frame": frame, "mapping": mapping, "threshold": threshold},
+ )
progress = (frame + 1) / db_task.data.size
if not cls._update_progress(progress):
@@ -828,8 +904,7 @@ def _get_frame_set(cls, db_task: Task, db_job: Optional[Job]):
data_start_frame = task_data.start_frame
step = task_data.get_frame_step()
frame_set = sorted(
- (abs_id - data_start_frame) // step
- for abs_id in db_job.segment.frame_set
+ (abs_id - data_start_frame) // step for abs_id in db_job.segment.frame_set
)
else:
frame_set = range(db_task.data.size)
@@ -844,7 +919,7 @@ def _call_reid(
threshold: float,
max_distance: int,
*,
- db_job: Optional[Job] = None
+ db_job: Optional[Job] = None,
):
if db_job:
data = dm.task.get_job_data(db_job.id)
@@ -872,10 +947,18 @@ def _call_reid(
boxes1 = boxes_by_frame[frame1]
if boxes0 and boxes1:
- matching = function.invoke(db_task, db_job=db_job, data={
- "frame0": frame0, "frame1": frame1,
- "boxes0": boxes0, "boxes1": boxes1, "threshold": threshold,
- "max_distance": max_distance})
+ matching = function.invoke(
+ db_task,
+ db_job=db_job,
+ data={
+ "frame0": frame0,
+ "frame1": frame1,
+ "boxes0": boxes0,
+ "boxes1": boxes1,
+ "threshold": threshold,
+ "max_distance": max_distance,
+ },
+ )
for idx0, idx1 in enumerate(matching):
if idx1 >= 0:
@@ -886,7 +969,6 @@ def _call_reid(
if not LambdaJob._update_progress((i + 1) / len(frame_set)):
break
-
for box in boxes_by_frame[frame_set[-1]]:
if "path_id" not in box:
path_id = len(paths)
@@ -896,14 +978,16 @@ def _call_reid(
tracks = []
for path_id in paths:
box0 = paths[path_id][0]
- tracks.append({
- "label_id": box0["label_id"],
- "group": None,
- "attributes": [],
- "frame": box0["frame"],
- "shapes": paths[path_id],
- "source": str(SourceType.AUTO)
- })
+ tracks.append(
+ {
+ "label_id": box0["label_id"],
+ "group": None,
+ "attributes": [],
+ "frame": box0["frame"],
+ "shapes": paths[path_id],
+ "source": str(SourceType.AUTO),
+ }
+ )
for box in tracks[-1]["shapes"]:
box.pop("id", None)
@@ -936,8 +1020,8 @@ def _call_reid(
def __call__(cls, function, task: int, cleanup: bool, **kwargs):
# TODO: need logging
db_job = None
- if job := kwargs.get('job'):
- db_job = Job.objects.select_related('segment', 'segment__task').get(pk=job)
+ if job := kwargs.get("job"):
+ db_job = Job.objects.select_related("segment", "segment__task").get(pk=job)
db_task = db_job.segment.task
else:
db_task = Task.objects.get(pk=task)
@@ -953,22 +1037,34 @@ def __call__(cls, function, task: int, cleanup: bool, **kwargs):
def convert_labels(db_labels):
labels = {}
for label in db_labels:
- labels[label.name] = {'id':label.id, 'attributes': {}, 'type': label.type}
- if label.type == 'skeleton':
- labels[label.name]['sublabels'] = convert_labels(label.sublabels.all())
+ labels[label.name] = {"id": label.id, "attributes": {}, "type": label.type}
+ if label.type == "skeleton":
+ labels[label.name]["sublabels"] = convert_labels(label.sublabels.all())
for attr in label.attributespec_set.values():
- labels[label.name]['attributes'][attr['name']] = attr['id']
+ labels[label.name]["attributes"][attr["name"]] = attr["id"]
return labels
labels = convert_labels(db_task.get_labels(prefetch=True))
if function.kind == FunctionKind.DETECTOR:
- cls._call_detector(function, db_task, labels,
- kwargs.get("threshold"), kwargs.get("mapping"), kwargs.get("conv_mask_to_poly"),
- db_job=db_job)
+ cls._call_detector(
+ function,
+ db_task,
+ labels,
+ kwargs.get("threshold"),
+ kwargs.get("mapping"),
+ kwargs.get("conv_mask_to_poly"),
+ db_job=db_job,
+ )
elif function.kind == FunctionKind.REID:
- cls._call_reid(function, db_task,
- kwargs.get("threshold"), kwargs.get("max_distance"), db_job=db_job)
+ cls._call_reid(
+ function,
+ db_task,
+ kwargs.get("threshold"),
+ kwargs.get("max_distance"),
+ db_job=db_job,
+ )
+
def return_response(success_code=status.HTTP_200_OK):
def wrap_response(func):
@@ -1000,23 +1096,28 @@ def func_wrapper(*args, **kwargs):
return Response(data=data, status=status_code)
return func_wrapper
+
return wrap_response
-@extend_schema(tags=['lambda'])
+
+@extend_schema(tags=["lambda"])
@extend_schema_view(
retrieve=extend_schema(
- operation_id='lambda_retrieve_functions',
- summary='Method returns the information about the function',
+ operation_id="lambda_retrieve_functions",
+ summary="Method returns the information about the function",
responses={
- '200': OpenApiResponse(response=OpenApiTypes.OBJECT, description='Information about the function'),
- }),
+ "200": OpenApiResponse(
+ response=OpenApiTypes.OBJECT, description="Information about the function"
+ ),
+ },
+ ),
list=extend_schema(
- operation_id='lambda_list_functions',
- summary='Method returns a list of functions')
+ operation_id="lambda_list_functions", summary="Method returns a list of functions"
+ ),
)
class FunctionViewSet(viewsets.ViewSet):
- lookup_value_regex = '[a-zA-Z0-9_.-]+'
- lookup_field = 'func_id'
+ lookup_value_regex = "[a-zA-Z0-9_.-]+"
+ lookup_field = "func_id"
iam_organization_field = None
serializer_class = None
@@ -1031,7 +1132,9 @@ def retrieve(self, request, func_id):
gateway = LambdaGateway()
return gateway.get(func_id).to_dict()
- @extend_schema(description=textwrap.dedent("""\
+ @extend_schema(
+ description=textwrap.dedent(
+ """\
Allows to execute a function for immediate computation.
Intended for short-lived executions, useful for interactive calls.
@@ -1039,44 +1142,51 @@ def retrieve(self, request, func_id):
When executed for interactive annotation, the job id must be specified
in the 'job' input field. The task id is not required in this case,
but if it is specified, it must match the job task id.
- """),
- request=inline_serializer("OnlineFunctionCall", fields={
- "job": serializers.IntegerField(required=False),
- "task": serializers.IntegerField(required=False),
- }),
- responses=OpenApiResponse(description="Returns function invocation results")
+ """
+ ),
+ request=inline_serializer(
+ "OnlineFunctionCall",
+ fields={
+ "job": serializers.IntegerField(required=False),
+ "task": serializers.IntegerField(required=False),
+ },
+ ),
+ responses=OpenApiResponse(description="Returns function invocation results"),
)
@return_response()
def call(self, request, func_id):
self.check_object_permissions(request, func_id)
try:
- job_id = request.data.get('job')
+ job_id = request.data.get("job")
job = None
if job_id is not None:
job = Job.objects.get(id=job_id)
task_id = job.get_task_id()
else:
- task_id = request.data['task']
+ task_id = request.data["task"]
db_task = Task.objects.get(pk=task_id)
except (KeyError, ObjectDoesNotExist) as err:
raise ValidationError(
- '`{}` lambda function was run '.format(func_id) +
- 'with wrong arguments ({})'.format(str(err)),
- code=status.HTTP_400_BAD_REQUEST)
+ "`{}` lambda function was run ".format(func_id)
+ + "with wrong arguments ({})".format(str(err)),
+ code=status.HTTP_400_BAD_REQUEST,
+ )
gateway = LambdaGateway()
lambda_func = gateway.get(func_id)
response = lambda_func.invoke(
db_task,
- request.data, # TODO: better to add validation via serializer for these data
+ request.data, # TODO: better to add validation via serializer for these data
db_job=job,
is_interactive=True,
- request=request
+ request=request,
)
- handle_function_call(func_id, db_task,
+ handle_function_call(
+ func_id,
+ db_task,
category="interactive",
parameters={
param_name: param_value
@@ -1088,41 +1198,44 @@ def call(self, request, func_id):
return response
-@extend_schema(tags=['lambda'])
+
+@extend_schema(tags=["lambda"])
@extend_schema_view(
retrieve=extend_schema(
- operation_id='lambda_retrieve_requests',
- summary='Method returns the status of the request',
+ operation_id="lambda_retrieve_requests",
+ summary="Method returns the status of the request",
parameters=[
- OpenApiParameter('id', location=OpenApiParameter.PATH, type=OpenApiTypes.STR,
- description='Request id'),
+ OpenApiParameter(
+ "id",
+ location=OpenApiParameter.PATH,
+ type=OpenApiTypes.STR,
+ description="Request id",
+ ),
],
- responses={
- '200': FunctionCallSerializer
- }
+ responses={"200": FunctionCallSerializer},
),
list=extend_schema(
- operation_id='lambda_list_requests',
- summary='Method returns a list of requests',
- responses={
- '200': FunctionCallSerializer(many=True)
- }
+ operation_id="lambda_list_requests",
+ summary="Method returns a list of requests",
+ responses={"200": FunctionCallSerializer(many=True)},
),
create=extend_schema(
parameters=ORGANIZATION_OPEN_API_PARAMETERS,
- summary='Method calls the function',
+ summary="Method calls the function",
request=FunctionCallRequestSerializer,
- responses={
- '200': FunctionCallSerializer
- }
+ responses={"200": FunctionCallSerializer},
),
destroy=extend_schema(
- operation_id='lambda_delete_requests',
- summary='Method cancels the request',
+ operation_id="lambda_delete_requests",
+ summary="Method cancels the request",
parameters=[
- OpenApiParameter('id', location=OpenApiParameter.PATH, type=OpenApiTypes.STR,
- description='Request id'),
- ]
+ OpenApiParameter(
+ "id",
+ location=OpenApiParameter.PATH,
+ type=OpenApiTypes.STR,
+ description="Request id",
+ ),
+ ],
),
)
class RequestViewSet(viewsets.ViewSet):
@@ -1158,25 +1271,35 @@ def create(self, request):
request_data = request_serializer.validated_data
try:
- function = request_data['function']
- threshold = request_data.get('threshold')
- task = request_data['task']
- job = request_data.get('job', None)
- cleanup = request_data.get('cleanup', False)
- conv_mask_to_poly = request_data.get('conv_mask_to_poly', False)
- mapping = request_data.get('mapping')
- max_distance = request_data.get('max_distance')
+ function = request_data["function"]
+ threshold = request_data.get("threshold")
+ task = request_data["task"]
+ job = request_data.get("job", None)
+ cleanup = request_data.get("cleanup", False)
+ conv_mask_to_poly = request_data.get("conv_mask_to_poly", False)
+ mapping = request_data.get("mapping")
+ max_distance = request_data.get("max_distance")
except KeyError as err:
raise ValidationError(
- '`{}` lambda function was run '.format(request_data.get('function', 'undefined')) +
- 'with wrong arguments ({})'.format(str(err)),
- code=status.HTTP_400_BAD_REQUEST)
+ "`{}` lambda function was run ".format(request_data.get("function", "undefined"))
+ + "with wrong arguments ({})".format(str(err)),
+ code=status.HTTP_400_BAD_REQUEST,
+ )
gateway = LambdaGateway()
queue = LambdaQueue()
lambda_func = gateway.get(function)
- rq_job = queue.enqueue(lambda_func, threshold, task,
- mapping, cleanup, conv_mask_to_poly, max_distance, request, job=job)
+ rq_job = queue.enqueue(
+ lambda_func,
+ threshold,
+ task,
+ mapping,
+ cleanup,
+ conv_mask_to_poly,
+ max_distance,
+ request,
+ job=job,
+ )
handle_function_call(function, job or task, category="batch")
diff --git a/cvat/apps/log_viewer/apps.py b/cvat/apps/log_viewer/apps.py
index 437c960e3929..a1806efc6462 100644
--- a/cvat/apps/log_viewer/apps.py
+++ b/cvat/apps/log_viewer/apps.py
@@ -6,8 +6,9 @@
class LogViewerConfig(AppConfig):
- name = 'cvat.apps.log_viewer'
+ name = "cvat.apps.log_viewer"
def ready(self) -> None:
from cvat.apps.iam.permissions import load_app_permissions
+
load_app_permissions(self)
diff --git a/cvat/apps/log_viewer/permissions.py b/cvat/apps/log_viewer/permissions.py
index d25aa7fe275a..4ad996fb7e67 100644
--- a/cvat/apps/log_viewer/permissions.py
+++ b/cvat/apps/log_viewer/permissions.py
@@ -12,12 +12,12 @@ class LogViewerPermission(OpenPolicyAgentPermission):
has_analytics_access: bool
class Scopes(StrEnum):
- VIEW = 'view'
+ VIEW = "view"
@classmethod
def create(cls, request, view, obj, iam_context):
permissions = []
- if view.basename == 'analytics':
+ if view.basename == "analytics":
for scope in cls.get_scopes(request, view, obj):
self = cls.create_base_perm(request, view, scope, iam_context, obj)
permissions.append(self)
@@ -33,20 +33,22 @@ def create_base_perm(cls, request, view, scope, iam_context, obj=None, **kwargs)
obj=obj,
has_analytics_access=request.user.profile.has_analytics_access,
**iam_context,
- **kwargs
+ **kwargs,
)
def __init__(self, has_analytics_access=False, **kwargs):
super().__init__(**kwargs)
- self.payload['input']['auth']['user']['has_analytics_access'] = has_analytics_access
- self.url = settings.IAM_OPA_DATA_URL + '/analytics/allow'
+ self.payload["input"]["auth"]["user"]["has_analytics_access"] = has_analytics_access
+ self.url = settings.IAM_OPA_DATA_URL + "/analytics/allow"
@staticmethod
def get_scopes(request, view, obj):
Scopes = __class__.Scopes
- return [{
- 'list': Scopes.VIEW,
- }[view.action]]
+ return [
+ {
+ "list": Scopes.VIEW,
+ }[view.action]
+ ]
def get_resource(self):
return None
diff --git a/cvat/apps/log_viewer/rules/tests/generators/analytics_test.gen.rego.py b/cvat/apps/log_viewer/rules/tests/generators/analytics_test.gen.rego.py
index 95d566e4b93a..12d28193cd9f 100644
--- a/cvat/apps/log_viewer/rules/tests/generators/analytics_test.gen.rego.py
+++ b/cvat/apps/log_viewer/rules/tests/generators/analytics_test.gen.rego.py
@@ -62,9 +62,15 @@ def eval_rule(scope, context, ownership, privilege, membership, data, has_analyt
)
)
rules = list(filter(lambda r: GROUPS.index(privilege) <= GROUPS.index(r["privilege"]), rules))
- rules = list(filter(lambda r: r["hasanalyticsaccess"] in ("na", str(has_analytics_access).lower()), rules))
+ rules = list(
+ filter(
+ lambda r: r["hasanalyticsaccess"] in ("na", str(has_analytics_access).lower()), rules
+ )
+ )
resource = data["resource"]
- rules = list(filter(lambda r: not r["limit"] or eval(r["limit"], {"resource": resource}), rules))
+ rules = list(
+ filter(lambda r: not r["limit"] or eval(r["limit"], {"resource": resource}), rules)
+ )
return bool(rules)
@@ -78,13 +84,15 @@ def get_data(scope, context, ownership, privilege, membership, resource, has_ana
"privilege": privilege,
"has_analytics_access": has_analytics_access,
},
- "organization": {
- "id": random.randrange(100, 200),
- "owner": {"id": random.randrange(200, 300)},
- "user": {"role": membership},
- }
- if context == "organization"
- else None,
+ "organization": (
+ {
+ "id": random.randrange(100, 200),
+ "owner": {"id": random.randrange(200, 300)},
+ "user": {"role": membership},
+ }
+ if context == "organization"
+ else None
+ ),
},
"resource": resource,
}
@@ -143,9 +151,15 @@ def gen_test_rego(name):
if not is_valid(scope, context, ownership, privilege, membership, resource):
continue
- data = get_data(scope, context, ownership, privilege, membership, resource, has_analytics_access)
- test_name = get_name(scope, context, ownership, privilege, membership, resource, has_analytics_access)
- result = eval_rule(scope, context, ownership, privilege, membership, data, has_analytics_access)
+ data = get_data(
+ scope, context, ownership, privilege, membership, resource, has_analytics_access
+ )
+ test_name = get_name(
+ scope, context, ownership, privilege, membership, resource, has_analytics_access
+ )
+ result = eval_rule(
+ scope, context, ownership, privilege, membership, data, has_analytics_access
+ )
f.write(
"{test_name} if {{\n {allow} with input as {data}\n}}\n\n".format(
test_name=test_name,
diff --git a/cvat/apps/log_viewer/urls.py b/cvat/apps/log_viewer/urls.py
index 0de56682a37e..96e88e38c9bc 100644
--- a/cvat/apps/log_viewer/urls.py
+++ b/cvat/apps/log_viewer/urls.py
@@ -1,4 +1,3 @@
-
# Copyright (C) 2018-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
@@ -8,6 +7,6 @@
from . import views
router = routers.DefaultRouter(trailing_slash=False)
-router.register('analytics', views.LogViewerAccessViewSet, basename='analytics')
+router.register("analytics", views.LogViewerAccessViewSet, basename="analytics")
urlpatterns = router.urls
diff --git a/cvat/apps/log_viewer/views.py b/cvat/apps/log_viewer/views.py
index 9e52f546c634..362f2bb97ec3 100644
--- a/cvat/apps/log_viewer/views.py
+++ b/cvat/apps/log_viewer/views.py
@@ -4,11 +4,11 @@
from django.conf import settings
from django.http import HttpResponsePermanentRedirect
+from drf_spectacular.utils import extend_schema
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
-from drf_spectacular.utils import extend_schema
@extend_schema(exclude=True)
class LogViewerAccessViewSet(viewsets.ViewSet):
@@ -19,7 +19,7 @@ def list(self, request):
# All log view requests are proxied by Traefik in production mode which is not available in debug mode,
# In order not to duplicate settings, let's just redirect to the default page in debug mode
- @action(detail=False, url_path='dashboards')
+ @action(detail=False, url_path="dashboards")
def redirect(self, request):
if settings.DEBUG:
- return HttpResponsePermanentRedirect('http://localhost:3001/dashboards')
+ return HttpResponsePermanentRedirect("http://localhost:3001/dashboards")
diff --git a/cvat/apps/organizations/__init__.py b/cvat/apps/organizations/__init__.py
index b1220197cf2a..f7c3408e3d12 100644
--- a/cvat/apps/organizations/__init__.py
+++ b/cvat/apps/organizations/__init__.py
@@ -1,4 +1,3 @@
# Copyright (C) 2021-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
-
diff --git a/cvat/apps/organizations/admin.py b/cvat/apps/organizations/admin.py
index 756100244743..33e711189a1f 100644
--- a/cvat/apps/organizations/admin.py
+++ b/cvat/apps/organizations/admin.py
@@ -2,27 +2,29 @@
#
# SPDX-License-Identifier: MIT
-from .models import Organization, Membership
from django.contrib import admin
+from .models import Membership, Organization
+
+
class MembershipInline(admin.TabularInline):
model = Membership
extra = 0
radio_fields = {
- 'role': admin.VERTICAL,
+ "role": admin.VERTICAL,
}
- autocomplete_fields = ('user', )
+ autocomplete_fields = ("user",)
+
class OrganizationAdmin(admin.ModelAdmin):
- search_fields = ('slug', 'name', 'owner__username')
- list_display = ('id', 'slug', 'name')
+ search_fields = ("slug", "name", "owner__username")
+ list_display = ("id", "slug", "name")
+
+ autocomplete_fields = ("owner",)
- autocomplete_fields = ('owner', )
+ inlines = [MembershipInline]
- inlines = [
- MembershipInline
- ]
admin.site.register(Organization, OrganizationAdmin)
diff --git a/cvat/apps/organizations/apps.py b/cvat/apps/organizations/apps.py
index f73094af1723..ad654a0b8061 100644
--- a/cvat/apps/organizations/apps.py
+++ b/cvat/apps/organizations/apps.py
@@ -5,9 +5,11 @@
from django.apps import AppConfig
+
class OrganizationsConfig(AppConfig):
- name = 'cvat.apps.organizations'
+ name = "cvat.apps.organizations"
def ready(self) -> None:
from cvat.apps.iam.permissions import load_app_permissions
+
load_app_permissions(self)
diff --git a/cvat/apps/organizations/migrations/0001_initial.py b/cvat/apps/organizations/migrations/0001_initial.py
index 1d2689d343d1..5d4887a15fb2 100644
--- a/cvat/apps/organizations/migrations/0001_initial.py
+++ b/cvat/apps/organizations/migrations/0001_initial.py
@@ -1,8 +1,8 @@
# Generated by Django 3.2.8 on 2021-10-26 14:52
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
class Migration(migrations.Migration):
@@ -15,46 +15,103 @@ class Migration(migrations.Migration):
operations = [
migrations.CreateModel(
- name='Organization',
+ name="Organization",
fields=[
- ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
- ('slug', models.SlugField(max_length=16, unique=True)),
- ('name', models.CharField(blank=True, max_length=64)),
- ('description', models.TextField(blank=True)),
- ('created_date', models.DateTimeField(auto_now_add=True)),
- ('updated_date', models.DateTimeField(auto_now=True)),
- ('contact', models.JSONField(blank=True, default=dict)),
- ('owner', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to=settings.AUTH_USER_MODEL)),
+ (
+ "id",
+ models.AutoField(
+ auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
+ ),
+ ),
+ ("slug", models.SlugField(max_length=16, unique=True)),
+ ("name", models.CharField(blank=True, max_length=64)),
+ ("description", models.TextField(blank=True)),
+ ("created_date", models.DateTimeField(auto_now_add=True)),
+ ("updated_date", models.DateTimeField(auto_now=True)),
+ ("contact", models.JSONField(blank=True, default=dict)),
+ (
+ "owner",
+ models.ForeignKey(
+ blank=True,
+ null=True,
+ on_delete=django.db.models.deletion.SET_NULL,
+ related_name="+",
+ to=settings.AUTH_USER_MODEL,
+ ),
+ ),
],
options={
- 'default_permissions': (),
+ "default_permissions": (),
},
),
migrations.CreateModel(
- name='Membership',
+ name="Membership",
fields=[
- ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
- ('is_active', models.BooleanField(default=False)),
- ('joined_date', models.DateTimeField(null=True)),
- ('role', models.CharField(choices=[('worker', 'Worker'), ('supervisor', 'Supervisor'), ('maintainer', 'Maintainer'), ('owner', 'Owner')], max_length=16)),
- ('organization', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='members', to='organizations.organization')),
- ('user', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='memberships', to=settings.AUTH_USER_MODEL)),
+ (
+ "id",
+ models.AutoField(
+ auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
+ ),
+ ),
+ ("is_active", models.BooleanField(default=False)),
+ ("joined_date", models.DateTimeField(null=True)),
+ (
+ "role",
+ models.CharField(
+ choices=[
+ ("worker", "Worker"),
+ ("supervisor", "Supervisor"),
+ ("maintainer", "Maintainer"),
+ ("owner", "Owner"),
+ ],
+ max_length=16,
+ ),
+ ),
+ (
+ "organization",
+ models.ForeignKey(
+ on_delete=django.db.models.deletion.CASCADE,
+ related_name="members",
+ to="organizations.organization",
+ ),
+ ),
+ (
+ "user",
+ models.ForeignKey(
+ null=True,
+ on_delete=django.db.models.deletion.CASCADE,
+ related_name="memberships",
+ to=settings.AUTH_USER_MODEL,
+ ),
+ ),
],
options={
- 'default_permissions': (),
- 'unique_together': {('user', 'organization')},
+ "default_permissions": (),
+ "unique_together": {("user", "organization")},
},
),
migrations.CreateModel(
- name='Invitation',
+ name="Invitation",
fields=[
- ('key', models.CharField(max_length=64, primary_key=True, serialize=False)),
- ('created_date', models.DateTimeField(auto_now_add=True)),
- ('membership', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to='organizations.membership')),
- ('owner', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL)),
+ ("key", models.CharField(max_length=64, primary_key=True, serialize=False)),
+ ("created_date", models.DateTimeField(auto_now_add=True)),
+ (
+ "membership",
+ models.OneToOneField(
+ on_delete=django.db.models.deletion.CASCADE, to="organizations.membership"
+ ),
+ ),
+ (
+ "owner",
+ models.ForeignKey(
+ null=True,
+ on_delete=django.db.models.deletion.SET_NULL,
+ to=settings.AUTH_USER_MODEL,
+ ),
+ ),
],
options={
- 'default_permissions': (),
+ "default_permissions": (),
},
),
]
diff --git a/cvat/apps/organizations/models.py b/cvat/apps/organizations/models.py
index 3da77bafbebf..d582459866f8 100644
--- a/cvat/apps/organizations/models.py
+++ b/cvat/apps/organizations/models.py
@@ -4,55 +4,62 @@
# SPDX-License-Identifier: MIT
from datetime import timedelta
-from django.conf import settings
-from allauth.account.adapter import get_adapter
-from django.contrib.sites.shortcuts import get_current_site
-from drf_spectacular.types import OpenApiTypes
-from drf_spectacular.utils import extend_schema_field
-from django.db import models
+from allauth.account.adapter import get_adapter
+from django.conf import settings
from django.contrib.auth import get_user_model
+from django.contrib.sites.shortcuts import get_current_site
from django.core.exceptions import ImproperlyConfigured
+from django.db import models
from django.utils import timezone
+from drf_spectacular.types import OpenApiTypes
+from drf_spectacular.utils import extend_schema_field
from cvat.apps.engine.models import TimestampedModel
+
class Organization(TimestampedModel):
slug = models.SlugField(max_length=16, blank=False, unique=True)
name = models.CharField(max_length=64, blank=True)
description = models.TextField(blank=True)
contact = models.JSONField(blank=True, default=dict)
- owner = models.ForeignKey(get_user_model(), null=True,
- blank=True, on_delete=models.SET_NULL, related_name='+')
+ owner = models.ForeignKey(
+ get_user_model(), null=True, blank=True, on_delete=models.SET_NULL, related_name="+"
+ )
def __str__(self):
return self.slug
+
class Meta:
default_permissions = ()
+
class Membership(models.Model):
- WORKER = 'worker'
- SUPERVISOR = 'supervisor'
- MAINTAINER = 'maintainer'
- OWNER = 'owner'
-
- user = models.ForeignKey(get_user_model(), on_delete=models.CASCADE,
- null=True, related_name='memberships')
- organization = models.ForeignKey(Organization, on_delete=models.CASCADE,
- related_name='members')
+ WORKER = "worker"
+ SUPERVISOR = "supervisor"
+ MAINTAINER = "maintainer"
+ OWNER = "owner"
+
+ user = models.ForeignKey(
+ get_user_model(), on_delete=models.CASCADE, null=True, related_name="memberships"
+ )
+ organization = models.ForeignKey(Organization, on_delete=models.CASCADE, related_name="members")
is_active = models.BooleanField(default=False)
joined_date = models.DateTimeField(null=True)
- role = models.CharField(max_length=16, choices=[
- (WORKER, 'Worker'),
- (SUPERVISOR, 'Supervisor'),
- (MAINTAINER, 'Maintainer'),
- (OWNER, 'Owner'),
- ])
+ role = models.CharField(
+ max_length=16,
+ choices=[
+ (WORKER, "Worker"),
+ (SUPERVISOR, "Supervisor"),
+ (MAINTAINER, "Maintainer"),
+ (OWNER, "Owner"),
+ ],
+ )
class Meta:
default_permissions = ()
- unique_together = ('user', 'organization')
+ unique_together = ("user", "organization")
# Inspried by https://github.com/bee-keeper/django-invitations
@@ -94,16 +101,16 @@ def send(self, request):
site_name = current_site.name
domain = current_site.domain
context = {
- 'email': target_email,
- 'invitation_key': self.key,
- 'domain': domain,
- 'site_name': site_name,
- 'invitation_owner': self.owner.get_username(),
- 'organization_name': self.membership.organization.slug,
- 'protocol': 'https' if request.is_secure() else 'http',
+ "email": target_email,
+ "invitation_key": self.key,
+ "domain": domain,
+ "site_name": site_name,
+ "invitation_owner": self.owner.get_username(),
+ "organization_name": self.membership.organization.slug,
+ "protocol": "https" if request.is_secure() else "http",
}
- get_adapter(request).send_mail('invitation/invitation', target_email, context)
+ get_adapter(request).send_mail("invitation/invitation", target_email, context)
self.sent_date = timezone.now()
self.save()
diff --git a/cvat/apps/organizations/permissions.py b/cvat/apps/organizations/permissions.py
index e45b05d978c3..1e18cf5e20c5 100644
--- a/cvat/apps/organizations/permissions.py
+++ b/cvat/apps/organizations/permissions.py
@@ -9,18 +9,19 @@
from .models import Membership
+
class OrganizationPermission(OpenPolicyAgentPermission):
class Scopes(StrEnum):
- LIST = 'list'
- CREATE = 'create'
- DELETE = 'delete'
- UPDATE = 'update'
- VIEW = 'view'
+ LIST = "list"
+ CREATE = "create"
+ DELETE = "delete"
+ UPDATE = "update"
+ VIEW = "view"
@classmethod
def create(cls, request, view, obj, iam_context):
permissions = []
- if view.basename == 'organization':
+ if view.basename == "organization":
for scope in cls.get_scopes(request, view, obj):
self = cls.create_base_perm(request, view, scope, iam_context, obj)
permissions.append(self)
@@ -29,127 +30,116 @@ def create(cls, request, view, obj, iam_context):
def __init__(self, **kwargs):
super().__init__(**kwargs)
- self.url = settings.IAM_OPA_DATA_URL + '/organizations/allow'
+ self.url = settings.IAM_OPA_DATA_URL + "/organizations/allow"
@staticmethod
def get_scopes(request, view, obj):
Scopes = __class__.Scopes
- return [{
- 'list': Scopes.LIST,
- 'create': Scopes.CREATE,
- 'destroy': Scopes.DELETE,
- 'partial_update': Scopes.UPDATE,
- 'retrieve': Scopes.VIEW,
- }[view.action]]
+ return [
+ {
+ "list": Scopes.LIST,
+ "create": Scopes.CREATE,
+ "destroy": Scopes.DELETE,
+ "partial_update": Scopes.UPDATE,
+ "retrieve": Scopes.VIEW,
+ }[view.action]
+ ]
def get_resource(self):
if self.obj:
- membership = Membership.objects.filter(
- organization=self.obj, user=self.user_id).first()
+ membership = Membership.objects.filter(organization=self.obj, user=self.user_id).first()
return {
- 'id': self.obj.id,
- 'owner': {
- 'id': getattr(self.obj.owner, 'id', None)
- },
- 'user': {
- 'role': membership.role if membership else None
- }
+ "id": self.obj.id,
+ "owner": {"id": getattr(self.obj.owner, "id", None)},
+ "user": {"role": membership.role if membership else None},
}
elif self.scope.startswith(__class__.Scopes.CREATE.value):
- return {
- 'id': None,
- 'owner': {
- 'id': self.user_id
- },
- 'user': {
- 'role': 'owner'
- }
- }
+ return {"id": None, "owner": {"id": self.user_id}, "user": {"role": "owner"}}
else:
return None
+
class InvitationPermission(OpenPolicyAgentPermission):
class Scopes(StrEnum):
- LIST = 'list'
- CREATE = 'create'
- DELETE = 'delete'
- ACCEPT = 'accept'
- DECLINE = 'decline'
- RESEND = 'resend'
- VIEW = 'view'
+ LIST = "list"
+ CREATE = "create"
+ DELETE = "delete"
+ ACCEPT = "accept"
+ DECLINE = "decline"
+ RESEND = "resend"
+ VIEW = "view"
@classmethod
def create(cls, request, view, obj, iam_context):
permissions = []
- if view.basename == 'invitation':
+ if view.basename == "invitation":
for scope in cls.get_scopes(request, view, obj):
- self = cls.create_base_perm(request, view, scope, iam_context, obj,
- role=request.data.get('role'))
+ self = cls.create_base_perm(
+ request, view, scope, iam_context, obj, role=request.data.get("role")
+ )
permissions.append(self)
return permissions
def __init__(self, **kwargs):
super().__init__(**kwargs)
- self.role = kwargs.get('role')
- self.url = settings.IAM_OPA_DATA_URL + '/invitations/allow'
+ self.role = kwargs.get("role")
+ self.url = settings.IAM_OPA_DATA_URL + "/invitations/allow"
@staticmethod
def get_scopes(request, view, obj):
Scopes = __class__.Scopes
- return [{
- 'list': Scopes.LIST,
- 'create': Scopes.CREATE,
- 'destroy': Scopes.DELETE,
- 'partial_update': Scopes.ACCEPT if 'accepted' in
- request.query_params else Scopes.RESEND,
- 'retrieve': Scopes.VIEW,
- 'accept': Scopes.ACCEPT,
- 'decline': Scopes.DECLINE,
- 'resend': Scopes.RESEND,
- }[view.action]]
+ return [
+ {
+ "list": Scopes.LIST,
+ "create": Scopes.CREATE,
+ "destroy": Scopes.DELETE,
+ "partial_update": (
+ Scopes.ACCEPT if "accepted" in request.query_params else Scopes.RESEND
+ ),
+ "retrieve": Scopes.VIEW,
+ "accept": Scopes.ACCEPT,
+ "decline": Scopes.DECLINE,
+ "resend": Scopes.RESEND,
+ }[view.action]
+ ]
def get_resource(self):
data = None
if self.obj:
data = {
- 'owner': { 'id': getattr(self.obj.owner, 'id', None) },
- 'invitee': { 'id': getattr(self.obj.membership.user, 'id', None) },
- 'role': self.obj.membership.role,
- 'organization': {
- 'id': self.obj.membership.organization.id
- }
+ "owner": {"id": getattr(self.obj.owner, "id", None)},
+ "invitee": {"id": getattr(self.obj.membership.user, "id", None)},
+ "role": self.obj.membership.role,
+ "organization": {"id": self.obj.membership.organization.id},
}
elif self.scope.startswith(__class__.Scopes.CREATE.value):
data = {
- 'owner': { 'id': self.user_id },
- 'invitee': {
- 'id': None # unknown yet
- },
- 'role': self.role,
- 'organization': {
- 'id': self.org_id
- } if self.org_id is not None else None
+ "owner": {"id": self.user_id},
+ "invitee": {"id": None}, # unknown yet
+ "role": self.role,
+ "organization": {"id": self.org_id} if self.org_id is not None else None,
}
return data
+
class MembershipPermission(OpenPolicyAgentPermission):
class Scopes(StrEnum):
- LIST = 'list'
- UPDATE = 'change'
- UPDATE_ROLE = 'change:role'
- VIEW = 'view'
- DELETE = 'delete'
+ LIST = "list"
+ UPDATE = "change"
+ UPDATE_ROLE = "change:role"
+ VIEW = "view"
+ DELETE = "delete"
@classmethod
def create(cls, request, view, obj, iam_context):
permissions = []
- if view.basename == 'membership':
+ if view.basename == "membership":
for scope in cls.get_scopes(request, view, obj):
params = {}
- if scope == 'change:role':
- params['role'] = request.data.get('role')
+ if scope == "change:role":
+ params["role"] = request.data.get("role")
self = cls.create_base_perm(request, view, scope, iam_context, obj, **params)
permissions.append(self)
@@ -158,7 +148,7 @@ def create(cls, request, view, obj, iam_context):
def __init__(self, **kwargs):
super().__init__(**kwargs)
- self.url = settings.IAM_OPA_DATA_URL + '/memberships/allow'
+ self.url = settings.IAM_OPA_DATA_URL + "/memberships/allow"
@staticmethod
def get_scopes(request, view, obj):
@@ -166,16 +156,21 @@ def get_scopes(request, view, obj):
scopes = []
scope = {
- 'list': Scopes.LIST,
- 'partial_update': Scopes.UPDATE,
- 'retrieve': Scopes.VIEW,
- 'destroy': Scopes.DELETE,
+ "list": Scopes.LIST,
+ "partial_update": Scopes.UPDATE,
+ "retrieve": Scopes.VIEW,
+ "destroy": Scopes.DELETE,
}[view.action]
if scope == Scopes.UPDATE:
- scopes.extend(__class__.get_per_field_update_scopes(request, {
- 'role': Scopes.UPDATE_ROLE,
- }))
+ scopes.extend(
+ __class__.get_per_field_update_scopes(
+ request,
+ {
+ "role": Scopes.UPDATE_ROLE,
+ },
+ )
+ )
else:
scopes.append(scope)
@@ -184,10 +179,10 @@ def get_scopes(request, view, obj):
def get_resource(self):
if self.obj:
return {
- 'role': self.obj.role,
- 'is_active': self.obj.is_active,
- 'user': { 'id': self.obj.user.id },
- 'organization': { 'id': self.obj.organization.id }
+ "role": self.obj.role,
+ "is_active": self.obj.is_active,
+ "user": {"id": self.obj.user.id},
+ "organization": {"id": self.obj.organization.id},
}
else:
return None
diff --git a/cvat/apps/organizations/rules/tests/generators/invitations_test.gen.rego.py b/cvat/apps/organizations/rules/tests/generators/invitations_test.gen.rego.py
index bf7edec50713..39ff446d8eac 100644
--- a/cvat/apps/organizations/rules/tests/generators/invitations_test.gen.rego.py
+++ b/cvat/apps/organizations/rules/tests/generators/invitations_test.gen.rego.py
@@ -109,13 +109,15 @@ def get_data(scope, context, ownership, privilege, membership, resource, same_or
"scope": scope,
"auth": {
"user": {"id": random.randrange(0, 100), "privilege": privilege},
- "organization": {
- "id": random.randrange(100, 200),
- "owner": {"id": random.randrange(200, 300)},
- "user": {"role": membership},
- }
- if context == "organization"
- else None,
+ "organization": (
+ {
+ "id": random.randrange(100, 200),
+ "owner": {"id": random.randrange(200, 300)},
+ "user": {"role": membership},
+ }
+ if context == "organization"
+ else None
+ ),
},
"resource": resource,
}
diff --git a/cvat/apps/organizations/rules/tests/generators/memberships_test.gen.rego.py b/cvat/apps/organizations/rules/tests/generators/memberships_test.gen.rego.py
index c74a4a7c992b..09258163b2db 100644
--- a/cvat/apps/organizations/rules/tests/generators/memberships_test.gen.rego.py
+++ b/cvat/apps/organizations/rules/tests/generators/memberships_test.gen.rego.py
@@ -98,14 +98,14 @@ def eval_rule(scope, context, ownership, privilege, membership, data):
return False
if scope != "create" and not data["resource"]["is_active"]:
- is_staff = membership == "owner" or membership == 'maintainer'
+ is_staff = membership == "owner" or membership == "maintainer"
if is_staff:
- if scope != 'view':
+ if scope != "view":
if ORG_ROLES.index(membership) >= ORG_ROLES.index(resource["role"]):
return False
if GROUPS.index(privilege) > GROUPS.index("user"):
return False
- if resource["user"]['id'] == data["auth"]["user"]['id']:
+ if resource["user"]["id"] == data["auth"]["user"]["id"]:
return False
return True
return False
@@ -118,13 +118,15 @@ def get_data(scope, context, ownership, privilege, membership, resource, same_or
"scope": scope,
"auth": {
"user": {"id": random.randrange(0, 100), "privilege": privilege},
- "organization": {
- "id": random.randrange(100, 200),
- "owner": {"id": random.randrange(200, 300)},
- "user": {"role": membership},
- }
- if context == "organization"
- else None,
+ "organization": (
+ {
+ "id": random.randrange(100, 200),
+ "owner": {"id": random.randrange(200, 300)},
+ "user": {"role": membership},
+ }
+ if context == "organization"
+ else None
+ ),
},
"resource": resource,
}
diff --git a/cvat/apps/organizations/rules/tests/generators/organizations_test.gen.rego.py b/cvat/apps/organizations/rules/tests/generators/organizations_test.gen.rego.py
index d2a8a6fb653b..35f4fad15678 100644
--- a/cvat/apps/organizations/rules/tests/generators/organizations_test.gen.rego.py
+++ b/cvat/apps/organizations/rules/tests/generators/organizations_test.gen.rego.py
@@ -78,13 +78,15 @@ def get_data(scope, context, ownership, privilege, membership, resource):
"scope": scope,
"auth": {
"user": {"id": random.randrange(0, 100), "privilege": privilege},
- "organization": {
- "id": random.randrange(100, 200),
- "owner": {"id": random.randrange(200, 300)},
- "user": {"role": membership},
- }
- if context == "organization"
- else None,
+ "organization": (
+ {
+ "id": random.randrange(100, 200),
+ "owner": {"id": random.randrange(200, 300)},
+ "user": {"role": membership},
+ }
+ if context == "organization"
+ else None
+ ),
},
"resource": {**resource, "owner": {"id": random.randrange(300, 400)}} if resource else None,
}
diff --git a/cvat/apps/organizations/serializers.py b/cvat/apps/organizations/serializers.py
index 9cfb467aa3b9..6fe3a7a851a4 100644
--- a/cvat/apps/organizations/serializers.py
+++ b/cvat/apps/organizations/serializers.py
@@ -3,33 +3,46 @@
#
# SPDX-License-Identifier: MIT
-from attr.converters import to_bool
-from django.contrib.auth import get_user_model
from allauth.account.models import EmailAddress
-from django.core.exceptions import ObjectDoesNotExist
+from attr.converters import to_bool
from django.conf import settings
+from django.contrib.auth import get_user_model
from django.contrib.auth.models import User
+from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
-
from rest_framework import serializers
+
from cvat.apps.engine.serializers import BasicUserSerializer
from cvat.apps.iam.utils import get_dummy_user
+
from .models import Invitation, Membership, Organization
+
class OrganizationReadSerializer(serializers.ModelSerializer):
owner = BasicUserSerializer(allow_null=True)
+
class Meta:
model = Organization
- fields = ['id', 'slug', 'name', 'description', 'created_date',
- 'updated_date', 'contact', 'owner']
+ fields = [
+ "id",
+ "slug",
+ "name",
+ "description",
+ "created_date",
+ "updated_date",
+ "contact",
+ "owner",
+ ]
read_only_fields = fields
+
class BasicOrganizationSerializer(serializers.ModelSerializer):
class Meta:
model = Organization
- fields = ['id', 'slug']
+ fields = ["id", "slug"]
read_only_fields = fields
+
class OrganizationWriteSerializer(serializers.ModelSerializer):
def to_representation(self, instance):
serializer = OrganizationReadSerializer(instance, context=self.context)
@@ -37,12 +50,12 @@ def to_representation(self, instance):
class Meta:
model = Organization
- fields = ['slug', 'name', 'description', 'contact', 'owner']
+ fields = ["slug", "name", "description", "contact", "owner"]
# TODO: at the moment isn't possible to change the owner. It should
# be a separate feature. Need to change it together with corresponding
# Membership. Also such operation should be well protected.
- read_only_fields = ['owner']
+ read_only_fields = ["owner"]
def create(self, validated_data):
organization = super().create(validated_data)
@@ -51,36 +64,47 @@ def create(self, validated_data):
organization=organization,
is_active=True,
joined_date=organization.created_date,
- role=Membership.OWNER)
+ role=Membership.OWNER,
+ )
return organization
+
class InvitationReadSerializer(serializers.ModelSerializer):
- role = serializers.ChoiceField(Membership.role.field.choices,
- source='membership.role')
- user = BasicUserSerializer(source='membership.user')
+ role = serializers.ChoiceField(Membership.role.field.choices, source="membership.role")
+ user = BasicUserSerializer(source="membership.user")
organization = serializers.PrimaryKeyRelatedField(
- queryset=Organization.objects.all(),
- source='membership.organization')
- organization_info = BasicOrganizationSerializer(source='membership.organization')
+ queryset=Organization.objects.all(), source="membership.organization"
+ )
+ organization_info = BasicOrganizationSerializer(source="membership.organization")
owner = BasicUserSerializer(allow_null=True)
class Meta:
model = Invitation
- fields = ['key', 'created_date', 'owner', 'role', 'user', 'organization', 'expired', 'organization_info']
+ fields = [
+ "key",
+ "created_date",
+ "owner",
+ "role",
+ "user",
+ "organization",
+ "expired",
+ "organization_info",
+ ]
read_only_fields = fields
extra_kwargs = {
- 'expired': {
- 'allow_null': True,
+ "expired": {
+ "allow_null": True,
}
}
+
class InvitationWriteSerializer(serializers.ModelSerializer):
- role = serializers.ChoiceField(Membership.role.field.choices,
- source='membership.role')
- email = serializers.EmailField(source='membership.user.email')
+ role = serializers.ChoiceField(Membership.role.field.choices, source="membership.role")
+ email = serializers.EmailField(source="membership.user.email")
organization = serializers.PrimaryKeyRelatedField(
- source='membership.organization', read_only=True)
+ source="membership.organization", read_only=True
+ )
def to_representation(self, instance):
serializer = InvitationReadSerializer(instance, context=self.context)
@@ -88,34 +112,35 @@ def to_representation(self, instance):
class Meta:
model = Invitation
- fields = ['key', 'created_date', 'owner', 'role', 'organization', 'email']
- read_only_fields = ['key', 'created_date', 'owner', 'organization']
+ fields = ["key", "created_date", "owner", "role", "organization", "email"]
+ read_only_fields = ["key", "created_date", "owner", "organization"]
@transaction.atomic
def create(self, validated_data):
- membership_data = validated_data.pop('membership')
- organization = validated_data.pop('organization')
+ membership_data = validated_data.pop("membership")
+ organization = validated_data.pop("organization")
try:
- user = get_user_model().objects.get(
- email__iexact=membership_data['user']['email'])
- del membership_data['user']
+ user = get_user_model().objects.get(email__iexact=membership_data["user"]["email"])
+ del membership_data["user"]
except ObjectDoesNotExist:
- user_email = membership_data['user']['email']
+ user_email = membership_data["user"]["email"]
user = User.objects.create_user(username=user_email, email=user_email)
user.set_unusable_password()
# User.objects.create_user(...) normalizes passed email and user.email can be different from original user_email
- email = EmailAddress.objects.create(user=user, email=user.email, primary=True, verified=False)
+ email = EmailAddress.objects.create(
+ user=user, email=user.email, primary=True, verified=False
+ )
user.save()
email.save()
- del membership_data['user']
+ del membership_data["user"]
membership, created = Membership.objects.get_or_create(
- defaults=membership_data,
- user=user, organization=organization)
+ defaults=membership_data, user=user, organization=organization
+ )
if not created:
- raise serializers.ValidationError('The user is a member of '
- 'the organization already.')
- invitation = Invitation.objects.create(**validated_data,
- membership=membership)
+ raise serializers.ValidationError(
+ "The user is a member of " "the organization already."
+ )
+ invitation = Invitation.objects.create(**validated_data, membership=membership)
return invitation
@@ -132,20 +157,21 @@ def save(self, request, **kwargs):
return invitation
+
class MembershipReadSerializer(serializers.ModelSerializer):
user = BasicUserSerializer()
class Meta:
model = Membership
- fields = ['id', 'user', 'organization', 'is_active', 'joined_date', 'role',
- 'invitation']
+ fields = ["id", "user", "organization", "is_active", "joined_date", "role", "invitation"]
read_only_fields = fields
extra_kwargs = {
- 'invitation': {
- 'allow_null': True, # owner of an organization does not have an invitation
+ "invitation": {
+ "allow_null": True, # owner of an organization does not have an invitation
}
}
+
class MembershipWriteSerializer(serializers.ModelSerializer):
def to_representation(self, instance):
serializer = MembershipReadSerializer(instance, context=self.context)
@@ -153,8 +179,9 @@ def to_representation(self, instance):
class Meta:
model = Membership
- fields = ['id', 'user', 'organization', 'is_active', 'joined_date', 'role']
- read_only_fields = ['user', 'organization', 'is_active', 'joined_date']
+ fields = ["id", "user", "organization", "is_active", "joined_date", "role"]
+ read_only_fields = ["user", "organization", "is_active", "joined_date"]
+
class AcceptInvitationReadSerializer(serializers.Serializer):
organization_slug = serializers.CharField()
diff --git a/cvat/apps/organizations/throttle.py b/cvat/apps/organizations/throttle.py
index 438538b61d4a..342b9463170b 100644
--- a/cvat/apps/organizations/throttle.py
+++ b/cvat/apps/organizations/throttle.py
@@ -4,5 +4,6 @@
from rest_framework.throttling import UserRateThrottle
+
class ResendOrganizationInvitationThrottle(UserRateThrottle):
- rate = '5/hour'
+ rate = "5/hour"
diff --git a/cvat/apps/organizations/urls.py b/cvat/apps/organizations/urls.py
index 068f72b0968d..4ec7fdc628bc 100644
--- a/cvat/apps/organizations/urls.py
+++ b/cvat/apps/organizations/urls.py
@@ -3,11 +3,12 @@
# SPDX-License-Identifier: MIT
from rest_framework.routers import DefaultRouter
+
from .views import InvitationViewSet, MembershipViewSet, OrganizationViewSet
router = DefaultRouter(trailing_slash=False)
-router.register('organizations', OrganizationViewSet)
-router.register('invitations', InvitationViewSet)
-router.register('memberships', MembershipViewSet)
+router.register("organizations", OrganizationViewSet)
+router.register("invitations", InvitationViewSet)
+router.register("memberships", MembershipViewSet)
urlpatterns = router.urls
diff --git a/cvat/apps/organizations/views.py b/cvat/apps/organizations/views.py
index 11b92b29cad8..dbb1eeec9a9c 100644
--- a/cvat/apps/organizations/views.py
+++ b/cvat/apps/organizations/views.py
@@ -3,76 +3,87 @@
#
# SPDX-License-Identifier: MIT
-from django.utils.crypto import get_random_string
-from django.db import transaction
from django.core.exceptions import ImproperlyConfigured
-
-from rest_framework import mixins, viewsets, status
-from rest_framework.permissions import SAFE_METHODS
+from django.db import transaction
+from django.utils.crypto import get_random_string
+from drf_spectacular.utils import OpenApiResponse, extend_schema, extend_schema_view
+from rest_framework import mixins, status, viewsets
from rest_framework.decorators import action
+from rest_framework.permissions import SAFE_METHODS
from rest_framework.response import Response
-from drf_spectacular.utils import OpenApiResponse, extend_schema, extend_schema_view
-
+from cvat.apps.engine.mixins import PartialUpdateModelMixin
from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS
from cvat.apps.organizations.permissions import (
- InvitationPermission, MembershipPermission, OrganizationPermission)
+ InvitationPermission,
+ MembershipPermission,
+ OrganizationPermission,
+)
from cvat.apps.organizations.throttle import ResendOrganizationInvitationThrottle
-from cvat.apps.engine.mixins import PartialUpdateModelMixin
from .models import Invitation, Membership, Organization
-
from .serializers import (
- InvitationReadSerializer, InvitationWriteSerializer,
- MembershipReadSerializer, MembershipWriteSerializer,
- OrganizationReadSerializer, OrganizationWriteSerializer,
- AcceptInvitationReadSerializer)
+ AcceptInvitationReadSerializer,
+ InvitationReadSerializer,
+ InvitationWriteSerializer,
+ MembershipReadSerializer,
+ MembershipWriteSerializer,
+ OrganizationReadSerializer,
+ OrganizationWriteSerializer,
+)
+
-@extend_schema(tags=['organizations'])
+@extend_schema(tags=["organizations"])
@extend_schema_view(
retrieve=extend_schema(
- summary='Get organization details',
+ summary="Get organization details",
responses={
- '200': OrganizationReadSerializer,
- }),
+ "200": OrganizationReadSerializer,
+ },
+ ),
list=extend_schema(
- summary='List organizations',
+ summary="List organizations",
responses={
- '200': OrganizationReadSerializer(many=True),
- }),
+ "200": OrganizationReadSerializer(many=True),
+ },
+ ),
partial_update=extend_schema(
- summary='Update an organization',
+ summary="Update an organization",
request=OrganizationWriteSerializer(partial=True),
responses={
- '200': OrganizationReadSerializer, # check OrganizationWriteSerializer.to_representation
- }),
+ "200": OrganizationReadSerializer, # check OrganizationWriteSerializer.to_representation
+ },
+ ),
create=extend_schema(
- summary='Create an organization',
+ summary="Create an organization",
request=OrganizationWriteSerializer,
responses={
- '201': OrganizationReadSerializer, # check OrganizationWriteSerializer.to_representation
- }),
+ "201": OrganizationReadSerializer, # check OrganizationWriteSerializer.to_representation
+ },
+ ),
destroy=extend_schema(
- summary='Delete an organization',
+ summary="Delete an organization",
responses={
- '204': OpenApiResponse(description='The organization has been deleted'),
- })
+ "204": OpenApiResponse(description="The organization has been deleted"),
+ },
+ ),
)
-class OrganizationViewSet(viewsets.GenericViewSet,
- mixins.RetrieveModelMixin,
- mixins.ListModelMixin,
- mixins.CreateModelMixin,
- mixins.DestroyModelMixin,
- PartialUpdateModelMixin,
- ):
- queryset = Organization.objects.select_related('owner').all()
- search_fields = ('name', 'owner', 'slug')
- filter_fields = list(search_fields) + ['id']
+class OrganizationViewSet(
+ viewsets.GenericViewSet,
+ mixins.RetrieveModelMixin,
+ mixins.ListModelMixin,
+ mixins.CreateModelMixin,
+ mixins.DestroyModelMixin,
+ PartialUpdateModelMixin,
+):
+ queryset = Organization.objects.select_related("owner").all()
+ search_fields = ("name", "owner", "slug")
+ filter_fields = list(search_fields) + ["id"]
simple_filters = list(search_fields)
- lookup_fields = {'owner': 'owner__username'}
+ lookup_fields = {"owner": "owner__username"}
ordering_fields = list(filter_fields)
- ordering = '-id'
- http_method_names = ['get', 'post', 'patch', 'delete', 'head', 'options']
+ ordering = "-id"
+ http_method_names = ["get", "post", "patch", "delete", "head", "options"]
iam_organization_field = None
def get_queryset(self):
@@ -88,50 +99,60 @@ def get_serializer_class(self):
return OrganizationWriteSerializer
def perform_create(self, serializer):
- extra_kwargs = { 'owner': self.request.user }
- if not serializer.validated_data.get('name'):
- extra_kwargs.update({ 'name': serializer.validated_data['slug'] })
+ extra_kwargs = {"owner": self.request.user}
+ if not serializer.validated_data.get("name"):
+ extra_kwargs.update({"name": serializer.validated_data["slug"]})
serializer.save(**extra_kwargs)
class Meta:
model = Membership
- fields = ("user", )
+ fields = ("user",)
-@extend_schema(tags=['memberships'])
+
+@extend_schema(tags=["memberships"])
@extend_schema_view(
retrieve=extend_schema(
- summary='Get membership details',
+ summary="Get membership details",
responses={
- '200': MembershipReadSerializer,
- }),
+ "200": MembershipReadSerializer,
+ },
+ ),
list=extend_schema(
- summary='List memberships',
+ summary="List memberships",
responses={
- '200': MembershipReadSerializer(many=True),
- }),
+ "200": MembershipReadSerializer(many=True),
+ },
+ ),
partial_update=extend_schema(
- summary='Update a membership',
+ summary="Update a membership",
request=MembershipWriteSerializer(partial=True),
responses={
- '200': MembershipReadSerializer, # check MembershipWriteSerializer.to_representation
- }),
+ "200": MembershipReadSerializer, # check MembershipWriteSerializer.to_representation
+ },
+ ),
destroy=extend_schema(
- summary='Delete a membership',
+ summary="Delete a membership",
responses={
- '204': OpenApiResponse(description='The membership has been deleted'),
- })
+ "204": OpenApiResponse(description="The membership has been deleted"),
+ },
+ ),
)
-class MembershipViewSet(mixins.RetrieveModelMixin, mixins.DestroyModelMixin,
- mixins.ListModelMixin, PartialUpdateModelMixin, viewsets.GenericViewSet):
- queryset = Membership.objects.select_related('invitation', 'user').all()
- ordering = '-id'
- http_method_names = ['get', 'patch', 'delete', 'head', 'options']
- search_fields = ('user', 'role')
- filter_fields = list(search_fields) + ['id']
+class MembershipViewSet(
+ mixins.RetrieveModelMixin,
+ mixins.DestroyModelMixin,
+ mixins.ListModelMixin,
+ PartialUpdateModelMixin,
+ viewsets.GenericViewSet,
+):
+ queryset = Membership.objects.select_related("invitation", "user").all()
+ ordering = "-id"
+ http_method_names = ["get", "patch", "delete", "head", "options"]
+ search_fields = ("user", "role")
+ filter_fields = list(search_fields) + ["id"]
simple_filters = list(search_fields)
ordering_fields = list(filter_fields)
- lookup_fields = {'user': 'user__username'}
- iam_organization_field = 'organization'
+ lookup_fields = {"user": "user__username"}
+ iam_organization_field = "organization"
def get_serializer_class(self):
if self.request.method in SAFE_METHODS:
@@ -142,86 +163,98 @@ def get_serializer_class(self):
def get_queryset(self):
queryset = super().get_queryset()
- if self.action == 'list':
+ if self.action == "list":
permission = MembershipPermission.create_scope_list(self.request)
queryset = permission.filter(queryset)
return queryset
-@extend_schema(tags=['invitations'])
+
+@extend_schema(tags=["invitations"])
@extend_schema_view(
retrieve=extend_schema(
- summary='Get invitation details',
+ summary="Get invitation details",
responses={
- '200': InvitationReadSerializer,
- }),
+ "200": InvitationReadSerializer,
+ },
+ ),
list=extend_schema(
- summary='List invitations',
+ summary="List invitations",
responses={
- '200': InvitationReadSerializer(many=True),
- }),
+ "200": InvitationReadSerializer(many=True),
+ },
+ ),
partial_update=extend_schema(
- summary='Update an invitation',
+ summary="Update an invitation",
request=InvitationWriteSerializer(partial=True),
responses={
- '200': InvitationReadSerializer, # check InvitationWriteSerializer.to_representation
- }),
+ "200": InvitationReadSerializer, # check InvitationWriteSerializer.to_representation
+ },
+ ),
create=extend_schema(
- summary='Create an invitation',
+ summary="Create an invitation",
request=InvitationWriteSerializer,
parameters=ORGANIZATION_OPEN_API_PARAMETERS,
responses={
- '201': InvitationReadSerializer, # check InvitationWriteSerializer.to_representation
- }),
+ "201": InvitationReadSerializer, # check InvitationWriteSerializer.to_representation
+ },
+ ),
destroy=extend_schema(
- summary='Delete an invitation',
+ summary="Delete an invitation",
responses={
- '204': OpenApiResponse(description='The invitation has been deleted'),
- }),
+ "204": OpenApiResponse(description="The invitation has been deleted"),
+ },
+ ),
accept=extend_schema(
- operation_id='invitations_accept',
+ operation_id="invitations_accept",
request=None,
- summary='Accept an invitation',
+ summary="Accept an invitation",
responses={
- '200': OpenApiResponse(response=AcceptInvitationReadSerializer, description='The invitation is accepted'),
- '400': OpenApiResponse(description='The invitation is expired or already accepted'),
- }),
+ "200": OpenApiResponse(
+ response=AcceptInvitationReadSerializer, description="The invitation is accepted"
+ ),
+ "400": OpenApiResponse(description="The invitation is expired or already accepted"),
+ },
+ ),
decline=extend_schema(
- operation_id='invitations_decline',
+ operation_id="invitations_decline",
request=None,
- summary='Decline an invitation',
+ summary="Decline an invitation",
responses={
- '204': OpenApiResponse(description='The invitation has been declined'),
- }),
+ "204": OpenApiResponse(description="The invitation has been declined"),
+ },
+ ),
resend=extend_schema(
- operation_id='invitations_resend',
- summary='Resend an invitation',
+ operation_id="invitations_resend",
+ summary="Resend an invitation",
request=None,
responses={
- '204': OpenApiResponse(description='Invitation has been sent'),
- '400': OpenApiResponse(description='The invitation is already accepted'),
- }),
+ "204": OpenApiResponse(description="Invitation has been sent"),
+ "400": OpenApiResponse(description="The invitation is already accepted"),
+ },
+ ),
)
-class InvitationViewSet(viewsets.GenericViewSet,
- mixins.RetrieveModelMixin,
- mixins.ListModelMixin,
- PartialUpdateModelMixin,
- mixins.CreateModelMixin,
- mixins.DestroyModelMixin,
- ):
+class InvitationViewSet(
+ viewsets.GenericViewSet,
+ mixins.RetrieveModelMixin,
+ mixins.ListModelMixin,
+ PartialUpdateModelMixin,
+ mixins.CreateModelMixin,
+ mixins.DestroyModelMixin,
+):
queryset = Invitation.objects.all()
- http_method_names = ['get', 'post', 'patch', 'delete', 'head', 'options']
- iam_organization_field = 'membership__organization'
+ http_method_names = ["get", "post", "patch", "delete", "head", "options"]
+ iam_organization_field = "membership__organization"
- search_fields = ('owner',)
- filter_fields = list(search_fields) + ['user_id', 'accepted']
+ search_fields = ("owner",)
+ filter_fields = list(search_fields) + ["user_id", "accepted"]
simple_filters = list(search_fields)
- ordering_fields = list(simple_filters) + ['created_date']
- ordering = '-created_date'
+ ordering_fields = list(simple_filters) + ["created_date"]
+ ordering = "-created_date"
lookup_fields = {
- 'owner': 'owner__username',
- 'user_id': 'membership__user__id',
- 'accepted': 'membership__is_active',
+ "owner": "owner__username",
+ "user_id": "membership__user__id",
+ "accepted": "membership__is_active",
}
def get_serializer_class(self):
@@ -242,7 +275,10 @@ def create(self, request):
try:
self.perform_create(serializer)
except ImproperlyConfigured:
- return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR, data="Email backend is not configured.")
+ return Response(
+ status=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ data="Email backend is not configured.",
+ )
return Response(serializer.data, status=status.HTTP_201_CREATED)
@@ -250,51 +286,75 @@ def perform_create(self, serializer):
serializer.save(
owner=self.request.user,
key=get_random_string(length=64),
- organization=self.request.iam_context['organization'],
+ organization=self.request.iam_context["organization"],
request=self.request,
)
def perform_update(self, serializer):
- if 'accepted' in self.request.query_params:
+ if "accepted" in self.request.query_params:
serializer.instance.accept()
else:
super().perform_update(serializer)
@transaction.atomic
- @action(detail=True, methods=['POST'], url_path='accept')
+ @action(detail=True, methods=["POST"], url_path="accept")
def accept(self, request, pk):
try:
- invitation = self.get_object() # force to call check_object_permissions
+ invitation = self.get_object() # force to call check_object_permissions
if invitation.expired:
- return Response(status=status.HTTP_400_BAD_REQUEST, data="Your invitation is expired. Please contact organization owner to renew it.")
+ return Response(
+ status=status.HTTP_400_BAD_REQUEST,
+ data="Your invitation is expired. Please contact organization owner to renew it.",
+ )
if invitation.membership.is_active:
- return Response(status=status.HTTP_400_BAD_REQUEST, data="Your invitation is already accepted.")
+ return Response(
+ status=status.HTTP_400_BAD_REQUEST, data="Your invitation is already accepted."
+ )
invitation.accept()
- response_serializer = AcceptInvitationReadSerializer(data={'organization_slug': invitation.membership.organization.slug})
+ response_serializer = AcceptInvitationReadSerializer(
+ data={"organization_slug": invitation.membership.organization.slug}
+ )
response_serializer.is_valid(raise_exception=True)
return Response(status=status.HTTP_200_OK, data=response_serializer.data)
except Invitation.DoesNotExist:
- return Response(status=status.HTTP_404_NOT_FOUND, data="This invitation does not exist. Please contact organization owner.")
+ return Response(
+ status=status.HTTP_404_NOT_FOUND,
+ data="This invitation does not exist. Please contact organization owner.",
+ )
- @action(detail=True, methods=['POST'], url_path='resend', throttle_classes=[ResendOrganizationInvitationThrottle])
+ @action(
+ detail=True,
+ methods=["POST"],
+ url_path="resend",
+ throttle_classes=[ResendOrganizationInvitationThrottle],
+ )
def resend(self, request, pk):
try:
- invitation = self.get_object() # force to call check_object_permissions
+ invitation = self.get_object() # force to call check_object_permissions
if invitation.membership.is_active:
- return Response(status=status.HTTP_400_BAD_REQUEST, data="This invitation is already accepted.")
+ return Response(
+ status=status.HTTP_400_BAD_REQUEST, data="This invitation is already accepted."
+ )
invitation.send(request)
return Response(status=status.HTTP_204_NO_CONTENT)
except Invitation.DoesNotExist:
- return Response(status=status.HTTP_404_NOT_FOUND, data="This invitation does not exist.")
+ return Response(
+ status=status.HTTP_404_NOT_FOUND, data="This invitation does not exist."
+ )
except ImproperlyConfigured:
- return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR, data="Email backend is not configured.")
+ return Response(
+ status=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ data="Email backend is not configured.",
+ )
- @action(detail=True, methods=['POST'], url_path='decline')
+ @action(detail=True, methods=["POST"], url_path="decline")
def decline(self, request, pk):
try:
- invitation = self.get_object() # force to call check_object_permissions
+ invitation = self.get_object() # force to call check_object_permissions
membership = invitation.membership
membership.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
except Invitation.DoesNotExist:
- return Response(status=status.HTTP_404_NOT_FOUND, data="This invitation does not exist.")
+ return Response(
+ status=status.HTTP_404_NOT_FOUND, data="This invitation does not exist."
+ )
diff --git a/cvat/apps/quality_control/migrations/0006_rename_match_empty_frames_qualitysettings_empty_is_annotated.py b/cvat/apps/quality_control/migrations/0006_rename_match_empty_frames_qualitysettings_empty_is_annotated.py
new file mode 100644
index 000000000000..ea2f74927309
--- /dev/null
+++ b/cvat/apps/quality_control/migrations/0006_rename_match_empty_frames_qualitysettings_empty_is_annotated.py
@@ -0,0 +1,18 @@
+# Generated by Django 4.2.15 on 2024-12-29 19:08
+
+from django.db import migrations
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ("quality_control", "0005_qualitysettings_match_empty"),
+ ]
+
+ operations = [
+ migrations.RenameField(
+ model_name="qualitysettings",
+ old_name="match_empty_frames",
+ new_name="empty_is_annotated",
+ ),
+ ]
diff --git a/cvat/apps/quality_control/models.py b/cvat/apps/quality_control/models.py
index a5359e4fe944..c521ac276f31 100644
--- a/cvat/apps/quality_control/models.py
+++ b/cvat/apps/quality_control/models.py
@@ -235,7 +235,7 @@ class QualitySettings(models.Model):
compare_attributes = models.BooleanField()
- match_empty_frames = models.BooleanField(default=False)
+ empty_is_annotated = models.BooleanField(default=False)
target_metric = models.CharField(
max_length=32,
diff --git a/cvat/apps/quality_control/quality_reports.py b/cvat/apps/quality_control/quality_reports.py
index 25b5c962dc26..f757aeabc61a 100644
--- a/cvat/apps/quality_control/quality_reports.py
+++ b/cvat/apps/quality_control/quality_reports.py
@@ -215,10 +215,11 @@ class ComparisonParameters(_Serializable):
panoptic_comparison: bool = True
"Use only the visible part of the masks and polygons in comparisons"
- match_empty_frames: bool = False
+ empty_is_annotated: bool = False
"""
- Consider unannotated (empty) frames as matching. If disabled, quality metrics, such as accuracy,
- will be 0 if both GT and DS frames have no annotations. When enabled, they will be 1 instead.
+ Consider unannotated (empty) frames virtually annotated as "nothing".
+ If disabled, quality metrics, such as accuracy, will be 0 if both GT and DS frames
+ have no annotations. When enabled, they will be 1 instead.
This will also add virtual annotations to empty frames in the comparison results.
"""
@@ -1977,15 +1978,20 @@ def _find_closest_unmatched_shape(shape: dm.Annotation):
gt_label_idx = label_id_map[gt_ann.label] if gt_ann else self._UNMATCHED_IDX
confusion_matrix[ds_label_idx, gt_label_idx] += 1
- if self.settings.match_empty_frames and not gt_item.annotations and not ds_item.annotations:
+ if self.settings.empty_is_annotated:
# Add virtual annotations for empty frames
- valid_labels_count = 1
- total_labels_count = 1
+ if not gt_item.annotations and not ds_item.annotations:
+ valid_labels_count = 1
+ total_labels_count = 1
- valid_shapes_count = 1
- total_shapes_count = 1
- ds_shapes_count = 1
- gt_shapes_count = 1
+ valid_shapes_count = 1
+ total_shapes_count = 1
+
+ if not ds_item.annotations:
+ ds_shapes_count = 1
+
+ if not gt_item.annotations:
+ gt_shapes_count = 1
self._frame_results[frame_id] = ComparisonReportFrameSummary(
annotations=self._generate_frame_annotations_summary(
@@ -2078,12 +2084,17 @@ def _generate_frame_annotations_summary(
) -> ComparisonReportAnnotationsSummary:
summary = self._compute_annotations_summary(confusion_matrix, confusion_matrix_labels)
- if self.settings.match_empty_frames and summary.total_count == 0:
+ if self.settings.empty_is_annotated:
# Add virtual annotations for empty frames
- summary.valid_count = 1
- summary.total_count = 1
- summary.ds_count = 1
- summary.gt_count = 1
+ if not summary.total_count:
+ summary.valid_count = 1
+ summary.total_count = 1
+
+ if not summary.ds_count:
+ summary.ds_count = 1
+
+ if not summary.gt_count:
+ summary.gt_count = 1
return summary
@@ -2108,14 +2119,26 @@ def _generate_dataset_annotations_summary(
),
)
mean_ious = []
- empty_frame_count = 0
+ empty_gt_frames = set()
+ empty_ds_frames = set()
confusion_matrix_labels, confusion_matrix, _ = self._make_zero_confusion_matrix()
- for frame_result in frame_summaries.values():
+ for frame_id, frame_result in frame_summaries.items():
confusion_matrix += frame_result.annotations.confusion_matrix.rows
- if not np.any(frame_result.annotations.confusion_matrix.rows):
- empty_frame_count += 1
+ if self.settings.empty_is_annotated and not np.any(
+ frame_result.annotations.confusion_matrix.rows[
+ np.triu_indices_from(frame_result.annotations.confusion_matrix.rows)
+ ]
+ ):
+ empty_ds_frames.add(frame_id)
+
+ if self.settings.empty_is_annotated and not np.any(
+ frame_result.annotations.confusion_matrix.rows[
+ np.tril_indices_from(frame_result.annotations.confusion_matrix.rows)
+ ]
+ ):
+ empty_gt_frames.add(frame_id)
if annotation_components is None:
annotation_components = deepcopy(frame_result.annotation_components)
@@ -2128,13 +2151,13 @@ def _generate_dataset_annotations_summary(
confusion_matrix, confusion_matrix_labels
)
- if self.settings.match_empty_frames and empty_frame_count:
+ if self.settings.empty_is_annotated:
# Add virtual annotations for empty frames,
# they are not included in the confusion matrix
- annotation_summary.valid_count += empty_frame_count
- annotation_summary.total_count += empty_frame_count
- annotation_summary.ds_count += empty_frame_count
- annotation_summary.gt_count += empty_frame_count
+ annotation_summary.valid_count += len(empty_ds_frames & empty_gt_frames)
+ annotation_summary.total_count += len(empty_ds_frames | empty_gt_frames)
+ annotation_summary.ds_count += len(empty_ds_frames)
+ annotation_summary.gt_count += len(empty_gt_frames)
# Cannot be computed in accumulate()
annotation_components.shape.mean_iou = np.mean(mean_ious)
diff --git a/cvat/apps/quality_control/serializers.py b/cvat/apps/quality_control/serializers.py
index 6164abc12200..11a5e0d8b02e 100644
--- a/cvat/apps/quality_control/serializers.py
+++ b/cvat/apps/quality_control/serializers.py
@@ -92,7 +92,7 @@ class Meta:
"object_visibility_threshold",
"panoptic_comparison",
"compare_attributes",
- "match_empty_frames",
+ "empty_is_annotated",
)
read_only_fields = (
"id",
@@ -100,7 +100,7 @@ class Meta:
)
extra_kwargs = {k: {"required": False} for k in fields}
- extra_kwargs.setdefault("match_empty_frames", {}).setdefault("default", False)
+ extra_kwargs.setdefault("empty_is_annotated", {}).setdefault("default", False)
for field_name, help_text in {
"target_metric": "The primary metric used for quality estimation",
@@ -166,9 +166,9 @@ class Meta:
Use only the visible part of the masks and polygons in comparisons
""",
"compare_attributes": "Enables or disables annotation attribute comparison",
- "match_empty_frames": """
- Count empty frames as matching. This affects target metrics like accuracy in cases
- there are no annotations. If disabled, frames without annotations
+ "empty_is_annotated": """
+ Consider empty frames annotated as "empty". This affects target metrics like
+ accuracy in cases there are no annotations. If disabled, frames without annotations
are counted as not matching (accuracy is 0). If enabled, accuracy will be 1 instead.
This will also add virtual annotations to empty frames in the comparison results.
""",
diff --git a/cvat/apps/webhooks/apps.py b/cvat/apps/webhooks/apps.py
index ac193baed755..0b4cf34198f7 100644
--- a/cvat/apps/webhooks/apps.py
+++ b/cvat/apps/webhooks/apps.py
@@ -9,7 +9,8 @@ class WebhooksConfig(AppConfig):
name = "cvat.apps.webhooks"
def ready(self):
- from . import signals # pylint: disable=unused-import
-
from cvat.apps.iam.permissions import load_app_permissions
+
load_app_permissions(self)
+
+ from . import signals # pylint: disable=unused-import
diff --git a/cvat/apps/webhooks/event_type.py b/cvat/apps/webhooks/event_type.py
index 59cdb6cf99ed..ef98e5212824 100644
--- a/cvat/apps/webhooks/event_type.py
+++ b/cvat/apps/webhooks/event_type.py
@@ -47,7 +47,11 @@ class AllEvents:
class ProjectEvents:
webhook_type = WebhookTypeChoice.PROJECT
- events = [*Events.select(["task", "job", "label", "issue", "comment"]), event_name("update", "project"), event_name("delete", "project")]
+ events = [
+ *Events.select(["task", "job", "label", "issue", "comment"]),
+ event_name("update", "project"),
+ event_name("delete", "project"),
+ ]
class OrganizationEvents:
diff --git a/cvat/apps/webhooks/migrations/0001_initial.py b/cvat/apps/webhooks/migrations/0001_initial.py
index fe8f296b0514..e3638bd6be97 100644
--- a/cvat/apps/webhooks/migrations/0001_initial.py
+++ b/cvat/apps/webhooks/migrations/0001_initial.py
@@ -1,9 +1,10 @@
# Generated by Django 3.2.15 on 2022-09-19 08:26
-import cvat.apps.webhooks.models
+import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
-import django.db.models.deletion
+
+import cvat.apps.webhooks.models
class Migration(migrations.Migration):
@@ -11,54 +12,120 @@ class Migration(migrations.Migration):
initial = True
dependencies = [
- ('engine', '0060_alter_label_parent'),
+ ("engine", "0060_alter_label_parent"),
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
- ('organizations', '0001_initial'),
+ ("organizations", "0001_initial"),
]
operations = [
migrations.CreateModel(
- name='Webhook',
+ name="Webhook",
fields=[
- ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
- ('target_url', models.URLField()),
- ('description', models.CharField(blank=True, default='', max_length=128)),
- ('events', models.CharField(default='', max_length=4096)),
- ('type', models.CharField(choices=[('organization', 'ORGANIZATION'), ('project', 'PROJECT')], max_length=16)),
- ('content_type', models.CharField(choices=[('application/json', 'JSON')], default=cvat.apps.webhooks.models.WebhookContentTypeChoice['JSON'], max_length=64)),
- ('secret', models.CharField(blank=True, default='', max_length=64)),
- ('is_active', models.BooleanField(default=True)),
- ('enable_ssl', models.BooleanField(default=True)),
- ('created_date', models.DateTimeField(auto_now_add=True)),
- ('updated_date', models.DateTimeField(auto_now=True)),
- ('organization', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='+', to='organizations.organization')),
- ('owner', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to=settings.AUTH_USER_MODEL)),
- ('project', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='+', to='engine.project')),
+ (
+ "id",
+ models.AutoField(
+ auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
+ ),
+ ),
+ ("target_url", models.URLField()),
+ ("description", models.CharField(blank=True, default="", max_length=128)),
+ ("events", models.CharField(default="", max_length=4096)),
+ (
+ "type",
+ models.CharField(
+ choices=[("organization", "ORGANIZATION"), ("project", "PROJECT")],
+ max_length=16,
+ ),
+ ),
+ (
+ "content_type",
+ models.CharField(
+ choices=[("application/json", "JSON")],
+ default=cvat.apps.webhooks.models.WebhookContentTypeChoice["JSON"],
+ max_length=64,
+ ),
+ ),
+ ("secret", models.CharField(blank=True, default="", max_length=64)),
+ ("is_active", models.BooleanField(default=True)),
+ ("enable_ssl", models.BooleanField(default=True)),
+ ("created_date", models.DateTimeField(auto_now_add=True)),
+ ("updated_date", models.DateTimeField(auto_now=True)),
+ (
+ "organization",
+ models.ForeignKey(
+ null=True,
+ on_delete=django.db.models.deletion.CASCADE,
+ related_name="+",
+ to="organizations.organization",
+ ),
+ ),
+ (
+ "owner",
+ models.ForeignKey(
+ blank=True,
+ null=True,
+ on_delete=django.db.models.deletion.SET_NULL,
+ related_name="+",
+ to=settings.AUTH_USER_MODEL,
+ ),
+ ),
+ (
+ "project",
+ models.ForeignKey(
+ null=True,
+ on_delete=django.db.models.deletion.CASCADE,
+ related_name="+",
+ to="engine.project",
+ ),
+ ),
],
options={
- 'default_permissions': (),
+ "default_permissions": (),
},
),
migrations.CreateModel(
- name='WebhookDelivery',
+ name="WebhookDelivery",
fields=[
- ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
- ('event', models.CharField(max_length=64)),
- ('status_code', models.CharField(max_length=128, null=True)),
- ('redelivery', models.BooleanField(default=False)),
- ('created_date', models.DateTimeField(auto_now_add=True)),
- ('updated_date', models.DateTimeField(auto_now=True)),
- ('changed_fields', models.CharField(default='', max_length=4096)),
- ('request', models.JSONField(default=dict)),
- ('response', models.JSONField(default=dict)),
- ('webhook', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='deliveries', to='webhooks.webhook')),
+ (
+ "id",
+ models.AutoField(
+ auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
+ ),
+ ),
+ ("event", models.CharField(max_length=64)),
+ ("status_code", models.CharField(max_length=128, null=True)),
+ ("redelivery", models.BooleanField(default=False)),
+ ("created_date", models.DateTimeField(auto_now_add=True)),
+ ("updated_date", models.DateTimeField(auto_now=True)),
+ ("changed_fields", models.CharField(default="", max_length=4096)),
+ ("request", models.JSONField(default=dict)),
+ ("response", models.JSONField(default=dict)),
+ (
+ "webhook",
+ models.ForeignKey(
+ on_delete=django.db.models.deletion.CASCADE,
+ related_name="deliveries",
+ to="webhooks.webhook",
+ ),
+ ),
],
options={
- 'default_permissions': (),
+ "default_permissions": (),
},
),
migrations.AddConstraint(
- model_name='webhook',
- constraint=models.CheckConstraint(check=models.Q(models.Q(('project_id__isnull', False), ('type', 'project')), models.Q(('organization_id__isnull', False), ('project_id__isnull', True), ('type', 'organization')), _connector='OR'), name='webhooks_project_or_organization'),
+ model_name="webhook",
+ constraint=models.CheckConstraint(
+ check=models.Q(
+ models.Q(("project_id__isnull", False), ("type", "project")),
+ models.Q(
+ ("organization_id__isnull", False),
+ ("project_id__isnull", True),
+ ("type", "organization"),
+ ),
+ _connector="OR",
+ ),
+ name="webhooks_project_or_organization",
+ ),
),
]
diff --git a/cvat/apps/webhooks/migrations/0002_alter_webhookdelivery_status_code.py b/cvat/apps/webhooks/migrations/0002_alter_webhookdelivery_status_code.py
index fd1a2397d249..0429b1445117 100644
--- a/cvat/apps/webhooks/migrations/0002_alter_webhookdelivery_status_code.py
+++ b/cvat/apps/webhooks/migrations/0002_alter_webhookdelivery_status_code.py
@@ -6,13 +6,77 @@
class Migration(migrations.Migration):
dependencies = [
- ('webhooks', '0001_initial'),
+ ("webhooks", "0001_initial"),
]
operations = [
migrations.AlterField(
- model_name='webhookdelivery',
- name='status_code',
- field=models.IntegerField(choices=[('CONTINUE', 100), ('SWITCHING_PROTOCOLS', 101), ('PROCESSING', 102), ('OK', 200), ('CREATED', 201), ('ACCEPTED', 202), ('NON_AUTHORITATIVE_INFORMATION', 203), ('NO_CONTENT', 204), ('RESET_CONTENT', 205), ('PARTIAL_CONTENT', 206), ('MULTI_STATUS', 207), ('ALREADY_REPORTED', 208), ('IM_USED', 226), ('MULTIPLE_CHOICES', 300), ('MOVED_PERMANENTLY', 301), ('FOUND', 302), ('SEE_OTHER', 303), ('NOT_MODIFIED', 304), ('USE_PROXY', 305), ('TEMPORARY_REDIRECT', 307), ('PERMANENT_REDIRECT', 308), ('BAD_REQUEST', 400), ('UNAUTHORIZED', 401), ('PAYMENT_REQUIRED', 402), ('FORBIDDEN', 403), ('NOT_FOUND', 404), ('METHOD_NOT_ALLOWED', 405), ('NOT_ACCEPTABLE', 406), ('PROXY_AUTHENTICATION_REQUIRED', 407), ('REQUEST_TIMEOUT', 408), ('CONFLICT', 409), ('GONE', 410), ('LENGTH_REQUIRED', 411), ('PRECONDITION_FAILED', 412), ('REQUEST_ENTITY_TOO_LARGE', 413), ('REQUEST_URI_TOO_LONG', 414), ('UNSUPPORTED_MEDIA_TYPE', 415), ('REQUESTED_RANGE_NOT_SATISFIABLE', 416), ('EXPECTATION_FAILED', 417), ('MISDIRECTED_REQUEST', 421), ('UNPROCESSABLE_ENTITY', 422), ('LOCKED', 423), ('FAILED_DEPENDENCY', 424), ('UPGRADE_REQUIRED', 426), ('PRECONDITION_REQUIRED', 428), ('TOO_MANY_REQUESTS', 429), ('REQUEST_HEADER_FIELDS_TOO_LARGE', 431), ('UNAVAILABLE_FOR_LEGAL_REASONS', 451), ('INTERNAL_SERVER_ERROR', 500), ('NOT_IMPLEMENTED', 501), ('BAD_GATEWAY', 502), ('SERVICE_UNAVAILABLE', 503), ('GATEWAY_TIMEOUT', 504), ('HTTP_VERSION_NOT_SUPPORTED', 505), ('VARIANT_ALSO_NEGOTIATES', 506), ('INSUFFICIENT_STORAGE', 507), ('LOOP_DETECTED', 508), ('NOT_EXTENDED', 510), ('NETWORK_AUTHENTICATION_REQUIRED', 511)], default=None, null=True),
+ model_name="webhookdelivery",
+ name="status_code",
+ field=models.IntegerField(
+ choices=[
+ ("CONTINUE", 100),
+ ("SWITCHING_PROTOCOLS", 101),
+ ("PROCESSING", 102),
+ ("OK", 200),
+ ("CREATED", 201),
+ ("ACCEPTED", 202),
+ ("NON_AUTHORITATIVE_INFORMATION", 203),
+ ("NO_CONTENT", 204),
+ ("RESET_CONTENT", 205),
+ ("PARTIAL_CONTENT", 206),
+ ("MULTI_STATUS", 207),
+ ("ALREADY_REPORTED", 208),
+ ("IM_USED", 226),
+ ("MULTIPLE_CHOICES", 300),
+ ("MOVED_PERMANENTLY", 301),
+ ("FOUND", 302),
+ ("SEE_OTHER", 303),
+ ("NOT_MODIFIED", 304),
+ ("USE_PROXY", 305),
+ ("TEMPORARY_REDIRECT", 307),
+ ("PERMANENT_REDIRECT", 308),
+ ("BAD_REQUEST", 400),
+ ("UNAUTHORIZED", 401),
+ ("PAYMENT_REQUIRED", 402),
+ ("FORBIDDEN", 403),
+ ("NOT_FOUND", 404),
+ ("METHOD_NOT_ALLOWED", 405),
+ ("NOT_ACCEPTABLE", 406),
+ ("PROXY_AUTHENTICATION_REQUIRED", 407),
+ ("REQUEST_TIMEOUT", 408),
+ ("CONFLICT", 409),
+ ("GONE", 410),
+ ("LENGTH_REQUIRED", 411),
+ ("PRECONDITION_FAILED", 412),
+ ("REQUEST_ENTITY_TOO_LARGE", 413),
+ ("REQUEST_URI_TOO_LONG", 414),
+ ("UNSUPPORTED_MEDIA_TYPE", 415),
+ ("REQUESTED_RANGE_NOT_SATISFIABLE", 416),
+ ("EXPECTATION_FAILED", 417),
+ ("MISDIRECTED_REQUEST", 421),
+ ("UNPROCESSABLE_ENTITY", 422),
+ ("LOCKED", 423),
+ ("FAILED_DEPENDENCY", 424),
+ ("UPGRADE_REQUIRED", 426),
+ ("PRECONDITION_REQUIRED", 428),
+ ("TOO_MANY_REQUESTS", 429),
+ ("REQUEST_HEADER_FIELDS_TOO_LARGE", 431),
+ ("UNAVAILABLE_FOR_LEGAL_REASONS", 451),
+ ("INTERNAL_SERVER_ERROR", 500),
+ ("NOT_IMPLEMENTED", 501),
+ ("BAD_GATEWAY", 502),
+ ("SERVICE_UNAVAILABLE", 503),
+ ("GATEWAY_TIMEOUT", 504),
+ ("HTTP_VERSION_NOT_SUPPORTED", 505),
+ ("VARIANT_ALSO_NEGOTIATES", 506),
+ ("INSUFFICIENT_STORAGE", 507),
+ ("LOOP_DETECTED", 508),
+ ("NOT_EXTENDED", 510),
+ ("NETWORK_AUTHENTICATION_REQUIRED", 511),
+ ],
+ default=None,
+ null=True,
+ ),
),
]
diff --git a/cvat/apps/webhooks/migrations/0003_alter_webhookdelivery_status_code.py b/cvat/apps/webhooks/migrations/0003_alter_webhookdelivery_status_code.py
index 676f03a2dc9b..234a4d685d58 100644
--- a/cvat/apps/webhooks/migrations/0003_alter_webhookdelivery_status_code.py
+++ b/cvat/apps/webhooks/migrations/0003_alter_webhookdelivery_status_code.py
@@ -6,13 +6,13 @@
class Migration(migrations.Migration):
dependencies = [
- ('webhooks', '0002_alter_webhookdelivery_status_code'),
+ ("webhooks", "0002_alter_webhookdelivery_status_code"),
]
operations = [
migrations.AlterField(
- model_name='webhookdelivery',
- name='status_code',
+ model_name="webhookdelivery",
+ name="status_code",
field=models.PositiveIntegerField(default=None, null=True),
),
]
diff --git a/cvat/apps/webhooks/migrations/0004_alter_webhook_target_url.py b/cvat/apps/webhooks/migrations/0004_alter_webhook_target_url.py
index 00be6a309df2..f2f716f8cd88 100644
--- a/cvat/apps/webhooks/migrations/0004_alter_webhook_target_url.py
+++ b/cvat/apps/webhooks/migrations/0004_alter_webhook_target_url.py
@@ -6,13 +6,13 @@
class Migration(migrations.Migration):
dependencies = [
- ('webhooks', '0003_alter_webhookdelivery_status_code'),
+ ("webhooks", "0003_alter_webhookdelivery_status_code"),
]
operations = [
migrations.AlterField(
- model_name='webhook',
- name='target_url',
+ model_name="webhook",
+ name="target_url",
field=models.URLField(max_length=8192),
),
]
diff --git a/cvat/apps/webhooks/models.py b/cvat/apps/webhooks/models.py
index 104faccd60a4..650cd814fae0 100644
--- a/cvat/apps/webhooks/models.py
+++ b/cvat/apps/webhooks/models.py
@@ -53,9 +53,7 @@ class Webhook(TimestampedModel):
owner = models.ForeignKey(
User, null=True, blank=True, on_delete=models.SET_NULL, related_name="+"
)
- project = models.ForeignKey(
- Project, null=True, on_delete=models.CASCADE, related_name="+"
- )
+ project = models.ForeignKey(Project, null=True, on_delete=models.CASCADE, related_name="+")
organization = models.ForeignKey(
Organization, null=True, on_delete=models.CASCADE, related_name="+"
)
@@ -66,9 +64,7 @@ class Meta:
models.CheckConstraint(
name="webhooks_project_or_organization",
check=(
- models.Q(
- type=WebhookTypeChoice.PROJECT.value, project_id__isnull=False
- )
+ models.Q(type=WebhookTypeChoice.PROJECT.value, project_id__isnull=False)
| models.Q(
type=WebhookTypeChoice.ORGANIZATION.value,
project_id__isnull=True,
@@ -80,9 +76,7 @@ class Meta:
class WebhookDelivery(TimestampedModel):
- webhook = models.ForeignKey(
- Webhook, on_delete=models.CASCADE, related_name="deliveries"
- )
+ webhook = models.ForeignKey(Webhook, on_delete=models.CASCADE, related_name="deliveries")
event = models.CharField(max_length=64)
status_code = models.PositiveIntegerField(null=True, default=None)
diff --git a/cvat/apps/webhooks/permissions.py b/cvat/apps/webhooks/permissions.py
index e5d132c55de6..3ce72bd350a4 100644
--- a/cvat/apps/webhooks/permissions.py
+++ b/cvat/apps/webhooks/permissions.py
@@ -4,7 +4,6 @@
# SPDX-License-Identifier: MIT
from django.conf import settings
-
from rest_framework.exceptions import ValidationError
from cvat.apps.engine.models import Project
@@ -13,27 +12,29 @@
from .models import WebhookTypeChoice
+
class WebhookPermission(OpenPolicyAgentPermission):
class Scopes(StrEnum):
- CREATE = 'create'
- CREATE_IN_PROJECT = 'create@project'
- CREATE_IN_ORG = 'create@organization'
- DELETE = 'delete'
- UPDATE = 'update'
- LIST = 'list'
- VIEW = 'view'
+ CREATE = "create"
+ CREATE_IN_PROJECT = "create@project"
+ CREATE_IN_ORG = "create@organization"
+ DELETE = "delete"
+ UPDATE = "update"
+ LIST = "list"
+ VIEW = "view"
@classmethod
def create(cls, request, view, obj, iam_context):
permissions = []
- if view.basename == 'webhook':
- project_id = request.data.get('project_id')
+ if view.basename == "webhook":
+ project_id = request.data.get("project_id")
for scope in cls.get_scopes(request, view, obj):
- self = cls.create_base_perm(request, view, scope, iam_context, obj,
- project_id=project_id)
+ self = cls.create_base_perm(
+ request, view, scope, iam_context, obj, project_id=project_id
+ )
permissions.append(self)
- owner = request.data.get('owner_id') or request.data.get('owner')
+ owner = request.data.get("owner_id") or request.data.get("owner")
if owner:
perm = UserPermission.create_scope_view(iam_context, owner)
permissions.append(perm)
@@ -46,29 +47,29 @@ def create(cls, request, view, obj, iam_context):
def __init__(self, **kwargs):
super().__init__(**kwargs)
- self.url = settings.IAM_OPA_DATA_URL + '/webhooks/allow'
+ self.url = settings.IAM_OPA_DATA_URL + "/webhooks/allow"
@staticmethod
def get_scopes(request, view, obj):
Scopes = __class__.Scopes
scope = {
- ('create', 'POST'): Scopes.CREATE,
- ('destroy', 'DELETE'): Scopes.DELETE,
- ('partial_update', 'PATCH'): Scopes.UPDATE,
- ('update', 'PUT'): Scopes.UPDATE,
- ('list', 'GET'): Scopes.LIST,
- ('retrieve', 'GET'): Scopes.VIEW,
- ('ping', 'POST'): Scopes.UPDATE,
- ('deliveries', 'GET'): Scopes.VIEW,
- ('retrieve_delivery', 'GET'): Scopes.VIEW,
- ('redelivery', 'POST'): Scopes.UPDATE,
+ ("create", "POST"): Scopes.CREATE,
+ ("destroy", "DELETE"): Scopes.DELETE,
+ ("partial_update", "PATCH"): Scopes.UPDATE,
+ ("update", "PUT"): Scopes.UPDATE,
+ ("list", "GET"): Scopes.LIST,
+ ("retrieve", "GET"): Scopes.VIEW,
+ ("ping", "POST"): Scopes.UPDATE,
+ ("deliveries", "GET"): Scopes.VIEW,
+ ("retrieve_delivery", "GET"): Scopes.VIEW,
+ ("redelivery", "POST"): Scopes.UPDATE,
}[(view.action, request.method)]
scopes = []
if scope == Scopes.CREATE:
- webhook_type = request.data.get('type')
+ webhook_type = request.data.get("type")
if webhook_type in [m.value for m in WebhookTypeChoice]:
- scope = Scopes(str(scope) + f'@{webhook_type}')
+ scope = Scopes(str(scope) + f"@{webhook_type}")
scopes.append(scope)
else:
scopes.append(scope)
@@ -80,42 +81,52 @@ def get_resource(self):
if self.obj:
data = {
"id": self.obj.id,
- "owner": {"id": getattr(self.obj.owner, 'id', None) },
- 'organization': {
- "id": getattr(self.obj.organization, 'id', None)
- },
- "project": None
+ "owner": {"id": getattr(self.obj.owner, "id", None)},
+ "organization": {"id": getattr(self.obj.organization, "id", None)},
+ "project": None,
}
- if self.obj.type == 'project' and getattr(self.obj, 'project', None):
- data['project'] = {
- 'owner': {'id': getattr(self.obj.project.owner, 'id', None)}
- }
+ if self.obj.type == "project" and getattr(self.obj, "project", None):
+ data["project"] = {"owner": {"id": getattr(self.obj.project.owner, "id", None)}}
elif self.scope in [
__class__.Scopes.CREATE,
__class__.Scopes.CREATE_IN_PROJECT,
- __class__.Scopes.CREATE_IN_ORG
+ __class__.Scopes.CREATE_IN_ORG,
]:
project = None
if self.project_id:
try:
project = Project.objects.get(id=self.project_id)
except Project.DoesNotExist:
- raise ValidationError(f"Could not find project with provided id: {self.project_id}")
+ raise ValidationError(
+ f"Could not find project with provided id: {self.project_id}"
+ )
data = {
- 'id': None,
- 'owner': self.user_id,
- 'project': {
- 'owner': {
- 'id': project.owner.id,
- } if project.owner else None,
- } if project else None,
- 'organization': {
- 'id': self.org_id,
- } if self.org_id is not None else None,
- 'user': {
- 'id': self.user_id,
- }
+ "id": None,
+ "owner": self.user_id,
+ "project": (
+ {
+ "owner": (
+ {
+ "id": project.owner.id,
+ }
+ if project.owner
+ else None
+ ),
+ }
+ if project
+ else None
+ ),
+ "organization": (
+ {
+ "id": self.org_id,
+ }
+ if self.org_id is not None
+ else None
+ ),
+ "user": {
+ "id": self.user_id,
+ },
}
return data
diff --git a/cvat/apps/webhooks/rules/tests/generators/webhooks_test.gen.rego.py b/cvat/apps/webhooks/rules/tests/generators/webhooks_test.gen.rego.py
index 66417f3d096d..2913bb5a2a6a 100644
--- a/cvat/apps/webhooks/rules/tests/generators/webhooks_test.gen.rego.py
+++ b/cvat/apps/webhooks/rules/tests/generators/webhooks_test.gen.rego.py
@@ -125,13 +125,15 @@ def get_data(scope, context, ownership, privilege, membership, resource, same_or
"scope": scope,
"auth": {
"user": {"id": random.randrange(0, 100), "privilege": privilege},
- "organization": {
- "id": random.randrange(100, 200),
- "owner": {"id": random.randrange(200, 300)},
- "user": {"role": membership},
- }
- if context == "organization"
- else None,
+ "organization": (
+ {
+ "id": random.randrange(100, 200),
+ "owner": {"id": random.randrange(200, 300)},
+ "user": {"role": membership},
+ }
+ if context == "organization"
+ else None
+ ),
},
"resource": resource,
}
diff --git a/cvat/apps/webhooks/serializers.py b/cvat/apps/webhooks/serializers.py
index d2bb1f309105..bd540de55fbd 100644
--- a/cvat/apps/webhooks/serializers.py
+++ b/cvat/apps/webhooks/serializers.py
@@ -7,13 +7,8 @@
from cvat.apps.engine.models import Project
from cvat.apps.engine.serializers import BasicUserSerializer, WriteOnceMixin
-from .event_type import EventTypeChoice, ProjectEvents, OrganizationEvents
-from .models import (
- Webhook,
- WebhookContentTypeChoice,
- WebhookTypeChoice,
- WebhookDelivery,
-)
+from .event_type import EventTypeChoice, OrganizationEvents, ProjectEvents
+from .models import Webhook, WebhookContentTypeChoice, WebhookDelivery, WebhookTypeChoice
class EventTypeValidator:
@@ -35,9 +30,7 @@ def __call__(self, attrs, serializer):
webhook_type == WebhookTypeChoice.ORGANIZATION
and not events.issubset(set(OrganizationEvents.events))
):
- raise serializers.ValidationError(
- f"Invalid events list for {webhook_type} webhook"
- )
+ raise serializers.ValidationError(f"Invalid events list for {webhook_type} webhook")
class EventTypesSerializer(serializers.MultipleChoiceField):
@@ -67,9 +60,7 @@ class WebhookReadSerializer(serializers.ModelSerializer):
type = serializers.ChoiceField(choices=WebhookTypeChoice.choices())
content_type = serializers.ChoiceField(choices=WebhookContentTypeChoice.choices())
- last_status = serializers.IntegerField(
- source="deliveries.last.status_code", read_only=True
- )
+ last_status = serializers.IntegerField(source="deliveries.last.status_code", read_only=True)
last_delivery_date = serializers.DateTimeField(
source="deliveries.last.updated_date", read_only=True
@@ -104,9 +95,7 @@ class Meta:
class WebhookWriteSerializer(WriteOnceMixin, serializers.ModelSerializer):
events = EventTypesSerializer(write_only=True)
- project_id = serializers.IntegerField(
- write_only=True, allow_null=True, required=False
- )
+ project_id = serializers.IntegerField(write_only=True, allow_null=True, required=False)
def to_representation(self, instance):
serializer = WebhookReadSerializer(instance, context=self.context)
@@ -129,8 +118,8 @@ class Meta:
validators = [EventTypeValidator()]
def create(self, validated_data):
- if (project_id := validated_data.get('project_id')) is not None:
- validated_data['organization'] = Project.objects.get(pk=project_id).organization
+ if (project_id := validated_data.get("project_id")) is not None:
+ validated_data["organization"] = Project.objects.get(pk=project_id).organization
db_webhook = Webhook.objects.create(**validated_data)
return db_webhook
diff --git a/cvat/apps/webhooks/signals.py b/cvat/apps/webhooks/signals.py
index 3e17e8f3d8f6..6e08e35192dd 100644
--- a/cvat/apps/webhooks/signals.py
+++ b/cvat/apps/webhooks/signals.py
@@ -13,17 +13,21 @@
from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
from django.db import transaction
-from django.db.models.signals import (post_delete, post_save, pre_delete,
- pre_save)
+from django.db.models.signals import post_delete, post_save, pre_delete, pre_save
from django.dispatch import Signal, receiver
from cvat.apps.engine.models import Comment, Issue, Job, Project, Task
from cvat.apps.engine.serializers import BasicUserSerializer
-from cvat.apps.events.handlers import (get_request, get_serializer, get_user,
- get_instance_diff, organization_id,
- project_id)
+from cvat.apps.events.handlers import (
+ get_instance_diff,
+ get_request,
+ get_serializer,
+ get_user,
+ organization_id,
+ project_id,
+)
from cvat.apps.organizations.models import Invitation, Membership, Organization
-from cvat.utils.http import make_requests_session, PROXIES_FOR_UNTRUSTED_URLS
+from cvat.utils.http import PROXIES_FOR_UNTRUSTED_URLS, make_requests_session
from .event_type import EventTypeChoice, event_name
from .models import Webhook, WebhookDelivery, WebhookTypeChoice
@@ -34,6 +38,7 @@
signal_redelivery = Signal()
signal_ping = Signal()
+
def send_webhook(webhook, payload, redelivery=False):
headers = {}
if webhook.secret:
@@ -59,9 +64,7 @@ def send_webhook(webhook, payload, redelivery=False):
proxies=PROXIES_FOR_UNTRUSTED_URLS,
)
status_code = response.status_code
- response_body = response.raw.read(
- RESPONSE_SIZE_LIMIT + 1, decode_content=True
- )
+ response_body = response.raw.read(RESPONSE_SIZE_LIMIT + 1, decode_content=True)
except requests.ConnectionError:
status_code = HTTPStatus.BAD_GATEWAY
except requests.Timeout:
@@ -83,6 +86,7 @@ def send_webhook(webhook, payload, redelivery=False):
return delivery
+
def add_to_queue(webhook, payload, redelivery=False):
queue = django_rq.get_queue(settings.CVAT_QUEUES.WEBHOOKS.value)
queue.enqueue_call(func=send_webhook, args=(webhook, payload, redelivery))
@@ -163,6 +167,7 @@ def pre_save_resource_event(sender, instance, **kwargs):
old_serializer = get_serializer(instance=old_instance)
instance._webhooks_old_data = old_serializer.data
+
@receiver(post_save, sender=Project, dispatch_uid=__name__ + ":project:post_save")
@receiver(post_save, sender=Task, dispatch_uid=__name__ + ":task:post_save")
@receiver(post_save, sender=Job, dispatch_uid=__name__ + ":job:post_save")
@@ -196,10 +201,7 @@ def post_save_resource_event(sender, instance, **kwargs):
if not created:
if diff := get_instance_diff(old_data=old_data, data=serializer.data):
- data["before_update"] = {
- attr: value["old_value"]
- for attr, value in diff.items()
- }
+ data["before_update"] = {attr: value["old_value"] for attr, value in diff.items()}
transaction.on_commit(
lambda: batch_add_to_queue(selected_webhooks, data),
@@ -250,7 +252,11 @@ def post_delete_resource_event(sender, instance, **kwargs):
"sender": get_sender(instance),
}
- related_webhooks = [webhook for webhook in getattr(instance, "_related_webhooks", []) if webhook.id not in map(lambda a: a.id, filtered_webhooks)]
+ related_webhooks = [
+ webhook
+ for webhook in getattr(instance, "_related_webhooks", [])
+ if webhook.id not in map(lambda a: a.id, filtered_webhooks)
+ ]
transaction.on_commit(
lambda: batch_add_to_queue(filtered_webhooks + related_webhooks, data),
diff --git a/cvat/apps/webhooks/urls.py b/cvat/apps/webhooks/urls.py
index c309df746f96..26f86fc2313e 100644
--- a/cvat/apps/webhooks/urls.py
+++ b/cvat/apps/webhooks/urls.py
@@ -3,6 +3,7 @@
# SPDX-License-Identifier: MIT
from rest_framework.routers import DefaultRouter
+
from .views import WebhookViewSet
router = DefaultRouter(trailing_slash=False)
diff --git a/cvat/apps/webhooks/views.py b/cvat/apps/webhooks/views.py
index 66529bc6a7bd..b4e059c528f6 100644
--- a/cvat/apps/webhooks/views.py
+++ b/cvat/apps/webhooks/views.py
@@ -2,9 +2,13 @@
#
# SPDX-License-Identifier: MIT
-from drf_spectacular.utils import (OpenApiParameter, OpenApiResponse,
- OpenApiTypes, extend_schema,
- extend_schema_view)
+from drf_spectacular.utils import (
+ OpenApiParameter,
+ OpenApiResponse,
+ OpenApiTypes,
+ extend_schema,
+ extend_schema_view,
+)
from rest_framework import status, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import SAFE_METHODS
@@ -16,8 +20,12 @@
from .event_type import AllEvents, OrganizationEvents, ProjectEvents
from .models import Webhook, WebhookDelivery, WebhookTypeChoice
from .permissions import WebhookPermission
-from .serializers import (EventsSerializer, WebhookDeliveryReadSerializer,
- WebhookReadSerializer, WebhookWriteSerializer)
+from .serializers import (
+ EventsSerializer,
+ WebhookDeliveryReadSerializer,
+ WebhookReadSerializer,
+ WebhookWriteSerializer,
+)
from .signals import signal_ping, signal_redelivery
@@ -34,24 +42,18 @@
update=extend_schema(
summary="Replace a webhook",
request=WebhookWriteSerializer,
- responses={
- "200": WebhookReadSerializer
- }, # check WebhookWriteSerializer.to_representation
+ responses={"200": WebhookReadSerializer}, # check WebhookWriteSerializer.to_representation
),
partial_update=extend_schema(
summary="Update a webhook",
request=WebhookWriteSerializer,
- responses={
- "200": WebhookReadSerializer
- }, # check WebhookWriteSerializer.to_representation
+ responses={"200": WebhookReadSerializer}, # check WebhookWriteSerializer.to_representation
),
create=extend_schema(
request=WebhookWriteSerializer,
summary="Create a webhook",
parameters=ORGANIZATION_OPEN_API_PARAMETERS,
- responses={
- "201": WebhookReadSerializer
- }, # check WebhookWriteSerializer.to_representation
+ responses={"201": WebhookReadSerializer}, # check WebhookWriteSerializer.to_representation
),
destroy=extend_schema(
summary="Delete a webhook",
@@ -71,9 +73,7 @@ class WebhookViewSet(viewsets.ModelViewSet):
iam_organization_field = "organization"
def get_serializer_class(self):
- if self.request.path.endswith("redelivery") or self.request.path.endswith(
- "ping"
- ):
+ if self.request.path.endswith("redelivery") or self.request.path.endswith("ping"):
return None
else:
if self.request.method in SAFE_METHODS:
@@ -109,7 +109,10 @@ def perform_create(self, serializer):
],
responses={"200": OpenApiResponse(EventsSerializer)},
)
- @action(detail=False, methods=["GET"], serializer_class=EventsSerializer,
+ @action(
+ detail=False,
+ methods=["GET"],
+ serializer_class=EventsSerializer,
permission_classes=[],
)
def events(self, request):
@@ -123,9 +126,7 @@ def events(self, request):
events = OrganizationEvents
if events is None:
- return Response(
- "Incorrect value of type parameter", status=status.HTTP_400_BAD_REQUEST
- )
+ return Response("Incorrect value of type parameter", status=status.HTTP_400_BAD_REQUEST)
return Response(EventsSerializer().to_representation(events))
@@ -137,10 +138,8 @@ def events(self, request):
)
@list_action(serializer_class=WebhookDeliveryReadSerializer)
def deliveries(self, request, pk):
- self.get_object() # force call of check_object_permissions()
- queryset = WebhookDelivery.objects.filter(webhook_id=pk).order_by(
- "-updated_date"
- )
+ self.get_object() # force call of check_object_permissions()
+ queryset = WebhookDelivery.objects.filter(webhook_id=pk).order_by("-updated_date")
return make_paginated_response(
queryset, viewset=self, serializer_type=self.serializer_class
) # from @action
@@ -156,11 +155,9 @@ def deliveries(self, request, pk):
serializer_class=WebhookDeliveryReadSerializer,
)
def retrieve_delivery(self, request, pk, delivery_id):
- self.get_object() # force call of check_object_permissions()
+ self.get_object() # force call of check_object_permissions()
queryset = WebhookDelivery.objects.get(webhook_id=pk, id=delivery_id)
- serializer = WebhookDeliveryReadSerializer(
- queryset, context={"request": request}
- )
+ serializer = WebhookDeliveryReadSerializer(queryset, context={"request": request})
return Response(serializer.data)
@extend_schema(
@@ -184,15 +181,11 @@ def redelivery(self, request, pk, delivery_id):
request=None,
responses={"200": WebhookDeliveryReadSerializer},
)
- @action(
- detail=True, methods=["POST"], serializer_class=WebhookDeliveryReadSerializer
- )
+ @action(detail=True, methods=["POST"], serializer_class=WebhookDeliveryReadSerializer)
def ping(self, request, pk):
- instance = self.get_object() # force call of check_object_permissions()
+ instance = self.get_object() # force call of check_object_permissions()
serializer = WebhookReadSerializer(instance, context={"request": request})
delivery = signal_ping.send(sender=self, serializer=serializer)[0][1]
- serializer = WebhookDeliveryReadSerializer(
- delivery, context={"request": request}
- )
+ serializer = WebhookDeliveryReadSerializer(delivery, context={"request": request})
return Response(serializer.data)
diff --git a/cvat/requirements/base.in b/cvat/requirements/base.in
index b3900f010dda..03d74579fb21 100644
--- a/cvat/requirements/base.in
+++ b/cvat/requirements/base.in
@@ -12,7 +12,7 @@ azure-storage-blob==12.13.0
boto3==1.17.61
clickhouse-connect==0.6.8
coreapi==2.3.3
-datumaro @ git+https://github.com/cvat-ai/datumaro.git@232c175ef1f3b7e55bd5162353df9c86a8116fde
+datumaro @ git+https://github.com/cvat-ai/datumaro.git@08e77b216080555a57e12c01625be8c8201e3131
dj-pagination==2.5.0
# Despite direct indication allauth in requirements we should keep 'with_social' for dj-rest-auth
# to avoid possible further versions conflicts (we use registration functionality)
diff --git a/cvat/requirements/base.txt b/cvat/requirements/base.txt
index fce784b9481c..f531f125ebf6 100644
--- a/cvat/requirements/base.txt
+++ b/cvat/requirements/base.txt
@@ -1,4 +1,4 @@
-# SHA1:5a3efd0a5c1892698d4394f019ef659275b10fdb
+# SHA1:3e6349d9e5e095c5a1f196eca66b3e5ba8672458
#
# This file is autogenerated by pip-compile-multi
# To update, run:
@@ -56,7 +56,7 @@ cryptography==44.0.0
# pyjwt
cycler==0.12.1
# via matplotlib
-datumaro @ git+https://github.com/cvat-ai/datumaro.git@232c175ef1f3b7e55bd5162353df9c86a8116fde
+datumaro @ git+https://github.com/cvat-ai/datumaro.git@08e77b216080555a57e12c01625be8c8201e3131
# via -r cvat/requirements/base.in
defusedxml==0.7.1
# via
@@ -147,7 +147,7 @@ idna==3.10
# via requests
importlib-metadata==8.5.0
# via clickhouse-connect
-importlib-resources==6.4.5
+importlib-resources==6.5.2
# via
# matplotlib
# nibabel
@@ -169,7 +169,7 @@ jsonschema==4.17.3
# via drf-spectacular
kiwisolver==1.4.7
# via matplotlib
-limits==3.14.1
+limits==4.0.0
# via python-logstash-async
lxml==5.3.0
# via
@@ -197,7 +197,7 @@ oauthlib==3.2.2
# via requests-oauthlib
orderedmultidict==1.0.1
# via furl
-orjson==3.10.12
+orjson==3.10.13
# via datumaro
packaging==24.2
# via
@@ -242,7 +242,7 @@ pyjwt[crypto]==2.10.1
# via django-allauth
pylogbeat==2.0.1
# via python-logstash-async
-pyparsing==3.2.0
+pyparsing==3.2.1
# via matplotlib
pyrsistent==0.20.0
# via jsonschema
@@ -308,7 +308,7 @@ rq-scheduler==0.13.1
# via -r cvat/requirements/base.in
rsa==4.9
# via google-auth
-ruamel-yaml==0.18.6
+ruamel-yaml==0.18.10
# via datumaro
ruamel-yaml-clib==0.2.12
# via ruamel-yaml
diff --git a/cvat/requirements/production.txt b/cvat/requirements/production.txt
index 155d626a6984..c65ede91ad59 100644
--- a/cvat/requirements/production.txt
+++ b/cvat/requirements/production.txt
@@ -6,7 +6,7 @@
# pip-compile-multi
#
-r base.txt
-anyio==4.7.0
+anyio==4.8.0
# via watchfiles
coverage==7.2.3
# via -r cvat/requirements/production.in
diff --git a/cvat/schema.yml b/cvat/schema.yml
index 8af068ecc8b2..4b6e9ddca0da 100644
--- a/cvat/schema.yml
+++ b/cvat/schema.yml
@@ -1,7 +1,7 @@
openapi: 3.0.3
info:
title: CVAT REST API
- version: 2.24.1
+ version: 2.25.1
description: REST API for Computer Vision Annotation Tool (CVAT)
termsOfService: https://www.google.com/policies/terms/
contact:
@@ -9775,12 +9775,12 @@ components:
compare_attributes:
type: boolean
description: Enables or disables annotation attribute comparison
- match_empty_frames:
+ empty_is_annotated:
type: boolean
default: false
description: |
- Count empty frames as matching. This affects target metrics like accuracy in cases
- there are no annotations. If disabled, frames without annotations
+ Consider empty frames annotated as "empty". This affects target metrics like
+ accuracy in cases there are no annotations. If disabled, frames without annotations
are counted as not matching (accuracy is 0). If enabled, accuracy will be 1 instead.
This will also add virtual annotations to empty frames in the comparison results.
PatchedTaskValidationLayoutWriteRequest:
@@ -10282,12 +10282,12 @@ components:
compare_attributes:
type: boolean
description: Enables or disables annotation attribute comparison
- match_empty_frames:
+ empty_is_annotated:
type: boolean
default: false
description: |
- Count empty frames as matching. This affects target metrics like accuracy in cases
- there are no annotations. If disabled, frames without annotations
+ Consider empty frames annotated as "empty". This affects target metrics like
+ accuracy in cases there are no annotations. If disabled, frames without annotations
are counted as not matching (accuracy is 0). If enabled, accuracy will be 1 instead.
This will also add virtual annotations to empty frames in the comparison results.
RegisterSerializerEx:
diff --git a/cvat/settings/base.py b/cvat/settings/base.py
index 0f6147dc4bf0..c73cb31eafa2 100644
--- a/cvat/settings/base.py
+++ b/cvat/settings/base.py
@@ -19,9 +19,9 @@
import os
import sys
import tempfile
+import urllib
from datetime import timedelta
from enum import Enum
-import urllib
from attr.converters import to_bool
from corsheaders.defaults import default_headers
@@ -74,7 +74,7 @@ def generate_secret_key():
try:
sys.path.append(BASE_DIR)
- from keys.secret_key import SECRET_KEY # pylint: disable=unused-import
+ from keys.secret_key import SECRET_KEY # pylint: disable=unused-import
except ModuleNotFoundError:
generate_secret_key()
from keys.secret_key import SECRET_KEY
@@ -740,6 +740,7 @@ class CVAT_QUEUES(Enum):
CVAT_CONCURRENT_CHUNK_PROCESSING = int(os.getenv('CVAT_CONCURRENT_CHUNK_PROCESSING', 1))
from cvat.rq_patching import update_started_job_registry_cleanup
+
update_started_job_registry_cleanup()
CLOUD_DATA_DOWNLOADING_MAX_THREADS_NUMBER = 4
diff --git a/cvat/settings/email_settings.py b/cvat/settings/email_settings.py
index ee6d08518399..c3ead7695aff 100644
--- a/cvat/settings/email_settings.py
+++ b/cvat/settings/email_settings.py
@@ -6,7 +6,6 @@
# Inherit parent config
from cvat.settings.production import * # pylint: disable=wildcard-import
-
# https://github.com/pennersr/django-allauth
ACCOUNT_AUTHENTICATION_METHOD = 'username_email'
ACCOUNT_CONFIRM_EMAIL_ON_GET = True
diff --git a/cvat/settings/testing.py b/cvat/settings/testing.py
index f7357b2005b9..0efcf95d11ae 100644
--- a/cvat/settings/testing.py
+++ b/cvat/settings/testing.py
@@ -2,10 +2,11 @@
#
# SPDX-License-Identifier: MIT
+import tempfile
+
# Inherit parent config
from .development import * # pylint: disable=wildcard-import
-import tempfile
DATABASES = {
'default': {
@@ -75,11 +76,13 @@
TEST_RUNNER = "cvat.settings.testing.PatchedDiscoverRunner"
from django.test.runner import DiscoverRunner
+
+
class PatchedDiscoverRunner(DiscoverRunner):
def __init__(self, *args, **kwargs):
# Used fakeredis for testing (don't affect production redis)
- from fakeredis import FakeRedis, FakeStrictRedis
import django_rq.queues
+ from fakeredis import FakeRedis, FakeStrictRedis
simple_redis = FakeRedis()
strict_redis = FakeStrictRedis()
django_rq.queues.get_redis_connection = lambda _, strict: strict_redis \
diff --git a/cvat/urls.py b/cvat/urls.py
index 08257a14b811..ca62b7cb03a3 100644
--- a/cvat/urls.py
+++ b/cvat/urls.py
@@ -20,7 +20,7 @@
from django.apps import apps
from django.contrib import admin
-from django.urls import path, include
+from django.urls import include, path
urlpatterns = [
path("admin/", admin.site.urls),
diff --git a/cvat/utils/http.py b/cvat/utils/http.py
index 2cb1b7498b32..ab8771aaa2ae 100644
--- a/cvat/utils/http.py
+++ b/cvat/utils/http.py
@@ -2,10 +2,9 @@
#
# SPDX-License-Identifier: MIT
-from django.conf import settings
-
import requests
import requests.utils
+from django.conf import settings
from cvat import __version__
diff --git a/dev/update_version.py b/dev/update_version.py
index fbe5da9971c0..7419a581ef4c 100755
--- a/dev/update_version.py
+++ b/dev/update_version.py
@@ -9,7 +9,6 @@
from re import Match, Pattern
from typing import Callable
-
SUCCESS_CHAR = "\u2714"
FAIL_CHAR = "\u2716"
diff --git a/pyproject.toml b/pyproject.toml
index 528bdc579fcc..b0c13a15766f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,17 +3,21 @@ profile = "black"
forced_separate = ["tests"]
line_length = 100
skip_gitignore = true # align tool behavior with Black
+extend_skip=[
+ # Correctly ordering the imports in serverless functions would
+ # require a pyproject.toml in every function; don't bother with it for now.
+ "serverless",
+ # Sorting the imports in this file causes test failures;
+ # TODO: fix them and remove this ignore.
+ "cvat/apps/dataset_manager/formats/registry.py",
+]
[tool.black]
line-length = 100
target-version = ['py39']
extend-exclude = """
# TODO: get rid of these
-^/cvat/apps/(
- dataset_manager|dataset_repo|engine|events
- |health|iam|lambda_manager|log_viewer
- |organizations|webhooks
-)/
+^/cvat/apps/(dataset_manager|engine)/
| ^/cvat/settings/
| ^/serverless/
| ^/utils/dataset_manifest/
diff --git a/rqscheduler.py b/rqscheduler.py
index 5ae76e64a7f0..b6cebe80f285 100644
--- a/rqscheduler.py
+++ b/rqscheduler.py
@@ -4,10 +4,10 @@
# implementation. This is required for correct work with CVAT queue settings and
# their access options such as login and password.
+from rq_scheduler.scripts import rqscheduler
+
# Required to initialize Django settings correctly
from cvat.asgi import application # pylint: disable=unused-import
-from rq_scheduler.scripts import rqscheduler
-
if __name__ == "__main__":
rqscheduler.main()
diff --git a/site/content/en/docs/manual/advanced/analytics-and-monitoring/auto-qa.md b/site/content/en/docs/manual/advanced/analytics-and-monitoring/auto-qa.md
index 21ebd2d99087..4a098c6545fa 100644
--- a/site/content/en/docs/manual/advanced/analytics-and-monitoring/auto-qa.md
+++ b/site/content/en/docs/manual/advanced/analytics-and-monitoring/auto-qa.md
@@ -385,7 +385,7 @@ Annotation quality settings have the following parameters:
| - | - |
| Min overlap threshold | Min overlap threshold used for the distinction between matched and unmatched shapes. Used to match all types of annotations. It corresponds to the Intersection over union (IoU) for spatial annotations, such as bounding boxes and masks. |
| Low overlap threshold | Low overlap threshold used for the distinction between strong and weak matches. Only affects _Low overlap_ warnings. It's supposed that _Min similarity threshold_ <= _Low overlap threshold_. |
-| Match empty frames | Consider frames matched if there are no annotations both on GT and regular job frames |
+| Empty frames are annotated | Consider frames annotated as "empty" if there are no annotations on a frame. If a frame is empty in both GT and job annotations, it will be considered a matching annotation. |
| _Point and Skeleton matching_ | |
| - | - |
diff --git a/tests/cypress/e2e/actions_projects_models/markdown_base_pipeline.js b/tests/cypress/e2e/actions_projects_models/markdown_base_pipeline.js
index 51143c318794..639e57ad09f1 100644
--- a/tests/cypress/e2e/actions_projects_models/markdown_base_pipeline.js
+++ b/tests/cypress/e2e/actions_projects_models/markdown_base_pipeline.js
@@ -14,14 +14,14 @@ context('Basic markdown pipeline', () => {
username: 'md_job_assignee',
firstName: 'Firstname',
lastName: 'Lastname',
- emailAddr: 'md_job_assignee@local.local',
+ email: 'md_job_assignee@local.local',
password: 'Fv5Df3#f55g',
},
taskAssignee: {
username: 'md_task_assignee',
firstName: 'Firstname',
lastName: 'Lastname',
- emailAddr: 'md_task_assignee@local.local',
+ email: 'md_task_assignee@local.local',
password: 'UfdU21!dds',
},
notAssignee: {
diff --git a/tests/cypress/e2e/actions_users/registration_involved/case_28_review_pipeline_feature.js b/tests/cypress/e2e/actions_users/registration_involved/case_28_review_pipeline_feature.js
index c77e00df3ad1..464464492832 100644
--- a/tests/cypress/e2e/actions_users/registration_involved/case_28_review_pipeline_feature.js
+++ b/tests/cypress/e2e/actions_users/registration_involved/case_28_review_pipeline_feature.js
@@ -12,14 +12,14 @@ context('Review pipeline feature', () => {
username: 'annotator',
firstName: 'Firstname',
lastName: 'Lastname',
- emailAddr: 'annotator@local.local',
+ email: 'annotator@local.local',
password: 'UfdU21!dds',
},
reviewer: {
username: 'reviewer',
firstName: 'Firstname',
lastName: 'Lastname',
- emailAddr: 'reviewer@local.local',
+ email: 'reviewer@local.local',
password: 'Fv5Df3#f55g',
},
};
diff --git a/tests/cypress/support/commands.js b/tests/cypress/support/commands.js
index 42b7d2772375..a027c260e7b0 100644
--- a/tests/cypress/support/commands.js
+++ b/tests/cypress/support/commands.js
@@ -360,8 +360,12 @@ Cypress.Commands.add('headlessCreateUser', (userSpec) => {
headers: {
'Content-type': 'application/json',
},
+ }).then((response) => {
+ expect(response.status).to.eq(201);
+ expect(response.body.username).to.eq(userSpec.username);
+ expect(response.body.email).to.eq(userSpec.email);
+ return cy.wrap();
});
- return cy.wrap();
});
Cypress.Commands.add('headlessLogout', () => {
diff --git a/tests/python/pyproject.toml b/tests/python/pyproject.toml
index ab4db6695977..6b5fba136a78 100644
--- a/tests/python/pyproject.toml
+++ b/tests/python/pyproject.toml
@@ -3,3 +3,4 @@ profile = "black"
forced_separate = ["tests"]
line_length = 100
skip_gitignore = true # align tool behavior with Black
+known_first_party = ["shared", "rest_api", "sdk", "cli"]
diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt
index d43d9b61d5df..dc21498a1ec0 100644
--- a/tests/python/requirements.txt
+++ b/tests/python/requirements.txt
@@ -8,7 +8,7 @@ deepdiff==7.0.1
boto3==1.17.61
Pillow==10.3.0
python-dateutil==2.8.2
-pyyaml==6.0.0
+pyyaml==6.0.2
numpy==2.0.0
# TODO: update pytest to 7.0.0 and pytest-timeout to 2.3.1 (better debug in vscode)
\ No newline at end of file
diff --git a/tests/python/rest_api/test_quality_control.py b/tests/python/rest_api/test_quality_control.py
index d03675c9156e..56dd24bb0abb 100644
--- a/tests/python/rest_api/test_quality_control.py
+++ b/tests/python/rest_api/test_quality_control.py
@@ -1213,7 +1213,7 @@ def test_modified_task_produces_different_metrics(
"compare_line_orientation",
"panoptic_comparison",
"point_size_base",
- "match_empty_frames",
+ "empty_is_annotated",
],
)
def test_settings_affect_metrics(
@@ -1246,8 +1246,11 @@ def test_settings_affect_metrics(
)
new_report = self.create_quality_report(admin_user, task_id)
- if parameter == "match_empty_frames":
+ if parameter == "empty_is_annotated":
assert new_report["summary"]["valid_count"] != old_report["summary"]["valid_count"]
+ assert new_report["summary"]["total_count"] != old_report["summary"]["total_count"]
+ assert new_report["summary"]["ds_count"] != old_report["summary"]["ds_count"]
+ assert new_report["summary"]["gt_count"] != old_report["summary"]["gt_count"]
else:
assert (
new_report["summary"]["conflict_count"] != old_report["summary"]["conflict_count"]
diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py
index 70d8a84827bb..a55bd1ded65b 100644
--- a/tests/python/rest_api/test_tasks.py
+++ b/tests/python/rest_api/test_tasks.py
@@ -45,7 +45,7 @@
from pytest_cases import fixture, fixture_ref, parametrize
import shared.utils.s3 as s3
-from shared.fixtures.init import docker_exec_cvat, kube_exec_cvat
+from shared.fixtures.init import container_exec_cvat
from shared.utils.config import (
delete_method,
get_method,
@@ -5315,12 +5315,9 @@ def test_check_import_cache_after_previous_interrupted_upload(self, tasks_with_s
number_of_files = 1
sleep(30) # wait when the cleaning job from rq worker will be started
command = ["/bin/bash", "-c", f"ls data/tasks/{task_id}/tmp | wc -l"]
- platform = request.config.getoption("--platform")
- assert platform in ("kube", "local")
- func = docker_exec_cvat if platform == "local" else kube_exec_cvat
for _ in range(12):
sleep(2)
- result, _ = func(command)
+ result, _ = container_exec_cvat(request, command)
number_of_files = int(result)
if not number_of_files:
break
diff --git a/tests/python/shared/assets/cvat_db/data.json b/tests/python/shared/assets/cvat_db/data.json
index 5b30d421cb5a..53863fa94fcc 100644
--- a/tests/python/shared/assets/cvat_db/data.json
+++ b/tests/python/shared/assets/cvat_db/data.json
@@ -18173,7 +18173,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18197,7 +18197,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18221,7 +18221,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18245,7 +18245,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18269,7 +18269,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18293,7 +18293,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18317,7 +18317,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18341,7 +18341,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18365,7 +18365,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18389,7 +18389,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18413,7 +18413,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18437,7 +18437,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18461,7 +18461,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18485,7 +18485,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18509,7 +18509,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18533,7 +18533,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18557,7 +18557,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18581,7 +18581,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18605,7 +18605,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18629,7 +18629,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18653,7 +18653,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18677,7 +18677,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18701,7 +18701,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
@@ -18725,7 +18725,7 @@
"object_visibility_threshold": 0.05,
"panoptic_comparison": true,
"compare_attributes": true,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"target_metric": "accuracy",
"target_metric_threshold": 0.7,
"max_validations_per_job": 0
diff --git a/tests/python/shared/assets/quality_settings.json b/tests/python/shared/assets/quality_settings.json
index 7ddc589bc7bf..dc56352fc1ef 100644
--- a/tests/python/shared/assets/quality_settings.json
+++ b/tests/python/shared/assets/quality_settings.json
@@ -14,7 +14,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -35,7 +35,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -56,7 +56,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -77,7 +77,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -98,7 +98,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -119,7 +119,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -140,7 +140,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -161,7 +161,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -182,7 +182,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -203,7 +203,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -224,7 +224,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -245,7 +245,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -266,7 +266,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -287,7 +287,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -308,7 +308,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -329,7 +329,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -350,7 +350,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -371,7 +371,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -392,7 +392,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -413,7 +413,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -434,7 +434,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -455,7 +455,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -476,7 +476,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
@@ -497,7 +497,7 @@
"line_orientation_threshold": 0.1,
"line_thickness": 0.01,
"low_overlap_threshold": 0.8,
- "match_empty_frames": false,
+ "empty_is_annotated": false,
"max_validations_per_job": 0,
"object_visibility_threshold": 0.05,
"oks_sigma": 0.09,
diff --git a/tests/python/shared/fixtures/init.py b/tests/python/shared/fixtures/init.py
index 1f5d57ffc5d7..b0d5f8a84db0 100644
--- a/tests/python/shared/fixtures/init.py
+++ b/tests/python/shared/fixtures/init.py
@@ -171,6 +171,16 @@ def kube_exec_cvat(command: Union[list[str], str]):
return _run(_command)
+def container_exec_cvat(request: pytest.FixtureRequest, command: Union[list[str], str]):
+ platform = request.config.getoption("--platform")
+ if platform == "local":
+ return docker_exec_cvat(command)
+ elif platform == "kube":
+ return kube_exec_cvat(command)
+ else:
+ assert False, "unknown platform"
+
+
def kube_exec_cvat_db(command):
pod_name = _kube_get_db_pod_name()
_run(["kubectl", "exec", pod_name, "--"] + command)
diff --git a/tests/yarn.lock b/tests/yarn.lock
index a5500151360a..4b82218c14a9 100644
--- a/tests/yarn.lock
+++ b/tests/yarn.lock
@@ -1146,9 +1146,9 @@ crc32-stream@^4.0.2:
readable-stream "^3.4.0"
cross-spawn@^7.0.0, cross-spawn@^7.0.3:
- version "7.0.3"
- resolved "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz"
- integrity sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==
+ version "7.0.6"
+ resolved "https://registry.yarnpkg.com/cross-spawn/-/cross-spawn-7.0.6.tgz#8a58fe78f00dcd70c370451759dfbfaf03e8ee9f"
+ integrity sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==
dependencies:
path-key "^3.1.0"
shebang-command "^2.0.0"
diff --git a/utils/dataset_manifest/__init__.py b/utils/dataset_manifest/__init__.py
index 74fd25ede729..7efcfcb48406 100644
--- a/utils/dataset_manifest/__init__.py
+++ b/utils/dataset_manifest/__init__.py
@@ -1,4 +1,4 @@
# Copyright (C) 2021-2022 Intel Corporation
#
# SPDX-License-Identifier: MIT
-from .core import VideoManifestManager, ImageManifestManager, is_manifest
+from .core import ImageManifestManager, VideoManifestManager, is_manifest
diff --git a/utils/dataset_manifest/core.py b/utils/dataset_manifest/core.py
index 449e70d64098..a855d170e86b 100644
--- a/utils/dataset_manifest/core.py
+++ b/utils/dataset_manifest/core.py
@@ -3,25 +3,24 @@
#
# SPDX-License-Identifier: MIT
-from enum import Enum
-from io import StringIO
-import av
import json
import os
-
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import closing
+from enum import Enum
+from inspect import isgenerator
+from io import StringIO
from itertools import islice
-from PIL import Image
from json.decoder import JSONDecodeError
-from inspect import isgenerator
+from typing import Any, Callable, Optional, Union
+
+import av
+from PIL import Image
from .errors import InvalidManifestError, InvalidVideoError
-from .utils import SortingMethod, md5_hash, rotate_image, sort
from .types import NamedBytesIO
-
-from typing import Any, Union, Optional, Callable
+from .utils import SortingMethod, md5_hash, rotate_image, sort
class VideoStreamReader:
diff --git a/utils/dataset_manifest/create.py b/utils/dataset_manifest/create.py
index 64efaed60f2d..fa31300e058a 100755
--- a/utils/dataset_manifest/create.py
+++ b/utils/dataset_manifest/create.py
@@ -7,13 +7,14 @@
import argparse
import os
-import sys
import re
+import sys
from glob import glob
from tqdm import tqdm
-from utils import detect_related_images, is_image, is_video, SortingMethod
+from utils import SortingMethod, detect_related_images, is_image, is_video
+
def get_args():
parser = argparse.ArgumentParser()
@@ -98,5 +99,5 @@ def main():
if __name__ == "__main__":
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(base_dir)
- from dataset_manifest.core import VideoManifestManager, ImageManifestManager
+ from dataset_manifest.core import ImageManifestManager, VideoManifestManager
main()
diff --git a/utils/dataset_manifest/requirements.txt b/utils/dataset_manifest/requirements.txt
index 6d3ed66aecb1..c073606622ed 100644
--- a/utils/dataset_manifest/requirements.txt
+++ b/utils/dataset_manifest/requirements.txt
@@ -13,7 +13,7 @@ numpy==1.22.4
# via opencv-python-headless
opencv-python-headless==4.10.0.84
# via -r utils/dataset_manifest/requirements.in
-pillow==11.0.0
+pillow==11.1.0
# via -r utils/dataset_manifest/requirements.in
tqdm==4.67.1
# via -r utils/dataset_manifest/requirements.in
diff --git a/utils/dataset_manifest/types.py b/utils/dataset_manifest/types.py
index 8847eee457ba..5ddcce9ad5c9 100644
--- a/utils/dataset_manifest/types.py
+++ b/utils/dataset_manifest/types.py
@@ -5,6 +5,7 @@
from io import BytesIO
from typing import Protocol
+
class Named(Protocol):
filename: str
diff --git a/utils/dataset_manifest/utils.py b/utils/dataset_manifest/utils.py
index b4eee9686b71..9cb89ce5cd4d 100644
--- a/utils/dataset_manifest/utils.py
+++ b/utils/dataset_manifest/utils.py
@@ -2,15 +2,17 @@
#
# SPDX-License-Identifier: MIT
-import os
-import re
import hashlib
import mimetypes
+import os
+import re
+from enum import Enum
+from random import shuffle
+
import cv2 as cv
from av import VideoFrame
-from enum import Enum
from natsort import os_sorted
-from random import shuffle
+
def rotate_image(image, angle):
height, width = image.shape[:2]
diff --git a/utils/dicom_converter/script.py b/utils/dicom_converter/script.py
index 3fe7ef0be6dd..a201845965f3 100644
--- a/utils/dicom_converter/script.py
+++ b/utils/dicom_converter/script.py
@@ -3,17 +3,16 @@
# SPDX-License-Identifier: MIT
-import os
import argparse
import logging
+import os
from glob import glob
import numpy as np
-from tqdm import tqdm
from PIL import Image
from pydicom import dcmread
from pydicom.pixel_data_handlers.util import convert_color_space
-
+from tqdm import tqdm
# Script configuration
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")