diff --git a/.github/workflows/full.yml b/.github/workflows/full.yml index e587e26aa1b8..e42380de5ead 100644 --- a/.github/workflows/full.yml +++ b/.github/workflows/full.yml @@ -156,7 +156,7 @@ jobs: - name: Install SDK run: | pip3 install -r ./tests/python/requirements.txt \ - -e './cvat-sdk[pytorch]' -e ./cvat-cli \ + -e './cvat-sdk[masks,pytorch]' -e ./cvat-cli \ --extra-index-url https://download.pytorch.org/whl/cpu - name: Running REST API and SDK tests diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index f4e3f11d1052..becca0218f94 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -166,7 +166,7 @@ jobs: - name: Install SDK run: | pip3 install -r ./tests/python/requirements.txt \ - -e './cvat-sdk[pytorch]' -e ./cvat-cli \ + -e './cvat-sdk[masks,pytorch]' -e ./cvat-cli \ --extra-index-url https://download.pytorch.org/whl/cpu - name: Run REST API and SDK tests diff --git a/.vscode/launch.json b/.vscode/launch.json index af93ae24c007..78f24c96ca83 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,7 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { "name": "REST API tests: Attach to server", "type": "debugpy", @@ -168,7 +169,7 @@ "CVAT_SERVERLESS": "1", "ALLOWED_HOSTS": "*", "DJANGO_LOG_SERVER_HOST": "localhost", - "DJANGO_LOG_SERVER_PORT": "8282" + "DJANGO_LOG_SERVER_PORT": "8282", }, "args": [ "runserver", @@ -178,7 +179,7 @@ ], "django": true, "cwd": "${workspaceFolder}", - "console": "internalConsole" + "console": "internalConsole", }, { "name": "server: chrome", @@ -360,6 +361,28 @@ }, "console": "internalConsole" }, + { + "name": "server: RQ - chunks", + "type": "debugpy", + "request": "launch", + "stopOnEntry": false, + "justMyCode": false, + "python": "${command:python.interpreterPath}", + "program": "${workspaceFolder}/manage.py", + "args": [ + "rqworker", + "chunks", + "--worker-class", + "cvat.rqworker.SimpleWorker" + ], + "django": true, + "cwd": "${workspaceFolder}", + "env": { + "DJANGO_LOG_SERVER_HOST": "localhost", + "DJANGO_LOG_SERVER_PORT": "8282" + }, + "console": "internalConsole" + }, { "name": "server: migrate", "type": "debugpy", @@ -553,7 +576,8 @@ "server: RQ - scheduler", "server: RQ - quality reports", "server: RQ - analytics reports", - "server: RQ - cleaning" + "server: RQ - cleaning", + "server: RQ - chunks", ] } ] diff --git a/.vscode/settings.json b/.vscode/settings.json index a0caaf036765..baf7dc5b3879 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -29,6 +29,15 @@ "database": "${workspaceFolder:cvat}/db.sqlite3" } ], + "python.analysis.exclude": [ + // VS Code defaults + "**/node_modules", + "**/__pycache__", + ".git", + + "cvat-cli/build", + "cvat-sdk/build", + ], "python.defaultInterpreterPath": "${workspaceFolder}/.env/", "python.testing.pytestArgs": [ "--rootdir","${workspaceFolder}/tests/" diff --git a/CHANGELOG.md b/CHANGELOG.md index dd8854040e02..a9143f436f05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,112 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 + +## \[2.23.0\] - 2024-11-29 + +### Added + +- Support for direct .json file import in Datumaro format + () + +- \[SDK, CLI\] Added a `conf_threshold` parameter to + `cvat_sdk.auto_annotation.annotate_task`, which is passed as-is to the AA + function object via the context. The CLI equivalent is `auto-annotate + --conf-threshold`. This makes it easier to write and use AA functions that + support object filtering based on confidence levels + () + +- \[SDK\] Built-in auto-annotation functions now support object filtering by + confidence level + () + +- New events (create|update|delete):(membership|webhook) and (create|delete):invitation + () + +- \[SDK\] Added new auto-annotation helpers (`mask`, `polygon`, `encode_mask`) + to support AA functions that return masks or polygons + () + +- \[SDK\] Added a new built-in auto-annotation function, + `torchvision_instance_segmentation` + () + +- \[SDK, CLI\] Added a new auto-annotation parameter, `conv_mask_to_poly` + (`--conv-mask-to-poly` in the CLI) + () + +- A user may undo or redo changes, made by an annotations actions using general approach (e.g. Ctrl+Z, Ctrl+Y) + () + +- Basically, annotations actions now support any kinds of objects (shapes, tracks, tags) + () + +- A user may run annotations actions on a certain object (added corresponding object menu item) + () + +- A shortcut to open annotations actions modal for a currently selected object + () + +- A default role if IAM_TYPE='LDAP' and if the user is not a member of any group in 'DJANGO_AUTH_LDAP_GROUPS' () + +- The `POST /api/lambda/requests` endpoint now has a `conv_mask_to_poly` + parameter with the same semantics as the old `convMaskToPoly` parameter + () + +- \[SDK\] Model instances can now be pickled + () + +### Changed + +- Chunks are now prepared in a separate worker process + () + +- \[Helm\] Traefik sticky sessions for the backend service are disabled + () + +- Payload for events (create|update|delete):(shapes|tags|tracks) does not include frame and attributes anymore + () + +### Deprecated + +- The `convMaskToPoly` parameter of the `POST /api/lambda/requests` endpoint + is deprecated; use `conv_mask_to_poly` instead + () + +### Removed + +- It it no longer possible to run lambda functions on compressed images; + original images will always be used + () + +### Fixed + +- Export without images in Datumaro format should include image info + () + +- Inconsistent zOrder behavior on job open + () + +- Ground truth annotations can be shown in standard mode + () + +- Keybinds in UI allow drawing disabled shape types + () + +- Style issues on the Quality page when browser zoom is applied + () +- Flickering of masks in review mode, even when no conflicts are highlighted + () + +- Fixed security header duplication in HTTP responses from the backend + () + +- The error occurs when trying to copy/paste a mask on a video after opening the job + () + +- Attributes do not get copied when copy/paste a mask + () + ## \[2.22.0\] - 2024-11-11 @@ -87,7 +193,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Tags in ground truth job couldn't be deleted via `x` button - () + () - Exception 'Canvas is busy' when change frame during drag/resize a track () @@ -377,7 +483,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecated - Client events `upload:annotations`, `lock:object`, `change:attribute`, `change:label` - () + () ### Removed @@ -404,7 +510,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - Sometimes it is not possible to switch workspace because active control broken after -trying to create a tag with a shortcut () + trying to create a tag with a shortcut + () ## \[2.16.3\] - 2024-08-13 @@ -445,13 +552,14 @@ trying to create a tag with a shortcut () + **Asset is already related to another guide** + () - Undo can't be done when a shape is rotated () - Exporting a skeleton track in a format defined for shapes raises error -`operands could not be broadcast together with shapes (X, ) (Y, )` + `operands could not be broadcast together with shapes (X, ) (Y, )` () - Delete label modal window does not have cancellation button @@ -470,10 +578,11 @@ trying to create a tag with a shortcut () - API call to run automatic annotations fails on a model with attributes - when mapping not provided in the request () + when mapping not provided in the request + () - Fixed a label collision issue where labels with similar prefixes -and numeric suffixes could conflict, causing error on export. + and numeric suffixes could conflict, causing error on export. () @@ -510,9 +619,9 @@ and numeric suffixes could conflict, causing error on export. ### Added - Set of features to track background activities: importing/exporting datasets, annotations or backups, creating tasks. -Now you may find these processes on Requests page, it allows a user to understand current status of these activities -and enhances user experience, not losing progress when the browser tab is closed -() + Now you may find these processes on Requests page, it allows a user to understand current status of these activities + and enhances user experience, not losing progress when the browser tab is closed + () - User now may update a job state from the corresponding task page () @@ -523,7 +632,8 @@ and enhances user experience, not losing progress when the browser tab is closed ### Changed - "Finish the job" button on annotation view now only sets state to 'completed'. - The job stage keeps unchanged () + The job stage keeps unchanged + () - Log files for individual backend processes are now stored in ephemeral storage of each backend container rather than in the `cvat_logs` volume @@ -535,7 +645,7 @@ and enhances user experience, not losing progress when the browser tab is closed ### Removed - Renew the job button in annotation menu was removed - () + () ### Fixed @@ -583,10 +693,12 @@ and enhances user experience, not losing progress when the browser tab is closed () - Exception 'this.el.node.getScreenCTM() is null' occuring in Firefox when -a user resizes window during skeleton dragging/resizing () + a user resizes window during skeleton dragging/resizing + () - Exception 'Edge's nodeFrom M or nodeTo N do not to refer to any node' -occuring when a user resizes window during skeleton dragging/resizing () + occuring when a user resizes window during skeleton dragging/resizing + () - Slightly broken layout when running attributed face detection model () @@ -644,7 +756,8 @@ occuring when a user resizes window during skeleton dragging/resizing () - When use route `/auth/login-with-token/` without `next` query parameter -the page reloads infinitely () + the page reloads infinitely + () - Fixed kvrocks port naming for istio () @@ -815,7 +928,7 @@ the page reloads infinitely () - Opening update CS page sends infinite requests when CS id does not exist () -Uploading files with TUS immediately failed when one of the requests failed +- Uploading files with TUS immediately failed when one of the requests failed () - Longer analytics report calculation because of inefficient requests to analytics db @@ -985,7 +1098,7 @@ Uploading files with TUS immediately failed when one of the requests failed () - 90 deg-rotated video was added with "Prefer Zip Chunks" disabled -was warped, fixed using the static cropImage function. + was warped, fixed using the static cropImage function. () @@ -1023,7 +1136,7 @@ was warped, fixed using the static cropImage function. ### Added - Single shape annotation mode allowing to easily annotate scenarious where a user -only needs to draw one object on one image () + only needs to draw one object on one image () ### Fixed @@ -1151,7 +1264,7 @@ only needs to draw one object on one image () - \[Compose, Helm\] Updated Clickhouse to version 23.11.* @@ -1200,11 +1313,11 @@ longer accepted automatically. Instead, the invitee can now review the invitatio () - Error message `Edge's nodeFrom ${dataNodeFrom} or nodeTo ${dataNodeTo} do not to refer to any node` - when upload a file with some abscent skeleton nodes () + when upload a file with some abscent skeleton nodes () - Wrong context menu position in skeleton configurator (Firefox only) - () + () - Fixed console error `(Error: attribute width: A negative value is not valid` - appearing when skeleton with all outside elements is created () + appearing when skeleton with all outside elements is created () - Updating cloud storage attached to CVAT using Azure connection string () @@ -1215,7 +1328,7 @@ longer accepted automatically. Instead, the invitee can now review the invitatio ### Added - Introduced CVAT actions. Actions allow performing different - predefined scenarios on annotations automatically (e.g. shape converters) + predefined scenarios on annotations automatically (e.g. shape converters) () - The UI will now retry requests that were rejected due to rate limiting diff --git a/Dockerfile.ui b/Dockerfile.ui index 170ee1a76633..da9c36d38960 100644 --- a/Dockerfile.ui +++ b/Dockerfile.ui @@ -1,11 +1,5 @@ FROM node:lts-slim AS cvat-ui -ARG WA_PAGE_VIEW_HIT -ARG UI_APP_CONFIG -ARG CLIENT_PLUGINS -ARG DISABLE_SOURCE_MAPS -ARG SOURCE_MAPS_TOKEN - ENV TERM=xterm \ LANG='C.UTF-8' \ LC_ALL='C.UTF-8' @@ -29,6 +23,13 @@ COPY cvat-core/ /tmp/cvat-core/ COPY cvat-canvas3d/ /tmp/cvat-canvas3d/ COPY cvat-canvas/ /tmp/cvat-canvas/ COPY cvat-ui/ /tmp/cvat-ui/ + +ARG WA_PAGE_VIEW_HIT +ARG UI_APP_CONFIG +ARG CLIENT_PLUGINS +ARG DISABLE_SOURCE_MAPS +ARG SOURCE_MAPS_TOKEN + RUN CLIENT_PLUGINS="${CLIENT_PLUGINS}" \ DISABLE_SOURCE_MAPS="${DISABLE_SOURCE_MAPS}" \ UI_APP_CONFIG="${UI_APP_CONFIG}" \ diff --git a/cvat-canvas/src/typescript/canvasView.ts b/cvat-canvas/src/typescript/canvasView.ts index 4c346b4d6735..d1bab2369521 100644 --- a/cvat-canvas/src/typescript/canvasView.ts +++ b/cvat-canvas/src/typescript/canvasView.ts @@ -2877,6 +2877,9 @@ export class CanvasViewImpl implements CanvasView, Listener { const shapeView = window.document.getElementById(`cvat_canvas_shape_${clientID}`); if (shapeView) shapeView.classList.remove(this.getHighlightClassname()); }); + const redrawMasks = (highlightedElements.elementsIDs.length !== 0 || + this.highlightedElements.elementsIDs.length !== 0); + if (highlightedElements.elementsIDs.length) { this.highlightedElements = { ...highlightedElements }; this.canvas.classList.add('cvat-canvas-highlight-enabled'); @@ -2891,9 +2894,11 @@ export class CanvasViewImpl implements CanvasView, Listener { }; this.canvas.classList.remove('cvat-canvas-highlight-enabled'); } - const masks = Object.values(this.drawnStates).filter((state) => state.shapeType === 'mask'); - this.deleteObjects(masks); - this.addObjects(masks); + if (redrawMasks) { + const masks = Object.values(this.drawnStates).filter((state) => state.shapeType === 'mask'); + this.deleteObjects(masks); + this.addObjects(masks); + } if (this.highlightedElements.elementsIDs.length) { this.deactivate(); const clientID = this.highlightedElements.elementsIDs[0]; diff --git a/cvat-canvas/src/typescript/masksHandler.ts b/cvat-canvas/src/typescript/masksHandler.ts index ca6e5e469a63..7f6a4e313fb3 100644 --- a/cvat-canvas/src/typescript/masksHandler.ts +++ b/cvat-canvas/src/typescript/masksHandler.ts @@ -404,6 +404,10 @@ export class MasksHandlerImpl implements MasksHandler { rle.push(wrappingBbox.left, wrappingBbox.top, wrappingBbox.right, wrappingBbox.bottom); this.onDrawDone({ + occluded: this.drawData.initialState.occluded, + attributes: { ...this.drawData.initialState.attributes }, + color: this.drawData.initialState.color, + objectType: this.drawData.initialState.objectType, shapeType: this.drawData.shapeType, points: rle, label: this.drawData.initialState.label, diff --git a/cvat-cli/requirements/base.txt b/cvat-cli/requirements/base.txt index e9be53974d91..5f27832efdb7 100644 --- a/cvat-cli/requirements/base.txt +++ b/cvat-cli/requirements/base.txt @@ -1,3 +1,3 @@ -cvat-sdk~=2.22.0 +cvat-sdk~=2.23.0 Pillow>=10.3.0 setuptools>=70.0.0 # not directly required, pinned by Snyk to avoid a vulnerability diff --git a/cvat-cli/src/cvat_cli/_internal/commands.py b/cvat-cli/src/cvat_cli/_internal/commands.py index e86ef3b6350f..324d427a64b8 100644 --- a/cvat-cli/src/cvat_cli/_internal/commands.py +++ b/cvat-cli/src/cvat_cli/_internal/commands.py @@ -20,7 +20,13 @@ from cvat_sdk.core.proxies.tasks import ResourceType from .command_base import CommandGroup -from .parsers import BuildDictAction, parse_function_parameter, parse_label_arg, parse_resource_type +from .parsers import ( + BuildDictAction, + parse_function_parameter, + parse_label_arg, + parse_resource_type, + parse_threshold, +) COMMANDS = CommandGroup(description="Perform common operations related to CVAT tasks.") @@ -463,6 +469,19 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None: help="Allow the function to declare labels not configured in the task", ) + parser.add_argument( + "--conf-threshold", + type=parse_threshold, + help="Confidence threshold for filtering detections", + default=None, + ) + + parser.add_argument( + "--conv-mask-to-poly", + action="store_true", + help="Convert mask shapes to polygon shapes", + ) + def execute( self, client: Client, @@ -473,6 +492,8 @@ def execute( function_parameters: dict[str, Any], clear_existing: bool = False, allow_unmatched_labels: bool = False, + conf_threshold: Optional[float], + conv_mask_to_poly: bool, ) -> None: if function_module is not None: function = importlib.import_module(function_module) @@ -497,4 +518,6 @@ def execute( pbar=DeferredTqdmProgressReporter(), clear_existing=clear_existing, allow_unmatched_labels=allow_unmatched_labels, + conf_threshold=conf_threshold, + conv_mask_to_poly=conv_mask_to_poly, ) diff --git a/cvat-cli/src/cvat_cli/_internal/parsers.py b/cvat-cli/src/cvat_cli/_internal/parsers.py index a66710a09f47..97dcb5b2668a 100644 --- a/cvat-cli/src/cvat_cli/_internal/parsers.py +++ b/cvat-cli/src/cvat_cli/_internal/parsers.py @@ -53,6 +53,17 @@ def parse_function_parameter(s: str) -> tuple[str, Any]: return (key, value) +def parse_threshold(s: str) -> float: + try: + value = float(s) + except ValueError as e: + raise argparse.ArgumentTypeError("must be a number") from e + + if not 0 <= value <= 1: + raise argparse.ArgumentTypeError("must be between 0 and 1") + return value + + class BuildDictAction(argparse.Action): def __init__(self, option_strings, dest, default=None, **kwargs): super().__init__(option_strings, dest, default=default or {}, **kwargs) diff --git a/cvat-cli/src/cvat_cli/version.py b/cvat-cli/src/cvat_cli/version.py index b2829a54b105..9b4fa879ca12 100644 --- a/cvat-cli/src/cvat_cli/version.py +++ b/cvat-cli/src/cvat_cli/version.py @@ -1 +1 @@ -VERSION = "2.22.0" +VERSION = "2.23.0" diff --git a/cvat-core/package.json b/cvat-core/package.json index a769b74bf78c..6b9039673812 100644 --- a/cvat-core/package.json +++ b/cvat-core/package.json @@ -1,6 +1,6 @@ { "name": "cvat-core", - "version": "15.2.1", + "version": "15.3.0", "type": "module", "description": "Part of Computer Vision Tool which presents an interface for client-side integration", "main": "src/api.ts", diff --git a/cvat-core/src/annotations-actions.ts b/cvat-core/src/annotations-actions.ts deleted file mode 100644 index 43d3ef29a910..000000000000 --- a/cvat-core/src/annotations-actions.ts +++ /dev/null @@ -1,320 +0,0 @@ -// Copyright (C) 2023-2024 CVAT.ai Corporation -// -// SPDX-License-Identifier: MIT - -import { omit, range, throttle } from 'lodash'; -import { ArgumentError } from './exceptions'; -import { SerializedCollection, SerializedShape } from './server-response-types'; -import { Job, Task } from './session'; -import { EventScope, ObjectType } from './enums'; -import ObjectState from './object-state'; -import { getAnnotations, getCollection } from './annotations'; -import { propagateShapes } from './object-utils'; - -export interface SingleFrameActionInput { - collection: Omit; - frameData: { - width: number; - height: number; - number: number; - }; -} - -export interface SingleFrameActionOutput { - collection: Omit; -} - -export enum ActionParameterType { - SELECT = 'select', - NUMBER = 'number', -} - -// For SELECT values should be a list of possible options -// For NUMBER values should be a list with [min, max, step], -// or a callback ({ instance }: { instance: Job | Task }) => [min, max, step] -type ActionParameters = Record string[]); - defaultValue: string | (({ instance }: { instance: Job | Task }) => string); -}>; - -export enum FrameSelectionType { - SEGMENT = 'segment', - CURRENT_FRAME = 'current_frame', -} - -export default class BaseSingleFrameAction { - /* eslint-disable @typescript-eslint/no-unused-vars */ - public async init( - sessionInstance: Job | Task, - parameters: Record, - ): Promise { - throw new Error('Method not implemented'); - } - - public async destroy(): Promise { - throw new Error('Method not implemented'); - } - - public async run(sessionInstance: Job | Task, input: SingleFrameActionInput): Promise { - throw new Error('Method not implemented'); - } - - public get name(): string { - throw new Error('Method not implemented'); - } - - public get parameters(): ActionParameters | null { - throw new Error('Method not implemented'); - } - - public get frameSelection(): FrameSelectionType { - return FrameSelectionType.SEGMENT; - } -} - -class RemoveFilteredShapes extends BaseSingleFrameAction { - public async init(): Promise { - // nothing to init - } - - public async destroy(): Promise { - // nothing to destroy - } - - public async run(): Promise { - return { collection: { shapes: [] } }; - } - - public get name(): string { - return 'Remove filtered shapes'; - } - - public get parameters(): ActionParameters | null { - return null; - } -} - -class PropagateShapes extends BaseSingleFrameAction { - #targetFrame: number; - - public async init(instance, parameters): Promise { - this.#targetFrame = parameters['Target frame']; - } - - public async destroy(): Promise { - // nothing to destroy - } - - public async run( - instance: Job | Task, - { collection: { shapes }, frameData: { number } }, - ): Promise { - if (number === this.#targetFrame) { - return { collection: { shapes } }; - } - - const frameNumbers = instance instanceof Job ? await instance.frames.frameNumbers() : range(0, instance.size); - const propagatedShapes = propagateShapes(shapes, number, this.#targetFrame, frameNumbers); - return { collection: { shapes: [...shapes, ...propagatedShapes] } }; - } - - public get name(): string { - return 'Propagate shapes'; - } - - public get parameters(): ActionParameters | null { - return { - 'Target frame': { - type: ActionParameterType.NUMBER, - values: ({ instance }) => { - if (instance instanceof Job) { - return [instance.startFrame, instance.stopFrame, 1].map((val) => val.toString()); - } - return [0, instance.size - 1, 1].map((val) => val.toString()); - }, - defaultValue: ({ instance }) => { - if (instance instanceof Job) { - return instance.stopFrame.toString(); - } - return (instance.size - 1).toString(); - }, - }, - }; - } - - public get frameSelection(): FrameSelectionType { - return FrameSelectionType.CURRENT_FRAME; - } -} - -const registeredActions: BaseSingleFrameAction[] = []; - -export async function listActions(): Promise { - return [...registeredActions]; -} - -export async function registerAction(action: BaseSingleFrameAction): Promise { - if (!(action instanceof BaseSingleFrameAction)) { - throw new ArgumentError('Provided action is not instance of BaseSingleFrameAction'); - } - - const { name } = action; - if (registeredActions.map((_action) => _action.name).includes(name)) { - throw new ArgumentError(`Action name must be unique. Name "${name}" is already exists`); - } - - registeredActions.push(action); -} - -registerAction(new RemoveFilteredShapes()); -registerAction(new PropagateShapes()); - -async function runSingleFrameChain( - instance: Job | Task, - actionsChain: BaseSingleFrameAction[], - actionParameters: Record[], - frameFrom: number, - frameTo: number, - filters: string[], - onProgress: (message: string, progress: number) => void, - cancelled: () => boolean, -): Promise { - type IDsToHandle = { shapes: number[] }; - const event = await instance.logger.log(EventScope.annotationsAction, { - from: frameFrom, - to: frameTo, - chain: actionsChain.map((action) => action.name).join(' => '), - }, true); - - // if called too fast, it will freeze UI, so, add throttling here - const wrappedOnProgress = throttle(onProgress, 100, { leading: true, trailing: true }); - const showMessageWithPause = async (message: string, progress: number, duration: number): Promise => { - // wrapper that gives a chance to abort action - wrappedOnProgress(message, progress); - await new Promise((resolve) => setTimeout(resolve, duration)); - }; - - try { - await showMessageWithPause('Actions initialization', 0, 500); - if (cancelled()) { - return; - } - - await Promise.all(actionsChain.map((action, idx) => { - const declaredParameters = action.parameters; - if (!declaredParameters) { - return action.init(instance, {}); - } - - const setupValues = actionParameters[idx]; - const parameters = Object.entries(declaredParameters).reduce((acc, [name, { type, defaultValue }]) => { - if (type === ActionParameterType.NUMBER) { - acc[name] = +(Object.hasOwn(setupValues, name) ? setupValues[name] : defaultValue); - } else { - acc[name] = (Object.hasOwn(setupValues, name) ? setupValues[name] : defaultValue); - } - return acc; - }, {} as Record); - - return action.init(instance, parameters); - })); - - const exportedCollection = getCollection(instance).export(); - const handledCollection: SingleFrameActionInput['collection'] = { shapes: [] }; - const modifiedCollectionIDs: IDsToHandle = { shapes: [] }; - - // Iterate over frames - const totalFrames = frameTo - frameFrom + 1; - for (let frame = frameFrom; frame <= frameTo; frame++) { - const frameData = await Object.getPrototypeOf(instance).frames - .get.implementation.call(instance, frame); - - // Ignore deleted frames - if (!frameData.deleted) { - // Get annotations according to filter - const states: ObjectState[] = await getAnnotations(instance, frame, false, filters); - const frameCollectionIDs = states.reduce((acc, val) => { - if (val.objectType === ObjectType.SHAPE) { - acc.shapes.push(val.clientID as number); - } - return acc; - }, { shapes: [] }); - - // Pick frame collection according to filtered IDs - let frameCollection = { - shapes: exportedCollection.shapes.filter((shape) => frameCollectionIDs - .shapes.includes(shape.clientID as number)), - }; - - // Iterate over actions on each not deleted frame - for await (const action of actionsChain) { - ({ collection: frameCollection } = await action.run(instance, { - collection: frameCollection, - frameData: { - width: frameData.width, - height: frameData.height, - number: frameData.number, - }, - })); - } - - const progress = Math.ceil(+(((frame - frameFrom) / totalFrames) * 100)); - wrappedOnProgress('Actions are running', progress); - if (cancelled()) { - return; - } - - handledCollection.shapes.push(...frameCollection.shapes.map((shape) => omit(shape, 'id'))); - modifiedCollectionIDs.shapes.push(...frameCollectionIDs.shapes); - } - } - - await showMessageWithPause('Commiting handled objects', 100, 1500); - if (cancelled()) { - return; - } - - exportedCollection.shapes.forEach((shape) => { - if (Number.isInteger(shape.clientID) && !modifiedCollectionIDs.shapes.includes(shape.clientID as number)) { - handledCollection.shapes.push(shape); - } - }); - - await instance.annotations.clear(); - await instance.actions.clear(); - await instance.annotations.import({ - ...handledCollection, - tracks: exportedCollection.tracks, - tags: exportedCollection.tags, - }); - - event.close(); - } finally { - wrappedOnProgress('Finalizing', 100); - await Promise.all(actionsChain.map((action) => action.destroy())); - } -} - -export async function runActions( - instance: Job | Task, - actionsChain: BaseSingleFrameAction[], - actionParameters: Record[], - frameFrom: number, - frameTo: number, - filters: string[], - onProgress: (message: string, progress: number) => void, - cancelled: () => boolean, -): Promise { - // there will be another function for MultiFrameChains (actions handling tracks) - return runSingleFrameChain( - instance, - actionsChain, - actionParameters, - frameFrom, - frameTo, - filters, - onProgress, - cancelled, - ); -} diff --git a/cvat-core/src/annotations-actions/annotations-actions.ts b/cvat-core/src/annotations-actions/annotations-actions.ts new file mode 100644 index 000000000000..172b8cd88e3d --- /dev/null +++ b/cvat-core/src/annotations-actions/annotations-actions.ts @@ -0,0 +1,113 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import ObjectState from '../object-state'; +import { ArgumentError } from '../exceptions'; +import { Job, Task } from '../session'; +import { BaseAction } from './base-action'; +import { + BaseShapesAction, run as runShapesAction, call as callShapesAction, +} from './base-shapes-action'; +import { + BaseCollectionAction, run as runCollectionAction, call as callCollectionAction, +} from './base-collection-action'; + +import { RemoveFilteredShapes } from './remove-filtered-shapes'; +import { PropagateShapes } from './propagate-shapes'; + +const registeredActions: BaseAction[] = []; + +export async function listActions(): Promise { + return [...registeredActions]; +} + +export async function registerAction(action: BaseAction): Promise { + if (!(action instanceof BaseAction)) { + throw new ArgumentError('Provided action must inherit one of base classes'); + } + + const { name } = action; + if (registeredActions.map((_action) => _action.name).includes(name)) { + throw new ArgumentError(`Action name must be unique. Name "${name}" is already exists`); + } + + registeredActions.push(action); +} + +registerAction(new RemoveFilteredShapes()); +registerAction(new PropagateShapes()); + +export async function runAction( + instance: Job | Task, + action: BaseAction, + actionParameters: Record, + frameFrom: number, + frameTo: number, + filters: object[], + onProgress: (message: string, progress: number) => void, + cancelled: () => boolean, +): Promise { + if (action instanceof BaseShapesAction) { + return runShapesAction( + instance, + action, + actionParameters, + frameFrom, + frameTo, + filters, + onProgress, + cancelled, + ); + } + + if (action instanceof BaseCollectionAction) { + return runCollectionAction( + instance, + action, + actionParameters, + frameFrom, + filters, + onProgress, + cancelled, + ); + } + + return Promise.resolve(); +} + +export async function callAction( + instance: Job | Task, + action: BaseAction, + actionParameters: Record, + frame: number, + states: ObjectState[], + onProgress: (message: string, progress: number) => void, + cancelled: () => boolean, +): Promise { + if (action instanceof BaseShapesAction) { + return callShapesAction( + instance, + action, + actionParameters, + frame, + states, + onProgress, + cancelled, + ); + } + + if (action instanceof BaseCollectionAction) { + return callCollectionAction( + instance, + action, + actionParameters, + frame, + states, + onProgress, + cancelled, + ); + } + + return Promise.resolve(); +} diff --git a/cvat-core/src/annotations-actions/base-action.ts b/cvat-core/src/annotations-actions/base-action.ts new file mode 100644 index 000000000000..3246261d2c9a --- /dev/null +++ b/cvat-core/src/annotations-actions/base-action.ts @@ -0,0 +1,60 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { SerializedCollection } from 'server-response-types'; +import ObjectState from '../object-state'; +import { Job, Task } from '../session'; + +export enum ActionParameterType { + SELECT = 'select', + NUMBER = 'number', +} + +// For SELECT values should be a list of possible options +// For NUMBER values should be a list with [min, max, step], +// or a callback ({ instance }: { instance: Job | Task }) => [min, max, step] +export type ActionParameters = Record string[]); + defaultValue: string | (({ instance }: { instance: Job | Task }) => string); +}>; + +export abstract class BaseAction { + public abstract init(sessionInstance: Job | Task, parameters: Record): Promise; + public abstract destroy(): Promise; + public abstract run(input: unknown): Promise; + public abstract applyFilter(input: unknown): unknown; + public abstract isApplicableForObject(objectState: ObjectState): boolean; + + public abstract get name(): string; + public abstract get parameters(): ActionParameters | null; +} + +export function prepareActionParameters(declared: ActionParameters, defined: object): Record { + if (!declared) { + return {}; + } + + return Object.entries(declared).reduce((acc, [name, { type, defaultValue }]) => { + if (type === ActionParameterType.NUMBER) { + acc[name] = +(Object.hasOwn(defined, name) ? defined[name] : defaultValue); + } else { + acc[name] = (Object.hasOwn(defined, name) ? defined[name] : defaultValue); + } + return acc; + }, {} as Record); +} + +export function validateClientIDs(collection: Partial) { + [].concat( + collection.shapes ?? [], + collection.tracks ?? [], + collection.tags ?? [], + ).forEach((object) => { + // clientID is required to correct collection filtering and commiting in annotations actions logic + if (typeof object.clientID !== 'number') { + throw new Error('ClientID is undefined when running annotations action, but required'); + } + }); +} diff --git a/cvat-core/src/annotations-actions/base-collection-action.ts b/cvat-core/src/annotations-actions/base-collection-action.ts new file mode 100644 index 000000000000..c48135694566 --- /dev/null +++ b/cvat-core/src/annotations-actions/base-collection-action.ts @@ -0,0 +1,178 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { throttle } from 'lodash'; + +import ObjectState from '../object-state'; +import AnnotationsFilter from '../annotations-filter'; +import { Job, Task } from '../session'; +import { + SerializedCollection, SerializedShape, + SerializedTag, SerializedTrack, +} from '../server-response-types'; +import { EventScope, ObjectType } from '../enums'; +import { getCollection } from '../annotations'; +import { BaseAction, prepareActionParameters, validateClientIDs } from './base-action'; + +export interface CollectionActionInput { + onProgress(message: string, percent: number): void; + cancelled(): boolean; + collection: Pick; + frameData: { + width: number; + height: number; + number: number; + }; +} + +export interface CollectionActionOutput { + created: CollectionActionInput['collection']; + deleted: CollectionActionInput['collection']; +} + +export abstract class BaseCollectionAction extends BaseAction { + public abstract run(input: CollectionActionInput): Promise; + public abstract applyFilter( + input: Pick, + ): CollectionActionInput['collection']; +} + +export async function run( + instance: Job | Task, + action: BaseCollectionAction, + actionParameters: Record, + frame: number, + filters: object[], + onProgress: (message: string, progress: number) => void, + cancelled: () => boolean, +): Promise { + const event = await instance.logger.log(EventScope.annotationsAction, { + from: frame, + to: frame, + name: action.name, + }, true); + + const wrappedOnProgress = throttle(onProgress, 100, { leading: true, trailing: true }); + const showMessageWithPause = async (message: string, progress: number, duration: number): Promise => { + // wrapper that gives a chance to abort action + wrappedOnProgress(message, progress); + await new Promise((resolve) => setTimeout(resolve, duration)); + }; + + try { + await showMessageWithPause('Action initialization', 0, 500); + if (cancelled()) { + return; + } + + await action.init(instance, prepareActionParameters(action.parameters, actionParameters)); + + const frameData = await Object.getPrototypeOf(instance).frames + .get.implementation.call(instance, frame); + const exportedCollection = getCollection(instance).export(); + + // Apply action filter first + const filteredByAction = action.applyFilter({ collection: exportedCollection, frameData }); + validateClientIDs(filteredByAction); + + let mapID2Obj = [].concat(filteredByAction.shapes, filteredByAction.tags, filteredByAction.tracks) + .reduce((acc, object) => { + acc[object.clientID as number] = object; + return acc; + }, {}); + + // Then apply user filter + const annotationsFilter = new AnnotationsFilter(); + const filteredCollectionIDs = annotationsFilter + .filterSerializedCollection(filteredByAction, instance.labels, filters); + const filteredByUser = { + shapes: filteredCollectionIDs.shapes.map((clientID) => mapID2Obj[clientID]), + tags: filteredCollectionIDs.tags.map((clientID) => mapID2Obj[clientID]), + tracks: filteredCollectionIDs.tracks.map((clientID) => mapID2Obj[clientID]), + }; + mapID2Obj = [].concat(filteredByUser.shapes, filteredByUser.tags, filteredByUser.tracks) + .reduce((acc, object) => { + acc[object.clientID as number] = object; + return acc; + }, {}); + + const { created, deleted } = await action.run({ + collection: filteredByUser, + frameData: { + width: frameData.width, + height: frameData.height, + number: frameData.number, + }, + onProgress: wrappedOnProgress, + cancelled, + }); + + await instance.annotations.commit(created, deleted, frame); + event.close(); + } finally { + wrappedOnProgress('Finalizing', 100); + await action.destroy(); + } +} + +export async function call( + instance: Job | Task, + action: BaseCollectionAction, + actionParameters: Record, + frame: number, + states: ObjectState[], + onProgress: (message: string, progress: number) => void, + cancelled: () => boolean, +): Promise { + const event = await instance.logger.log(EventScope.annotationsAction, { + from: frame, + to: frame, + name: action.name, + }, true); + + const throttledOnProgress = throttle(onProgress, 100, { leading: true, trailing: true }); + try { + await action.init(instance, prepareActionParameters(action.parameters, actionParameters)); + const exportedStates = await Promise.all(states.map((state) => state.export())); + const exportedCollection = exportedStates.reduce((acc, value, idx) => { + if (states[idx].objectType === ObjectType.SHAPE) { + acc.shapes.push(value as SerializedShape); + } + + if (states[idx].objectType === ObjectType.TAG) { + acc.tags.push(value as SerializedTag); + } + + if (states[idx].objectType === ObjectType.TRACK) { + acc.tracks.push(value as SerializedTrack); + } + + return acc; + }, { shapes: [], tags: [], tracks: [] }); + + const frameData = await Object.getPrototypeOf(instance).frames.get.implementation.call(instance, frame); + const filteredByAction = action.applyFilter({ collection: exportedCollection, frameData }); + validateClientIDs(filteredByAction); + + const processedCollection = await action.run({ + onProgress: throttledOnProgress, + cancelled, + collection: filteredByAction, + frameData: { + width: frameData.width, + height: frameData.height, + number: frameData.number, + }, + }); + + await instance.annotations.commit( + processedCollection.created, + processedCollection.deleted, + frame, + ); + event.close(); + } finally { + await action.destroy(); + } +} diff --git a/cvat-core/src/annotations-actions/base-shapes-action.ts b/cvat-core/src/annotations-actions/base-shapes-action.ts new file mode 100644 index 000000000000..80d2b4fee78b --- /dev/null +++ b/cvat-core/src/annotations-actions/base-shapes-action.ts @@ -0,0 +1,196 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { throttle } from 'lodash'; + +import ObjectState from '../object-state'; +import AnnotationsFilter from '../annotations-filter'; +import { Job, Task } from '../session'; +import { SerializedCollection, SerializedShape } from '../server-response-types'; +import { EventScope, ObjectType } from '../enums'; +import { getCollection } from '../annotations'; +import { BaseAction, prepareActionParameters, validateClientIDs } from './base-action'; + +export interface ShapesActionInput { + onProgress(message: string, percent: number): void; + cancelled(): boolean; + collection: Pick; + frameData: { + width: number; + height: number; + number: number; + }; +} + +export interface ShapesActionOutput { + created: ShapesActionInput['collection']; + deleted: ShapesActionInput['collection']; +} + +export abstract class BaseShapesAction extends BaseAction { + public abstract run(input: ShapesActionInput): Promise; + public abstract applyFilter( + input: Pick + ): ShapesActionInput['collection']; +} + +export async function run( + instance: Job | Task, + action: BaseShapesAction, + actionParameters: Record, + frameFrom: number, + frameTo: number, + filters: object[], + onProgress: (message: string, progress: number) => void, + cancelled: () => boolean, +): Promise { + const event = await instance.logger.log(EventScope.annotationsAction, { + from: frameFrom, + to: frameTo, + name: action.name, + }, true); + + const throttledOnProgress = throttle(onProgress, 100, { leading: true, trailing: true }); + const showMessageWithPause = async (message: string, progress: number, duration: number): Promise => { + // wrapper that gives a chance to abort action + throttledOnProgress(message, progress); + await new Promise((resolve) => setTimeout(resolve, duration)); + }; + + try { + await showMessageWithPause('Actions initialization', 0, 500); + if (cancelled()) { + return; + } + + await action.init(instance, prepareActionParameters(action.parameters, actionParameters)); + + const exportedCollection = getCollection(instance).export(); + validateClientIDs(exportedCollection); + + const annotationsFilter = new AnnotationsFilter(); + const filteredShapeIDs = annotationsFilter.filterSerializedCollection({ + shapes: exportedCollection.shapes, + tags: [], + tracks: [], + }, instance.labels, filters).shapes; + + const filteredShapesByFrame = exportedCollection.shapes.reduce((acc, shape) => { + if (shape.frame >= frameFrom && shape.frame <= frameTo && filteredShapeIDs.includes(shape.clientID)) { + acc[shape.frame] = acc[shape.frame] ?? []; + acc[shape.frame].push(shape); + } + return acc; + }, {} as Record); + + const totalUpdates = { created: { shapes: [] }, deleted: { shapes: [] } }; + // Iterate over frames + const totalFrames = frameTo - frameFrom + 1; + for (let frame = frameFrom; frame <= frameTo; frame++) { + const frameData = await Object.getPrototypeOf(instance).frames + .get.implementation.call(instance, frame); + + // Ignore deleted frames + if (!frameData.deleted) { + const frameShapes = filteredShapesByFrame[frame] ?? []; + if (!frameShapes.length) { + continue; + } + + // finally apply the own filter of the action + const filteredByAction = action.applyFilter({ + collection: { + shapes: frameShapes, + }, + frameData, + }); + validateClientIDs(filteredByAction); + + const { created, deleted } = await action.run({ + onProgress: throttledOnProgress, + cancelled, + collection: { shapes: filteredByAction.shapes }, + frameData: { + width: frameData.width, + height: frameData.height, + number: frameData.number, + }, + }); + + Array.prototype.push.apply(totalUpdates.created.shapes, created.shapes); + Array.prototype.push.apply(totalUpdates.deleted.shapes, deleted.shapes); + + const progress = Math.ceil(+(((frame - frameFrom) / totalFrames) * 100)); + throttledOnProgress('Actions are running', progress); + if (cancelled()) { + return; + } + } + } + + await showMessageWithPause('Commiting handled objects', 100, 1500); + if (cancelled()) { + return; + } + + await instance.annotations.commit( + { shapes: totalUpdates.created.shapes, tags: [], tracks: [] }, + { shapes: totalUpdates.deleted.shapes, tags: [], tracks: [] }, + frameFrom, + ); + + event.close(); + } finally { + throttledOnProgress('Finalizing', 100); + await action.destroy(); + } +} + +export async function call( + instance: Job | Task, + action: BaseShapesAction, + actionParameters: Record, + frame: number, + states: ObjectState[], + onProgress: (message: string, progress: number) => void, + cancelled: () => boolean, +): Promise { + const event = await instance.logger.log(EventScope.annotationsAction, { + from: frame, + to: frame, + name: action.name, + }, true); + + const throttledOnProgress = throttle(onProgress, 100, { leading: true, trailing: true }); + try { + await action.init(instance, prepareActionParameters(action.parameters, actionParameters)); + + const exported = await Promise.all(states.filter((state) => state.objectType === ObjectType.SHAPE) + .map((state) => state.export())) as SerializedShape[]; + const frameData = await Object.getPrototypeOf(instance).frames.get.implementation.call(instance, frame); + const filteredByAction = action.applyFilter({ collection: { shapes: exported }, frameData }); + validateClientIDs(filteredByAction); + + const processedCollection = await action.run({ + onProgress: throttledOnProgress, + cancelled, + collection: { shapes: filteredByAction.shapes }, + frameData: { + width: frameData.width, + height: frameData.height, + number: frameData.number, + }, + }); + + await instance.annotations.commit( + { shapes: processedCollection.created.shapes, tags: [], tracks: [] }, + { shapes: processedCollection.deleted.shapes, tags: [], tracks: [] }, + frame, + ); + + event.close(); + } finally { + await action.destroy(); + } +} diff --git a/cvat-core/src/annotations-actions/propagate-shapes.ts b/cvat-core/src/annotations-actions/propagate-shapes.ts new file mode 100644 index 000000000000..ee68b9600f4f --- /dev/null +++ b/cvat-core/src/annotations-actions/propagate-shapes.ts @@ -0,0 +1,85 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { range } from 'lodash'; + +import ObjectState from '../object-state'; +import { Job, Task } from '../session'; +import { SerializedShape } from '../server-response-types'; +import { propagateShapes } from '../object-utils'; +import { ObjectType } from '../enums'; + +import { ActionParameterType, ActionParameters } from './base-action'; +import { BaseCollectionAction, CollectionActionInput, CollectionActionOutput } from './base-collection-action'; + +export class PropagateShapes extends BaseCollectionAction { + #instance: Task | Job; + #targetFrame: number; + + public async init(instance: Job | Task, parameters): Promise { + this.#instance = instance; + this.#targetFrame = parameters['Target frame']; + } + + public async destroy(): Promise { + // nothing to destroy + } + + public async run(input: CollectionActionInput): Promise { + const { collection, frameData: { number } } = input; + if (number === this.#targetFrame) { + return { + created: { shapes: [], tags: [], tracks: [] }, + deleted: { shapes: [], tags: [], tracks: [] }, + }; + } + + const frameNumbers = this.#instance instanceof Job ? + await this.#instance.frames.frameNumbers() : range(0, this.#instance.size); + const propagatedShapes = propagateShapes( + collection.shapes, number, this.#targetFrame, frameNumbers, + ); + + return { + created: { shapes: propagatedShapes, tags: [], tracks: [] }, + deleted: { shapes: [], tags: [], tracks: [] }, + }; + } + + public applyFilter(input: CollectionActionInput): CollectionActionInput['collection'] { + return { + shapes: input.collection.shapes.filter((shape) => shape.frame === input.frameData.number), + tags: [], + tracks: [], + }; + } + + public isApplicableForObject(objectState: ObjectState): boolean { + return objectState.objectType === ObjectType.SHAPE; + } + + public get name(): string { + return 'Propagate shapes'; + } + + public get parameters(): ActionParameters | null { + return { + 'Target frame': { + type: ActionParameterType.NUMBER, + values: ({ instance }) => { + if (instance instanceof Job) { + return [instance.startFrame, instance.stopFrame, 1].map((val) => val.toString()); + } + return [0, instance.size - 1, 1].map((val) => val.toString()); + }, + defaultValue: ({ instance }) => { + if (instance instanceof Job) { + return instance.stopFrame.toString(); + } + return (instance.size - 1).toString(); + }, + }, + }; + } +} diff --git a/cvat-core/src/annotations-actions/remove-filtered-shapes.ts b/cvat-core/src/annotations-actions/remove-filtered-shapes.ts new file mode 100644 index 000000000000..ab2a30964fad --- /dev/null +++ b/cvat-core/src/annotations-actions/remove-filtered-shapes.ts @@ -0,0 +1,41 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +import { BaseShapesAction, ShapesActionInput, ShapesActionOutput } from './base-shapes-action'; +import { ActionParameters } from './base-action'; + +export class RemoveFilteredShapes extends BaseShapesAction { + public async init(): Promise { + // nothing to init + } + + public async destroy(): Promise { + // nothing to destroy + } + + public async run(input: ShapesActionInput): Promise { + return { + created: { shapes: [] }, + deleted: input.collection, + }; + } + + public applyFilter(input: ShapesActionInput): ShapesActionInput['collection'] { + const { collection } = input; + return collection; + } + + public isApplicableForObject(): boolean { + // remove action does not make sense when running on one object + return false; + } + + public get name(): string { + return 'Remove filtered shapes'; + } + + public get parameters(): ActionParameters | null { + return null; + } +} diff --git a/cvat-core/src/annotations-collection.ts b/cvat-core/src/annotations-collection.ts index 14879e86bcd8..25496dfe69a7 100644 --- a/cvat-core/src/annotations-collection.ts +++ b/cvat-core/src/annotations-collection.ts @@ -157,9 +157,68 @@ export default class Collection { return result; } - public export(): Omit { + public commit( + appended: Omit, + removed: Omit, + frame: number, + ): { tags: Tag[]; shapes: Shape[]; tracks: Track[]; } { + const isCollectionConsistent = [].concat(removed.shapes, removed.tags, removed.tracks) + .every((object) => typeof object.clientID === 'number' && + Object.prototype.hasOwnProperty.call(this.objects, object.clientID)); + + if (!isCollectionConsistent) { + throw new ArgumentError('Objects required to be deleted were not found in the collection'); + } + + const removedCollection: (Shape | Tag | Track)[] = [].concat(removed.shapes, removed.tags, removed.tracks) + .map((object) => this.objects[object.clientID as number]); + + const imported = this.import(appended); + const appendedCollection = ([] as (Shape | Tag | Track)[]) + .concat(imported.shapes, imported.tags, imported.tracks); + if (!(appendedCollection.length > 0 || removedCollection.length > 0)) { + // nothing to commit + return; + } + + let prevRemoved = []; + removedCollection.forEach((collectionObject) => { + prevRemoved.push(collectionObject.removed); + collectionObject.removed = true; + }); + + this.history.do( + HistoryActions.COMMIT_ANNOTATIONS, + () => { + removedCollection.forEach((collectionObject, idx) => { + collectionObject.removed = prevRemoved[idx]; + }); + prevRemoved = []; + appendedCollection.forEach((collectionObject) => { + collectionObject.removed = true; + }); + }, + () => { + removedCollection.forEach((collectionObject) => { + prevRemoved.push(collectionObject.removed); + collectionObject.removed = true; + }); + appendedCollection.forEach((collectionObject) => { + collectionObject.removed = false; + }); + }, + [].concat( + removedCollection.map((object) => object.clientID), + appendedCollection.map((object) => object.clientID), + ), + frame, + ); + } + + public export(): Pick { const data = { - tracks: this.tracks.filter((track) => !track.removed).map((track) => track.toJSON() as SerializedTrack), + tracks: this.tracks.filter((track) => !track.removed) + .map((track) => track.toJSON() as SerializedTrack), shapes: Object.values(this.shapes) .reduce((accumulator, frameShapes) => { accumulator.push(...frameShapes); @@ -201,7 +260,7 @@ export default class Collection { } const objectStates = []; - const filtered = this.annotationsFilter.filter(visible, filters); + const filtered = this.annotationsFilter.filterSerializedObjectStates(visible, filters); visible.forEach((stateData) => { if (!filters.length || filtered.includes(stateData.clientID)) { @@ -1338,7 +1397,7 @@ export default class Collection { statesData.push(...tracks.map((track) => track.get(frame)).filter((state) => !state.outside)); // Filtering - const filtered = this.annotationsFilter.filter(statesData, annotationsFilters); + const filtered = this.annotationsFilter.filterSerializedObjectStates(statesData, annotationsFilters); if (filtered.length) { return frame; } diff --git a/cvat-core/src/annotations-filter.ts b/cvat-core/src/annotations-filter.ts index 58c9e82a63e5..fa7b8e739f5a 100644 --- a/cvat-core/src/annotations-filter.ts +++ b/cvat-core/src/annotations-filter.ts @@ -6,15 +6,74 @@ import jsonLogic from 'json-logic-js'; import { SerializedData } from './object-state'; import { AttributeType, ObjectType, ShapeType } from './enums'; +import { SerializedCollection } from './server-response-types'; +import { Attribute, Label } from './labels'; function adjustName(name): string { return name.replace(/\./g, '\u2219'); } +function getDimensions(points: number[], shapeType: ShapeType): { + width: number | null; + height: number | null; +} { + let [width, height]: (number | null)[] = [null, null]; + if (shapeType === ShapeType.MASK) { + const [xtl, ytl, xbr, ybr] = points.slice(-4); + [width, height] = [xbr - xtl + 1, ybr - ytl + 1]; + } else if (shapeType === ShapeType.ELLIPSE) { + const [cx, cy, rightX, topY] = points; + width = Math.abs(rightX - cx) * 2; + height = Math.abs(cy - topY) * 2; + } else { + let xtl = Number.MAX_SAFE_INTEGER; + let xbr = Number.MIN_SAFE_INTEGER; + let ytl = Number.MAX_SAFE_INTEGER; + let ybr = Number.MIN_SAFE_INTEGER; + + points.forEach((coord, idx) => { + if (idx % 2) { + // y + ytl = Math.min(ytl, coord); + ybr = Math.max(ybr, coord); + } else { + // x + xtl = Math.min(xtl, coord); + xbr = Math.max(xbr, coord); + } + }); + [width, height] = [xbr - xtl, ybr - ytl]; + } + + return { + width, + height, + }; +} + +function convertAttribute(id: number, value: string, attributesSpec: Record): [ + string, + number | boolean | string, +] { + const spec = attributesSpec[id]; + const name = adjustName(spec.name); + if (spec.inputType === AttributeType.NUMBER) { + return [name, +value]; + } + + if (spec.inputType === AttributeType.CHECKBOX) { + return [name, value === 'true']; + } + + return [name, value]; +} + +type ConvertedAttributes = Record; + interface ConvertedObjectData { width: number | null; height: number | null; - attr: Record>; + attr: Record; label: string; serverID: number; objectID: number; @@ -24,7 +83,7 @@ interface ConvertedObjectData { } export default class AnnotationsFilter { - _convertObjects(statesData: SerializedData[]): ConvertedObjectData[] { + private _convertSerializedObjectStates(statesData: SerializedData[]): ConvertedObjectData[] { const objects = statesData.map((state) => { const labelAttributes = state.label.attributes.reduce((acc, attr) => { acc[attr.id] = attr; @@ -33,50 +92,26 @@ export default class AnnotationsFilter { let [width, height]: (number | null)[] = [null, null]; if (state.objectType !== ObjectType.TAG) { - if (state.shapeType === ShapeType.MASK) { - const [xtl, ytl, xbr, ybr] = state.points.slice(-4); - [width, height] = [xbr - xtl + 1, ybr - ytl + 1]; - } else { - let xtl = Number.MAX_SAFE_INTEGER; - let xbr = Number.MIN_SAFE_INTEGER; - let ytl = Number.MAX_SAFE_INTEGER; - let ybr = Number.MIN_SAFE_INTEGER; - - const points = state.points || state.elements.reduce((acc, val) => { - acc.push(val.points); - return acc; - }, []).flat(); - points.forEach((coord, idx) => { - if (idx % 2) { - // y - ytl = Math.min(ytl, coord); - ybr = Math.max(ybr, coord); - } else { - // x - xtl = Math.min(xtl, coord); - xbr = Math.max(xbr, coord); - } - }); - [width, height] = [xbr - xtl, ybr - ytl]; - } + const points = state.shapeType === ShapeType.SKELETON ? state.elements.reduce((acc, val) => { + acc.push(val.points); + return acc; + }, []).flat() : state.points; + + ({ width, height } = getDimensions(points, state.shapeType as ShapeType)); } - const attributes = Object.keys(state.attributes).reduce>((acc, key) => { - const attr = labelAttributes[key]; - let value = state.attributes[key]; - if (attr.inputType === AttributeType.NUMBER) { - value = +value; - } else if (attr.inputType === AttributeType.CHECKBOX) { - value = value === 'true'; - } - acc[adjustName(attr.name)] = value; + const attributes = Object.keys(state.attributes).reduce((acc, key) => { + const [name, value] = convertAttribute(+key, state.attributes[key], labelAttributes); + acc[name] = value; return acc; - }, {}); + }, {} as Record); return { width, height, - attr: Object.fromEntries([[adjustName(state.label.name), attributes]]), + attr: { + [adjustName(state.label.name)]: attributes, + }, label: state.label.name, serverID: state.serverID, objectID: state.clientID, @@ -89,11 +124,119 @@ export default class AnnotationsFilter { return objects; } - filter(statesData: SerializedData[], filters: object[]): number[] { - if (!filters.length) return statesData.map((stateData): number => stateData.clientID); - const converted = this._convertObjects(statesData); + private _convertSerializedCollection( + collection: Omit, + labelsSpec: Label[], + ): { shapes: ConvertedObjectData[]; tags: ConvertedObjectData[]; tracks: ConvertedObjectData[]; } { + const labelByID = labelsSpec.reduce>((acc, label) => ({ + [label.id]: label, + ...acc, + }), {}); + + const attributeById = labelsSpec.map((label) => label.attributes).flat().reduce((acc, attribute) => ({ + ...acc, + [attribute.id]: attribute, + }), {} as Record); + + const convertAttributes = ( + attributes: SerializedCollection['shapes'][0]['attributes'], + ): ConvertedAttributes => attributes.reduce((acc, { spec_id, value }) => { + const [name, adjustedValue] = convertAttribute(spec_id, value, attributeById); + acc[name] = adjustedValue; + return acc; + }, {} as Record); + + return { + shapes: collection.shapes.map((shape) => { + const label = labelByID[shape.label_id]; + const points = shape.type === ShapeType.SKELETON ? + shape.elements.map((el) => el.points).flat() : shape.points; + let [width, height]: (number | null)[] = [null, null]; + ({ width, height } = getDimensions(points, shape.type)); + + return { + width, + height, + attr: { + [adjustName(label.name)]: convertAttributes(shape.attributes), + }, + label: label.name, + serverID: shape.id ?? null, + type: ObjectType.SHAPE, + shape: shape.type, + occluded: shape.occluded, + objectID: shape.clientID ?? null, + }; + }), + tags: collection.tags.map((tag) => { + const label = labelByID[tag.label_id]; + + return { + width: null, + height: null, + attr: { + [adjustName(label.name)]: convertAttributes(tag.attributes), + }, + label: labelByID[tag.label_id]?.name ?? null, + serverID: tag.id ?? null, + type: ObjectType.SHAPE, + shape: null, + occluded: false, + objectID: tag.clientID ?? null, + }; + }), + tracks: collection.tracks.map((track) => { + const label = labelByID[track.label_id]; + + return { + width: null, + height: null, + attr: { + [adjustName(label.name)]: convertAttributes(track.attributes), + }, + label: labelByID[track.label_id]?.name ?? null, + serverID: track.id, + type: ObjectType.TRACK, + shape: track.shapes[0]?.type ?? null, + occluded: null, + objectID: track.clientID ?? null, + }; + }), + }; + } + + public filterSerializedObjectStates(statesData: SerializedData[], filters: object[]): number[] { + if (!filters.length) { + return statesData.map((stateData): number => stateData.clientID); + } + + const converted = this._convertSerializedObjectStates(statesData); return converted .map((state) => state.objectID) .filter((_, index) => jsonLogic.apply(filters[0], converted[index])); } + + public filterSerializedCollection( + collection: Omit, + labelsSpec: Label[], + filters: object[], + ): { shapes: number[]; tags: number[]; tracks: number[]; } { + if (!filters.length) { + return { + shapes: collection.shapes.map((shape) => shape.clientID), + tags: collection.tags.map((tag) => tag.clientID), + tracks: collection.tracks.map((track) => track.clientID), + }; + } + + const converted = this._convertSerializedCollection(collection, labelsSpec); + return { + shapes: converted.shapes.map((shape) => shape.objectID) + .filter((_, index) => jsonLogic.apply(filters[0], converted.shapes[index])), + tags: converted.tags.map((shape) => shape.objectID) + .filter((_, index) => jsonLogic.apply(filters[0], converted.tags[index])), + tracks: converted.tracks.map((shape) => shape.objectID) + .filter((_, index) => jsonLogic.apply(filters[0], converted.tracks[index])), + }; + } } diff --git a/cvat-core/src/annotations-history.ts b/cvat-core/src/annotations-history.ts index 748d55bcf93d..2e59db96ea1f 100644 --- a/cvat-core/src/annotations-history.ts +++ b/cvat-core/src/annotations-history.ts @@ -5,7 +5,7 @@ import { HistoryActions } from './enums'; -const MAX_HISTORY_LENGTH = 128; +const MAX_HISTORY_LENGTH = 32; interface ActionItem { action: HistoryActions; diff --git a/cvat-core/src/annotations-objects.ts b/cvat-core/src/annotations-objects.ts index defcf7dbbada..ab7e32de9784 100644 --- a/cvat-core/src/annotations-objects.ts +++ b/cvat-core/src/annotations-objects.ts @@ -150,17 +150,12 @@ class Annotation { injection.groups.max = Math.max(injection.groups.max, this.group); } - protected withContext(frame: number): { - __internal: { - save: (data: ObjectState) => ObjectState; - delete: Annotation['delete']; - }; + // eslint-disable-next-line @typescript-eslint/no-unused-vars + protected withContext(_: number): { + delete: Annotation['delete']; } { return { - __internal: { - save: (this as any).save.bind(this, frame), - delete: this.delete.bind(this), - }, + delete: this.delete.bind(this), }; } @@ -530,6 +525,17 @@ export class Shape extends Drawn { this.zOrder = data.z_order; } + protected withContext(frame: number): ReturnType & { + save: (data: ObjectState) => ObjectState; + export: () => SerializedShape; + } { + return { + ...super.withContext(frame), + save: this.save.bind(this, frame), + export: this.toJSON.bind(this) as () => SerializedShape, + }; + } + // Method is used to export data to the server public toJSON(): SerializedShape | SerializedShape['elements'][0] { const result: SerializedShape = { @@ -592,7 +598,7 @@ export class Shape extends Drawn { pinned: this.pinned, frame, source: this.source, - ...this.withContext(frame), + __internal: this.withContext(frame), }; if (typeof this.outside !== 'undefined') { @@ -838,6 +844,17 @@ export class Track extends Drawn { }, {}); } + protected withContext(frame: number): ReturnType & { + save: (data: ObjectState) => ObjectState; + export: () => SerializedTrack; + } { + return { + ...super.withContext(frame), + save: this.save.bind(this, frame), + export: this.toJSON.bind(this) as () => SerializedTrack, + }; + } + // Method is used to export data to the server public toJSON(): SerializedTrack | SerializedTrack['elements'][0] { const labelAttributes = attrsAsAnObject(this.label.attributes); @@ -931,7 +948,7 @@ export class Track extends Drawn { }, frame, source: this.source, - ...this.withContext(frame), + __internal: this.withContext(frame), }; } @@ -1405,6 +1422,17 @@ export class Track extends Drawn { } export class Tag extends Annotation { + protected withContext(frame: number): ReturnType & { + save: (data: ObjectState) => ObjectState; + export: () => SerializedTag; + } { + return { + ...super.withContext(frame), + save: this.save.bind(this, frame), + export: this.toJSON.bind(this) as () => SerializedTag, + }; + } + // Method is used to export data to the server public toJSON(): SerializedTag { const result: SerializedTag = { @@ -1451,7 +1479,7 @@ export class Tag extends Annotation { updated: this.updated, frame, source: this.source, - ...this.withContext(frame), + __internal: this.withContext(frame), }; } @@ -2022,7 +2050,7 @@ export class SkeletonShape extends Shape { hidden: elements.every((el) => el.hidden), frame, source: this.source, - ...this.withContext(frame), + __internal: this.withContext(frame), }; } @@ -3064,7 +3092,7 @@ export class SkeletonTrack extends Track { occluded: elements.every((el) => el.occluded), lock: elements.every((el) => el.lock), hidden: elements.every((el) => el.hidden), - ...this.withContext(frame), + __internal: this.withContext(frame), }; } diff --git a/cvat-core/src/api-implementation.ts b/cvat-core/src/api-implementation.ts index 0e9f400ad499..c9e53a2e1e0d 100644 --- a/cvat-core/src/api-implementation.ts +++ b/cvat-core/src/api-implementation.ts @@ -39,7 +39,9 @@ import QualityConflict, { ConflictSeverity } from './quality-conflict'; import QualitySettings from './quality-settings'; import { getFramesMeta } from './frames'; import AnalyticsReport from './analytics-report'; -import { listActions, registerAction, runActions } from './annotations-actions'; +import { + callAction, listActions, registerAction, runAction, +} from './annotations-actions/annotations-actions'; import { convertDescriptions, getServerAPISchema } from './server-schema'; import { JobType } from './enums'; import { PaginatedResource } from './core-types'; @@ -54,7 +56,8 @@ export default function implementAPI(cvat: CVATCore): CVATCore { implementationMixin(cvat.plugins.register, PluginRegistry.register.bind(cvat)); implementationMixin(cvat.actions.list, listActions); implementationMixin(cvat.actions.register, registerAction); - implementationMixin(cvat.actions.run, runActions); + implementationMixin(cvat.actions.run, runAction); + implementationMixin(cvat.actions.call, callAction); implementationMixin(cvat.lambda.list, lambdaManager.list.bind(lambdaManager)); implementationMixin(cvat.lambda.run, lambdaManager.run.bind(lambdaManager)); diff --git a/cvat-core/src/api.ts b/cvat-core/src/api.ts index ca33f431c43e..f4eb5d8b23fd 100644 --- a/cvat-core/src/api.ts +++ b/cvat-core/src/api.ts @@ -21,7 +21,9 @@ import CloudStorage from './cloud-storage'; import Organization from './organization'; import Webhook from './webhook'; import AnnotationGuide from './guide'; -import BaseSingleFrameAction from './annotations-actions'; +import { BaseAction } from './annotations-actions/base-action'; +import { BaseCollectionAction } from './annotations-actions/base-collection-action'; +import { BaseShapesAction } from './annotations-actions/base-shapes-action'; import QualityReport from './quality-report'; import QualityConflict from './quality-conflict'; import QualitySettings from './quality-settings'; @@ -191,14 +193,14 @@ function build(): CVATCore { const result = await PluginRegistry.apiWrapper(cvat.actions.list); return result; }, - async register(action: BaseSingleFrameAction) { + async register(action: BaseAction) { const result = await PluginRegistry.apiWrapper(cvat.actions.register, action); return result; }, async run( instance: Job | Task, - actionsChain: BaseSingleFrameAction[], - actionsParameters: Record[], + actions: BaseAction, + actionsParameters: Record, frameFrom: number, frameTo: number, filters: string[], @@ -211,7 +213,7 @@ function build(): CVATCore { const result = await PluginRegistry.apiWrapper( cvat.actions.run, instance, - actionsChain, + actions, actionsParameters, frameFrom, frameTo, @@ -221,6 +223,30 @@ function build(): CVATCore { ); return result; }, + async call( + instance: Job | Task, + actions: BaseAction, + actionsParameters: Record, + frame: number, + states: ObjectState[], + onProgress: ( + message: string, + progress: number, + ) => void, + cancelled: () => boolean, + ) { + const result = await PluginRegistry.apiWrapper( + cvat.actions.call, + instance, + actions, + actionsParameters, + frame, + states, + onProgress, + cancelled, + ); + return result; + }, }, lambda: { async list() { @@ -420,7 +446,8 @@ function build(): CVATCore { Organization, Webhook, AnnotationGuide, - BaseSingleFrameAction, + BaseShapesAction, + BaseCollectionAction, QualitySettings, AnalyticsReport, QualityConflict, diff --git a/cvat-core/src/enums.ts b/cvat-core/src/enums.ts index 1b291662d213..25fdf815fa20 100644 --- a/cvat-core/src/enums.ts +++ b/cvat-core/src/enums.ts @@ -148,6 +148,7 @@ export enum HistoryActions { REMOVED_OBJECT = 'Removed object', REMOVED_FRAME = 'Removed frame', RESTORED_FRAME = 'Restored frame', + COMMIT_ANNOTATIONS = 'Commit annotations', } export enum ModelKind { diff --git a/cvat-core/src/index.ts b/cvat-core/src/index.ts index 8a4c9e8bfb53..79ce8b305a9f 100644 --- a/cvat-core/src/index.ts +++ b/cvat-core/src/index.ts @@ -34,7 +34,14 @@ import AnalyticsReport from './analytics-report'; import AnnotationGuide from './guide'; import { JobValidationLayout, TaskValidationLayout } from './validation-layout'; import { Request } from './request'; -import BaseSingleFrameAction, { listActions, registerAction, runActions } from './annotations-actions'; +import { + runAction, + callAction, + listActions, + registerAction, +} from './annotations-actions/annotations-actions'; +import { BaseCollectionAction } from './annotations-actions/base-collection-action'; +import { BaseShapesAction } from './annotations-actions/base-shapes-action'; import { ArgumentError, DataError, Exception, ScriptingError, ServerError, } from './exceptions'; @@ -165,7 +172,8 @@ export default interface CVATCore { actions: { list: typeof listActions; register: typeof registerAction; - run: typeof runActions; + run: typeof runAction; + call: typeof callAction; }; logger: typeof logger; config: { @@ -209,7 +217,8 @@ export default interface CVATCore { Organization: typeof Organization; Webhook: typeof Webhook; AnnotationGuide: typeof AnnotationGuide; - BaseSingleFrameAction: typeof BaseSingleFrameAction; + BaseShapesAction: typeof BaseShapesAction; + BaseCollectionAction: typeof BaseCollectionAction; QualityReport: typeof QualityReport; QualityConflict: typeof QualityConflict; QualitySettings: typeof QualitySettings; diff --git a/cvat-core/src/object-state.ts b/cvat-core/src/object-state.ts index 9b35736a08a1..28993a0d114c 100644 --- a/cvat-core/src/object-state.ts +++ b/cvat-core/src/object-state.ts @@ -1,5 +1,5 @@ // Copyright (C) 2019-2022 Intel Corporation -// Copyright (C) 2022-2023 CVAT.ai Corporation +// Copyright (C) 2022-2024 CVAT.ai Corporation // // SPDX-License-Identifier: MIT @@ -8,6 +8,7 @@ import PluginRegistry from './plugins'; import { ArgumentError } from './exceptions'; import { Label } from './labels'; import { isEnum } from './common'; +import { SerializedShape, SerializedTag, SerializedTrack } from './server-response-types'; interface UpdateFlags { label: boolean; @@ -516,10 +517,15 @@ export default class ObjectState { const result = await PluginRegistry.apiWrapper.call(this, ObjectState.prototype.delete, frame, force); return result; } + + async export(): Promise { + const result = await PluginRegistry.apiWrapper.call(this, ObjectState.prototype.export); + return result; + } } Object.defineProperty(ObjectState.prototype.save, 'implementation', { - value: function save(): ObjectState { + value: function saveImplementation(): ObjectState { if (this.__internal && this.__internal.save) { return this.__internal.save(this); } @@ -529,8 +535,19 @@ Object.defineProperty(ObjectState.prototype.save, 'implementation', { writable: false, }); +Object.defineProperty(ObjectState.prototype.export, 'implementation', { + value: function exportImplementation(): ObjectState { + if (this.__internal && this.__internal.export) { + return this.__internal.export(this); + } + + return this; + }, + writable: false, +}); + Object.defineProperty(ObjectState.prototype.delete, 'implementation', { - value: function remove(frame: number, force: boolean): boolean { + value: function deleteImplementation(frame: number, force: boolean): boolean { if (this.__internal && this.__internal.delete) { if (!Number.isInteger(+frame) || +frame < 0) { throw new ArgumentError('Frame argument must be a non negative integer'); diff --git a/cvat-core/src/session-implementation.ts b/cvat-core/src/session-implementation.ts index 904899831abf..7ea9e326fb8b 100644 --- a/cvat-core/src/session-implementation.ts +++ b/cvat-core/src/session-implementation.ts @@ -519,6 +519,18 @@ export function implementJob(Job: typeof JobClass): typeof JobClass { }, }); + Object.defineProperty(Job.prototype.annotations.commit, 'implementation', { + value: function commitAnnotationsImplementation( + this: JobClass, + added: Parameters[0], + removed: Parameters[1], + frame: Parameters[2], + ): ReturnType { + getCollection(this).commit(added, removed, frame); + return Promise.resolve(); + }, + }); + Object.defineProperty(Job.prototype.annotations.upload, 'implementation', { value: async function uploadAnnotationsImplementation( this: JobClass, @@ -1208,6 +1220,18 @@ export function implementTask(Task: typeof TaskClass): typeof TaskClass { }, }); + Object.defineProperty(Task.prototype.annotations.commit, 'implementation', { + value: function commitAnnotationsImplementation( + this: TaskClass, + added: Parameters[0], + removed: Parameters[1], + frame: Parameters[2], + ): ReturnType { + getCollection(this).commit(added, removed, frame); + return Promise.resolve(); + }, + }); + Object.defineProperty(Task.prototype.annotations.exportDataset, 'implementation', { value: async function exportDatasetImplementation( this: TaskClass, diff --git a/cvat-core/src/session.ts b/cvat-core/src/session.ts index a2bc2008aef0..b3269ee78076 100644 --- a/cvat-core/src/session.ts +++ b/cvat-core/src/session.ts @@ -172,6 +172,17 @@ function buildDuplicatedAPI(prototype) { return result; }, + async commit(added, removed, frame) { + const result = await PluginRegistry.apiWrapper.call( + this, + prototype.annotations.commit, + added, + removed, + frame, + ); + return result; + }, + async exportDataset( format: string, saveImages: boolean, @@ -332,7 +343,7 @@ export class Session { delTrackKeyframesOnly?: boolean; }) => Promise; save: ( - onUpdate ?: (message: string) => void, + onUpdate?: (message: string) => void, ) => Promise; search: ( frameFrom: number, @@ -361,6 +372,11 @@ export class Session { }>; import: (data: Omit) => Promise; export: () => Promise>; + commit: ( + added: Omit, + removed: Omit, + frame: number, + ) => Promise; statistics: () => Promise; hasUnsavedChanges: () => boolean; exportDataset: ( @@ -431,6 +447,7 @@ export class Session { select: Object.getPrototypeOf(this).annotations.select.bind(this), import: Object.getPrototypeOf(this).annotations.import.bind(this), export: Object.getPrototypeOf(this).annotations.export.bind(this), + commit: Object.getPrototypeOf(this).annotations.commit.bind(this), statistics: Object.getPrototypeOf(this).annotations.statistics.bind(this), hasUnsavedChanges: Object.getPrototypeOf(this).annotations.hasUnsavedChanges.bind(this), exportDataset: Object.getPrototypeOf(this).annotations.exportDataset.bind(this), diff --git a/cvat-sdk/README.md b/cvat-sdk/README.md index fa68c0e5d40d..89702c02abd4 100644 --- a/cvat-sdk/README.md +++ b/cvat-sdk/README.md @@ -20,7 +20,14 @@ To install a prebuilt package, run the following command in the terminal: pip install cvat-sdk ``` -To use the PyTorch adapter, request the `pytorch` extra: +To use the `cvat_sdk.masks` module, request the `masks` extra: + +```bash +pip install "cvat-sdk[masks]" +``` + +To use the PyTorch adapter or the built-in PyTorch-based auto-annotation functions, +request the `pytorch` extra: ```bash pip install "cvat-sdk[pytorch]" diff --git a/cvat-sdk/cvat_sdk/auto_annotation/__init__.py b/cvat-sdk/cvat_sdk/auto_annotation/__init__.py index e5dbdf9fcc42..adbb6007e125 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/__init__.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/__init__.py @@ -10,8 +10,27 @@ keypoint, keypoint_spec, label_spec, + mask, + polygon, rectangle, shape, skeleton, skeleton_label_spec, ) + +__all__ = [ + "annotate_task", + "BadFunctionError", + "DetectionFunction", + "DetectionFunctionContext", + "DetectionFunctionSpec", + "keypoint_spec", + "keypoint", + "label_spec", + "mask", + "polygon", + "rectangle", + "shape", + "skeleton_label_spec", + "skeleton", +] diff --git a/cvat-sdk/cvat_sdk/auto_annotation/driver.py b/cvat-sdk/cvat_sdk/auto_annotation/driver.py index 0f3d82ea32ea..5ffdb36f5bee 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/driver.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/driver.py @@ -99,9 +99,11 @@ def __init__( ds_labels: Sequence[models.ILabel], *, allow_unmatched_labels: bool, + conv_mask_to_poly: bool, ) -> None: self._logger = logger self._allow_unmatched_labels = allow_unmatched_labels + self._conv_mask_to_poly = conv_mask_to_poly ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels} @@ -217,12 +219,19 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame: if getattr(shape, "elements", None): raise BadFunctionError("function output non-skeleton shape with elements") + if shape.type.value == "mask" and self._conv_mask_to_poly: + raise BadFunctionError( + "function output mask shape despite conv_mask_to_poly=True" + ) + shapes[:] = new_shapes -@attrs.frozen +@attrs.frozen(kw_only=True) class _DetectionFunctionContextImpl(DetectionFunctionContext): frame_name: str + conf_threshold: Optional[float] = None + conv_mask_to_poly: bool = False def annotate_task( @@ -233,6 +242,8 @@ def annotate_task( pbar: Optional[ProgressReporter] = None, clear_existing: bool = False, allow_unmatched_labels: bool = False, + conf_threshold: Optional[float] = None, + conv_mask_to_poly: bool = False, ) -> None: """ Downloads data for the task with the given ID, applies the given function to it @@ -264,11 +275,21 @@ def annotate_task( function declares a label in its spec that has no corresponding label in the task. If it's set to true, then such labels are allowed, and any annotations returned by the function that refer to this label are ignored. Otherwise, BadFunctionError is raised. + + The conf_threshold parameter must be None or a number between 0 and 1. It will be passed + to the AA function as the conf_threshold attribute of the context object. + + The conv_mask_to_poly parameter will be passed to the AA function as the conv_mask_to_poly + attribute of the context object. If it's true, and the AA function returns any mask shapes, + BadFunctionError will be raised. """ if pbar is None: pbar = NullProgressReporter() + if conf_threshold is not None and not 0 <= conf_threshold <= 1: + raise ValueError("conf_threshold must be None or a number between 0 and 1") + dataset = TaskDataset(client, task_id, load_annotations=False) assert isinstance(function.spec, DetectionFunctionSpec) @@ -278,6 +299,7 @@ def annotate_task( function.spec.labels, dataset.labels, allow_unmatched_labels=allow_unmatched_labels, + conv_mask_to_poly=conv_mask_to_poly, ) shapes = [] @@ -285,12 +307,17 @@ def annotate_task( with pbar.task(total=len(dataset.samples), unit="samples"): for sample in pbar.iter(dataset.samples): frame_shapes = function.detect( - _DetectionFunctionContextImpl(sample.frame_name), sample.media.load_image() + _DetectionFunctionContextImpl( + frame_name=sample.frame_name, + conf_threshold=conf_threshold, + conv_mask_to_poly=conv_mask_to_poly, + ), + sample.media.load_image(), ) mapper.validate_and_remap(frame_shapes, sample.frame_index) shapes.extend(frame_shapes) - client.logger.info("Uploading annotations to task %d", task_id) + client.logger.info("Uploading annotations to task %d...", task_id) if clear_existing: client.tasks.api.update_annotations( @@ -302,3 +329,5 @@ def annotate_task( task_id, patched_labeled_data_request=models.PatchedLabeledDataRequest(shapes=shapes), ) + + client.logger.info("Upload complete") diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/_torchvision.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/_torchvision.py new file mode 100644 index 000000000000..9fa88e0a7c07 --- /dev/null +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/_torchvision.py @@ -0,0 +1,26 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from functools import cached_property + +import torchvision.models + +import cvat_sdk.auto_annotation as cvataa + + +class TorchvisionFunction: + def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None: + weights_enum = torchvision.models.get_model_weights(model_name) + self._weights = weights_enum[weights_name] + self._transforms = self._weights.transforms() + self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs) + self._model.eval() + + @cached_property + def spec(self) -> cvataa.DetectionFunctionSpec: + return cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec(cat, i) for i, cat in enumerate(self._weights.meta["categories"]) + ] + ) diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py index d257cb7ec889..b16e4d8874ae 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py @@ -2,38 +2,26 @@ # # SPDX-License-Identifier: MIT -from functools import cached_property - import PIL.Image -import torchvision.models import cvat_sdk.auto_annotation as cvataa import cvat_sdk.models as models +from ._torchvision import TorchvisionFunction -class _TorchvisionDetectionFunction: - def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None: - weights_enum = torchvision.models.get_model_weights(model_name) - self._weights = weights_enum[weights_name] - self._transforms = self._weights.transforms() - self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs) - self._model.eval() - - @cached_property - def spec(self) -> cvataa.DetectionFunctionSpec: - return cvataa.DetectionFunctionSpec( - labels=[ - cvataa.label_spec(cat, i) for i, cat in enumerate(self._weights.meta["categories"]) - ] - ) - def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]: +class _TorchvisionDetectionFunction(TorchvisionFunction): + def detect( + self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> list[models.LabeledShapeRequest]: + conf_threshold = context.conf_threshold or 0 results = self._model([self._transforms(image)]) return [ cvataa.rectangle(label.item(), [x.item() for x in box]) for result in results - for box, label in zip(result["boxes"], result["labels"]) + for box, label, score in zip(result["boxes"], result["labels"], result["scores"]) + if score >= conf_threshold ] diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_instance_segmentation.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_instance_segmentation.py new file mode 100644 index 000000000000..6aa891811f5b --- /dev/null +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_instance_segmentation.py @@ -0,0 +1,70 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import math +from collections.abc import Iterator + +import numpy as np +import PIL.Image +from skimage import measure +from torch import Tensor + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.models as models +from cvat_sdk.masks import encode_mask + +from ._torchvision import TorchvisionFunction + + +def _is_positively_oriented(contour: np.ndarray) -> bool: + ys, xs = contour.T + + # This is the shoelace formula, except we only need the sign of the result, + # so we compare instead of subtracting. Compared to the typical formula, + # the sign is inverted, because the Y axis points downwards. + return np.sum(xs * np.roll(ys, -1)) < np.sum(ys * np.roll(xs, -1)) + + +def _generate_shapes( + context: cvataa.DetectionFunctionContext, box: Tensor, mask: Tensor, label: Tensor +) -> Iterator[models.LabeledShapeRequest]: + LEVEL = 0.5 + + if context.conv_mask_to_poly: + # Since we treat mask values of exactly LEVEL as true, we'd like them + # to also be considered high by find_contours. And for that, the level + # parameter must be slightly less than LEVEL. + contours = measure.find_contours(mask[0].detach().numpy(), level=math.nextafter(LEVEL, 0)) + + for contour in contours: + if len(contour) < 3 or _is_positively_oriented(contour): + continue + + contour = measure.approximate_polygon(contour, tolerance=2.5) + + yield cvataa.polygon(label.item(), contour[:, ::-1].ravel().tolist()) + + else: + yield cvataa.mask(label.item(), encode_mask(mask[0] >= LEVEL, box.tolist())) + + +class _TorchvisionInstanceSegmentationFunction(TorchvisionFunction): + def detect( + self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> list[models.LabeledShapeRequest]: + conf_threshold = context.conf_threshold or 0 + results = self._model([self._transforms(image)]) + + return [ + shape + for result in results + for box, mask, label, score in zip( + result["boxes"], result["masks"], result["labels"], result["scores"] + ) + if score >= conf_threshold + for shape in _generate_shapes(context, box, mask, label) + ] + + +create = _TorchvisionInstanceSegmentationFunction diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py index c7199b67738b..4d2250d61c35 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py @@ -5,20 +5,14 @@ from functools import cached_property import PIL.Image -import torchvision.models import cvat_sdk.auto_annotation as cvataa import cvat_sdk.models as models +from ._torchvision import TorchvisionFunction -class _TorchvisionKeypointDetectionFunction: - def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None: - weights_enum = torchvision.models.get_model_weights(model_name) - self._weights = weights_enum[weights_name] - self._transforms = self._weights.transforms() - self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs) - self._model.eval() +class _TorchvisionKeypointDetectionFunction(TorchvisionFunction): @cached_property def spec(self) -> cvataa.DetectionFunctionSpec: return cvataa.DetectionFunctionSpec( @@ -35,7 +29,10 @@ def spec(self) -> cvataa.DetectionFunctionSpec: ] ) - def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]: + def detect( + self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> list[models.LabeledShapeRequest]: + conf_threshold = context.conf_threshold or 0 results = self._model([self._transforms(image)]) return [ @@ -51,7 +48,10 @@ def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeReq ], ) for result in results - for keypoints, label in zip(result["keypoints"], result["labels"]) + for keypoints, label, score in zip( + result["keypoints"], result["labels"], result["scores"] + ) + if score >= conf_threshold ] diff --git a/cvat-sdk/cvat_sdk/auto_annotation/interface.py b/cvat-sdk/cvat_sdk/auto_annotation/interface.py index 20a21fe4a5cf..f95cb50b4f2d 100644 --- a/cvat-sdk/cvat_sdk/auto_annotation/interface.py +++ b/cvat-sdk/cvat_sdk/auto_annotation/interface.py @@ -4,7 +4,7 @@ import abc from collections.abc import Sequence -from typing import Protocol +from typing import Optional, Protocol import attrs import PIL.Image @@ -50,7 +50,33 @@ def frame_name(self) -> str: The file name of the frame that the current image corresponds to in the dataset. """ - ... + + @property + @abc.abstractmethod + def conf_threshold(self) -> Optional[float]: + """ + The confidence threshold that the function should use for filtering + detections. + + If the function is able to estimate confidence levels, then: + + * If this value is None, the function may apply a default threshold at its discretion. + + * Otherwise, it will be a number between 0 and 1. The function must only return + objects with confidence levels greater than or equal to this value. + + If the function is not able to estimate confidence levels, it can ignore this value. + """ + + @property + @abc.abstractmethod + def conv_mask_to_poly(self) -> bool: + """ + If this is true, the function must convert any mask shapes to polygon shapes + before returning them. + + If the function does not return any mask shapes, then it can ignore this value. + """ class DetectionFunction(Protocol): @@ -152,6 +178,21 @@ def rectangle(label_id: int, points: Sequence[float], **kwargs) -> models.Labele return shape(label_id, type="rectangle", points=points, **kwargs) +def polygon(label_id: int, points: Sequence[float], **kwargs) -> models.LabeledShapeRequest: + """Helper factory function for LabeledShapeRequest with frame=0 and type="polygon".""" + return shape(label_id, type="polygon", points=points, **kwargs) + + +def mask(label_id: int, points: Sequence[float], **kwargs) -> models.LabeledShapeRequest: + """ + Helper factory function for LabeledShapeRequest with frame=0 and type="mask". + + It's recommended to use the cvat.masks.encode_mask function to build the + points argument. + """ + return shape(label_id, type="mask", points=points, **kwargs) + + def skeleton( label_id: int, elements: Sequence[models.SubLabeledShapeRequest], **kwargs ) -> models.LabeledShapeRequest: diff --git a/cvat-sdk/cvat_sdk/masks.py b/cvat-sdk/cvat_sdk/masks.py new file mode 100644 index 000000000000..f623aec7d043 --- /dev/null +++ b/cvat-sdk/cvat_sdk/masks.py @@ -0,0 +1,44 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import math +from collections.abc import Sequence + +import numpy as np +from numpy.typing import ArrayLike + + +def encode_mask(bitmap: ArrayLike, /, bbox: Sequence[float]) -> list[float]: + """ + Encodes an image mask into an array of numbers suitable for the "points" + attribute of a LabeledShapeRequest object of type "mask". + + bitmap must be a boolean array of shape (H, W), where H is the height and + W is the width of the image that the mask applies to. + + bbox must have the form [x1, y1, x2, y2], where (0, 0) <= (x1, y1) < (x2, y2) <= (W, H). + The mask will be limited to points between (x1, y1) and (x2, y2). + """ + + bitmap = np.asanyarray(bitmap) + if bitmap.ndim != 2: + raise ValueError("bitmap must have 2 dimensions") + if bitmap.dtype != np.bool_: + raise ValueError("bitmap must have boolean items") + + x1, y1 = map(math.floor, bbox[0:2]) + x2, y2 = map(math.ceil, bbox[2:4]) + + if not (0 <= x1 < x2 <= bitmap.shape[1] and 0 <= y1 < y2 <= bitmap.shape[0]): + raise ValueError("bbox has invalid coordinates") + + flat = bitmap[y1:y2, x1:x2].ravel() + + (run_indices,) = np.diff(flat, prepend=[not flat[0]], append=[not flat[-1]]).nonzero() + if flat[0]: + run_lengths = np.diff(run_indices, prepend=[0]) + else: + run_lengths = np.diff(run_indices) + + return run_lengths.tolist() + [x1, y1, x2 - 1, y2 - 1] diff --git a/cvat-sdk/gen/generate.sh b/cvat-sdk/gen/generate.sh index ca9a08be98fe..60875c499496 100755 --- a/cvat-sdk/gen/generate.sh +++ b/cvat-sdk/gen/generate.sh @@ -8,7 +8,7 @@ set -e GENERATOR_VERSION="v6.0.1" -VERSION="2.22.0" +VERSION="2.23.0" LIB_NAME="cvat_sdk" LAYER1_LIB_NAME="${LIB_NAME}/api_client" DST_DIR="$(cd "$(dirname -- "$0")/.." && pwd)" diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_composed.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_composed.mustache index 97d3cb930c27..e56437b401ee 100644 --- a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_composed.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_composed.mustache @@ -22,7 +22,6 @@ {{name}} ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}} [optional]{{#defaultValue}} if omitted the server will use the default value of {{{.}}}{{/defaultValue}} # noqa: E501 {{/optionalVars}} """ - from {{packageName}}.configuration import Configuration {{#requiredVars}} {{#defaultValue}} @@ -32,7 +31,7 @@ _check_type = kwargs.pop('_check_type', True) _spec_property_naming = kwargs.pop('_spec_property_naming', False) _path_to_item = kwargs.pop('_path_to_item', ()) - _configuration = kwargs.pop('_configuration', Configuration()) + _configuration = kwargs.pop('_configuration', None) _visited_composed_classes = kwargs.pop('_visited_composed_classes', ()) self = super(OpenApiModel, cls).__new__(cls) diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_shared.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_shared.mustache index 4c149f22ce88..12dbba9ac641 100644 --- a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_shared.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_shared.mustache @@ -27,7 +27,6 @@ {{/optionalVars}} {{> model_templates/docstring_init_required_kwargs }} """ - from {{packageName}}.configuration import Configuration {{#requiredVars}} {{#defaultValue}} @@ -37,7 +36,7 @@ _check_type = kwargs.pop('_check_type', True) _spec_property_naming = kwargs.pop('_spec_property_naming', True) _path_to_item = kwargs.pop('_path_to_item', ()) - _configuration = kwargs.pop('_configuration', Configuration()) + _configuration = kwargs.pop('_configuration', None) _visited_composed_classes = kwargs.pop('_visited_composed_classes', ()) self = super(OpenApiModel, cls).__new__(cls) diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_simple.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_simple.mustache index 853532e9f5ca..e8daa85e829c 100644 --- a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_simple.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_simple.mustache @@ -12,7 +12,6 @@ value ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}}{{#defaultValue}} if omitted defaults to {{{.}}}{{/defaultValue}}{{#allowableValues}}, must be one of [{{#enumVars}}{{{value}}}, {{/enumVars}}]{{/allowableValues}} # noqa: E501 {{> model_templates/docstring_init_required_kwargs }} """ - from {{packageName}}.configuration import Configuration # required up here when default value is not given _path_to_item = kwargs.pop('_path_to_item', ()) @@ -39,7 +38,7 @@ _check_type = kwargs.pop('_check_type', True) _spec_property_naming = kwargs.pop('_spec_property_naming', False) - _configuration = kwargs.pop('_configuration', Configuration()) + _configuration = kwargs.pop('_configuration', None) _visited_composed_classes = kwargs.pop('_visited_composed_classes', ()) {{> model_templates/invalid_pos_args }} diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_shared.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_shared.mustache index 998b4841b7e7..c7d402a6cc52 100644 --- a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_shared.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_shared.mustache @@ -30,7 +30,6 @@ {{/optionalVars}} {{> model_templates/docstring_init_required_kwargs }} """ - from {{packageName}}.configuration import Configuration {{#requiredVars}} {{^isReadOnly}} @@ -42,7 +41,7 @@ _check_type = kwargs.pop('_check_type', True) _spec_property_naming = kwargs.pop('_spec_property_naming', False) _path_to_item = kwargs.pop('_path_to_item', ()) - _configuration = kwargs.pop('_configuration', Configuration()) + _configuration = kwargs.pop('_configuration', None) _visited_composed_classes = kwargs.pop('_visited_composed_classes', ()) {{> model_templates/invalid_pos_args }} diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_simple.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_simple.mustache index 8c8b42ce1f49..424b1d439c62 100644 --- a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_simple.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_simple.mustache @@ -20,7 +20,6 @@ value ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}}{{#defaultValue}} if omitted defaults to {{{.}}}{{/defaultValue}}{{#allowableValues}}, must be one of [{{#enumVars}}{{{value}}}, {{/enumVars}}]{{/allowableValues}} # noqa: E501 {{> model_templates/docstring_init_required_kwargs }} """ - from {{packageName}}.configuration import Configuration # required up here when default value is not given _path_to_item = kwargs.pop('_path_to_item', ()) @@ -45,7 +44,7 @@ _check_type = kwargs.pop('_check_type', True) _spec_property_naming = kwargs.pop('_spec_property_naming', False) - _configuration = kwargs.pop('_configuration', Configuration()) + _configuration = kwargs.pop('_configuration', None) _visited_composed_classes = kwargs.pop('_visited_composed_classes', ()) {{> model_templates/invalid_pos_args }} diff --git a/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache b/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache index c9e2b70d77bb..cc3c03dbce77 100644 --- a/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache @@ -354,6 +354,13 @@ class OpenApiModel(object): new_inst = new_cls._new_from_openapi_data(*args, **kwargs) return new_inst + def __setstate__(self, state): + # This is the same as the default implementation. We override it, + # because unpickling attempts to access `obj.__setstate__` on an uninitialized + # object, and if this method is not defined, it results in a call to `__getattr__`. + # This fails, because `__getattr__` relies on `self._data_store`, which doesn't + # exist in an uninitialized object. + self.__dict__.update(state) class ModelSimple(OpenApiModel): """the parent class of models whose type != object in their @@ -1084,7 +1091,7 @@ def deserialize_file(response_data, configuration, content_disposition=None): (file_type): the deserialized file which is open The user is responsible for closing and reading the file """ - fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path) + fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path if configuration else None) os.close(fd) os.remove(path) @@ -1263,27 +1270,21 @@ def validate_and_convert_types(input_value, required_types_mixed, path_to_item, input_class_simple = get_simple_class(input_value) valid_type = is_valid_type(input_class_simple, valid_classes) if not valid_type: - if (configuration - or (input_class_simple == dict - and dict not in valid_classes)): - # if input_value is not valid_type try to convert it - converted_instance = attempt_convert_item( - input_value, - valid_classes, - path_to_item, - configuration, - spec_property_naming, - key_type=False, - must_convert=True, - check_type=_check_type - ) - return converted_instance - else: - raise get_type_error(input_value, path_to_item, valid_classes, - key_type=False) + # if input_value is not valid_type try to convert it + converted_instance = attempt_convert_item( + input_value, + valid_classes, + path_to_item, + configuration, + spec_property_naming, + key_type=False, + must_convert=True, + check_type=_check_type + ) + return converted_instance # input_value's type is in valid_classes - if len(valid_classes) > 1 and configuration: + if len(valid_classes) > 1: # there are valid classes which are not the current class valid_classes_coercible = remove_uncoercible( valid_classes, input_value, spec_property_naming, must_convert=False) diff --git a/cvat-sdk/gen/templates/openapi-generator/setup.mustache b/cvat-sdk/gen/templates/openapi-generator/setup.mustache index eb89f5d20554..e0379cabd06e 100644 --- a/cvat-sdk/gen/templates/openapi-generator/setup.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/setup.mustache @@ -77,7 +77,8 @@ setup( python_requires="{{{generatorLanguageVersion}}}", install_requires=BASE_REQUIREMENTS, extras_require={ - "pytorch": ['torch', 'torchvision'], + "masks": ["numpy>=2"], + "pytorch": ['torch', 'torchvision', 'scikit-image>=0.24', 'cvat_sdk[masks]'], }, package_dir={"": "."}, packages=find_packages(include=["cvat_sdk*"]), diff --git a/cvat-ui/package.json b/cvat-ui/package.json index a74485fa107d..703718121cd1 100644 --- a/cvat-ui/package.json +++ b/cvat-ui/package.json @@ -1,6 +1,6 @@ { "name": "cvat-ui", - "version": "1.66.4", + "version": "1.67.0", "description": "CVAT single-page application", "main": "src/index.tsx", "scripts": { diff --git a/cvat-ui/src/actions/annotation-actions.ts b/cvat-ui/src/actions/annotation-actions.ts index 0cc8f3052bc0..115470429990 100644 --- a/cvat-ui/src/actions/annotation-actions.ts +++ b/cvat-ui/src/actions/annotation-actions.ts @@ -1081,7 +1081,7 @@ export function finishCurrentJobAsync(): ThunkAction { export function rememberObject(createParams: { activeObjectType?: ObjectType; activeLabelID?: number; - activeShapeType?: ShapeType; + activeShapeType?: ShapeType | null; activeNumOfPoints?: number; activeRectDrawingMethod?: RectDrawingMethod; activeCuboidDrawingMethod?: CuboidDrawingMethod; diff --git a/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx b/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx index 27898da9fa2a..f33dd9bf231a 100644 --- a/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx +++ b/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx @@ -7,6 +7,7 @@ import './styles.scss'; import React, { useEffect, useReducer, useRef, useState, } from 'react'; +import { createRoot } from 'react-dom/client'; import Button from 'antd/lib/button'; import { Col, Row } from 'antd/lib/grid'; import Progress from 'antd/lib/progress'; @@ -22,28 +23,27 @@ import { useIsMounted } from 'utils/hooks'; import { createAction, ActionUnion } from 'utils/redux'; import { getCVATStore } from 'cvat-store'; import { - BaseSingleFrameAction, FrameSelectionType, Job, getCore, + BaseCollectionAction, BaseAction, Job, getCore, + ObjectState, } from 'cvat-core-wrapper'; import { Canvas } from 'cvat-canvas-wrapper'; -import { fetchAnnotationsAsync, saveAnnotationsAsync } from 'actions/annotation-actions'; -import { switchAutoSave } from 'actions/settings-actions'; +import { fetchAnnotationsAsync } from 'actions/annotation-actions'; import { clamp } from 'utils/math'; const core = getCore(); interface State { - actions: BaseSingleFrameAction[]; - activeAction: BaseSingleFrameAction | null; + actions: BaseAction[]; + activeAction: BaseAction | null; fetching: boolean; progress: number | null; progressMessage: string | null; cancelled: boolean; - autoSaveEnabled: boolean; - jobHasBeenSaved: boolean; frameFrom: number; frameTo: number; actionParameters: Record; modalVisible: boolean; + targetObjectState?: ObjectState | null; } enum ReducerActionType { @@ -53,8 +53,6 @@ enum ReducerActionType { RESET_BEFORE_RUN = 'RESET_BEFORE_RUN', RESET_AFTER_RUN = 'RESET_AFTER_RUN', CANCEL_ACTION = 'CANCEL_ACTION', - SET_AUTOSAVE_DISABLED_FLAG = 'SET_AUTOSAVE_DISABLED_FLAG', - SET_JOB_WAS_SAVED_FLAG = 'SET_JOB_WAS_SAVED_FLAG', UPDATE_FRAME_FROM = 'UPDATE_FRAME_FROM', UPDATE_FRAME_TO = 'UPDATE_FRAME_TO', UPDATE_ACTION_PARAMETER = 'UPDATE_ACTION_PARAMETER', @@ -62,10 +60,10 @@ enum ReducerActionType { } export const reducerActions = { - setAnnotationsActions: (actions: BaseSingleFrameAction[]) => ( + setAnnotationsActions: (actions: BaseAction[]) => ( createAction(ReducerActionType.SET_ANNOTATIONS_ACTIONS, { actions }) ), - setActiveAnnotationsAction: (activeAction: BaseSingleFrameAction) => ( + setActiveAnnotationsAction: (activeAction: BaseAction) => ( createAction(ReducerActionType.SET_ACTIVE_ANNOTATIONS_ACTION, { activeAction }) ), updateProgress: (progress: number | null, progressMessage: string | null) => ( @@ -80,12 +78,6 @@ export const reducerActions = { cancelAction: () => ( createAction(ReducerActionType.CANCEL_ACTION) ), - setAutoSaveDisabledFlag: () => ( - createAction(ReducerActionType.SET_AUTOSAVE_DISABLED_FLAG) - ), - setJobSavedFlag: (jobHasBeenSaved: boolean) => ( - createAction(ReducerActionType.SET_JOB_WAS_SAVED_FLAG, { jobHasBeenSaved }) - ), updateFrameFrom: (frameFrom: number) => ( createAction(ReducerActionType.UPDATE_FRAME_FROM, { frameFrom }) ), @@ -100,19 +92,54 @@ export const reducerActions = { ), }; +const KEEP_LATEST = 5; +let lastSelectedActions: [string, Record][] = []; +function updateLatestActions(name: string, parameters: Record = {}): void { + const idx = lastSelectedActions.findIndex((el) => el[0] === name); + if (idx === -1) { + lastSelectedActions = [[name, parameters], ...lastSelectedActions]; + } else { + lastSelectedActions = [ + [name, parameters], + ...lastSelectedActions.slice(0, idx), + ...lastSelectedActions.slice(idx + 1), + ]; + } + + lastSelectedActions = lastSelectedActions.slice(-KEEP_LATEST); +} + const reducer = (state: State, action: ActionUnion): State => { if (action.type === ReducerActionType.SET_ANNOTATIONS_ACTIONS) { + const { actions } = action.payload; + const list = state.targetObjectState ? actions + .filter((_action) => _action.isApplicableForObject(state.targetObjectState as ObjectState)) : actions; + + let activeAction = null; + let activeActionParameters = {}; + for (const item of lastSelectedActions) { + const [actionName, actionParameters] = item; + const candidate = list.find((el) => el.name === actionName); + if (candidate) { + activeAction = candidate; + activeActionParameters = actionParameters; + break; + } + } + return { ...state, - actions: action.payload.actions, - activeAction: action.payload.actions[0] || null, - actionParameters: {}, + actions: list, + activeAction: activeAction ?? list[0] ?? null, + actionParameters: activeActionParameters, }; } if (action.type === ReducerActionType.SET_ACTIVE_ANNOTATIONS_ACTION) { - const { frameSelection } = action.payload.activeAction; - if (frameSelection === FrameSelectionType.CURRENT_FRAME) { + const { activeAction } = action.payload; + updateLatestActions(activeAction.name, {}); + + if (action.payload.activeAction instanceof BaseCollectionAction) { const storage = getCVATStore(); const currentFrame = storage.getState().annotation.player.frame.number; return { @@ -123,6 +150,7 @@ const reducer = (state: State, action: ActionUnion): Stat actionParameters: {}, }; } + return { ...state, activeAction: action.payload.activeAction, @@ -163,20 +191,6 @@ const reducer = (state: State, action: ActionUnion): Stat }; } - if (action.type === ReducerActionType.SET_AUTOSAVE_DISABLED_FLAG) { - return { - ...state, - autoSaveEnabled: false, - }; - } - - if (action.type === ReducerActionType.SET_JOB_WAS_SAVED_FLAG) { - return { - ...state, - jobHasBeenSaved: action.payload.jobHasBeenSaved, - }; - } - if (action.type === ReducerActionType.UPDATE_FRAME_FROM) { return { ...state, @@ -194,12 +208,16 @@ const reducer = (state: State, action: ActionUnion): Stat } if (action.type === ReducerActionType.UPDATE_ACTION_PARAMETER) { + const updatedActionParameters = { + ...state.actionParameters, + [action.payload.name]: action.payload.value, + }; + + updateLatestActions((state.activeAction as BaseAction).name, updatedActionParameters); + return { ...state, - actionParameters: { - ...state.actionParameters, - [action.payload.name]: action.payload.value, - }, + actionParameters: updatedActionParameters, }; } @@ -213,9 +231,9 @@ const reducer = (state: State, action: ActionUnion): Stat return state; }; -type Props = NonNullable[keyof BaseSingleFrameAction['parameters']]; +type ActionParameterProps = NonNullable[keyof BaseAction['parameters']]; -function ActionParameterComponent(props: Props & { onChange: (value: string) => void }): JSX.Element { +function ActionParameterComponent(props: ActionParameterProps & { onChange: (value: string) => void }): JSX.Element { const { defaultValue, type, values, onChange, } = props; @@ -262,8 +280,13 @@ function ActionParameterComponent(props: Props & { onChange: (value: string) => ); } -function AnnotationsActionsModalContent(props: { onClose: () => void; }): JSX.Element { - const { onClose } = props; +interface Props { + onClose: () => void; + targetObjectState?: ObjectState; +} + +function AnnotationsActionsModalContent(props: Props): JSX.Element { + const { onClose, targetObjectState: defaultTargetObjectState } = props; const isMounted = useIsMounted(); const storage = getCVATStore(); const cancellationRef = useRef(false); @@ -276,29 +299,27 @@ function AnnotationsActionsModalContent(props: { onClose: () => void; }): JSX.El progress: null, progressMessage: null, cancelled: false, - autoSaveEnabled: storage.getState().settings.workspace.autoSave, - jobHasBeenSaved: true, frameFrom: jobInstance.startFrame, frameTo: jobInstance.stopFrame, actionParameters: {}, modalVisible: true, + targetObjectState: defaultTargetObjectState ?? null, }); useEffect(() => { - core.actions.list().then((list: BaseSingleFrameAction[]) => { + core.actions.list().then((list: BaseAction[]) => { if (isMounted()) { - dispatch(reducerActions.setJobSavedFlag(!jobInstance.annotations.hasUnsavedChanges())); dispatch(reducerActions.setAnnotationsActions(list)); } }); }, []); const { - actions, activeAction, fetching, autoSaveEnabled, jobHasBeenSaved, + actions, activeAction, fetching, targetObjectState, progress, progressMessage, frameFrom, frameTo, actionParameters, modalVisible, } = state; - const currentFrameAction = activeAction?.frameSelection === FrameSelectionType.CURRENT_FRAME; + const currentFrameAction = activeAction instanceof BaseCollectionAction || targetObjectState !== null; return ( void; }): JSX.El - Actions allow executing certain algorithms on - - - filtered - - - annotations. - It affects only the local browser state. - Once an action has finished, - it cannot be reverted. - You may reload the page to get annotations from the server. - It is strongly recommended to review the changes - before saving annotations to the server. - + targetObjectState ? ( + Selected action will be applied to the current object + ) : ( +
+ Actions allow executing certain algorithms on + + + filtered + + + annotations. +
+ ) )} type='info' showIcon /> - {!jobHasBeenSaved ? ( - - - Recommendation: - - - )} - type='warning' - showIcon - /> - - ) : null} - - {autoSaveEnabled ? ( - - - Recommendation: - - - )} - type='warning' - showIcon - /> - - ) : null} - - 1. Select action + Select action
@@ -406,7 +376,7 @@ function AnnotationsActionsModalContent(props: { onClose: () => void; }): JSX.El }} > {actions.map( - (annotationFunction: BaseSingleFrameAction): JSX.Element => ( + (annotationFunction: BaseAction): JSX.Element => ( void; }): JSX.El
- {activeAction ? ( + {activeAction && !currentFrameAction ? ( <> - 2. Specify frames to apply the action + Specify frames to apply the action
- { - currentFrameAction ? ( - Running the action is only allowed on current frame - ) : ( - <> - Starting from frame - { - if (typeof value === 'number') { - dispatch(reducerActions.updateFrameFrom( - clamp( - Math.round(value), - jobInstance.startFrame, - frameTo, - ), - )); - } - }} - /> - up to frame - { - if (typeof value === 'number') { - dispatch(reducerActions.updateFrameTo( - clamp( - Math.round(value), - frameFrom, - jobInstance.stopFrame, - ), - )); - } - }} - /> - - - ) - } + Starting from frame + { + if (typeof value === 'number') { + dispatch(reducerActions.updateFrameFrom( + clamp( + Math.round(value), + jobInstance.startFrame, + frameTo, + ), + )); + } + }} + /> + up to frame + { + if (typeof value === 'number') { + dispatch(reducerActions.updateFrameTo( + clamp( + Math.round(value), + frameFrom, + jobInstance.stopFrame, + ), + )); + } + }} + />
@@ -534,7 +495,7 @@ function AnnotationsActionsModalContent(props: { onClose: () => void; }): JSX.El - 3. Setup action parameters + Setup action parameters
{Object.entries(activeAction.parameters) @@ -545,7 +506,7 @@ function AnnotationsActionsModalContent(props: { onClose: () => void; }): JSX.El onChange={(value: string) => { dispatch(reducerActions.updateActionParameter(name, value)); }} - defaultValue={defaultValue} + defaultValue={actionParameters[name] ?? defaultValue} type={type} values={values} /> @@ -593,24 +554,40 @@ function AnnotationsActionsModalContent(props: { onClose: () => void; }): JSX.El if (activeAction) { cancellationRef.current = false; dispatch(reducerActions.resetBeforeRun()); + const updateProgressWrapper = (_message: string, _progress: number): void => { + if (isMounted()) { + dispatch(reducerActions.updateProgress(_progress, _message)); + } + }; - core.actions.run( + const actionPromise = targetObjectState ? core.actions.call( + jobInstance, + activeAction, + actionParameters, + storage.getState().annotation.player.frame.number, + [targetObjectState], + updateProgressWrapper, + () => cancellationRef.current, + ) : core.actions.run( jobInstance, - [activeAction], - [actionParameters], + activeAction, + actionParameters, frameFrom, frameTo, storage.getState().annotation.annotations.filters, - (_message: string, _progress: number) => { - if (isMounted()) { - dispatch(reducerActions.updateProgress(_progress, _message)); - } - }, + updateProgressWrapper, () => cancellationRef.current, - ).then(() => { + ); + + actionPromise.then(() => { if (!cancellationRef.current) { canvasInstance.setup(frameData, []); storage.dispatch(fetchAnnotationsAsync()); + if (isMounted()) { + if (targetObjectState !== null) { + onClose(); + } + } } }).finally(() => { if (isMounted()) { @@ -634,4 +611,19 @@ function AnnotationsActionsModalContent(props: { onClose: () => void; }): JSX.El ); } -export default React.memo(AnnotationsActionsModalContent); +const MemoizedAnnotationsActionsModalContent = React.memo(AnnotationsActionsModalContent); + +export function openAnnotationsActionModal(objectState?: ObjectState): void { + const div = window.document.createElement('div'); + window.document.body.append(div); + const root = createRoot(div); + root.render( + { + root.unmount(); + div.remove(); + }} + />, + ); +} diff --git a/cvat-ui/src/components/annotation-page/annotations-actions/styles.scss b/cvat-ui/src/components/annotation-page/annotations-actions/styles.scss index 787d5685ff37..b7eae1e50242 100644 --- a/cvat-ui/src/components/annotation-page/annotations-actions/styles.scss +++ b/cvat-ui/src/components/annotation-page/annotations-actions/styles.scss @@ -15,10 +15,6 @@ margin-top: $grid-unit-size * 2; } -.cvat-action-runner-info:not(:first-child) { - margin-top: $grid-unit-size * 2; -} - .cvat-action-runner-info { .ant-alert { text-align: justify; diff --git a/cvat-ui/src/components/annotation-page/canvas/views/canvas2d/canvas-wrapper.tsx b/cvat-ui/src/components/annotation-page/canvas/views/canvas2d/canvas-wrapper.tsx index df98f6b5c4cc..322a345efea7 100644 --- a/cvat-ui/src/components/annotation-page/canvas/views/canvas2d/canvas-wrapper.tsx +++ b/cvat-ui/src/components/annotation-page/canvas/views/canvas2d/canvas-wrapper.tsx @@ -653,7 +653,8 @@ class CanvasWrapperComponent extends React.PureComponent { const { state, duration } = event.detail; const isDrawnFromScratch = !state.label; - state.objectType = state.objectType || activeObjectType; + state.objectType = state.shapeType === ShapeType.MASK ? + ObjectType.SHAPE : state.objectType ?? activeObjectType; state.label = state.label || jobInstance.labels.filter((label: any) => label.id === activeLabelID)[0]; state.frame = frame; state.rotation = state.rotation || 0; diff --git a/cvat-ui/src/components/annotation-page/single-shape-workspace/single-shape-sidebar/single-shape-sidebar.tsx b/cvat-ui/src/components/annotation-page/single-shape-workspace/single-shape-sidebar/single-shape-sidebar.tsx index b7f7d60097af..fb2dae58e154 100644 --- a/cvat-ui/src/components/annotation-page/single-shape-workspace/single-shape-sidebar/single-shape-sidebar.tsx +++ b/cvat-ui/src/components/annotation-page/single-shape-workspace/single-shape-sidebar/single-shape-sidebar.tsx @@ -21,6 +21,7 @@ import message from 'antd/lib/message'; import { ActiveControl, CombinedState, NavigationType, ObjectType, } from 'reducers'; +import { labelShapeType } from 'reducers/annotation-reducer'; import { Canvas, CanvasMode } from 'cvat-canvas-wrapper'; import { Job, Label, LabelType, ShapeType, @@ -259,6 +260,7 @@ function SingleShapeSidebar(): JSX.Element { appDispatch(rememberObject({ activeObjectType: ObjectType.SHAPE, activeLabelID: state.label.id, + activeShapeType: labelShapeType(state.label), })); canvas.draw({ diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx index a31307277e68..dc73360d1f1d 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/controls-side-bar/tools-control.tsx @@ -1254,8 +1254,8 @@ export class ToolsControlComponent extends React.PureComponent { try { this.setState({ mode: 'detection', fetching: true }); - // The function call endpoint doesn't support the cleanup and convMaskToPoly parameters. - const { cleanup, convMaskToPoly, ...restOfBody } = body; + // The function call endpoint doesn't support the cleanup and conv_mask_to_poly parameters. + const { cleanup, conv_mask_to_poly: convMaskToPoly, ...restOfBody } = body; const result = await core.lambda.call(jobInstance.taskId, model, { ...restOfBody, frame, job: jobInstance.id, diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item-basics.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item-basics.tsx index aee51de644c0..078da4b82669 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item-basics.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item-basics.tsx @@ -37,6 +37,7 @@ interface Props { toForegroundShortcut: string; removeShortcut: string; sliceShortcut: string; + runAnnotationsActionShortcut: string; changeColor(color: string): void; changeLabel(label: any): void; copy(): void; @@ -47,6 +48,7 @@ interface Props { toBackground(): void; toForeground(): void; resetCuboidPerspective(): void; + runAnnotationAction(): void; edit(): void; slice(): void; } @@ -72,6 +74,7 @@ function ItemTopComponent(props: Props): JSX.Element { toForegroundShortcut, removeShortcut, sliceShortcut, + runAnnotationsActionShortcut, isGroundTruth, changeColor, changeLabel, @@ -83,6 +86,7 @@ function ItemTopComponent(props: Props): JSX.Element { toBackground, toForeground, resetCuboidPerspective, + runAnnotationAction, edit, slice, jobInstance, @@ -154,6 +158,7 @@ function ItemTopComponent(props: Props): JSX.Element { toForegroundShortcut, removeShortcut, sliceShortcut, + runAnnotationsActionShortcut, changeColor, copy, remove, @@ -166,6 +171,7 @@ function ItemTopComponent(props: Props): JSX.Element { setColorPickerVisible, edit, slice, + runAnnotationAction, })} > diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item-menu.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item-menu.tsx index 30b239d8187a..3a18f035f4a6 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item-menu.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item-menu.tsx @@ -8,6 +8,7 @@ import Button from 'antd/lib/button'; import { MenuProps } from 'antd/lib/menu'; import Icon, { LinkOutlined, CopyOutlined, BlockOutlined, RetweetOutlined, DeleteOutlined, EditOutlined, + FunctionOutlined, } from '@ant-design/icons'; import { @@ -34,6 +35,7 @@ interface Props { toBackgroundShortcut: string; toForegroundShortcut: string; removeShortcut: string; + runAnnotationsActionShortcut: string; changeColor(value: string): void; copy(): void; remove(): void; @@ -46,6 +48,7 @@ interface Props { setColorPickerVisible(visible: boolean): void; edit(): void; slice(): void; + runAnnotationAction(): void; jobInstance: Job; } @@ -232,6 +235,23 @@ function RemoveItem(props: ItemProps): JSX.Element { ); } +function RunAnnotationActionItem(props: ItemProps): JSX.Element { + const { toolProps } = props; + const { runAnnotationsActionShortcut, runAnnotationAction } = toolProps; + return ( + + + + ); +} + export default function ItemMenu(props: Props): MenuProps { const { readonly, shapeType, objectType, colorBy, jobInstance, @@ -249,6 +269,7 @@ export default function ItemMenu(props: Props): MenuProps { REMOVE_ITEM = 'remove_item', EDIT_MASK = 'edit_mask', SLICE_ITEM = 'slice_item', + RUN_ANNOTATION_ACTION = 'run_annotation_action', } const is2D = jobInstance.dimension === DimensionType.DIMENSION_2D; @@ -326,6 +347,13 @@ export default function ItemMenu(props: Props): MenuProps { }); } + if (!readonly) { + items.push({ + key: MenuKeys.RUN_ANNOTATION_ACTION, + label: , + }); + } + return { items, selectable: false, diff --git a/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item.tsx b/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item.tsx index 30811abad1cd..7ae46a7a71a3 100644 --- a/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item.tsx +++ b/cvat-ui/src/components/annotation-page/standard-workspace/objects-side-bar/object-item.tsx @@ -41,6 +41,7 @@ interface Props { changeLabel(label: any): void; changeColor(color: string): void; resetCuboidPerspective(): void; + runAnnotationAction(): void; edit(): void; slice(): void; } @@ -73,6 +74,7 @@ function ObjectItemComponent(props: Props): JSX.Element { changeLabel, changeColor, resetCuboidPerspective, + runAnnotationAction, edit, slice, jobInstance, @@ -121,6 +123,7 @@ function ObjectItemComponent(props: Props): JSX.Element { removeShortcut={normalizedKeyMap.DELETE_OBJECT_STANDARD_WORKSPACE} changeColorShortcut={normalizedKeyMap.CHANGE_OBJECT_COLOR} sliceShortcut={normalizedKeyMap.SWITCH_SLICE_MODE} + runAnnotationsActionShortcut={normalizedKeyMap.RUN_ANNOTATIONS_ACTION} changeLabel={changeLabel} changeColor={changeColor} copy={copy} @@ -133,6 +136,7 @@ function ObjectItemComponent(props: Props): JSX.Element { resetCuboidPerspective={resetCuboidPerspective} edit={edit} slice={slice} + runAnnotationAction={runAnnotationAction} /> {!!attributes.length && ( diff --git a/cvat-ui/src/components/annotation-page/top-bar/annotation-menu.tsx b/cvat-ui/src/components/annotation-page/top-bar/annotation-menu.tsx index f845b30233df..522f5f978b74 100644 --- a/cvat-ui/src/components/annotation-page/top-bar/annotation-menu.tsx +++ b/cvat-ui/src/components/annotation-page/top-bar/annotation-menu.tsx @@ -6,7 +6,6 @@ import React, { useCallback, useState } from 'react'; import { useSelector, useDispatch } from 'react-redux'; import { useHistory } from 'react-router'; -import { createRoot } from 'react-dom/client'; import Modal from 'antd/lib/modal'; import Text from 'antd/lib/typography/Text'; import InputNumber from 'antd/lib/input-number'; @@ -22,7 +21,7 @@ import { MainMenuIcon } from 'icons'; import { Job, JobState } from 'cvat-core-wrapper'; import CVATTooltip from 'components/common/cvat-tooltip'; -import AnnotationsActionsModalContent from 'components/annotation-page/annotations-actions/annotations-actions-modal'; +import { openAnnotationsActionModal } from 'components/annotation-page/annotations-actions/annotations-actions-modal'; import { CombinedState } from 'reducers'; import { updateCurrentJobAsync, finishCurrentJobAsync, @@ -179,17 +178,7 @@ function AnnotationMenuComponent(): JSX.Element { key: Actions.RUN_ACTIONS, label: 'Run actions', onClick: () => { - const div = window.document.createElement('div'); - window.document.body.append(div); - const root = createRoot(div); - root.render( - { - root.unmount(); - div.remove(); - }} - />, - ); + openAnnotationsActionModal(); }, }); diff --git a/cvat-ui/src/components/model-runner-modal/detector-runner.tsx b/cvat-ui/src/components/model-runner-modal/detector-runner.tsx index d6c92826b662..ab3393b2c290 100644 --- a/cvat-ui/src/components/model-runner-modal/detector-runner.tsx +++ b/cvat-ui/src/components/model-runner-modal/detector-runner.tsx @@ -40,7 +40,7 @@ type ServerMapping = Record { } }; + private runAnnotationAction = (): void => { + const { objectState } = this.props; + openAnnotationsActionModal(objectState); + }; + private commit(): void { const { objectState, readonly, updateState } = this.props; if (!readonly) { @@ -426,6 +432,7 @@ class ObjectItemContainer extends React.PureComponent { edit={this.edit} slice={this.slice} resetCuboidPerspective={this.resetCuboidPerspective} + runAnnotationAction={this.runAnnotationAction} /> ); } diff --git a/cvat-ui/src/containers/annotation-page/standard-workspace/objects-side-bar/objects-list.tsx b/cvat-ui/src/containers/annotation-page/standard-workspace/objects-side-bar/objects-list.tsx index 16ccdc08bff7..5df7b556ff34 100644 --- a/cvat-ui/src/containers/annotation-page/standard-workspace/objects-side-bar/objects-list.tsx +++ b/cvat-ui/src/containers/annotation-page/standard-workspace/objects-side-bar/objects-list.tsx @@ -34,6 +34,7 @@ import { filterAnnotations } from 'utils/filter-annotations'; import { registerComponentShortcuts } from 'actions/shortcuts-actions'; import { ShortcutScope } from 'utils/enums'; import { subKeyMap } from 'utils/component-subkeymap'; +import { openAnnotationsActionModal } from 'components/annotation-page/annotations-actions/annotations-actions-modal'; interface OwnProps { readonly: boolean; @@ -148,6 +149,12 @@ const componentShortcuts = { sequences: ['ctrl+c'], scope: ShortcutScope.OBJECTS_SIDEBAR, }, + RUN_ANNOTATIONS_ACTION: { + name: 'Run annotations action', + description: 'Opens a dialog with annotations actions', + sequences: ['ctrl+e'], + scope: ShortcutScope.OBJECTS_SIDEBAR, + }, PROPAGATE_OBJECT: { name: 'Propagate object', description: 'Make a copy of the object on the following frames', @@ -588,6 +595,16 @@ class ObjectsListContainer extends React.PureComponent { copyShape(state); } }, + RUN_ANNOTATIONS_ACTION: () => { + const state = activatedState(true); + if (!readonly) { + if (state) { + openAnnotationsActionModal(state); + } else { + openAnnotationsActionModal(); + } + } + }, PROPAGATE_OBJECT: (event: KeyboardEvent | undefined) => { preventDefault(event); const state = activatedState(); diff --git a/cvat-ui/src/cvat-core-wrapper.ts b/cvat-ui/src/cvat-core-wrapper.ts index 52f71d6044bc..ba7b47fcfa54 100644 --- a/cvat-ui/src/cvat-core-wrapper.ts +++ b/cvat-ui/src/cvat-core-wrapper.ts @@ -26,8 +26,8 @@ import QualitySettings, { TargetMetric } from 'cvat-core/src/quality-settings'; import { FramesMetaData, FrameData } from 'cvat-core/src/frames'; import { ServerError, RequestError } from 'cvat-core/src/exceptions'; import { - ShapeType, LabelType, ModelKind, ModelProviders, - ModelReturnType, DimensionType, JobType, + ShapeType, ObjectType, LabelType, ModelKind, ModelProviders, + ModelReturnType, DimensionType, JobType, Source, JobStage, JobState, RQStatus, StorageLocation, } from 'cvat-core/src/enums'; import { Storage, StorageData } from 'cvat-core/src/storage'; @@ -41,7 +41,9 @@ import AnalyticsReport, { AnalyticsEntryViewType, AnalyticsEntry } from 'cvat-co import { Dumper } from 'cvat-core/src/annotation-formats'; import { Event } from 'cvat-core/src/event'; import { APIWrapperEnterOptions } from 'cvat-core/src/plugins'; -import BaseSingleFrameAction, { ActionParameterType, FrameSelectionType } from 'cvat-core/src/annotations-actions'; +import { BaseShapesAction } from 'cvat-core/src/annotations-actions/base-shapes-action'; +import { BaseCollectionAction } from 'cvat-core/src/annotations-actions/base-collection-action'; +import { ActionParameterType, BaseAction } from 'cvat-core/src/annotations-actions/base-action'; import { Request, RequestOperation } from 'cvat-core/src/request'; const cvat: CVATCore = _cvat; @@ -69,6 +71,8 @@ export { AnnotationGuide, Attribute, ShapeType, + Source, + ObjectType, LabelType, Storage, Webhook, @@ -89,7 +93,9 @@ export { JobStage, JobState, RQStatus, - BaseSingleFrameAction, + BaseAction, + BaseShapesAction, + BaseCollectionAction, QualityReport, QualityConflict, QualitySettings, @@ -105,7 +111,6 @@ export { Event, FrameData, ActionParameterType, - FrameSelectionType, Request, JobValidationLayout, TaskValidationLayout, diff --git a/cvat-ui/src/reducers/annotation-reducer.ts b/cvat-ui/src/reducers/annotation-reducer.ts index c21aff497548..311f54c0fe96 100644 --- a/cvat-ui/src/reducers/annotation-reducer.ts +++ b/cvat-ui/src/reducers/annotation-reducer.ts @@ -10,7 +10,9 @@ import { AuthActionTypes } from 'actions/auth-actions'; import { BoundariesActionTypes } from 'actions/boundaries-actions'; import { Canvas, CanvasMode } from 'cvat-canvas-wrapper'; import { Canvas3d } from 'cvat-canvas3d-wrapper'; -import { DimensionType, JobStage, LabelType } from 'cvat-core-wrapper'; +import { + DimensionType, JobStage, Label, LabelType, +} from 'cvat-core-wrapper'; import { clamp } from 'utils/math'; import { @@ -29,6 +31,16 @@ function updateActivatedStateID(newStates: any[], prevActivatedStateID: number | null; } +export function labelShapeType(label?: Label): ShapeType | null { + if (label && Object.values(ShapeType).includes(label.type as any)) { + return label.type as unknown as ShapeType; + } + if (label?.type === LabelType.TAG) { + return null; + } + return ShapeType.RECTANGLE; +} + const defaultState: AnnotationState = { activities: { loads: {}, @@ -183,12 +195,11 @@ export default (state = defaultState, action: AnyAction): AnnotationState => { const isReview = job.stage === JobStage.VALIDATION; let workspaceSelected = null; let activeObjectType; - let activeShapeType; + let activeShapeType = null; if (defaultLabel?.type === LabelType.TAG) { activeObjectType = ObjectType.TAG; } else { - activeShapeType = defaultLabel && defaultLabel.type !== 'any' ? - defaultLabel.type : ShapeType.RECTANGLE; + activeShapeType = labelShapeType(defaultLabel); activeObjectType = job.mode === 'interpolation' ? ObjectType.TRACK : ObjectType.SHAPE; } @@ -235,6 +246,10 @@ export default (state = defaultState, action: AnyAction): AnnotationState => { annotations: { ...state.annotations, filters, + zLayer: { + ...state.annotations.zLayer, + cur: Number.MAX_SAFE_INTEGER, + }, }, player: { ...state.player, diff --git a/cvat-ui/src/reducers/index.ts b/cvat-ui/src/reducers/index.ts index 14196846a393..337ef29927b2 100644 --- a/cvat-ui/src/reducers/index.ts +++ b/cvat-ui/src/reducers/index.ts @@ -769,7 +769,7 @@ export interface AnnotationState { drawing: { activeInteractor?: MLModel | OpenCVTool; activeInteractorParameters?: MLModel['params']['canvas']; - activeShapeType: ShapeType; + activeShapeType: ShapeType | null; activeRectDrawingMethod?: RectDrawingMethod; activeCuboidDrawingMethod?: CuboidDrawingMethod; activeNumOfPoints?: number; diff --git a/cvat-ui/src/reducers/settings-reducer.ts b/cvat-ui/src/reducers/settings-reducer.ts index 0c662908d767..2a9e5ca79db7 100644 --- a/cvat-ui/src/reducers/settings-reducer.ts +++ b/cvat-ui/src/reducers/settings-reducer.ts @@ -444,8 +444,11 @@ export default (state = defaultState, action: AnyAction): SettingsState => { return { ...state, - imageFilters: filters, + shapes: { + ...state.shapes, + showGroundTruth: false, + }, }; } case AnnotationActionTypes.INTERACT_WITH_CANVAS: { diff --git a/cvat/__init__.py b/cvat/__init__.py index 10ef426963f5..7474e260b0e1 100644 --- a/cvat/__init__.py +++ b/cvat/__init__.py @@ -4,6 +4,6 @@ from cvat.utils.version import get_version -VERSION = (2, 22, 0, "final", 0) +VERSION = (2, 23, 0, "final", 0) __version__ = get_version(VERSION) diff --git a/cvat/apps/dataset_manager/annotation.py b/cvat/apps/dataset_manager/annotation.py index 3971d5536919..0f0dd2a329b5 100644 --- a/cvat/apps/dataset_manager/annotation.py +++ b/cvat/apps/dataset_manager/annotation.py @@ -6,7 +6,8 @@ from copy import copy, deepcopy import math -from typing import Container, Optional, Sequence +from collections.abc import Container, Sequence +from typing import Optional import numpy as np from itertools import chain from scipy.optimize import linear_sum_assignment diff --git a/cvat/apps/dataset_manager/bindings.py b/cvat/apps/dataset_manager/bindings.py index 1c70520a7090..9b01dced2a94 100644 --- a/cvat/apps/dataset_manager/bindings.py +++ b/cvat/apps/dataset_manager/bindings.py @@ -8,21 +8,19 @@ import os.path as osp import re import sys +from collections import OrderedDict, defaultdict +from collections.abc import Iterable, Iterator, Mapping, Sequence from functools import reduce from operator import add from pathlib import Path from types import SimpleNamespace -from typing import (Any, Callable, DefaultDict, Dict, Iterable, Iterator, List, Literal, Mapping, - NamedTuple, Optional, OrderedDict, Sequence, Set, Tuple, Union) +from typing import Any, Callable, Literal, NamedTuple, Optional, Union from attrs.converters import to_bool import datumaro as dm import defusedxml.ElementTree as ET import rq from attr import attrib, attrs -from datumaro.components.media import PointCloud -from datumaro.components.environment import Environment -from datumaro.components.extractor import Importer from datumaro.components.format_detection import RejectionReason from django.db.models import QuerySet from django.utils import timezone @@ -280,12 +278,12 @@ def __init__(self, self._create_callback = create_callback self._MAX_ANNO_SIZE = 30000 self._frame_info = {} - self._frame_mapping: Dict[str, int] = {} + self._frame_mapping: dict[str, int] = {} self._frame_step = db_task.data.get_frame_step() self._db_data: models.Data = db_task.data self._use_server_track_ids = use_server_track_ids self._required_frames = included_frames - self._initialized_included_frames: Optional[Set[int]] = None + self._initialized_included_frames: Optional[set[int]] = None self._db_subset = db_task.subset super().__init__(db_task) @@ -963,9 +961,9 @@ class LabeledShape: type: str = attrib() frame: int = attrib() label: str = attrib() - points: List[float] = attrib() + points: list[float] = attrib() occluded: bool = attrib() - attributes: List[InstanceLabelData.Attribute] = attrib() + attributes: list[InstanceLabelData.Attribute] = attrib() source: str = attrib(default='manual') group: int = attrib(default=0) rotation: int = attrib(default=0) @@ -973,40 +971,40 @@ class LabeledShape: task_id: int = attrib(default=None) subset: str = attrib(default=None) outside: bool = attrib(default=False) - elements: List['ProjectData.LabeledShape'] = attrib(default=[]) + elements: list['ProjectData.LabeledShape'] = attrib(default=[]) @attrs class TrackedShape: type: str = attrib() frame: int = attrib() - points: List[float] = attrib() + points: list[float] = attrib() occluded: bool = attrib() outside: bool = attrib() keyframe: bool = attrib() - attributes: List[InstanceLabelData.Attribute] = attrib() + attributes: list[InstanceLabelData.Attribute] = attrib() rotation: int = attrib(default=0) source: str = attrib(default='manual') group: int = attrib(default=0) z_order: int = attrib(default=0) label: str = attrib(default=None) track_id: int = attrib(default=0) - elements: List['ProjectData.TrackedShape'] = attrib(default=[]) + elements: list['ProjectData.TrackedShape'] = attrib(default=[]) @attrs class Track: label: str = attrib() - shapes: List['ProjectData.TrackedShape'] = attrib() + shapes: list['ProjectData.TrackedShape'] = attrib() source: str = attrib(default='manual') group: int = attrib(default=0) task_id: int = attrib(default=None) subset: str = attrib(default=None) - elements: List['ProjectData.Track'] = attrib(default=[]) + elements: list['ProjectData.Track'] = attrib(default=[]) @attrs class Tag: frame: int = attrib() label: str = attrib() - attributes: List[InstanceLabelData.Attribute] = attrib() + attributes: list[InstanceLabelData.Attribute] = attrib() source: str = attrib(default='manual') group: int = attrib(default=0) task_id: int = attrib(default=None) @@ -1020,8 +1018,8 @@ class Frame: name: str = attrib() width: int = attrib() height: int = attrib() - labeled_shapes: List[Union['ProjectData.LabeledShape', 'ProjectData.TrackedShape']] = attrib() - tags: List['ProjectData.Tag'] = attrib() + labeled_shapes: list[Union['ProjectData.LabeledShape', 'ProjectData.TrackedShape']] = attrib() + tags: list['ProjectData.Tag'] = attrib() task_id: int = attrib(default=None) subset: str = attrib(default=None) @@ -1040,12 +1038,12 @@ def __init__(self, self._host = host self._soft_attribute_import = False self._project_annotation = project_annotation - self._tasks_data: Dict[int, TaskData] = {} - self._frame_info: Dict[Tuple[int, int], Literal["path", "width", "height", "subset"]] = dict() + self._tasks_data: dict[int, TaskData] = {} + self._frame_info: dict[tuple[int, int], Literal["path", "width", "height", "subset"]] = dict() # (subset, path): (task id, frame number) - self._frame_mapping: Dict[Tuple[str, str], Tuple[int, int]] = dict() - self._frame_steps: Dict[int, int] = {} - self.new_tasks: Set[int] = set() + self._frame_mapping: dict[tuple[str, str], tuple[int, int]] = dict() + self._frame_steps: dict[int, int] = {} + self.new_tasks: set[int] = set() self._use_server_track_ids = use_server_track_ids InstanceLabelData.__init__(self, db_project) @@ -1083,12 +1081,12 @@ def _init_tasks(self): subsets = set() for task in self._db_tasks.values(): subsets.add(task.subset) - self._subsets: List[str] = list(subsets) + self._subsets: list[str] = list(subsets) - self._frame_steps: Dict[int, int] = {task.id: task.data.get_frame_step() for task in self._db_tasks.values()} + self._frame_steps: dict[int, int] = {task.id: task.data.get_frame_step() for task in self._db_tasks.values()} def _init_task_frame_offsets(self): - self._task_frame_offsets: Dict[int, int] = dict() + self._task_frame_offsets: dict[int, int] = dict() s = 0 subset = None @@ -1103,7 +1101,7 @@ def _init_task_frame_offsets(self): def _init_frame_info(self): self._frame_info = dict() self._deleted_frames = { (task.id, frame): True for task in self._db_tasks.values() for frame in task.data.deleted_frames } - original_names = DefaultDict[Tuple[str, str], int](int) + original_names = defaultdict[tuple[str, str], int](int) for task in self._db_tasks.values(): defaulted_subset = get_defaulted_subset(task.subset, self._subsets) if hasattr(task.data, 'video'): @@ -1257,7 +1255,7 @@ def _export_track(self, track: dict, task_id: int, task_size: int, idx: int): ) def group_by_frame(self, include_empty: bool = False): - frames: Dict[Tuple[str, int], ProjectData.Frame] = {} + frames: dict[tuple[str, int], ProjectData.Frame] = {} def get_frame(task_id: int, idx: int) -> ProjectData.Frame: frame_info = self._frame_info[(task_id, idx)] abs_frame = self.abs_frame_id(task_id, idx) @@ -1368,7 +1366,7 @@ def db_project(self): return self._db_project @property - def subsets(self) -> List[str]: + def subsets(self) -> list[str]: return self._subsets @property @@ -1450,7 +1448,7 @@ def split_dataset(self, dataset: dm.Dataset): subset_dataset: dm.Dataset = dataset.subsets()[task_data.db_instance.subset].as_dataset() yield subset_dataset, task_data - def add_labels(self, labels: List[dict]): + def add_labels(self, labels: list[dict]): attributes = [] _labels = [] for label in labels: @@ -1463,19 +1461,22 @@ def add_task(self, task, files): self._project_annotation.add_task(task, files, self) @attrs(frozen=True, auto_attribs=True) -class ImageSource: +class MediaSource: db_task: Task - is_video: bool = attrib(kw_only=True) -class ImageProvider: - def __init__(self, sources: Dict[int, ImageSource]) -> None: + @property + def is_video(self) -> bool: + return self.db_task.mode == 'interpolation' + +class MediaProvider: + def __init__(self, sources: dict[int, MediaSource]) -> None: self._sources = sources def unload(self) -> None: pass -class ImageProvider2D(ImageProvider): - def __init__(self, sources: Dict[int, ImageSource]) -> None: +class MediaProvider2D(MediaProvider): + def __init__(self, sources: dict[int, MediaSource]) -> None: super().__init__(sources) self._current_source_id = None self._frame_provider = None @@ -1483,7 +1484,7 @@ def __init__(self, sources: Dict[int, ImageSource]) -> None: def unload(self) -> None: self._unload_source() - def get_image_for_frame(self, source_id: int, frame_index: int, **image_kwargs): + def get_media_for_frame(self, source_id: int, frame_index: int, **image_kwargs) -> dm.Image: source = self._sources[source_id] if source.is_video: @@ -1510,7 +1511,7 @@ def image_loader(_): return dm.ByteImage(data=image_loader, **image_kwargs) - def _load_source(self, source_id: int, source: ImageSource) -> None: + def _load_source(self, source_id: int, source: MediaSource) -> None: if self._current_source_id == source_id: return @@ -1525,8 +1526,8 @@ def _unload_source(self) -> None: self._current_source_id = None -class ImageProvider3D(ImageProvider): - def __init__(self, sources: Dict[int, ImageSource]) -> None: +class MediaProvider3D(MediaProvider): + def __init__(self, sources: dict[int, MediaSource]) -> None: super().__init__(sources) self._images_per_source = { source_id: { @@ -1536,7 +1537,7 @@ def __init__(self, sources: Dict[int, ImageSource]) -> None: for source_id, source in sources.items() } - def get_image_for_frame(self, source_id: int, frame_id: int, **image_kwargs): + def get_media_for_frame(self, source_id: int, frame_id: int, **image_kwargs) -> dm.PointCloud: source = self._sources[source_id] point_cloud_path = osp.join( @@ -1546,17 +1547,17 @@ def get_image_for_frame(self, source_id: int, frame_id: int, **image_kwargs): image = self._images_per_source[source_id][frame_id] related_images = [ - path + dm.Image(path=path) for rf in image.related_files.all() for path in [osp.realpath(str(rf.path))] if osp.isfile(path) ] - return point_cloud_path, related_images + return dm.PointCloud(point_cloud_path, extra_images=related_images) -IMAGE_PROVIDERS_BY_DIMENSION = { - DimensionType.DIM_3D: ImageProvider3D, - DimensionType.DIM_2D: ImageProvider2D, +MEDIA_PROVIDERS_BY_DIMENSION: dict[DimensionType, MediaProvider] = { + DimensionType.DIM_3D: MediaProvider3D, + DimensionType.DIM_2D: MediaProvider2D, } class CVATDataExtractorMixin: @@ -1565,21 +1566,21 @@ def __init__(self, *, ): self.convert_annotations = convert_annotations or convert_cvat_anno_to_dm - self._image_provider: Optional[ImageProvider] = None + self._media_provider: Optional[MediaProvider] = None def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback) -> None: - if self._image_provider: - self._image_provider.unload() + if self._media_provider: + self._media_provider.unload() def categories(self) -> dict: raise NotImplementedError() @staticmethod def _load_categories(labels: list): - categories: Dict[dm.AnnotationType, + categories: dict[dm.AnnotationType, dm.Categories] = {} label_categories = dm.LabelCategories(attributes=['occluded']) @@ -1639,7 +1640,7 @@ def __init__( instance_meta = instance_data.meta[instance_data.META_FIELD] dm.SourceExtractor.__init__( self, - media_type=dm.Image if dimension == DimensionType.DIM_2D else PointCloud, + media_type=dm.Image if dimension == DimensionType.DIM_2D else dm.PointCloud, subset=instance_meta['subset'], ) CVATDataExtractorMixin.__init__(self, **kwargs) @@ -1648,7 +1649,6 @@ def __init__( self._user = self._load_user_info(instance_meta) if dimension == DimensionType.DIM_3D else {} self._dimension = dimension self._format_type = format_type - dm_items = [] is_video = instance_meta['mode'] == 'interpolation' ext = '' @@ -1663,46 +1663,61 @@ def __init__( else: assert False - self._image_provider = IMAGE_PROVIDERS_BY_DIMENSION[dimension]( - {0: ImageSource(db_task, is_video=is_video)} + self._media_provider = MEDIA_PROVIDERS_BY_DIMENSION[dimension]( + {0: MediaSource(db_task)} ) + dm_items: list[dm.DatasetItem] = [] for frame_data in instance_data.group_by_frame(include_empty=True): - image_args = { - 'path': frame_data.name + ext, - 'size': (frame_data.height, frame_data.width), - } - + dm_media_args = { 'path': frame_data.name + ext } if dimension == DimensionType.DIM_3D: - dm_image = self._image_provider.get_image_for_frame(0, frame_data.id, **image_args) - elif include_images: - dm_image = self._image_provider.get_image_for_frame(0, frame_data.idx, **image_args) + dm_media: dm.PointCloud = self._media_provider.get_media_for_frame( + 0, frame_data.id, **dm_media_args + ) + + if not include_images: + dm_media_args["extra_images"] = [ + dm.Image(path=osp.basename(image.path)) + for image in dm_media.extra_images + ] + dm_media = dm.PointCloud(**dm_media_args) else: - dm_image = dm.Image(**image_args) + dm_media_args['size'] = (frame_data.height, frame_data.width) + if include_images: + dm_media: dm.Image = self._media_provider.get_media_for_frame( + 0, frame_data.idx, **dm_media_args + ) + else: + dm_media = dm.Image(**dm_media_args) + dm_anno = self._read_cvat_anno(frame_data, instance_meta['labels']) + dm_attributes = {'frame': frame_data.frame} + if dimension == DimensionType.DIM_2D: dm_item = dm.DatasetItem( - id=osp.splitext(frame_data.name)[0], - annotations=dm_anno, media=dm_image, - subset=frame_data.subset, - attributes={'frame': frame_data.frame - }) + id=osp.splitext(frame_data.name)[0], + subset=frame_data.subset, + annotations=dm_anno, + media=dm_media, + attributes=dm_attributes, + ) elif dimension == DimensionType.DIM_3D: - attributes = {'frame': frame_data.frame} if format_type == "sly_pointcloud": - attributes["name"] = self._user["name"] - attributes["createdAt"] = self._user["createdAt"] - attributes["updatedAt"] = self._user["updatedAt"] - attributes["labels"] = [] + dm_attributes["name"] = self._user["name"] + dm_attributes["createdAt"] = self._user["createdAt"] + dm_attributes["updatedAt"] = self._user["updatedAt"] + dm_attributes["labels"] = [] for (idx, (_, label)) in enumerate(instance_meta['labels']): - attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]}) - attributes["track_id"] = -1 + dm_attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]}) + dm_attributes["track_id"] = -1 dm_item = dm.DatasetItem( id=osp.splitext(osp.split(frame_data.name)[-1])[0], - annotations=dm_anno, media=PointCloud(dm_image[0]), related_images=dm_image[1], - attributes=attributes, subset=frame_data.subset, + subset=frame_data.subset, + annotations=dm_anno, + media=dm_media, + attributes=dm_attributes, ) dm_items.append(dm_item) @@ -1732,7 +1747,7 @@ def __init__( **kwargs ): dm.Extractor.__init__( - self, media_type=dm.Image if dimension == DimensionType.DIM_2D else PointCloud + self, media_type=dm.Image if dimension == DimensionType.DIM_2D else dm.PointCloud ) CVATDataExtractorMixin.__init__(self, **kwargs) @@ -1741,59 +1756,71 @@ def __init__( self._dimension = dimension self._format_type = format_type - dm_items: List[dm.DatasetItem] = [] - if self._dimension == DimensionType.DIM_3D or include_images: - self._image_provider = IMAGE_PROVIDERS_BY_DIMENSION[self._dimension]( + self._media_provider = MEDIA_PROVIDERS_BY_DIMENSION[self._dimension]( { - task.id: ImageSource(task, is_video=task.mode == 'interpolation') + task.id: MediaSource(task) for task in project_data.tasks } ) - ext_per_task: Dict[int, str] = { + ext_per_task: dict[int, str] = { task.id: TaskFrameProvider.VIDEO_FRAME_EXT if is_video else '' for task in project_data.tasks for is_video in [task.mode == 'interpolation'] } + dm_items: list[dm.DatasetItem] = [] for frame_data in project_data.group_by_frame(include_empty=True): - image_args = { - 'path': frame_data.name + ext_per_task[frame_data.task_id], - 'size': (frame_data.height, frame_data.width), - } + dm_media_args = { 'path': frame_data.name + ext_per_task[frame_data.task_id] } if self._dimension == DimensionType.DIM_3D: - dm_image = self._image_provider.get_image_for_frame( - frame_data.task_id, frame_data.id, **image_args) - elif include_images: - dm_image = self._image_provider.get_image_for_frame( - frame_data.task_id, frame_data.idx, **image_args) + dm_media: dm.PointCloud = self._media_provider.get_media_for_frame( + frame_data.task_id, frame_data.id, **dm_media_args + ) + + if not include_images: + dm_media_args["extra_images"] = [ + dm.Image(path=osp.basename(image.path)) + for image in dm_media.extra_images + ] + dm_media = dm.PointCloud(**dm_media_args) else: - dm_image = dm.Image(**image_args) + dm_media_args['size'] = (frame_data.height, frame_data.width) + if include_images: + dm_media: dm.Image = self._media_provider.get_media_for_frame( + frame_data.task_id, frame_data.idx, **dm_media_args + ) + else: + dm_media = dm.Image(**dm_media_args) + dm_anno = self._read_cvat_anno(frame_data, project_data.meta[project_data.META_FIELD]['labels']) + + dm_attributes = {'frame': frame_data.frame} + if self._dimension == DimensionType.DIM_2D: dm_item = dm.DatasetItem( id=osp.splitext(frame_data.name)[0], - annotations=dm_anno, media=dm_image, + annotations=dm_anno, media=dm_media, subset=frame_data.subset, - attributes={'frame': frame_data.frame} + attributes=dm_attributes, ) - else: - attributes = {'frame': frame_data.frame} + elif self._dimension == DimensionType.DIM_3D: if format_type == "sly_pointcloud": - attributes["name"] = self._user["name"] - attributes["createdAt"] = self._user["createdAt"] - attributes["updatedAt"] = self._user["updatedAt"] - attributes["labels"] = [] + dm_attributes["name"] = self._user["name"] + dm_attributes["createdAt"] = self._user["createdAt"] + dm_attributes["updatedAt"] = self._user["updatedAt"] + dm_attributes["labels"] = [] for (idx, (_, label)) in enumerate(project_data.meta[project_data.META_FIELD]['labels']): - attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]}) - attributes["track_id"] = -1 + dm_attributes["labels"].append({"label_id": idx, "name": label["name"], "color": label["color"], "type": label["type"]}) + dm_attributes["track_id"] = -1 dm_item = dm.DatasetItem( id=osp.splitext(osp.split(frame_data.name)[-1])[0], - annotations=dm_anno, media=PointCloud(dm_image[0]), related_images=dm_image[1], - attributes=attributes, subset=frame_data.subset + annotations=dm_anno, media=dm_media, + subset=frame_data.subset, + attributes=dm_attributes, ) + dm_items.append(dm_item) self._items = dm_items @@ -1855,7 +1882,7 @@ def _clean_display_message(self) -> str: message = "Dataset must contain a file:" + message return re.sub(r' +', " ", message) -def mangle_image_name(name: str, subset: str, names: DefaultDict[Tuple[str, str], int]) -> str: +def mangle_image_name(name: str, subset: str, names: defaultdict[tuple[str, str], int]) -> str: name, ext = name.rsplit(osp.extsep, maxsplit=1) if not names[(subset, name)]: @@ -1876,7 +1903,7 @@ def mangle_image_name(name: str, subset: str, names: DefaultDict[Tuple[str, str] i += 1 raise Exception('Cannot mangle image name') -def get_defaulted_subset(subset: str, subsets: List[str]) -> str: +def get_defaulted_subset(subset: str, subsets: list[str]) -> str: if subset: return subset else: @@ -2038,7 +2065,7 @@ def _convert_shape(self, return results - def _convert_shapes(self, shapes: List[CommonData.LabeledShape]) -> Iterable[dm.Annotation]: + def _convert_shapes(self, shapes: list[CommonData.LabeledShape]) -> Iterable[dm.Annotation]: dm_anno = [] self.num_of_tracks = reduce( @@ -2052,7 +2079,7 @@ def _convert_shapes(self, shapes: List[CommonData.LabeledShape]) -> Iterable[dm. return dm_anno - def convert(self) -> List[dm.Annotation]: + def convert(self) -> list[dm.Annotation]: dm_anno = [] dm_anno.extend(self._convert_tags(self.cvat_frame_anno.tags)) dm_anno.extend(self._convert_shapes(self.cvat_frame_anno.labeled_shapes)) @@ -2065,7 +2092,7 @@ def convert_cvat_anno_to_dm( map_label, format_name=None, dimension=DimensionType.DIM_2D -) -> List[dm.Annotation]: +) -> list[dm.Annotation]: converter = CvatToDmAnnotationConverter( cvat_frame_anno=cvat_frame_anno, label_attrs=label_attrs, @@ -2442,18 +2469,27 @@ def load_dataset_data(project_annotation, dataset: dm.Dataset, project_data): project_annotation.add_task(task_fields, dataset_files, project_data) -def detect_dataset(dataset_dir: str, format_name: str, importer: Importer) -> None: +class NoMediaInAnnotationFileError(CvatImportError): + def __str__(self) -> str: + return ( + "Can't import media data from the annotation file. " + "Please upload full dataset as a zip archive." + ) + +def detect_dataset(dataset_dir: str, format_name: str, importer: dm.Importer) -> None: not_found_error_instance = CvatDatasetNotFoundError() - def not_found_error(_, reason, human_message): + def _handle_rejection(format_name: str, reason: RejectionReason, human_message: str) -> None: not_found_error_instance.format_name = format_name not_found_error_instance.reason = reason not_found_error_instance.message = human_message - detection_env = Environment() + detection_env = dm.Environment() detection_env.importers.items.clear() detection_env.importers.register(format_name, importer) - detected = detection_env.detect_dataset(dataset_dir, depth=4, rejection_callback=not_found_error) + detected = detection_env.detect_dataset( + dataset_dir, depth=4, rejection_callback=_handle_rejection + ) if not detected and not_found_error_instance.reason != RejectionReason.detection_unsupported: raise not_found_error_instance diff --git a/cvat/apps/dataset_manager/formats/coco.py b/cvat/apps/dataset_manager/formats/coco.py index 6d63aeb0360f..1d1a8ce4d0d5 100644 --- a/cvat/apps/dataset_manager/formats/coco.py +++ b/cvat/apps/dataset_manager/formats/coco.py @@ -9,8 +9,9 @@ from datumaro.components.annotation import AnnotationType from datumaro.plugins.coco_format.importer import CocoImporter -from cvat.apps.dataset_manager.bindings import GetCVATDataExtractor, detect_dataset, \ - import_dm_annotations +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, NoMediaInAnnotationFileError, import_dm_annotations, detect_dataset +) from cvat.apps.dataset_manager.util import make_zip_archive from .registry import dm_env, exporter, importer @@ -35,6 +36,9 @@ def _import(src_file, temp_dir, instance_data, load_data_callback=None, **kwargs load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) else: + if load_data_callback: + raise NoMediaInAnnotationFileError() + dataset = Dataset.import_from(src_file.name, 'coco_instances', env=dm_env) import_dm_annotations(dataset, instance_data) @@ -52,6 +56,8 @@ def _export(dst_file, temp_dir, instance_data, save_images=False): def _import(src_file, temp_dir, instance_data, load_data_callback=None, **kwargs): def remove_extra_annotations(dataset): for item in dataset: + # Boxes would have invalid (skeleton) labels, so remove them + # TODO: find a way to import boxes annotations = [ann for ann in item.annotations if ann.type != AnnotationType.bbox] item.annotations = annotations @@ -66,7 +72,9 @@ def remove_extra_annotations(dataset): load_data_callback(dataset, instance_data) import_dm_annotations(dataset, instance_data) else: - dataset = Dataset.import_from(src_file.name, - 'coco_person_keypoints', env=dm_env) + if load_data_callback: + raise NoMediaInAnnotationFileError() + + dataset = Dataset.import_from(src_file.name, 'coco_person_keypoints', env=dm_env) remove_extra_annotations(dataset) import_dm_annotations(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/formats/cvat.py b/cvat/apps/dataset_manager/formats/cvat.py index 03ef389599e8..fa46b58813bf 100644 --- a/cvat/apps/dataset_manager/formats/cvat.py +++ b/cvat/apps/dataset_manager/formats/cvat.py @@ -22,10 +22,16 @@ from datumaro.util.image import Image from defusedxml import ElementTree -from cvat.apps.dataset_manager.bindings import (ProjectData, TaskData, JobData, detect_dataset, - get_defaulted_subset, - import_dm_annotations, - match_dm_item) +from cvat.apps.dataset_manager.bindings import ( + NoMediaInAnnotationFileError, + ProjectData, + TaskData, + JobData, + detect_dataset, + get_defaulted_subset, + import_dm_annotations, + match_dm_item +) from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.engine.frame_provider import FrameQuality, FrameOutputType, make_frame_provider @@ -1456,4 +1462,7 @@ def _import(src_file, temp_dir, instance_data, load_data_callback=None, **kwargs for p in anno_paths: load_anno(p, instance_data) else: + if load_data_callback: + raise NoMediaInAnnotationFileError() + load_anno(src_file, instance_data) diff --git a/cvat/apps/dataset_manager/formats/datumaro.py b/cvat/apps/dataset_manager/formats/datumaro.py index 090397b7a471..4fc1d246dd47 100644 --- a/cvat/apps/dataset_manager/formats/datumaro.py +++ b/cvat/apps/dataset_manager/formats/datumaro.py @@ -3,43 +3,40 @@ # # SPDX-License-Identifier: MIT +import zipfile from datumaro.components.dataset import Dataset -from datumaro.components.extractor import ItemTransform -from datumaro.util.image import Image -from pyunpack import Archive - -from cvat.apps.dataset_manager.bindings import (GetCVATDataExtractor, detect_dataset, - import_dm_annotations) +from cvat.apps.dataset_manager.bindings import ( + GetCVATDataExtractor, import_dm_annotations, NoMediaInAnnotationFileError, detect_dataset +) from cvat.apps.dataset_manager.util import make_zip_archive from cvat.apps.engine.models import DimensionType from .registry import dm_env, exporter, importer -class DeleteImagePath(ItemTransform): - def transform_item(self, item): - image = None - if item.has_image and item.image.has_data: - image = Image(data=item.image.data, size=item.image.size) - return item.wrap(image=image, point_cloud='', related_images=[]) - @exporter(name="Datumaro", ext="ZIP", version="1.0") def _export(dst_file, temp_dir, instance_data, save_images=False): - with GetCVATDataExtractor(instance_data=instance_data, include_images=save_images) as extractor: + with GetCVATDataExtractor( + instance_data=instance_data, include_images=save_images + ) as extractor: dataset = Dataset.from_extractors(extractor, env=dm_env) - if not save_images: - dataset.transform(DeleteImagePath) dataset.export(temp_dir, 'datumaro', save_images=save_images) make_zip_archive(temp_dir, dst_file) -@importer(name="Datumaro", ext="ZIP", version="1.0") +@importer(name="Datumaro", ext="JSON, ZIP", version="1.0") def _import(src_file, temp_dir, instance_data, load_data_callback=None, **kwargs): - Archive(src_file.name).extractall(temp_dir) + if zipfile.is_zipfile(src_file): + zipfile.ZipFile(src_file).extractall(temp_dir) - detect_dataset(temp_dir, format_name='datumaro', importer=dm_env.importers.get('datumaro')) - dataset = Dataset.import_from(temp_dir, 'datumaro', env=dm_env) + detect_dataset(temp_dir, format_name='datumaro', importer=dm_env.importers.get('datumaro')) + dataset = Dataset.import_from(temp_dir, 'datumaro', env=dm_env) + else: + if load_data_callback: + raise NoMediaInAnnotationFileError() + + dataset = Dataset.import_from(src_file.name, 'datumaro', env=dm_env) if load_data_callback is not None: load_data_callback(dataset, instance_data) @@ -52,19 +49,22 @@ def _export(dst_file, temp_dir, instance_data, save_images=False): dimension=DimensionType.DIM_3D, ) as extractor: dataset = Dataset.from_extractors(extractor, env=dm_env) - - if not save_images: - dataset.transform(DeleteImagePath) dataset.export(temp_dir, 'datumaro', save_images=save_images) make_zip_archive(temp_dir, dst_file) -@importer(name="Datumaro 3D", ext="ZIP", version="1.0", dimension=DimensionType.DIM_3D) +@importer(name="Datumaro 3D", ext="JSON, ZIP", version="1.0", dimension=DimensionType.DIM_3D) def _import(src_file, temp_dir, instance_data, load_data_callback=None, **kwargs): - Archive(src_file.name).extractall(temp_dir) + if zipfile.is_zipfile(src_file): + zipfile.ZipFile(src_file).extractall(temp_dir) + + detect_dataset(temp_dir, format_name='datumaro', importer=dm_env.importers.get('datumaro')) + dataset = Dataset.import_from(temp_dir, 'datumaro', env=dm_env) + else: + if load_data_callback: + raise NoMediaInAnnotationFileError() - detect_dataset(temp_dir, format_name='datumaro', importer=dm_env.importers.get('datumaro')) - dataset = Dataset.import_from(temp_dir, 'datumaro', env=dm_env) + dataset = Dataset.import_from(src_file.name, 'datumaro', env=dm_env) if load_data_callback is not None: load_data_callback(dataset, instance_data) diff --git a/cvat/apps/dataset_manager/project.py b/cvat/apps/dataset_manager/project.py index 759483b10a06..93ac651cf477 100644 --- a/cvat/apps/dataset_manager/project.py +++ b/cvat/apps/dataset_manager/project.py @@ -4,9 +4,10 @@ # SPDX-License-Identifier: MIT import os +from collections.abc import Mapping from tempfile import TemporaryDirectory import rq -from typing import Any, Callable, List, Mapping, Tuple +from typing import Any, Callable from datumaro.components.errors import DatasetError, DatasetImportError, DatasetNotFoundError from django.db import transaction @@ -109,7 +110,7 @@ def split_name(file): project_data.new_tasks.add(db_task.id) project_data.init() - def add_labels(self, labels: List[models.Label], attributes: List[Tuple[str, models.AttributeSpec]] = None): + def add_labels(self, labels: list[models.Label], attributes: list[tuple[str, models.AttributeSpec]] = None): for label in labels: label.project = self.db_project # We need label_id here, so we can't use bulk_create here diff --git a/cvat/apps/dataset_manager/tests/utils.py b/cvat/apps/dataset_manager/tests/utils.py index 9a134b887bf7..6e3b51a878d9 100644 --- a/cvat/apps/dataset_manager/tests/utils.py +++ b/cvat/apps/dataset_manager/tests/utils.py @@ -6,7 +6,7 @@ import tempfile import unittest from types import TracebackType -from typing import Optional, Type +from typing import Optional from datumaro.util.os_util import rmfile, rmtree @@ -23,7 +23,7 @@ def __enter__(self) -> str: def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: diff --git a/cvat/apps/dataset_manager/util.py b/cvat/apps/dataset_manager/util.py index 0193748446f3..2f1029049bbf 100644 --- a/cvat/apps/dataset_manager/util.py +++ b/cvat/apps/dataset_manager/util.py @@ -8,11 +8,12 @@ import os.path as osp import re import zipfile +from collections.abc import Generator, Sequence from contextlib import contextmanager from copy import deepcopy from datetime import timedelta from threading import Lock -from typing import Any, Generator, Optional, Sequence +from typing import Any, Optional import attrs import django_rq diff --git a/cvat/apps/engine/cache.py b/cvat/apps/engine/cache.py index 295e405a41da..197c10f14d71 100644 --- a/cvat/apps/engine/cache.py +++ b/cvat/apps/engine/cache.py @@ -10,6 +10,7 @@ import os.path import pickle # nosec import tempfile +import time import zipfile import zlib from contextlib import ExitStack, closing @@ -29,11 +30,18 @@ overload, ) +import attrs import av import cv2 +import django_rq import PIL.Image import PIL.ImageOps +import rq +from django.conf import settings from django.core.cache import caches +from django.db import models as django_models +from django.utils import timezone as django_tz +from redis.exceptions import LockError from rest_framework.exceptions import NotFound, ValidationError from cvat.apps.engine import models @@ -54,74 +62,254 @@ ZipChunkWriter, ZipCompressedChunkWriter, ) -from cvat.apps.engine.utils import load_image, md5_hash +from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.engine.utils import ( + CvatChunkTimestampMismatchError, + get_rq_lock_for_job, + load_image, + md5_hash, +) from utils.dataset_manifest import ImageManifestManager slogger = ServerLogManager(__name__) DataWithMime = Tuple[io.BytesIO, str] -_CacheItem = Tuple[io.BytesIO, str, int] +_CacheItem = Tuple[io.BytesIO, str, int, Union[datetime, None]] + + +def enqueue_create_chunk_job( + queue: rq.Queue, + rq_job_id: str, + create_callback: Callback, + *, + blocking_timeout: int = 50, + rq_job_result_ttl: int = 60, + rq_job_failure_ttl: int = 3600 * 24 * 14, # 2 weeks +) -> rq.job.Job: + try: + with get_rq_lock_for_job(queue, rq_job_id, blocking_timeout=blocking_timeout): + rq_job = queue.fetch_job(rq_job_id) + + if not rq_job: + rq_job = queue.enqueue( + create_callback, + job_id=rq_job_id, + result_ttl=rq_job_result_ttl, + failure_ttl=rq_job_failure_ttl, + ) + except LockError: + raise TimeoutError(f"Cannot acquire lock for {rq_job_id}") + + return rq_job + + +def wait_for_rq_job(rq_job: rq.job.Job): + retries = settings.CVAT_CHUNK_CREATE_TIMEOUT // settings.CVAT_CHUNK_CREATE_CHECK_INTERVAL or 1 + while retries > 0: + job_status = rq_job.get_status() + if job_status in ("finished",): + return + elif job_status in ("failed",): + job_meta = rq_job.get_meta() + exc_type = job_meta.get(RQJobMetaField.EXCEPTION_TYPE, Exception) + exc_args = job_meta.get(RQJobMetaField.EXCEPTION_ARGS, ("Cannot create chunk",)) + raise exc_type(*exc_args) + + time.sleep(settings.CVAT_CHUNK_CREATE_CHECK_INTERVAL) + retries -= 1 + + raise TimeoutError(f"Chunk processing takes too long {rq_job.id}") + + +def _is_run_inside_rq() -> bool: + return rq.get_current_job() is not None + + +def _convert_args_for_callback(func_args: list[Any]) -> list[Any]: + result = [] + for func_arg in func_args: + if _is_run_inside_rq(): + result.append(func_arg) + else: + if isinstance( + func_arg, + django_models.Model, + ): + result.append(func_arg.id) + elif isinstance(func_arg, list): + result.append(_convert_args_for_callback(func_arg)) + else: + result.append(func_arg) + + return result + + +@attrs.frozen +class Callback: + _callable: Callable[..., DataWithMime] = attrs.field( + validator=attrs.validators.is_callable(), + ) + _args: list[Any] = attrs.field( + factory=list, + validator=attrs.validators.instance_of(list), + converter=_convert_args_for_callback, + ) + _kwargs: dict[str, Union[bool, int, float, str, None]] = attrs.field( + factory=dict, + validator=attrs.validators.deep_mapping( + key_validator=attrs.validators.instance_of(str), + value_validator=attrs.validators.instance_of((bool, int, float, str, type(None))), + mapping_validator=attrs.validators.instance_of(dict), + ), + ) + + def __call__(self) -> DataWithMime: + return self._callable(*self._args, **self._kwargs) class MediaCache: - def __init__(self) -> None: - self._cache = caches["media"] + _QUEUE_NAME = settings.CVAT_QUEUES.CHUNKS.value + _QUEUE_JOB_PREFIX_TASK = "chunks:prepare-item-" + _CACHE_NAME = "media" + _PREVIEW_TTL = settings.CVAT_PREVIEW_CACHE_TTL - def _get_checksum(self, value: bytes) -> int: + @staticmethod + def _cache(): + return caches[MediaCache._CACHE_NAME] + + @staticmethod + def _get_checksum(value: bytes) -> int: return zlib.crc32(value) def _get_or_set_cache_item( - self, key: str, create_callback: Callable[[], DataWithMime] + self, + key: str, + create_callback: Callback, + *, + cache_item_ttl: Optional[int] = None, ) -> _CacheItem: - def create_item() -> _CacheItem: - slogger.glob.info(f"Starting to prepare chunk: key {key}") - item_data = create_callback() - slogger.glob.info(f"Ending to prepare chunk: key {key}") + item = self._get_cache_item(key) + if item: + return item - item_data_bytes = item_data[0].getvalue() - item = (item_data[0], item_data[1], self._get_checksum(item_data_bytes)) - if item_data_bytes: - self._cache.set(key, item) + return self._create_cache_item( + key, + create_callback, + cache_item_ttl=cache_item_ttl, + ) - return item + def _get_queue(self) -> rq.Queue: + return django_rq.get_queue(self._QUEUE_NAME) - item = self._get_cache_item(key) - if not item: - item = create_item() + def _make_queue_job_id(self, key: str) -> str: + return f"{self._QUEUE_JOB_PREFIX_TASK}{key}" + + @staticmethod + def _drop_return_value(func: Callable[..., DataWithMime], *args: Any, **kwargs: Any): + func(*args, **kwargs) + + @classmethod + def _create_and_set_cache_item( + cls, + key: str, + create_callback: Callback, + cache_item_ttl: Optional[int] = None, + ) -> DataWithMime: + timestamp = django_tz.now() + item_data = create_callback() + item_data_bytes = item_data[0].getvalue() + item = (item_data[0], item_data[1], cls._get_checksum(item_data_bytes), timestamp) + if item_data_bytes: + cache = cls._cache() + cache.set(key, item, timeout=cache_item_ttl or cache.default_timeout) + + return item + + def _create_cache_item( + self, + key: str, + create_callback: Callback, + *, + cache_item_ttl: Optional[int] = None, + ) -> _CacheItem: + + queue = self._get_queue() + rq_id = self._make_queue_job_id(key) + + slogger.glob.info(f"Starting to prepare chunk: key {key}") + if _is_run_inside_rq(): + with get_rq_lock_for_job(queue, rq_id, timeout=None, blocking_timeout=None): + item = self._create_and_set_cache_item( + key, + create_callback, + cache_item_ttl=cache_item_ttl, + ) else: - # compare checksum - item_data = item[0].getbuffer() if isinstance(item[0], io.BytesIO) else item[0] - item_checksum = item[2] if len(item) == 3 else None - if item_checksum != self._get_checksum(item_data): - slogger.glob.info(f"Recreating cache item {key} due to checksum mismatch") - item = create_item() + rq_job = enqueue_create_chunk_job( + queue=queue, + rq_job_id=rq_id, + create_callback=Callback( + callable=self._drop_return_value, + args=[ + self._create_and_set_cache_item, + key, + create_callback, + ], + kwargs={ + "cache_item_ttl": cache_item_ttl, + }, + ), + ) + wait_for_rq_job(rq_job) + item = self._get_cache_item(key) + + slogger.glob.info(f"Ending to prepare chunk: key {key}") return item def _delete_cache_item(self, key: str): try: - self._cache.delete(key) + self._cache().delete(key) slogger.glob.info(f"Removed chunk from the cache: key {key}") except pickle.UnpicklingError: slogger.glob.error(f"Failed to remove item from the cache: key {key}", exc_info=True) def _get_cache_item(self, key: str) -> Optional[_CacheItem]: - slogger.glob.info(f"Starting to get chunk from cache: key {key}") try: - item = self._cache.get(key) + item = self._cache().get(key) except pickle.UnpicklingError: slogger.glob.error(f"Unable to get item from cache: key {key}", exc_info=True) item = None - slogger.glob.info(f"Ending to get chunk from cache: key {key}, is_cached {bool(item)}") + + if not item: + return None + + item_data = item[0].getbuffer() if isinstance(item[0], io.BytesIO) else item[0] + item_checksum = item[2] if len(item) == 4 else None + if item_checksum != self._get_checksum(item_data): + slogger.glob.info(f"Cache item {key} checksum mismatch") + return None return item - def _has_key(self, key: str) -> bool: - return self._cache.has_key(key) + def _validate_cache_item_timestamp( + self, item: _CacheItem, expected_timestamp: datetime + ) -> _CacheItem: + if item[3] < expected_timestamp: + raise CvatChunkTimestampMismatchError( + f"Cache timestamp mismatch. Item_ts: {item[3]}, expected_ts: {expected_timestamp}" + ) + + return item + @classmethod + def _has_key(cls, key: str) -> bool: + return cls._cache().has_key(key) + + @staticmethod def _make_cache_key_prefix( - self, obj: Union[models.Task, models.Segment, models.Job, models.CloudStorage] + obj: Union[models.Task, models.Segment, models.Job, models.CloudStorage] ) -> str: if isinstance(obj, models.Task): return f"task_{obj.id}" @@ -134,14 +322,15 @@ def _make_cache_key_prefix( else: assert False, f"Unexpected object type {type(obj)}" + @classmethod def _make_chunk_key( - self, + cls, db_obj: Union[models.Task, models.Segment, models.Job], chunk_number: int, *, quality: FrameQuality, ) -> str: - return f"{self._make_cache_key_prefix(db_obj)}_chunk_{chunk_number}_{quality}" + return f"{cls._make_cache_key_prefix(db_obj)}_chunk_{chunk_number}_{quality}" def _make_preview_key(self, db_obj: Union[models.Segment, models.CloudStorage]) -> str: return f"{self._make_cache_key_prefix(db_obj)}_preview" @@ -173,35 +362,47 @@ def _to_data_with_mime(self, cache_item: Optional[_CacheItem]) -> Optional[DataW def get_or_set_segment_chunk( self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality ) -> DataWithMime: + + item = self._get_or_set_cache_item( + self._make_chunk_key(db_segment, chunk_number, quality=quality), + Callback( + callable=self.prepare_segment_chunk, + args=[db_segment, chunk_number], + kwargs={"quality": quality}, + ), + ) + db_segment.refresh_from_db(fields=["chunks_updated_date"]) + return self._to_data_with_mime( - self._get_or_set_cache_item( - key=self._make_chunk_key(db_segment, chunk_number, quality=quality), - create_callback=lambda: self.prepare_segment_chunk( - db_segment, chunk_number, quality=quality - ), - ) + self._validate_cache_item_timestamp(item, db_segment.chunks_updated_date) ) def get_task_chunk( self, db_task: models.Task, chunk_number: int, *, quality: FrameQuality ) -> Optional[DataWithMime]: return self._to_data_with_mime( - self._get_cache_item(key=self._make_chunk_key(db_task, chunk_number, quality=quality)) + self._get_cache_item( + key=self._make_chunk_key(db_task, chunk_number, quality=quality), + ) ) def get_or_set_task_chunk( self, db_task: models.Task, chunk_number: int, + set_callback: Callback, *, quality: FrameQuality, - set_callback: Callable[[], DataWithMime], ) -> DataWithMime: + + item = self._get_or_set_cache_item( + self._make_chunk_key(db_task, chunk_number, quality=quality), + set_callback, + ) + db_task.refresh_from_db(fields=["segment_set"]) + return self._to_data_with_mime( - self._get_or_set_cache_item( - key=self._make_chunk_key(db_task, chunk_number, quality=quality), - create_callback=set_callback, - ) + self._validate_cache_item_timestamp(item, db_task.get_chunks_updated_date()) ) def get_segment_task_chunk( @@ -209,7 +410,7 @@ def get_segment_task_chunk( ) -> Optional[DataWithMime]: return self._to_data_with_mime( self._get_cache_item( - key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality) + key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality), ) ) @@ -219,13 +420,17 @@ def get_or_set_segment_task_chunk( chunk_number: int, *, quality: FrameQuality, - set_callback: Callable[[], DataWithMime], + set_callback: Callback, ) -> DataWithMime: + + item = self._get_or_set_cache_item( + self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality), + set_callback, + ) + db_segment.refresh_from_db(fields=["chunks_updated_date"]) + return self._to_data_with_mime( - self._get_or_set_cache_item( - key=self._make_segment_task_chunk_key(db_segment, chunk_number, quality=quality), - create_callback=set_callback, - ) + self._validate_cache_item_timestamp(item, db_segment.chunks_updated_date), ) def get_or_set_selective_job_chunk( @@ -233,9 +438,13 @@ def get_or_set_selective_job_chunk( ) -> DataWithMime: return self._to_data_with_mime( self._get_or_set_cache_item( - key=self._make_chunk_key(db_job, chunk_number, quality=quality), - create_callback=lambda: self.prepare_masked_range_segment_chunk( - db_job.segment, chunk_number, quality=quality + self._make_chunk_key(db_job, chunk_number, quality=quality), + Callback( + callable=self.prepare_masked_range_segment_chunk, + args=[db_job.segment, chunk_number], + kwargs={ + "quality": quality, + }, ), ) ) @@ -244,7 +453,11 @@ def get_or_set_segment_preview(self, db_segment: models.Segment) -> DataWithMime return self._to_data_with_mime( self._get_or_set_cache_item( self._make_preview_key(db_segment), - create_callback=lambda: self._prepare_segment_preview(db_segment), + Callback( + callable=self._prepare_segment_preview, + args=[db_segment], + ), + cache_item_ttl=self._PREVIEW_TTL, ) ) @@ -262,7 +475,11 @@ def get_or_set_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithM return self._to_data_with_mime( self._get_or_set_cache_item( self._make_preview_key(db_storage), - create_callback=lambda: self._prepare_cloud_preview(db_storage), + Callback( + callable=self._prepare_cloud_preview, + args=[db_storage], + ), + cache_item_ttl=self._PREVIEW_TTL, ) ) @@ -271,13 +488,16 @@ def get_or_set_frame_context_images_chunk( ) -> DataWithMime: return self._to_data_with_mime( self._get_or_set_cache_item( - key=self._make_context_image_preview_key(db_data, frame_number), - create_callback=lambda: self.prepare_context_images_chunk(db_data, frame_number), + self._make_context_image_preview_key(db_data, frame_number), + Callback( + callable=self.prepare_context_images_chunk, + args=[db_data, frame_number], + ), ) ) + @staticmethod def _read_raw_images( - self, db_task: models.Task, frame_ids: Sequence[int], *, @@ -361,9 +581,13 @@ def _read_raw_images( yield from media + @staticmethod def _read_raw_frames( - self, db_task: models.Task, frame_ids: Sequence[int] + db_task: Union[models.Task, int], frame_ids: Sequence[int] ) -> Generator[Tuple[Union[av.VideoFrame, PIL.Image.Image], str, str], None, None]: + if isinstance(db_task, int): + db_task = models.Task.objects.get(pk=db_task) + for prev_frame, cur_frame in pairwise(frame_ids): assert ( prev_frame <= cur_frame @@ -400,11 +624,14 @@ def _read_raw_frames( for frame_tuple in reader.iterate_frames(frame_filter=frame_ids): yield frame_tuple else: - yield from self._read_raw_images(db_task, frame_ids, manifest_path=manifest_path) + yield from MediaCache._read_raw_images(db_task, frame_ids, manifest_path=manifest_path) def prepare_segment_chunk( - self, db_segment: models.Segment, chunk_number: int, *, quality: FrameQuality + self, db_segment: Union[models.Segment, int], chunk_number: int, *, quality: FrameQuality ) -> DataWithMime: + if isinstance(db_segment, int): + db_segment = models.Segment.objects.get(pk=db_segment) + if db_segment.type == models.SegmentType.RANGE: return self.prepare_range_segment_chunk(db_segment, chunk_number, quality=quality) elif db_segment.type == models.SegmentType.SPECIFIC_FRAMES: @@ -427,10 +654,11 @@ def prepare_range_segment_chunk( return self.prepare_custom_range_segment_chunk(db_task, chunk_frame_ids, quality=quality) + @classmethod def prepare_custom_range_segment_chunk( - self, db_task: models.Task, frame_ids: Sequence[int], *, quality: FrameQuality + cls, db_task: models.Task, frame_ids: Sequence[int], *, quality: FrameQuality ) -> DataWithMime: - with closing(self._read_raw_frames(db_task, frame_ids=frame_ids)) as frame_iter: + with closing(cls._read_raw_frames(db_task, frame_ids=frame_ids)) as frame_iter: return prepare_chunk(frame_iter, quality=quality, db_task=db_task) def prepare_masked_range_segment_chunk( @@ -448,15 +676,19 @@ def prepare_masked_range_segment_chunk( db_task, chunk_frame_ids, chunk_number, quality=quality ) + @classmethod def prepare_custom_masked_range_segment_chunk( - self, - db_task: models.Task, + cls, + db_task: Union[models.Task, int], frame_ids: Collection[int], chunk_number: int, *, quality: FrameQuality, insert_placeholders: bool = False, ) -> DataWithMime: + if isinstance(db_task, int): + db_task = models.Task.objects.get(pk=db_task) + db_data = db_task.data frame_step = db_data.get_frame_step() @@ -493,8 +725,8 @@ def prepare_custom_masked_range_segment_chunk( if not list(chunk_frames): continue - chunk_available = self._has_key( - self._make_chunk_key(db_segment, i, quality=quality) + chunk_available = cls._has_key( + cls._make_chunk_key(db_segment, i, quality=quality) ) available_chunks.append(chunk_available) @@ -521,7 +753,7 @@ def get_frames(): frame_range = frame_ids if not use_cached_data: - frames_gen = self._read_raw_frames(db_task, frame_ids) + frames_gen = cls._read_raw_frames(db_task, frame_ids) frames_iter = iter(es.enter_context(closing(frames_gen))) for abs_frame_idx in frame_range: @@ -569,7 +801,10 @@ def get_frames(): buff.seek(0) return buff, get_chunk_mime_type_for_writer(writer) - def _prepare_segment_preview(self, db_segment: models.Segment) -> DataWithMime: + def _prepare_segment_preview(self, db_segment: Union[models.Segment, int]) -> DataWithMime: + if isinstance(db_segment, int): + db_segment = models.Segment.objects.get(pk=db_segment) + if db_segment.task.dimension == models.DimensionType.DIM_3D: # TODO preview = PIL.Image.open( @@ -591,7 +826,10 @@ def _prepare_segment_preview(self, db_segment: models.Segment) -> DataWithMime: return prepare_preview_image(preview) - def _prepare_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithMime: + def _prepare_cloud_preview(self, db_storage: Union[models.CloudStorage, int]) -> DataWithMime: + if isinstance(db_storage, int): + db_storage = models.CloudStorage.objects.get(pk=db_storage) + storage = db_storage_to_storage_instance(db_storage) if not db_storage.manifests.count(): raise ValidationError("Cannot get the cloud storage preview. There is no manifest file") @@ -631,7 +869,12 @@ def _prepare_cloud_preview(self, db_storage: models.CloudStorage) -> DataWithMim image = PIL.Image.open(buff) return prepare_preview_image(image) - def prepare_context_images_chunk(self, db_data: models.Data, frame_number: int) -> DataWithMime: + def prepare_context_images_chunk( + self, db_data: Union[models.Data, int], frame_number: int + ) -> DataWithMime: + if isinstance(db_data, int): + db_data = models.Data.objects.get(pk=db_data) + zip_buffer = io.BytesIO() related_images = db_data.related_files.filter(images__frame=frame_number).all() diff --git a/cvat/apps/engine/default_settings.py b/cvat/apps/engine/default_settings.py index 826fe1c9bef2..15e1b3fd8c32 100644 --- a/cvat/apps/engine/default_settings.py +++ b/cvat/apps/engine/default_settings.py @@ -14,3 +14,13 @@ When enabled, this option can increase data access speed and reduce server load, but significantly increase disk space occupied by tasks. """ + +CVAT_CHUNK_CREATE_TIMEOUT = 50 +""" +Sets the chunk preparation timeout in seconds after which the backend will respond with 429 code. +""" + +CVAT_CHUNK_CREATE_CHECK_INTERVAL = 0.2 +""" +Sets the frequency of checking the readiness of the chunk +""" diff --git a/cvat/apps/engine/frame_provider.py b/cvat/apps/engine/frame_provider.py index 1787d84aac40..2da1741b5bc7 100644 --- a/cvat/apps/engine/frame_provider.py +++ b/cvat/apps/engine/frame_provider.py @@ -10,6 +10,7 @@ import math from abc import ABCMeta, abstractmethod from bisect import bisect +from collections import OrderedDict from dataclasses import dataclass from enum import Enum, auto from io import BytesIO @@ -36,7 +37,7 @@ from rest_framework.exceptions import ValidationError from cvat.apps.engine import models -from cvat.apps.engine.cache import DataWithMime, MediaCache, prepare_chunk +from cvat.apps.engine.cache import Callback, DataWithMime, MediaCache, prepare_chunk from cvat.apps.engine.media_extractors import ( FrameQuality, IMediaReader, @@ -310,38 +311,60 @@ def get_chunk( # The requested frames match one of the job chunks, we can use it directly return segment_frame_provider.get_chunk(matching_chunk_index, quality=quality) - def _set_callback() -> DataWithMime: - # Create and return a joined / cleaned chunk - task_chunk_frames = {} - for db_segment in matching_segments: - segment_frame_provider = SegmentFrameProvider(db_segment) - segment_frame_set = db_segment.frame_set - - for task_chunk_frame_id in sorted(task_chunk_frame_set): - if ( - task_chunk_frame_id not in segment_frame_set - or task_chunk_frame_id in task_chunk_frames - ): - continue - - frame, frame_name, _ = segment_frame_provider._get_raw_frame( - self.get_rel_frame_number(task_chunk_frame_id), quality=quality - ) - task_chunk_frames[task_chunk_frame_id] = (frame, frame_name, None) - - return prepare_chunk( - task_chunk_frames.values(), - quality=quality, - db_task=self._db_task, - dump_unchanged=True, - ) - buffer, mime_type = cache.get_or_set_task_chunk( - self._db_task, chunk_number, quality=quality, set_callback=_set_callback + self._db_task, + chunk_number, + quality=quality, + set_callback=Callback( + callable=self._get_chunk_create_callback, + args=[ + self._db_task, + matching_segments, + {f: self.get_rel_frame_number(f) for f in task_chunk_frame_set}, + quality, + ], + ), ) return return_type(data=buffer, mime=mime_type) + @staticmethod + def _get_chunk_create_callback( + db_task: Union[models.Task, int], + matching_segments: list[models.Segment], + task_chunk_frames_with_rel_numbers: dict[int, int], + quality: FrameQuality, + ) -> DataWithMime: + # Create and return a joined / cleaned chunk + task_chunk_frames = OrderedDict() + for db_segment in matching_segments: + if isinstance(db_segment, int): + db_segment = models.Segment.objects.get(pk=db_segment) + segment_frame_provider = SegmentFrameProvider(db_segment) + segment_frame_set = db_segment.frame_set + + for task_chunk_frame_id in sorted(task_chunk_frames_with_rel_numbers.keys()): + if ( + task_chunk_frame_id not in segment_frame_set + or task_chunk_frame_id in task_chunk_frames + ): + continue + + frame, frame_name, _ = segment_frame_provider._get_raw_frame( + task_chunk_frames_with_rel_numbers[task_chunk_frame_id], quality=quality + ) + task_chunk_frames[task_chunk_frame_id] = (frame, frame_name, None) + + if isinstance(db_task, int): + db_task = models.Task.objects.get(pk=db_task) + + return prepare_chunk( + task_chunk_frames.values(), + quality=quality, + db_task=db_task, + dump_unchanged=True, + ) + def get_frame( self, frame_number: int, @@ -664,35 +687,55 @@ def get_chunk( if matching_chunk is not None: return self.get_chunk(matching_chunk, quality=quality) - def _set_callback() -> DataWithMime: - # Create and return a joined / cleaned chunk - segment_chunk_frame_ids = sorted( - task_chunk_frame_set.intersection(self._db_segment.frame_set) - ) - - if self._db_segment.type == models.SegmentType.RANGE: - return cache.prepare_custom_range_segment_chunk( - db_task=self._db_segment.task, - frame_ids=segment_chunk_frame_ids, - quality=quality, - ) - elif self._db_segment.type == models.SegmentType.SPECIFIC_FRAMES: - return cache.prepare_custom_masked_range_segment_chunk( - db_task=self._db_segment.task, - frame_ids=segment_chunk_frame_ids, - chunk_number=chunk_number, - quality=quality, - insert_placeholders=True, - ) - else: - assert False + segment_chunk_frame_ids = sorted( + task_chunk_frame_set.intersection(self._db_segment.frame_set) + ) buffer, mime_type = cache.get_or_set_segment_task_chunk( - self._db_segment, chunk_number, quality=quality, set_callback=_set_callback + self._db_segment, + chunk_number, + quality=quality, + set_callback=Callback( + callable=self._get_chunk_create_callback, + args=[ + self._db_segment, + segment_chunk_frame_ids, + chunk_number, + quality, + ], + ), ) return return_type(data=buffer, mime=mime_type) + @staticmethod + def _get_chunk_create_callback( + db_segment: Union[models.Segment, int], + segment_chunk_frame_ids: list[int], + chunk_number: int, + quality: FrameQuality, + ) -> DataWithMime: + # Create and return a joined / cleaned chunk + if isinstance(db_segment, int): + db_segment = models.Segment.objects.get(pk=db_segment) + + if db_segment.type == models.SegmentType.RANGE: + return MediaCache.prepare_custom_range_segment_chunk( + db_task=db_segment.task, + frame_ids=segment_chunk_frame_ids, + quality=quality, + ) + elif db_segment.type == models.SegmentType.SPECIFIC_FRAMES: + return MediaCache.prepare_custom_masked_range_segment_chunk( + db_task=db_segment.task, + frame_ids=segment_chunk_frame_ids, + chunk_number=chunk_number, + quality=quality, + insert_placeholders=True, + ) + else: + assert False + @overload def make_frame_provider(data_source: models.Job) -> JobFrameProvider: ... diff --git a/cvat/apps/engine/rq_job_handler.py b/cvat/apps/engine/rq_job_handler.py index 25900fba20a9..bef7d94eaa69 100644 --- a/cvat/apps/engine/rq_job_handler.py +++ b/cvat/apps/engine/rq_job_handler.py @@ -28,7 +28,8 @@ class RQJobMetaField: # export specific fields RESULT_URL = 'result_url' FUNCTION_ID = 'function_id' - + EXCEPTION_TYPE = 'exc_type' + EXCEPTION_ARGS = 'exc_args' def is_rq_job_owner(rq_job: RQJob, user_id: int) -> bool: return rq_job.meta.get(RQJobMetaField.USER, {}).get('id') == user_id diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 5b3845f8260e..f8678248d2b8 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -11,6 +11,7 @@ import re import shutil import string +import django_rq import rq.defaults as rq_defaults from tempfile import NamedTemporaryFile @@ -22,14 +23,15 @@ from decimal import Decimal from rest_framework import serializers, exceptions +from django.conf import settings from django.contrib.auth.models import User, Group from django.db import transaction from django.utils import timezone from numpy import random from cvat.apps.dataset_manager.formats.utils import get_label_color -from cvat.apps.engine.frame_provider import TaskFrameProvider -from cvat.apps.engine.utils import format_list, parse_exception_message +from cvat.apps.engine.frame_provider import TaskFrameProvider, FrameQuality +from cvat.apps.engine.utils import format_list, parse_exception_message, CvatChunkTimestampMismatchError from cvat.apps.engine import field_validation, models from cvat.apps.engine.cloud_provider import get_cloud_storage_instance, Credentials, Status from cvat.apps.engine.log import ServerLogManager @@ -980,8 +982,8 @@ def validate(self, attrs): @transaction.atomic def update(self, instance: models.Job, validated_data: dict[str, Any]) -> models.Job: - from cvat.apps.engine.cache import MediaCache - from cvat.apps.engine.frame_provider import FrameQuality, JobFrameProvider, prepare_chunk + from cvat.apps.engine.cache import MediaCache, Callback, enqueue_create_chunk_job, wait_for_rq_job + from cvat.apps.engine.frame_provider import JobFrameProvider from cvat.apps.dataset_manager.task import JobAnnotation, AnnotationManager db_job = instance @@ -1129,7 +1131,6 @@ def _to_rel_frame(abs_frame: int) -> int: job_annotation.delete(job_annotation_manager.data) # Update chunks - task_frame_provider = TaskFrameProvider(db_task) job_frame_provider = JobFrameProvider(db_job) updated_segment_chunk_ids = set( job_frame_provider.get_chunk_number(updated_segment_frame_id) @@ -1138,7 +1139,7 @@ def _to_rel_frame(abs_frame: int) -> int: segment_frames = sorted(segment_frame_set) segment_frame_map = dict(zip(segment_honeypots, requested_frames)) - media_cache = MediaCache() + queue = django_rq.get_queue(settings.CVAT_QUEUES.CHUNKS.value) for chunk_id in sorted(updated_segment_chunk_ids): chunk_frames = segment_frames[ chunk_id * db_data.chunk_size : @@ -1146,36 +1147,26 @@ def _to_rel_frame(abs_frame: int) -> int: ] for quality in FrameQuality.__members__.values(): - def _write_updated_static_chunk(): - def _iterate_chunk_frames(): - for chunk_frame in chunk_frames: - db_frame = all_task_frames[chunk_frame] - chunk_real_frame = segment_frame_map.get(chunk_frame, chunk_frame) - yield ( - task_frame_provider.get_frame( - chunk_real_frame, quality=quality - ).data, - os.path.basename(db_frame.path), - chunk_frame, - ) - - with closing(_iterate_chunk_frames()) as frame_iter: - chunk, _ = prepare_chunk( - frame_iter, quality=quality, db_task=db_task, dump_unchanged=True, - ) - - get_chunk_path = { - FrameQuality.COMPRESSED: db_data.get_compressed_segment_chunk_path, - FrameQuality.ORIGINAL: db_data.get_original_segment_chunk_path, - }[quality] - - with open(get_chunk_path(chunk_id, db_segment.id), 'wb') as f: - f.write(chunk.getvalue()) - if db_data.storage_method == models.StorageMethodChoice.FILE_SYSTEM: - _write_updated_static_chunk() + rq_id = f"segment_{db_segment.id}_write_chunk_{chunk_id}_{quality}" + rq_job = enqueue_create_chunk_job( + queue=queue, + rq_job_id=rq_id, + create_callback=Callback( + callable=self._write_updated_static_chunk, + args=[ + db_segment.id, + chunk_id, + chunk_frames, + quality, + {chunk_frame: all_task_frames[chunk_frame].path for chunk_frame in chunk_frames}, + segment_frame_map, + ], + ), + ) + wait_for_rq_job(rq_job) - media_cache.remove_segment_chunk(db_segment, chunk_id, quality=quality) + MediaCache().remove_segment_chunk(db_segment, chunk_id, quality=quality) db_segment.chunks_updated_date = timezone.now() db_segment.save(update_fields=['chunks_updated_date']) @@ -1199,6 +1190,54 @@ def _iterate_chunk_frames(): return instance + @staticmethod + def _write_updated_static_chunk( + db_segment_id: int, + chunk_id: int, + chunk_frames: list[int], + quality: FrameQuality, + frame_path_map: dict[int, str], + segment_frame_map: dict[int,int], + ): + from cvat.apps.engine.frame_provider import prepare_chunk + + db_segment = models.Segment.objects.select_related("task").get(pk=db_segment_id) + initial_chunks_updated_date = db_segment.chunks_updated_date + db_task = db_segment.task + task_frame_provider = TaskFrameProvider(db_task) + db_data = db_task.data + + def _iterate_chunk_frames(): + for chunk_frame in chunk_frames: + db_frame_path = frame_path_map[chunk_frame] + chunk_real_frame = segment_frame_map.get(chunk_frame, chunk_frame) + yield ( + task_frame_provider.get_frame( + chunk_real_frame, quality=quality + ).data, + os.path.basename(db_frame_path), + chunk_frame, + ) + + with closing(_iterate_chunk_frames()) as frame_iter: + chunk, _ = prepare_chunk( + frame_iter, quality=quality, db_task=db_task, dump_unchanged=True, + ) + + get_chunk_path = { + FrameQuality.COMPRESSED: db_data.get_compressed_segment_chunk_path, + FrameQuality.ORIGINAL: db_data.get_original_segment_chunk_path, + }[quality] + + db_segment.refresh_from_db(fields=["chunks_updated_date"]) + if db_segment.chunks_updated_date > initial_chunks_updated_date: + raise CvatChunkTimestampMismatchError( + "Attempting to write an out of date static chunk, " + f"segment.chunks_updated_date: {db_segment.chunks_updated_date}, expected_ts: {initial_chunks_updated_date}" + ) + with open(get_chunk_path(chunk_id, db_segment_id), 'wb') as f: + f.write(chunk.getvalue()) + class JobValidationLayoutReadSerializer(serializers.Serializer): honeypot_count = serializers.IntegerField(min_value=0, required=False) honeypot_frames = serializers.ListField( diff --git a/cvat/apps/engine/utils.py b/cvat/apps/engine/utils.py index b45cb1baf020..72cb52eb5168 100644 --- a/cvat/apps/engine/utils.py +++ b/cvat/apps/engine/utils.py @@ -97,6 +97,9 @@ def execute_python_code(source_code, global_vars=None, local_vars=None): line_number = traceback.extract_tb(tb)[-1][1] raise InterpreterError("{} at line {}: {}".format(error_class, line_number, details)) +class CvatChunkTimestampMismatchError(Exception): + pass + def av_scan_paths(*paths): if 'yes' == os.environ.get('CLAM_AV'): command = ['clamscan', '--no-summary', '-i', '-o'] @@ -198,14 +201,22 @@ def define_dependent_job( return Dependency(jobs=[sorted(user_jobs, key=lambda job: job.created_at)[-1]], allow_failure=True) if user_jobs else None -def get_rq_lock_by_user(queue: DjangoRQ, user_id: int) -> Union[Lock, nullcontext]: +def get_rq_lock_by_user(queue: DjangoRQ, user_id: int, *, timeout: Optional[int] = 30, blocking_timeout: Optional[int] = None) -> Union[Lock, nullcontext]: if settings.ONE_RUNNING_JOB_IN_QUEUE_PER_USER: - return queue.connection.lock(f'{queue.name}-lock-{user_id}', timeout=30) + return queue.connection.lock( + name=f'{queue.name}-lock-{user_id}', + timeout=timeout, + blocking_timeout=blocking_timeout, + ) return nullcontext() -def get_rq_lock_for_job(queue: DjangoRQ, rq_id: str) -> Lock: +def get_rq_lock_for_job(queue: DjangoRQ, rq_id: str, *, timeout: Optional[int] = 60, blocking_timeout: Optional[int] = None) -> Lock: # lock timeout corresponds to the nginx request timeout (proxy_read_timeout) - return queue.connection.lock(f'lock-for-job-{rq_id}'.lower(), timeout=60) + return queue.connection.lock( + name=f'lock-for-job-{rq_id}'.lower(), + timeout=timeout, + blocking_timeout=blocking_timeout, + ) def get_rq_job_meta( request: HttpRequest, diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index ac046b1d0b26..a73cf9449a60 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -106,7 +106,7 @@ from .log import ServerLogManager from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS from cvat.apps.iam.permissions import PolicyEnforcer, IsAuthenticatedOrReadPublicResource -from cvat.apps.engine.cache import MediaCache +from cvat.apps.engine.cache import MediaCache, CvatChunkTimestampMismatchError, LockError from cvat.apps.engine.permissions import (CloudStoragePermission, CommentPermission, IssuePermission, JobPermission, LabelPermission, ProjectPermission, TaskPermission, UserPermission) @@ -118,6 +118,7 @@ _DATA_CHECKSUM_HEADER_NAME = 'X-Checksum' _DATA_UPDATED_DATE_HEADER_NAME = 'X-Updated-Date' +_RETRY_AFTER_TIMEOUT = 10 @extend_schema(tags=['server']) class ServerViewSet(viewsets.ViewSet): @@ -723,6 +724,11 @@ def __call__(self): msg = str(ex) if not isinstance(ex, ValidationError) else \ '\n'.join([str(d) for d in ex.detail]) return Response(data=msg, status=ex.status_code) + except (TimeoutError, CvatChunkTimestampMismatchError, LockError): + return Response( + status=status.HTTP_429_TOO_MANY_REQUESTS, + headers={'Retry-After': _RETRY_AFTER_TIMEOUT}, + ) @abstractmethod def _get_chunk_response_headers(self, chunk_data: DataWithMeta) -> dict[str, str]: ... @@ -806,20 +812,26 @@ def __call__(self): # Reproduce the task chunk indexing frame_provider = self._get_frame_provider() - if self.index is not None: - data = frame_provider.get_chunk( - self.index, quality=self.quality, is_task_chunk=False + try: + if self.index is not None: + data = frame_provider.get_chunk( + self.index, quality=self.quality, is_task_chunk=False + ) + else: + data = frame_provider.get_chunk( + self.number, quality=self.quality, is_task_chunk=True + ) + + return HttpResponse( + data.data.getvalue(), + content_type=data.mime, + headers=self._get_chunk_response_headers(data), ) - else: - data = frame_provider.get_chunk( - self.number, quality=self.quality, is_task_chunk=True + except (TimeoutError, CvatChunkTimestampMismatchError, LockError): + return Response( + status=status.HTTP_429_TOO_MANY_REQUESTS, + headers={'Retry-After': _RETRY_AFTER_TIMEOUT}, ) - - return HttpResponse( - data.data.getvalue(), - content_type=data.mime, - headers=self._get_chunk_response_headers(data), - ) else: return super().__call__() @@ -2968,6 +2980,11 @@ def preview(self, request, pk): '\n'.join([str(d) for d in ex.detail]) slogger.cloud_storage[pk].info(msg) return Response(data=msg, status=ex.status_code) + except (TimeoutError, CvatChunkTimestampMismatchError, LockError): + return Response( + status=status.HTTP_429_TOO_MANY_REQUESTS, + headers={'Retry-After': _RETRY_AFTER_TIMEOUT}, + ) except Exception as ex: slogger.glob.error(str(ex)) return Response("An internal error has occurred", @@ -3254,6 +3271,9 @@ def perform_destroy(self, instance): def rq_exception_handler(rq_job, exc_type, exc_value, tb): rq_job.meta[RQJobMetaField.FORMATTED_EXCEPTION] = "".join( traceback.format_exception_only(exc_type, exc_value)) + if rq_job.origin == settings.CVAT_QUEUES.CHUNKS.value: + rq_job.meta[RQJobMetaField.EXCEPTION_TYPE] = exc_type + rq_job.meta[RQJobMetaField.EXCEPTION_ARGS] = exc_value.args rq_job.save_meta() return True diff --git a/cvat/apps/events/event.py b/cvat/apps/events/event.py index ae519b568644..a4afff968549 100644 --- a/cvat/apps/events/event.py +++ b/cvat/apps/events/event.py @@ -20,6 +20,8 @@ class EventScopes: "task": ["create", "update", "delete"], "job": ["create", "update", "delete"], "organization": ["create", "update", "delete"], + "membership": ["create", "update", "delete"], + "invitation": ["create", "delete"], "user": ["create", "update", "delete"], "cloudstorage": ["create", "update", "delete"], "issue": ["create", "update", "delete"], @@ -28,6 +30,7 @@ class EventScopes: "label": ["create", "update", "delete"], "dataset": ["export", "import"], "function": ["call"], + "webhook": ["create", "update", "delete"], } @classmethod diff --git a/cvat/apps/events/handlers.py b/cvat/apps/events/handlers.py index f2d3f7577617..8f29f91d9a1a 100644 --- a/cvat/apps/events/handlers.py +++ b/cvat/apps/events/handlers.py @@ -4,7 +4,7 @@ import datetime import traceback -from typing import Optional, Union +from typing import Any, Optional, Union import rq from crum import get_current_request, get_current_user @@ -26,6 +26,8 @@ MembershipReadSerializer, OrganizationReadSerializer) from cvat.apps.engine.rq_job_handler import RQJobMetaField +from cvat.apps.webhooks.models import Webhook +from cvat.apps.webhooks.serializers import WebhookReadSerializer from .cache import get_cache from .event import event_scope, record_server_event @@ -66,6 +68,7 @@ def task_id(instance): except Exception: return None + def job_id(instance): if isinstance(instance, Job): return instance.id @@ -78,6 +81,7 @@ def job_id(instance): except Exception: return None + def get_user(instance=None): # Try to get current user from request user = get_current_user() @@ -97,6 +101,7 @@ def get_user(instance=None): return None + def get_request(instance=None): request = get_current_request() if request is not None: @@ -111,6 +116,7 @@ def get_request(instance=None): return None + def _get_value(obj, key): if obj is not None: if isinstance(obj, dict): @@ -119,22 +125,27 @@ def _get_value(obj, key): return None + def request_id(instance=None): request = get_request(instance) return _get_value(request, "uuid") + def user_id(instance=None): current_user = get_user(instance) return _get_value(current_user, "id") + def user_name(instance=None): current_user = get_user(instance) return _get_value(current_user, "username") + def user_email(instance=None): current_user = get_user(instance) return _get_value(current_user, "email") or None + def organization_slug(instance): if isinstance(instance, Organization): return instance.slug @@ -147,6 +158,7 @@ def organization_slug(instance): except Exception: return None + def get_instance_diff(old_data, data): ignore_related_fields = ( "labels", @@ -164,7 +176,8 @@ def get_instance_diff(old_data, data): return diff -def _cleanup_fields(obj): + +def _cleanup_fields(obj: dict[str, Any]) -> dict[str, Any]: fields=( "slug", "id", @@ -183,6 +196,7 @@ def _cleanup_fields(obj): "url", "issues", "attributes", + "key", ) subfields=( "url", @@ -198,6 +212,7 @@ def _cleanup_fields(obj): data[k] = v return data + def _get_object_name(instance): if isinstance(instance, Organization) or \ isinstance(instance, Project) or \ @@ -217,34 +232,32 @@ def _get_object_name(instance): return None + +SERIALIZERS = [ + (Organization, OrganizationReadSerializer), + (Project, ProjectReadSerializer), + (Task, TaskReadSerializer), + (Job, JobReadSerializer), + (User, BasicUserSerializer), + (CloudStorage, CloudStorageReadSerializer), + (Issue, IssueReadSerializer), + (Comment, CommentReadSerializer), + (Label, LabelSerializer), + (Membership, MembershipReadSerializer), + (Invitation, InvitationReadSerializer), + (Webhook, WebhookReadSerializer), +] + + def get_serializer(instance): context = { "request": get_current_request() } serializer = None - if isinstance(instance, Organization): - serializer = OrganizationReadSerializer(instance=instance, context=context) - if isinstance(instance, Project): - serializer = ProjectReadSerializer(instance=instance, context=context) - if isinstance(instance, Task): - serializer = TaskReadSerializer(instance=instance, context=context) - if isinstance(instance, Job): - serializer = JobReadSerializer(instance=instance, context=context) - if isinstance(instance, User): - serializer = BasicUserSerializer(instance=instance, context=context) - if isinstance(instance, CloudStorage): - serializer = CloudStorageReadSerializer(instance=instance, context=context) - if isinstance(instance, Issue): - serializer = IssueReadSerializer(instance=instance, context=context) - if isinstance(instance, Comment): - serializer = CommentReadSerializer(instance=instance, context=context) - if isinstance(instance, Label): - serializer = LabelSerializer(instance=instance, context=context) - if isinstance(instance, Membership): - serializer = MembershipReadSerializer(instance=instance, context=context) - if isinstance(instance, Invitation): - serializer = InvitationReadSerializer(instance=instance, context=context) + for model, serializer_class in SERIALIZERS: + if isinstance(instance, model): + serializer = serializer_class(instance=instance, context=context) return serializer @@ -254,6 +267,7 @@ def get_serializer_without_url(instance): serializer.fields.pop("url", None) return serializer + def handle_create(scope, instance, **kwargs): oid = organization_id(instance) oslug = organization_slug(instance) @@ -288,6 +302,7 @@ def handle_create(scope, instance, **kwargs): payload=payload, ) + def handle_update(scope, instance, old_instance, **kwargs): oid = organization_id(instance) oslug = organization_slug(instance) @@ -322,12 +337,14 @@ def handle_update(scope, instance, old_instance, **kwargs): payload={"old_value": change["old_value"]}, ) + def handle_delete(scope, instance, store_in_deletion_cache=False, **kwargs): deletion_cache = get_cache() + instance_id = getattr(instance, "id", None) if store_in_deletion_cache: deletion_cache.set( instance.__class__, - instance.id, + instance_id, { "oid": organization_id(instance), "oslug": organization_slug(instance), @@ -338,7 +355,7 @@ def handle_delete(scope, instance, store_in_deletion_cache=False, **kwargs): ) return - instance_meta_info = deletion_cache.pop(instance.__class__, instance.id) + instance_meta_info = deletion_cache.pop(instance.__class__, instance_id) if instance_meta_info: oid = instance_meta_info["oid"] oslug = instance_meta_info["oslug"] @@ -360,7 +377,7 @@ def handle_delete(scope, instance, store_in_deletion_cache=False, **kwargs): scope=scope, request_id=request_id(), on_commit=True, - obj_id=getattr(instance, 'id', None), + obj_id=instance_id, obj_name=_get_object_name(instance), org_id=oid, org_slug=oslug, @@ -372,15 +389,12 @@ def handle_delete(scope, instance, store_in_deletion_cache=False, **kwargs): user_email=uemail, ) + def handle_annotations_change(instance, annotations, action, **kwargs): def filter_data(data): filtered_data = { "id": data["id"], - "frame": data["frame"], - "attributes": data["attributes"], } - if label_id := data.get("label_id"): - filtered_data["label_id"] = label_id return filtered_data diff --git a/cvat/apps/events/signals.py b/cvat/apps/events/signals.py index 25d320c35e1d..c304fc69b61c 100644 --- a/cvat/apps/events/signals.py +++ b/cvat/apps/events/signals.py @@ -2,26 +2,30 @@ # # SPDX-License-Identifier: MIT -from django.dispatch import receiver -from django.db.models.signals import pre_save, post_save, post_delete from django.core.exceptions import ObjectDoesNotExist +from django.db.models.signals import post_delete, post_save, pre_save +from django.dispatch import receiver from cvat.apps.engine.models import ( - TimestampedModel, - Project, - Task, - Job, - User, CloudStorage, - Issue, Comment, + Issue, + Job, Label, + Project, + Task, + TimestampedModel, + User, ) -from cvat.apps.organizations.models import Organization +from cvat.apps.organizations.models import Invitation, Membership, Organization +from cvat.apps.webhooks.models import Webhook -from .handlers import handle_update, handle_create, handle_delete from .event import EventScopeChoice, event_scope +from .handlers import handle_create, handle_delete, handle_update + +@receiver(pre_save, sender=Webhook, dispatch_uid="webhook:update_receiver") +@receiver(pre_save, sender=Membership, dispatch_uid="membership:update_receiver") @receiver(pre_save, sender=Organization, dispatch_uid="organization:update_receiver") @receiver(pre_save, sender=Project, dispatch_uid="project:update_receiver") @receiver(pre_save, sender=Task, dispatch_uid="task:update_receiver") @@ -34,7 +38,8 @@ def resource_update(sender, *, instance, update_fields, **kwargs): if ( isinstance(instance, TimestampedModel) - and update_fields and list(update_fields) == ["updated_date"] + and update_fields + and list(update_fields) == ["updated_date"] ): # This is an optimization for the common case where only the date is bumped # (see `TimestampedModel.touch`). Since the actual update of the field will @@ -57,6 +62,10 @@ def resource_update(sender, *, instance, update_fields, **kwargs): handle_update(scope=scope, instance=instance, old_instance=old_instance, **kwargs) + +@receiver(post_save, sender=Webhook, dispatch_uid="webhook:create_receiver") +@receiver(post_save, sender=Membership, dispatch_uid="membership:create_receiver") +@receiver(post_save, sender=Invitation, dispatch_uid="invitation:create_receiver") @receiver(post_save, sender=Organization, dispatch_uid="organization:create_receiver") @receiver(post_save, sender=Project, dispatch_uid="project:create_receiver") @receiver(post_save, sender=Task, dispatch_uid="task:create_receiver") @@ -78,6 +87,10 @@ def resource_create(sender, instance, created, **kwargs): handle_create(scope=scope, instance=instance, **kwargs) + +@receiver(post_delete, sender=Webhook, dispatch_uid="webhook:delete_receiver") +@receiver(post_delete, sender=Membership, dispatch_uid="membership:delete_receiver") +@receiver(post_delete, sender=Invitation, dispatch_uid="invitation:delete_receiver") @receiver(post_delete, sender=Organization, dispatch_uid="organization:delete_receiver") @receiver(post_delete, sender=Project, dispatch_uid="project:delete_receiver") @receiver(post_delete, sender=Task, dispatch_uid="task:delete_receiver") diff --git a/cvat/apps/events/tests/test_events.py b/cvat/apps/events/tests/test_events.py index 990daa1ea325..81b054171dce 100644 --- a/cvat/apps/events/tests/test_events.py +++ b/cvat/apps/events/tests/test_events.py @@ -5,7 +5,7 @@ import json import unittest from datetime import datetime, timedelta, timezone -from typing import List, Optional +from typing import Optional from django.contrib.auth import get_user_model from django.test import RequestFactory @@ -42,7 +42,7 @@ def _working_time(event: dict) -> int: return payload["working_time"] @staticmethod - def _deserialize(events: List[dict], previous_event: Optional[dict] = None) -> List[dict]: + def _deserialize(events: list[dict], previous_event: Optional[dict] = None) -> list[dict]: request = RequestFactory().post("/api/events") request.user = get_user_model()(id=100, username="testuser", email="testuser@example.org") request.iam_context = { diff --git a/cvat/apps/iam/permissions.py b/cvat/apps/iam/permissions.py index bb2ab44a414f..d4925426724a 100644 --- a/cvat/apps/iam/permissions.py +++ b/cvat/apps/iam/permissions.py @@ -8,9 +8,10 @@ import importlib import operator from abc import ABCMeta, abstractmethod +from collections.abc import Sequence from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, TypeVar +from typing import Any, Optional, TypeVar from attrs import define, field from django.apps import AppConfig @@ -33,7 +34,7 @@ def __str__(self) -> str: @define class PermissionResult: allow: bool - reasons: List[str] = field(factory=list) + reasons: list[str] = field(factory=list) def get_organization(request, obj): @@ -83,7 +84,7 @@ def build_iam_context(request, organization: Optional[Organization], membership: } -def get_iam_context(request, obj) -> Dict[str, Any]: +def get_iam_context(request, obj) -> dict[str, Any]: organization = get_organization(request, obj) membership = get_membership(request, organization) diff --git a/cvat/apps/iam/rules/tests/generate_tests.py b/cvat/apps/iam/rules/tests/generate_tests.py index 254930e73d61..729de6732eb2 100755 --- a/cvat/apps/iam/rules/tests/generate_tests.py +++ b/cvat/apps/iam/rules/tests/generate_tests.py @@ -7,9 +7,10 @@ import subprocess import sys from argparse import ArgumentParser, Namespace +from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from functools import partial -from typing import Optional, Sequence +from typing import Optional from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[5] diff --git a/cvat/apps/iam/serializers.py b/cvat/apps/iam/serializers.py index 862712454de0..967b696a4f21 100644 --- a/cvat/apps/iam/serializers.py +++ b/cvat/apps/iam/serializers.py @@ -19,7 +19,7 @@ from django.contrib.auth.models import User from drf_spectacular.utils import extend_schema_field -from typing import Optional, Union, Dict +from typing import Optional, Union from cvat.apps.iam.forms import ResetPasswordFormEx from cvat.apps.iam.utils import get_dummy_user @@ -32,11 +32,11 @@ class RegisterSerializerEx(RegisterSerializer): key = serializers.SerializerMethodField() @extend_schema_field(serializers.BooleanField) - def get_email_verification_required(self, obj: Union[Dict, User]) -> bool: + def get_email_verification_required(self, obj: Union[dict, User]) -> bool: return allauth_settings.EMAIL_VERIFICATION == allauth_settings.EmailVerificationMethod.MANDATORY @extend_schema_field(serializers.CharField(allow_null=True)) - def get_key(self, obj: Union[Dict, User]) -> Optional[str]: + def get_key(self, obj: Union[dict, User]) -> Optional[str]: key = None if isinstance(obj, User) and allauth_settings.EMAIL_VERIFICATION != \ allauth_settings.EmailVerificationMethod.MANDATORY: diff --git a/cvat/apps/iam/signals.py b/cvat/apps/iam/signals.py index 28159cddc745..73f919a1a4a4 100644 --- a/cvat/apps/iam/signals.py +++ b/cvat/apps/iam/signals.py @@ -42,6 +42,9 @@ def create_user(sender, user=None, ldap_user=None, **kwargs): if role == settings.IAM_ADMIN_ROLE: user.is_staff = user.is_superuser = True break + # add default group if no other group has been assigned + if not len(user_groups): + user_groups.append(Group.objects.get(name=settings.IAM_DEFAULT_ROLE)) # It is important to save the user before adding groups. Please read # https://django-auth-ldap.readthedocs.io/en/latest/users.html#populating-users diff --git a/cvat/apps/iam/utils.py b/cvat/apps/iam/utils.py index a13de3367336..8095902769f3 100644 --- a/cvat/apps/iam/utils.py +++ b/cvat/apps/iam/utils.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Tuple import functools import hashlib import importlib @@ -14,7 +13,7 @@ } @functools.lru_cache(maxsize=None) -def get_opa_bundle() -> Tuple[bytes, str]: +def get_opa_bundle() -> tuple[bytes, str]: bundle_file = io.BytesIO() with tarfile.open(fileobj=bundle_file, mode='w:gz') as tar: diff --git a/cvat/apps/lambda_manager/serializers.py b/cvat/apps/lambda_manager/serializers.py index 4108b4e97ad9..ab8809bd7cc8 100644 --- a/cvat/apps/lambda_manager/serializers.py +++ b/cvat/apps/lambda_manager/serializers.py @@ -24,13 +24,11 @@ class FunctionCallRequestSerializer(serializers.Serializer): function = serializers.CharField(help_text="The name of the function to execute") task = serializers.IntegerField(help_text="The id of the task to be annotated") job = serializers.IntegerField(required=False, help_text="The id of the job to be annotated") - quality = serializers.ChoiceField(choices=['compressed', 'original'], default="original", - help_text="The quality of the images to use in the model run" - ) max_distance = serializers.IntegerField(required=False) threshold = serializers.FloatField(required=False) cleanup = serializers.BooleanField(help_text="Whether existing annotations should be removed", default=False) - convMaskToPoly = serializers.BooleanField(default=False) # TODO: use lowercase naming + convMaskToPoly = serializers.BooleanField(required=False, source="conv_mask_to_poly", write_only=True, help_text="Deprecated; use conv_mask_to_poly instead") + conv_mask_to_poly = serializers.BooleanField(required=False, help_text="Convert mask shapes to polygons") mapping = serializers.DictField(child=LabelMappingEntrySerializer(), required=False, help_text="Label mapping from the model to the task labels" ) diff --git a/cvat/apps/lambda_manager/tests/test_lambda.py b/cvat/apps/lambda_manager/tests/test_lambda.py index 794ef8cefabe..f9292b278b45 100644 --- a/cvat/apps/lambda_manager/tests/test_lambda.py +++ b/cvat/apps/lambda_manager/tests/test_lambda.py @@ -5,7 +5,7 @@ from collections import Counter, OrderedDict from itertools import groupby -from typing import Dict, Optional +from typing import Optional from unittest import mock, skip import json import os @@ -368,7 +368,6 @@ def test_api_v2_lambda_requests_read(self): "task": self.main_task["id"], "cleanup": True, "threshold": 55, - "quality": "original", "mapping": { "car": { "name": "car" }, }, @@ -447,7 +446,6 @@ def test_api_v2_lambda_requests_create(self): "task": self.main_task["id"], "cleanup": True, "threshold": 55, - "quality": "original", "mapping": { "car": { "name": "car" }, }, @@ -456,7 +454,6 @@ def test_api_v2_lambda_requests_create(self): "function": id_func, "task": self.assigneed_to_user_task["id"], "cleanup": False, - "quality": "compressed", "max_distance": 70, "mapping": { "car": { "name": "car" }, @@ -769,7 +766,6 @@ def test_api_v2_lambda_functions_create_reid(self): OrderedDict([('attributes', []), ('frame', 1), ('group', None), ('id', 11260), ('label_id', 8), ('occluded', False), ('points', [1076.0, 199.0, 1218.0, 593.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]), OrderedDict([('attributes', []), ('frame', 1), ('group', None), ('id', 11261), ('label_id', 8), ('occluded', False), ('points', [924.0, 177.0, 1090.0, 615.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]), ], - "quality": None, "threshold": 0.5, "max_distance": 55, } @@ -785,7 +781,6 @@ def test_api_v2_lambda_functions_create_reid(self): OrderedDict([('attributes', []), ('frame', 1), ('group', None), ('id', 11260), ('label_id', 8), ('occluded', False), ('points', [1076.0, 199.0, 1218.0, 593.0]), ('source', 'auto'), ('type', 'rectangle'), ('z_order', 0)]), OrderedDict([('attributes', []), ('frame', 1), ('group', 0), ('id', 11398), ('label_id', 8), ('occluded', False), ('points', [184.3935546875, 211.5048828125, 331.64968722073354, 97.27792672028772, 445.87667560321825, 126.17873100983161, 454.13404825737416, 691.8087578194827, 180.26452189455085]), ('source', 'manual'), ('type', 'polygon'), ('z_order', 0)]), ], - "quality": None, } response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_reid_with_response_data}", self.admin, data_main_task) @@ -829,42 +824,11 @@ def test_api_v2_lambda_functions_create_negative(self): self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR) - def test_api_v2_lambda_functions_create_quality(self): - qualities = [None, "original", "compressed"] - - for quality in qualities: - data = { - "task": self.main_task["id"], - "frame": 0, - "cleanup": True, - "quality": quality, - "mapping": { - "car": { "name": "car" }, - }, - } - - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) - self.assertEqual(response.status_code, status.HTTP_200_OK) - - data = { - "task": self.main_task["id"], - "frame": 0, - "cleanup": True, - "quality": "test-error-quality", - "mapping": { - "car": { "name": "car" }, - }, - } - - response = self._post_request(f"{LAMBDA_FUNCTIONS_PATH}/{id_function_detector}", self.admin, data) - self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) - def test_api_v2_lambda_functions_convert_mask_to_rle(self): data_main_task = { "function": id_function_detector, "task": self.main_task["id"], "cleanup": True, - "quality": "original", "mapping": { "car": { "name": "car" }, }, @@ -1476,7 +1440,7 @@ class Issue4996_Cases(_LambdaTestCaseBase): # We need to check that job assignee can call functions in the assigned jobs # This requires to pass the job id in the call request. - def _create_org(self, *, owner: int, members: Dict[int, str] = None) -> dict: + def _create_org(self, *, owner: int, members: dict[int, str] = None) -> dict: org = self._post_request('/api/organizations', user=owner, data={ "slug": "testorg", "name": "test Org", diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 143537985fd7..559ef29813b5 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -12,7 +12,7 @@ from copy import deepcopy from datetime import timedelta from functools import wraps -from typing import Any, Dict, Optional +from typing import Any, Optional import datumaro.util.mask_tools as mask_tools import django_rq @@ -32,7 +32,7 @@ from rest_framework.request import Request import cvat.apps.dataset_manager as dm -from cvat.apps.engine.frame_provider import FrameQuality, TaskFrameProvider +from cvat.apps.engine.frame_provider import TaskFrameProvider from cvat.apps.engine.models import ( Job, ShapeType, SourceType, Task, Label, RequestAction, RequestTarget ) @@ -231,7 +231,7 @@ def to_dict(self): def invoke( self, db_task: Task, - data: Dict[str, Any], + data: dict[str, Any], *, db_job: Optional[Job] = None, is_interactive: Optional[bool] = False, @@ -257,13 +257,12 @@ def mandatory_arg(name: str) -> Any: threshold = data.get("threshold") if threshold: payload.update({ "threshold": threshold }) - quality = data.get("quality") mapping = data.get("mapping", {}) model_labels = self.labels task_labels = db_task.get_labels(prefetch=True) - def labels_compatible(model_label: Dict, task_label: Label) -> bool: + def labels_compatible(model_label: dict, task_label: Label) -> bool: model_type = model_label['type'] db_type = task_label.type compatible_types = [[ShapeType.MASK, ShapeType.POLYGON]] @@ -387,19 +386,19 @@ def validate_attributes_mapping(attributes_mapping, model_attributes, db_attribu if self.kind == FunctionKind.DETECTOR: payload.update({ - "image": self._get_image(db_task, mandatory_arg("frame"), quality) + "image": self._get_image(db_task, mandatory_arg("frame")) }) elif self.kind == FunctionKind.INTERACTOR: payload.update({ - "image": self._get_image(db_task, mandatory_arg("frame"), quality), + "image": self._get_image(db_task, mandatory_arg("frame")), "pos_points": mandatory_arg("pos_points"), "neg_points": mandatory_arg("neg_points"), "obj_bbox": data.get("obj_bbox", None) }) elif self.kind == FunctionKind.REID: payload.update({ - "image0": self._get_image(db_task, mandatory_arg("frame0"), quality), - "image1": self._get_image(db_task, mandatory_arg("frame1"), quality), + "image0": self._get_image(db_task, mandatory_arg("frame0")), + "image1": self._get_image(db_task, mandatory_arg("frame1")), "boxes0": mandatory_arg("boxes0"), "boxes1": mandatory_arg("boxes1") }) @@ -410,7 +409,7 @@ def validate_attributes_mapping(attributes_mapping, model_attributes, db_attribu }) elif self.kind == FunctionKind.TRACKER: payload.update({ - "image": self._get_image(db_task, mandatory_arg("frame"), quality), + "image": self._get_image(db_task, mandatory_arg("frame")), "shapes": data.get("shapes", []), "states": data.get("states", []) }) @@ -487,19 +486,9 @@ def transform_attributes(input_attributes, attr_mapping, db_attributes): return response - def _get_image(self, db_task, frame, quality): - if quality is None or quality == "original": - quality = FrameQuality.ORIGINAL - elif quality == "compressed": - quality = FrameQuality.COMPRESSED - else: - raise ValidationError( - '`{}` lambda function was run '.format(self.id) + - 'with wrong arguments (quality={})'.format(quality), - code=status.HTTP_400_BAD_REQUEST) - + def _get_image(self, db_task, frame): frame_provider = TaskFrameProvider(db_task) - image = frame_provider.get_frame(frame, quality=quality) + image = frame_provider.get_frame(frame) return base64.b64encode(image.data.getvalue()).decode('utf-8') @@ -523,7 +512,7 @@ def get_jobs(self): return [LambdaJob(job) for job in jobs if job and job.meta.get("lambda")] def enqueue(self, - lambda_func, threshold, task, quality, mapping, cleanup, conv_mask_to_poly, max_distance, request, + lambda_func, threshold, task, mapping, cleanup, conv_mask_to_poly, max_distance, request, *, job: Optional[int] = None ) -> LambdaJob: @@ -576,7 +565,6 @@ def enqueue(self, "threshold": threshold, "task": task, "job": job, - "quality": quality, "cleanup": cleanup, "conv_mask_to_poly": conv_mask_to_poly, "mapping": mapping, @@ -666,10 +654,9 @@ def _call_detector( cls, function: LambdaFunction, db_task: Task, - labels: Dict[str, Dict[str, Any]], - quality: str, + labels: dict[str, dict[str, Any]], threshold: float, - mapping: Optional[Dict[str, str]], + mapping: Optional[dict[str, str]], conv_mask_to_poly: bool, *, db_job: Optional[Job] = None @@ -799,7 +786,7 @@ def _map(sublabel_body): continue annotations = function.invoke(db_task, db_job=db_job, data={ - "frame": frame, "quality": quality, "mapping": mapping, + "frame": frame, "mapping": mapping, "threshold": threshold }) @@ -854,7 +841,6 @@ def _call_reid( cls, function: LambdaFunction, db_task: Task, - quality: str, threshold: float, max_distance: int, *, @@ -887,7 +873,7 @@ def _call_reid( boxes1 = boxes_by_frame[frame1] if boxes0 and boxes1: matching = function.invoke(db_task, db_job=db_job, data={ - "frame0": frame0, "frame1": frame1, "quality": quality, + "frame0": frame0, "frame1": frame1, "boxes0": boxes0, "boxes1": boxes1, "threshold": threshold, "max_distance": max_distance}) @@ -947,7 +933,7 @@ def _call_reid( dm.task.put_task_data(db_task.id, serializer.data) @classmethod - def __call__(cls, function, task: int, quality: str, cleanup: bool, **kwargs): + def __call__(cls, function, task: int, cleanup: bool, **kwargs): # TODO: need logging db_job = None if job := kwargs.get('job'): @@ -977,11 +963,11 @@ def convert_labels(db_labels): labels = convert_labels(db_task.get_labels(prefetch=True)) if function.kind == FunctionKind.DETECTOR: - cls._call_detector(function, db_task, labels, quality, + cls._call_detector(function, db_task, labels, kwargs.get("threshold"), kwargs.get("mapping"), kwargs.get("conv_mask_to_poly"), db_job=db_job) elif function.kind == FunctionKind.REID: - cls._call_reid(function, db_task, quality, + cls._call_reid(function, db_task, kwargs.get("threshold"), kwargs.get("max_distance"), db_job=db_job) def return_response(success_code=status.HTTP_200_OK): @@ -1176,9 +1162,8 @@ def create(self, request): threshold = request_data.get('threshold') task = request_data['task'] job = request_data.get('job', None) - quality = request_data.get("quality") cleanup = request_data.get('cleanup', False) - conv_mask_to_poly = request_data.get('convMaskToPoly', False) + conv_mask_to_poly = request_data.get('conv_mask_to_poly', False) mapping = request_data.get('mapping') max_distance = request_data.get('max_distance') except KeyError as err: @@ -1190,7 +1175,7 @@ def create(self, request): gateway = LambdaGateway() queue = LambdaQueue() lambda_func = gateway.get(function) - rq_job = queue.enqueue(lambda_func, threshold, task, quality, + rq_job = queue.enqueue(lambda_func, threshold, task, mapping, cleanup, conv_mask_to_poly, max_distance, request, job=job) handle_function_call(function, job or task, category="batch") diff --git a/cvat/apps/quality_control/models.py b/cvat/apps/quality_control/models.py index b8cf76873597..a5359e4fe944 100644 --- a/cvat/apps/quality_control/models.py +++ b/cvat/apps/quality_control/models.py @@ -4,9 +4,10 @@ from __future__ import annotations +from collections.abc import Sequence from copy import deepcopy from enum import Enum -from typing import Any, Sequence +from typing import Any from django.core.exceptions import ValidationError from django.db import models diff --git a/cvat/apps/quality_control/quality_reports.py b/cvat/apps/quality_control/quality_reports.py index f5e527468aa3..627c4dc7b978 100644 --- a/cvat/apps/quality_control/quality_reports.py +++ b/cvat/apps/quality_control/quality_reports.py @@ -7,10 +7,11 @@ import itertools import math from collections import Counter +from collections.abc import Hashable, Sequence from copy import deepcopy from datetime import timedelta from functools import cached_property, partial -from typing import Any, Callable, Dict, Hashable, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Callable, Optional, Union, cast import datumaro as dm import datumaro.util.mask_tools @@ -77,7 +78,7 @@ def _value_serializer(self, v): def to_dict(self) -> dict: return self._value_serializer(self._fields_dict()) - def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + def _fields_dict(self, *, include_properties: Optional[list[str]] = None) -> dict: d = asdict(self, recurse=False) for field_name in include_properties or []: @@ -117,7 +118,7 @@ def from_dict(cls, d: dict): class AnnotationConflict(_Serializable): frame_id: int type: AnnotationConflictType - annotation_ids: List[AnnotationId] + annotation_ids: list[AnnotationId] @property def severity(self) -> AnnotationConflictSeverity: @@ -146,7 +147,7 @@ def _value_serializer(self, v): else: return super()._value_serializer(v) - def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + def _fields_dict(self, *, include_properties: Optional[list[str]] = None) -> dict: return super()._fields_dict(include_properties=include_properties or ["severity"]) @classmethod @@ -160,7 +161,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) class ComparisonParameters(_Serializable): - included_annotation_types: List[dm.AnnotationType] = [ + included_annotation_types: list[dm.AnnotationType] = [ dm.AnnotationType.bbox, dm.AnnotationType.points, dm.AnnotationType.mask, @@ -176,7 +177,7 @@ class ComparisonParameters(_Serializable): compare_attributes: bool = True "Enables or disables attribute checks" - ignored_attributes: List[str] = [] + ignored_attributes: list[str] = [] iou_threshold: float = 0.4 "Used for distinction between matched / unmatched shapes" @@ -238,7 +239,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) class ConfusionMatrix(_Serializable): - labels: List[str] + labels: list[str] rows: np.ndarray precision: np.ndarray recall: np.ndarray @@ -255,7 +256,7 @@ def _value_serializer(self, v): else: return super()._value_serializer(v) - def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + def _fields_dict(self, *, include_properties: Optional[list[str]] = None) -> dict: return super()._fields_dict(include_properties=include_properties or ["axes"]) @classmethod @@ -305,7 +306,7 @@ def accumulate(self, other: ComparisonReportAnnotationsSummary): ]: setattr(self, field, getattr(self, field) + getattr(other, field)) - def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + def _fields_dict(self, *, include_properties: Optional[list[str]] = None) -> dict: return super()._fields_dict( include_properties=include_properties or ["accuracy", "precision", "recall"] ) @@ -348,7 +349,7 @@ def accumulate(self, other: ComparisonReportAnnotationShapeSummary): ]: setattr(self, field, getattr(self, field) + getattr(other, field)) - def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + def _fields_dict(self, *, include_properties: Optional[list[str]] = None) -> dict: return super()._fields_dict(include_properties=include_properties or ["accuracy"]) @classmethod @@ -378,7 +379,7 @@ def accumulate(self, other: ComparisonReportAnnotationLabelSummary): for field in ["valid_count", "total_count", "invalid_count"]: setattr(self, field, getattr(self, field) + getattr(other, field)) - def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + def _fields_dict(self, *, include_properties: Optional[list[str]] = None) -> dict: return super()._fields_dict(include_properties=include_properties or ["accuracy"]) @classmethod @@ -410,7 +411,7 @@ def from_dict(cls, d: dict): @define(kw_only=True) class ComparisonReportComparisonSummary(_Serializable): frame_share: float - frames: List[str] + frames: list[str] @property def mean_conflict_count(self) -> float: @@ -419,7 +420,7 @@ def mean_conflict_count(self) -> float: conflict_count: int warning_count: int error_count: int - conflicts_by_type: Dict[AnnotationConflictType, int] + conflicts_by_type: dict[AnnotationConflictType, int] annotations: ComparisonReportAnnotationsSummary annotation_components: ComparisonReportAnnotationComponentsSummary @@ -434,7 +435,7 @@ def _value_serializer(self, v): else: return super()._value_serializer(v) - def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + def _fields_dict(self, *, include_properties: Optional[list[str]] = None) -> dict: return super()._fields_dict( include_properties=include_properties or [ @@ -466,7 +467,7 @@ def from_dict(cls, d: dict): @define(kw_only=True, init=False) class ComparisonReportFrameSummary(_Serializable): - conflicts: List[AnnotationConflict] + conflicts: list[AnnotationConflict] @cached_property def conflict_count(self) -> int: @@ -481,7 +482,7 @@ def error_count(self) -> int: return len([c for c in self.conflicts if c.severity == AnnotationConflictSeverity.ERROR]) @cached_property - def conflicts_by_type(self) -> Dict[AnnotationConflictType, int]: + def conflicts_by_type(self) -> dict[AnnotationConflictType, int]: return Counter(c.type for c in self.conflicts) annotations: ComparisonReportAnnotationsSummary @@ -503,7 +504,7 @@ def __init__(self, *args, **kwargs): self.__attrs_init__(*args, **kwargs) - def _fields_dict(self, *, include_properties: Optional[List[str]] = None) -> dict: + def _fields_dict(self, *, include_properties: Optional[list[str]] = None) -> dict: return super()._fields_dict(include_properties=include_properties or self._CACHED_FIELDS) @classmethod @@ -534,14 +535,14 @@ def from_dict(cls, d: dict): class ComparisonReport(_Serializable): parameters: ComparisonParameters comparison_summary: ComparisonReportComparisonSummary - frame_results: Dict[int, ComparisonReportFrameSummary] + frame_results: dict[int, ComparisonReportFrameSummary] @property - def conflicts(self) -> List[AnnotationConflict]: + def conflicts(self) -> list[AnnotationConflict]: return list(itertools.chain.from_iterable(r.conflicts for r in self.frame_results.values())) @classmethod - def from_dict(cls, d: Dict[str, Any]) -> ComparisonReport: + def from_dict(cls, d: dict[str, Any]) -> ComparisonReport: return cls( parameters=ComparisonParameters.from_dict(d["parameters"]), comparison_summary=ComparisonReportComparisonSummary.from_dict(d["comparison_summary"]), @@ -632,7 +633,7 @@ def get_source_ann( def clear(self): self._annotation_mapping.clear() - def __call__(self, *args, **kwargs) -> List[dm.Annotation]: + def __call__(self, *args, **kwargs) -> list[dm.Annotation]: converter = _MemoizingAnnotationConverter(*args, factory=self, **kwargs) return converter.convert() @@ -861,7 +862,7 @@ def _compare_lines(self, a: np.ndarray, b: np.ndarray) -> float: return sum(np.exp(-(dists**2) / (2 * scale * (2 * self.torso_r) ** 2))) / len(a) @classmethod - def approximate_points(cls, a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + def approximate_points(cls, a: np.ndarray, b: np.ndarray) -> tuple[np.ndarray, np.ndarray]: """ Creates 2 polylines with the same numbers of points, the points are placed on the original lines with the same step. @@ -959,7 +960,7 @@ def __init__( self, categories: dm.CategoriesInfo, *, - included_ann_types: Optional[List[dm.AnnotationType]] = None, + included_ann_types: Optional[list[dm.AnnotationType]] = None, return_distances: bool = False, iou_threshold: float = 0.5, # https://cocodataset.org/#keypoints-eval @@ -994,7 +995,7 @@ def __init__( def _instance_bbox( self, instance_anns: Sequence[dm.Annotation] - ) -> Tuple[float, float, float, float]: + ) -> tuple[float, float, float, float]: return dm.ops.max_bbox( a.get_bbox() if isinstance(a, dm.Skeleton) else a for a in instance_anns @@ -1141,7 +1142,7 @@ def _find_instances(annotations): return instances, instance_map def _get_compiled_mask( - anns: Sequence[dm.Annotation], *, instance_ids: Dict[int, int] + anns: Sequence[dm.Annotation], *, instance_ids: dict[int, int] ) -> dm.CompiledMask: if not anns: return None @@ -1583,7 +1584,7 @@ def match_attrs(self, ann_a: dm.Annotation, ann_b: dm.Annotation): def find_groups( self, item: dm.DatasetItem - ) -> Tuple[Dict[int, List[dm.Annotation]], Dict[int, int]]: + ) -> tuple[dict[int, list[dm.Annotation]], dict[int, int]]: ann_groups = dm.ops.find_instances( [ ann @@ -1632,7 +1633,7 @@ def _group_distance(gt_group_id, ds_group_id): return ds_to_gt_groups - def find_covered(self, item: dm.DatasetItem) -> List[dm.Annotation]: + def find_covered(self, item: dm.DatasetItem) -> list[dm.Annotation]: # Get annotations that can cover or be covered spatial_types = { dm.AnnotationType.polygon, @@ -1707,7 +1708,7 @@ def __init__( self._ds_dataset = self._ds_data_provider.dm_dataset self._gt_dataset = self._gt_data_provider.dm_dataset - self._frame_results: Dict[int, ComparisonReportFrameSummary] = {} + self._frame_results: dict[int, ComparisonReportFrameSummary] = {} self.comparator = _Comparator(self._gt_dataset.categories(), settings=settings) @@ -1744,7 +1745,7 @@ def _find_gt_conflicts(self): def _process_frame( self, ds_item: dm.DatasetItem, gt_item: dm.DatasetItem - ) -> List[AnnotationConflict]: + ) -> list[AnnotationConflict]: frame_id = self._dm_item_to_frame_id(ds_item, self._ds_dataset) frame_results = self.comparator.match_annotations(gt_item, ds_item) @@ -1756,7 +1757,7 @@ def _process_frame( def _generate_frame_annotation_conflicts( self, frame_id: str, frame_results, *, gt_item: dm.DatasetItem, ds_item: dm.DatasetItem - ) -> List[AnnotationConflict]: + ) -> list[AnnotationConflict]: conflicts = [] matches, mismatches, gt_unmatched, ds_unmatched, _ = frame_results["all_ann_types"] @@ -2017,7 +2018,7 @@ def _find_closest_unmatched_shape(shape: dm.Annotation): # row/column index in the confusion matrix corresponding to unmatched annotations _UNMATCHED_IDX = -1 - def _make_zero_confusion_matrix(self) -> Tuple[List[str], np.ndarray, Dict[int, int]]: + def _make_zero_confusion_matrix(self) -> tuple[list[str], np.ndarray, dict[int, int]]: label_id_idx_map = {} label_names = [] for label_id, label in enumerate(self._gt_dataset.categories()[dm.AnnotationType.label]): @@ -2033,7 +2034,7 @@ def _make_zero_confusion_matrix(self) -> Tuple[List[str], np.ndarray, Dict[int, return label_names, confusion_matrix, label_id_idx_map def _compute_annotations_summary( - self, confusion_matrix: np.ndarray, confusion_matrix_labels: List[str] + self, confusion_matrix: np.ndarray, confusion_matrix_labels: list[str] ) -> ComparisonReportAnnotationsSummary: matched_ann_counts = np.diag(confusion_matrix) ds_ann_counts = np.sum(confusion_matrix, axis=1) @@ -2076,7 +2077,7 @@ def _compute_annotations_summary( ) def _generate_frame_annotations_summary( - self, confusion_matrix: np.ndarray, confusion_matrix_labels: List[str] + self, confusion_matrix: np.ndarray, confusion_matrix_labels: list[str] ) -> ComparisonReportAnnotationsSummary: summary = self._compute_annotations_summary(confusion_matrix, confusion_matrix_labels) @@ -2090,8 +2091,8 @@ def _generate_frame_annotations_summary( return summary def _generate_dataset_annotations_summary( - self, frame_summaries: Dict[int, ComparisonReportFrameSummary] - ) -> Tuple[ComparisonReportAnnotationsSummary, ComparisonReportAnnotationComponentsSummary]: + self, frame_summaries: dict[int, ComparisonReportFrameSummary] + ) -> tuple[ComparisonReportAnnotationsSummary, ComparisonReportAnnotationComponentsSummary]: # accumulate stats annotation_components = ComparisonReportAnnotationComponentsSummary( shape=ComparisonReportAnnotationShapeSummary( @@ -2372,7 +2373,7 @@ def _compute_reports(self, task_id: int) -> int: in active_validation_frames ) - jobs: List[Job] = [j for j in job_queryset if j.type == JobType.ANNOTATION] + jobs: list[Job] = [j for j in job_queryset if j.type == JobType.ANNOTATION] job_data_providers = { job.id: JobDataProvider( job.id, @@ -2384,7 +2385,7 @@ def _compute_reports(self, task_id: int) -> int: quality_params = self._get_task_quality_params(task) - job_comparison_reports: Dict[int, ComparisonReport] = {} + job_comparison_reports: dict[int, ComparisonReport] = {} for job in jobs: job_data_provider = job_data_providers[job.id] comparator = DatasetComparator( @@ -2449,14 +2450,14 @@ def _get_current_job(self): return get_current_job() def _compute_task_report( - self, task: Task, job_reports: Dict[int, ComparisonReport] + self, task: Task, job_reports: dict[int, ComparisonReport] ) -> ComparisonReport: # The task dataset can be different from any jobs' dataset because of frame overlaps # between jobs, from which annotations are merged to get the task annotations. # Thus, a separate report could be computed for the task. Instead, here we only # compute the combined summary of the job reports. task_intersection_frames = set() - task_conflicts: List[AnnotationConflict] = [] + task_conflicts: list[AnnotationConflict] = [] task_annotations_summary = None task_ann_components_summary = None task_mean_shape_ious = [] @@ -2533,7 +2534,7 @@ def _compute_task_report( return task_report_data - def _save_reports(self, *, task_report: Dict, job_reports: List[Dict]) -> models.QualityReport: + def _save_reports(self, *, task_report: dict, job_reports: list[dict]) -> models.QualityReport: # TODO: add validation (e.g. ann id count for different types of conflicts) db_task_report = models.QualityReport( diff --git a/cvat/nginx.conf b/cvat/nginx.conf index 392c49d61a30..9cf14332abed 100644 --- a/cvat/nginx.conf +++ b/cvat/nginx.conf @@ -41,14 +41,27 @@ http { # CVAT Settings ## + # Only add security headers if the upstream server does not already provide them. + map $upstream_http_referrer_policy $hdr_referrer_policy { + '' "strict-origin-when-cross-origin"; + } + + map $upstream_http_x_content_type_options $hdr_x_content_type_options { + '' "nosniff"; + } + + map $upstream_http_x_frame_options $hdr_x_frame_options { + '' "deny"; + } + server { listen 8080; # previously used value client_max_body_size 1G; - add_header X-Frame-Options deny; - add_header Referrer-Policy "strict-origin-when-cross-origin" always; - add_header X-Content-Type-Options "nosniff" always; + add_header Referrer-Policy $hdr_referrer_policy always; + add_header X-Content-Type-Options $hdr_x_content_type_options always; + add_header X-Frame-Options $hdr_x_frame_options always; server_name _; diff --git a/cvat/schema.yml b/cvat/schema.yml index 1938cabc5071..ad8809e93111 100644 --- a/cvat/schema.yml +++ b/cvat/schema.yml @@ -1,7 +1,7 @@ openapi: 3.0.3 info: title: CVAT REST API - version: 2.22.0 + version: 2.23.0 description: REST API for Computer Vision Annotation Tool (CVAT) termsOfService: https://www.google.com/policies/terms/ contact: @@ -8049,15 +8049,6 @@ components: job: type: integer description: The id of the job to be annotated - quality: - allOf: - - $ref: '#/components/schemas/QualityEnum' - default: original - description: |- - The quality of the images to use in the model run - - * `compressed` - compressed - * `original` - original max_distance: type: integer threshold: @@ -8069,7 +8060,11 @@ components: description: Whether existing annotations should be removed convMaskToPoly: type: boolean - default: false + writeOnly: true + description: Deprecated; use conv_mask_to_poly instead + conv_mask_to_poly: + type: boolean + description: Convert mask shapes to polygons mapping: type: object additionalProperties: @@ -10020,14 +10015,6 @@ components: * `AZURE_CONTAINER` - AZURE_CONTAINER * `GOOGLE_DRIVE` - GOOGLE_DRIVE * `GOOGLE_CLOUD_STORAGE` - GOOGLE_CLOUD_STORAGE - QualityEnum: - enum: - - compressed - - original - type: string - description: |- - * `compressed` - compressed - * `original` - original QualityReport: type: object properties: diff --git a/cvat/settings/base.py b/cvat/settings/base.py index 404628fa555e..0f6147dc4bf0 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -276,6 +276,7 @@ class CVAT_QUEUES(Enum): QUALITY_REPORTS = 'quality_reports' ANALYTICS_REPORTS = 'analytics_reports' CLEANING = 'cleaning' + CHUNKS = 'chunks' redis_inmem_host = os.getenv('CVAT_REDIS_INMEM_HOST', 'localhost') redis_inmem_port = os.getenv('CVAT_REDIS_INMEM_PORT', 6379) @@ -321,6 +322,10 @@ class CVAT_QUEUES(Enum): **shared_queue_settings, 'DEFAULT_TIMEOUT': '1h', }, + CVAT_QUEUES.CHUNKS.value: { + **shared_queue_settings, + 'DEFAULT_TIMEOUT': '5m', + }, } NUCLIO = { @@ -539,14 +544,20 @@ class CVAT_QUEUES(Enum): redis_ondisk_port = os.getenv('CVAT_REDIS_ONDISK_PORT', 6666) redis_ondisk_password = os.getenv('CVAT_REDIS_ONDISK_PASSWORD', '') +# Sets the timeout for the expiration of data chunk in redis_ondisk +CVAT_CHUNK_CACHE_TTL = 3600 * 24 # 1 day + +# Sets the timeout for the expiration of preview image in redis_ondisk +CVAT_PREVIEW_CACHE_TTL = 3600 * 24 * 7 # 7 days + CACHES = { - 'default': { + 'default': { 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', }, 'media': { - 'BACKEND' : 'django.core.cache.backends.redis.RedisCache', - "LOCATION": f"redis://:{urllib.parse.quote(redis_ondisk_password)}@{redis_ondisk_host}:{redis_ondisk_port}", - 'TIMEOUT' : 3600 * 24, # 1 day + 'BACKEND' : 'django.core.cache.backends.redis.RedisCache', + "LOCATION": f'redis://:{urllib.parse.quote(redis_ondisk_password)}@{redis_ondisk_host}:{redis_ondisk_port}', + 'TIMEOUT' : CVAT_CHUNK_CACHE_TTL, } } @@ -574,6 +585,8 @@ class CVAT_QUEUES(Enum): # How django uses X-Forwarded-Proto - https://docs.djangoproject.com/en/2.2/ref/settings/#secure-proxy-ssl-header SECURE_PROXY_SSL_HEADER = ('HTTP_X_FORWARDED_PROTO', 'https') +SECURE_REFERRER_POLICY = 'strict-origin-when-cross-origin' + # Forwarded host - https://docs.djangoproject.com/en/4.0/ref/settings/#std:setting-USE_X_FORWARDED_HOST # Is used in TUS uploads to provide correct upload endpoint USE_X_FORWARDED_HOST = True diff --git a/dev/check_changelog_fragments.py b/dev/check_changelog_fragments.py index e837842efaf0..437e3fb02cdb 100755 --- a/dev/check_changelog_fragments.py +++ b/dev/check_changelog_fragments.py @@ -29,15 +29,28 @@ def complain(message): for fragment_path in REPO_ROOT.glob("changelog.d/*.md"): with open(fragment_path) as fragment_file: for line_index, line in enumerate(fragment_file): - if not line.startswith(md_header_prefix): - # The first line should be a header, and all headers should be of appropriate level. - if line_index == 0 or line.startswith("#"): - complain(f"line should start with {md_header_prefix!r}") - continue - - category = line.removeprefix(md_header_prefix).strip() - if category not in categories: - complain(f"unknown category: {category}") + line = line.rstrip("\n") + + if line_index == 0: + # The first line should always be a header. + if not line.startswith("#"): + complain("line should be a header") + elif ( + line + and not line.startswith("#") + and not line.startswith("-") + and not line.startswith(" ") + ): + complain("line should be a header, a list item, or indented") + + if line.startswith("#"): + if line.startswith(md_header_prefix): + category = line.removeprefix(md_header_prefix).strip() + if category not in categories: + complain(f"unknown category: {category}") + else: + # All headers should be of the same level. + complain(f"header should start with {md_header_prefix!r}") sys.exit(0 if success else 1) diff --git a/dev/format_python_code.sh b/dev/format_python_code.sh index 2e70e5b0ea4e..27fda9eff4ff 100755 --- a/dev/format_python_code.sh +++ b/dev/format_python_code.sh @@ -32,6 +32,7 @@ for paths in \ "cvat/apps/engine/model_utils.py" \ "cvat/apps/dataset_manager/tests/test_annotation.py" \ "cvat/apps/dataset_manager/tests/utils.py" \ + "cvat/apps/events/signals.py" \ ; do ${BLACK} -- ${paths} ${ISORT} -- ${paths} diff --git a/docker-compose.yml b/docker-compose.yml index bec741c5536f..bed3fdae6255 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -79,7 +79,7 @@ services: cvat_server: container_name: cvat_server - image: cvat/server:${CVAT_VERSION:-v2.22.0} + image: cvat/server:${CVAT_VERSION:-v2.23.0} restart: always depends_on: <<: *backend-deps @@ -113,7 +113,7 @@ services: cvat_utils: container_name: cvat_utils - image: cvat/server:${CVAT_VERSION:-v2.22.0} + image: cvat/server:${CVAT_VERSION:-v2.23.0} restart: always depends_on: *backend-deps environment: @@ -130,7 +130,7 @@ services: cvat_worker_import: container_name: cvat_worker_import - image: cvat/server:${CVAT_VERSION:-v2.22.0} + image: cvat/server:${CVAT_VERSION:-v2.23.0} restart: always depends_on: *backend-deps environment: @@ -146,7 +146,7 @@ services: cvat_worker_export: container_name: cvat_worker_export - image: cvat/server:${CVAT_VERSION:-v2.22.0} + image: cvat/server:${CVAT_VERSION:-v2.23.0} restart: always depends_on: *backend-deps environment: @@ -162,7 +162,7 @@ services: cvat_worker_annotation: container_name: cvat_worker_annotation - image: cvat/server:${CVAT_VERSION:-v2.22.0} + image: cvat/server:${CVAT_VERSION:-v2.23.0} restart: always depends_on: *backend-deps environment: @@ -178,7 +178,7 @@ services: cvat_worker_webhooks: container_name: cvat_worker_webhooks - image: cvat/server:${CVAT_VERSION:-v2.22.0} + image: cvat/server:${CVAT_VERSION:-v2.23.0} restart: always depends_on: *backend-deps environment: @@ -194,7 +194,7 @@ services: cvat_worker_quality_reports: container_name: cvat_worker_quality_reports - image: cvat/server:${CVAT_VERSION:-v2.22.0} + image: cvat/server:${CVAT_VERSION:-v2.23.0} restart: always depends_on: *backend-deps environment: @@ -210,7 +210,7 @@ services: cvat_worker_analytics_reports: container_name: cvat_worker_analytics_reports - image: cvat/server:${CVAT_VERSION:-v2.22.0} + image: cvat/server:${CVAT_VERSION:-v2.23.0} restart: always depends_on: *backend-deps environment: @@ -224,9 +224,25 @@ services: networks: - cvat + cvat_worker_chunks: + container_name: cvat_worker_chunks + image: cvat/server:${CVAT_VERSION:-v2.23.0} + restart: always + depends_on: *backend-deps + environment: + <<: *backend-env + NUMPROCS: 2 + command: run worker.chunks + volumes: + - cvat_data:/home/django/data + - cvat_keys:/home/django/keys + - cvat_logs:/home/django/logs + networks: + - cvat + cvat_ui: container_name: cvat_ui - image: cvat/ui:${CVAT_VERSION:-v2.22.0} + image: cvat/ui:${CVAT_VERSION:-v2.23.0} restart: always depends_on: - cvat_server diff --git a/helm-chart/templates/cvat_backend/worker_chunks/deployment.yml b/helm-chart/templates/cvat_backend/worker_chunks/deployment.yml new file mode 100644 index 000000000000..74e80b1b185d --- /dev/null +++ b/helm-chart/templates/cvat_backend/worker_chunks/deployment.yml @@ -0,0 +1,96 @@ +{{- $localValues := .Values.cvat.backend.worker.chunks -}} + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ .Release.Name }}-backend-worker-chunks + namespace: {{ .Release.Namespace }} + labels: + app: cvat-app + tier: backend + component: worker-chunks + {{- include "cvat.labels" . | nindent 4 }} + {{- with merge $localValues.labels .Values.cvat.backend.labels }} + {{- toYaml . | nindent 4 }} + {{- end }} + {{- with merge $localValues.annotations .Values.cvat.backend.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + replicas: {{ $localValues.replicas }} + strategy: + type: Recreate + selector: + matchLabels: + {{- include "cvat.labels" . | nindent 6 }} + {{- with merge $localValues.labels .Values.cvat.backend.labels }} + {{- toYaml . | nindent 6 }} + {{- end }} + app: cvat-app + tier: backend + component: worker-chunks + template: + metadata: + labels: + app: cvat-app + tier: backend + component: worker-chunks + {{- include "cvat.labels" . | nindent 8 }} + {{- with merge $localValues.labels .Values.cvat.backend.labels }} + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with merge $localValues.annotations .Values.cvat.backend.annotations }} + annotations: + {{- toYaml . | nindent 8 }} + {{- end }} + spec: + serviceAccountName: {{ include "cvat.backend.serviceAccountName" . }} + containers: + - name: cvat-backend + image: {{ .Values.cvat.backend.image }}:{{ .Values.cvat.backend.tag }} + imagePullPolicy: {{ .Values.cvat.backend.imagePullPolicy }} + {{- with merge $localValues.resources .Values.cvat.backend.resources }} + resources: + {{- toYaml . | nindent 12 }} + {{- end }} + args: ["run", "worker.chunks"] + env: + {{ include "cvat.sharedBackendEnv" . | indent 10 }} + {{- with concat .Values.cvat.backend.additionalEnv $localValues.additionalEnv }} + {{- toYaml . | nindent 10 }} + {{- end }} + {{- $probeArgs := list "chunks" -}} + {{- $probeConfig := dict "args" $probeArgs "livenessProbe" $.Values.cvat.backend.worker.livenessProbe -}} + {{ include "cvat.backend.worker.livenessProbe" $probeConfig | indent 10 }} + volumeMounts: + - mountPath: /home/django/data + name: cvat-backend-data + subPath: data + - mountPath: /home/django/logs + name: cvat-backend-data + subPath: logs + {{- with concat .Values.cvat.backend.additionalVolumeMounts $localValues.additionalVolumeMounts }} + {{- toYaml . | nindent 10 }} + {{- end }} + {{- with merge $localValues.affinity .Values.cvat.backend.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with concat .Values.cvat.backend.tolerations $localValues.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} + volumes: + {{- if .Values.cvat.backend.defaultStorage.enabled }} + - name: cvat-backend-data + persistentVolumeClaim: + claimName: "{{ .Release.Name }}-backend-data" + {{- end }} + {{- with concat .Values.cvat.backend.additionalVolumes $localValues.additionalVolumes }} + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} diff --git a/helm-chart/test.values.yaml b/helm-chart/test.values.yaml index 73edaa815d70..350cc384c178 100644 --- a/helm-chart/test.values.yaml +++ b/helm-chart/test.values.yaml @@ -14,6 +14,15 @@ cvat: value: cvat.settings.testing_rest worker: import: + replicas: 1 + additionalVolumeMounts: + - mountPath: /home/django/share + name: cvat-backend-data + subPath: share + export: + replicas: 1 + chunks: + replicas: 1 additionalVolumeMounts: - mountPath: /home/django/share name: cvat-backend-data @@ -22,6 +31,8 @@ cvat: additionalEnv: - name: DJANGO_SETTINGS_MODULE value: cvat.settings.testing_rest + annotation: + replicas: 0 # Images are already present in the node imagePullPolicy: Never frontend: diff --git a/helm-chart/values.yaml b/helm-chart/values.yaml index b99625f1a104..2e36950b3f2a 100644 --- a/helm-chart/values.yaml +++ b/helm-chart/values.yaml @@ -117,6 +117,16 @@ cvat: additionalEnv: [] additionalVolumes: [] additionalVolumeMounts: [] + chunks: + replicas: 2 + labels: {} + annotations: {} + resources: {} + affinity: {} + tolerations: [] + additionalEnv: [] + additionalVolumes: [] + additionalVolumeMounts: [] utils: replicas: 1 labels: {} @@ -129,13 +139,12 @@ cvat: additionalVolumeMounts: [] replicas: 1 image: cvat/server - tag: v2.22.0 + tag: v2.23.0 imagePullPolicy: Always permissionFix: enabled: true service: - annotations: - traefik.ingress.kubernetes.io/service.sticky.cookie: "true" + annotations: {} spec: type: ClusterIP ports: @@ -153,7 +162,7 @@ cvat: frontend: replicas: 1 image: cvat/ui - tag: v2.22.0 + tag: v2.23.0 imagePullPolicy: Always labels: {} # test: test diff --git a/serverless/deploy_cpu.sh b/serverless/deploy_cpu.sh index 03d6f17bad67..9f37ea020a6b 100755 --- a/serverless/deploy_cpu.sh +++ b/serverless/deploy_cpu.sh @@ -25,7 +25,10 @@ do echo "Deploying $func_rel_path function..." nuctl deploy --project-name cvat --path "$func_root" \ - --file "$func_config" --platform local + --file "$func_config" --platform local \ + --env CVAT_REDIS_HOST=$(echo ${CVAT_REDIS_INMEM_HOST:-cvat_redis_ondisk}) \ + --env CVAT_REDIS_PORT=$(echo ${CVAT_REDIS_INMEM_PORT:-6666}) \ + --env CVAT_REDIS_PASSWORD=$(echo ${CVAT_REDIS_INMEM_PASSWORD}) done nuctl get function --platform local diff --git a/serverless/deploy_gpu.sh b/serverless/deploy_gpu.sh index c813a8232ad4..9c8e1515b73b 100755 --- a/serverless/deploy_gpu.sh +++ b/serverless/deploy_gpu.sh @@ -17,7 +17,10 @@ do echo "Deploying $func_rel_path function..." nuctl deploy --project-name cvat --path "$func_root" \ - --file "$func_config" --platform local + --file "$func_config" --platform local \ + --env CVAT_REDIS_HOST=$(echo ${CVAT_REDIS_INMEM_HOST:-cvat_redis_ondisk}) \ + --env CVAT_REDIS_PORT=$(echo ${CVAT_REDIS_INMEM_PORT:-6666}) \ + --env CVAT_REDIS_PASSWORD=$(echo ${CVAT_REDIS_INMEM_PASSWORD}) done nuctl get function --platform local diff --git a/site/content/en/docs/administration/advanced/analytics.md b/site/content/en/docs/administration/advanced/analytics.md index 445c212f9687..b99ca0a824c2 100644 --- a/site/content/en/docs/administration/advanced/analytics.md +++ b/site/content/en/docs/administration/advanced/analytics.md @@ -135,6 +135,12 @@ Server events: - `call:function` +- `create:membership`, `update:membership`, `delete:membership` + +- `create:webhook`, `update:webhook`, `delete:webhook` + +- `create:invitation`, `delete:invitation` + Client events: - `load:cvat` diff --git a/site/content/en/docs/api_sdk/sdk/_index.md b/site/content/en/docs/api_sdk/sdk/_index.md index e9683583ab0e..e855dadd979f 100644 --- a/site/content/en/docs/api_sdk/sdk/_index.md +++ b/site/content/en/docs/api_sdk/sdk/_index.md @@ -42,7 +42,14 @@ To install an [official release of CVAT SDK](https://pypi.org/project/cvat-sdk/) pip install cvat-sdk ``` -To use the PyTorch adapter, request the `pytorch` extra: +To use the `cvat_sdk.masks` module, request the `masks` extra: + +```bash +pip install "cvat-sdk[masks]" +``` + +To use the PyTorch adapter or the built-in PyTorch-based auto-annotation functions, +request the `pytorch` extra: ```bash pip install "cvat-sdk[pytorch]" diff --git a/site/content/en/docs/api_sdk/sdk/auto-annotation.md b/site/content/en/docs/api_sdk/sdk/auto-annotation.md index 24e16c7e6218..d8401955da7f 100644 --- a/site/content/en/docs/api_sdk/sdk/auto-annotation.md +++ b/site/content/en/docs/api_sdk/sdk/auto-annotation.md @@ -68,7 +68,12 @@ class TorchvisionDetectionFunction: ] ) - def detect(self, context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: + def detect( + self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> list[models.LabeledShapeRequest]: + # determine the threshold for filtering results + conf_threshold = context.conf_threshold or 0 + # convert the input into a form the model can understand transformed_image = [self._transforms(image)] @@ -79,7 +84,8 @@ class TorchvisionDetectionFunction: return [ cvataa.rectangle(label.item(), [x.item() for x in box]) for result in results - for box, label in zip(result['boxes'], result['labels']) + for box, label, score in zip(result["boxes"], result["labels"], result["scores"]) + if score >= conf_threshold ] # log into the CVAT server @@ -112,9 +118,13 @@ that these objects must follow. `detect` must be a function/method accepting two parameters: - `context` (`DetectionFunctionContext`). - Contains information about the current image. - Currently `DetectionFunctionContext` only contains a single field, `frame_name`, - which contains the file name of the frame on the CVAT server. + Contains invocation parameters and information about the current image. + The following fields are available: + + - `frame_name` (`str`). The file name of the frame on the CVAT server. + - `conf_threshold` (`float | None`). The confidence threshold that the function + should use to filter objects. If `None`, the function may apply a default + threshold at its discretion. - `image` (`PIL.Image.Image`). Contains image data. @@ -171,10 +181,23 @@ The following helpers are available for use in `detect`: | Name | Model type | Fixed attributes | |-------------|--------------------------|-------------------------------| | `shape` | `LabeledShapeRequest` | `frame=0` | +| `mask` | `LabeledShapeRequest` | `frame=0`, `type="mask"` | +| `polygon` | `LabeledShapeRequest` | `frame=0`, `type="polygon"` | | `rectangle` | `LabeledShapeRequest` | `frame=0`, `type="rectangle"` | | `skeleton` | `LabeledShapeRequest` | `frame=0`, `type="skeleton"` | | `keypoint` | `SubLabeledShapeRequest` | `frame=0`, `type="points"` | +For `mask`, it is recommended to create the points list using +the `cvat.masks.encode_mask` function, which will convert a bitmap into a +list in the format that CVAT expects. For example: + +```python +cvataa.mask(my_label, encode_mask( + my_mask, # boolean 2D array, same size as the input image + [x1, y1, x2, y2], # top left and bottom right coordinates of the mask +)) +``` + ## Auto-annotation driver The `annotate_task` function uses an AA function to annotate a CVAT task. @@ -195,6 +218,9 @@ If you use `allow_unmatched_label=True`, then such labels will be ignored, and any shapes referring to them will be dropped. Same logic applies to sub-label IDs. +It's possible to pass a custom confidence threshold to the function via the +`conf_threshold` parameter. + `annotate_task` will raise a `BadFunctionError` exception if it detects that the function violated the AA function protocol. @@ -244,10 +270,18 @@ The `create` function accepts the following parameters: It also accepts arbitrary additional parameters, which are passed directly to the model constructor. +### `cvat_sdk.auto_annotation.functions.torchvision_instance_segmentation` + +This AA function is analogous to `torchvision_detection`, +except it uses torchvision's instance segmentation models and produces mask +or polygon annotations (depending on the value of `conv_mask_to_poly`). + +Refer to that function's description for usage instructions and parameter information. + ### `cvat_sdk.auto_annotation.functions.torchvision_keypoint_detection` This AA function is analogous to `torchvision_detection`, except it uses torchvision's keypoint detection models and produces skeleton annotations. Keypoints which the model marks as invisible will be marked as occluded in CVAT. -Refer to the previous section for usage instructions and parameter information. +Refer to that function's description for usage instructions and parameter information. diff --git a/supervisord/worker.chunks.conf b/supervisord/worker.chunks.conf new file mode 100644 index 000000000000..9eccd41e8cba --- /dev/null +++ b/supervisord/worker.chunks.conf @@ -0,0 +1,29 @@ +[unix_http_server] +file = /tmp/supervisord/supervisor.sock + +[supervisorctl] +serverurl = unix:///tmp/supervisord/supervisor.sock + + +[rpcinterface:supervisor] +supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface + +[supervisord] +nodaemon=true +logfile=%(ENV_HOME)s/logs/supervisord.log ; supervisord log file +logfile_maxbytes=50MB ; maximum size of logfile before rotation +logfile_backups=10 ; number of backed up logfiles +loglevel=debug ; info, debug, warn, trace +pidfile=/tmp/supervisord/supervisord.pid ; pidfile location + +[program:rqworker-chunks] +command=%(ENV_HOME)s/wait_for_deps.sh + python3 %(ENV_HOME)s/manage.py rqworker -v 3 chunks + --worker-class cvat.rqworker.DefaultWorker +environment=VECTOR_EVENT_HANDLER="SynchronousLogstashHandler",CVAT_POSTGRES_APPLICATION_NAME="cvat:worker:chunks" +numprocs=%(ENV_NUMPROCS)s +process_name=%(program_name)s-%(process_num)d +autorestart=true + +[program:smokescreen] +command=smokescreen --listen-ip=127.0.0.1 %(ENV_SMOKESCREEN_OPTS)s diff --git a/tests/cypress/e2e/features/annotations_actions.js b/tests/cypress/e2e/features/annotations_actions.js index cda91f9c33ba..55fe7542c680 100644 --- a/tests/cypress/e2e/features/annotations_actions.js +++ b/tests/cypress/e2e/features/annotations_actions.js @@ -86,47 +86,6 @@ context('Testing annotations actions workflow', () => { cy.closeAnnotationsActionsModal(); }); - - it('Recommendation to save the job appears if there are unsaved changes', () => { - cy.createRectangle({ - points: 'By 2 Points', - type: 'Shape', - labelName: taskPayload.labels[0].name, - firstX: 250, - firstY: 350, - secondX: 350, - secondY: 450, - }); - - cy.openAnnotationsActionsModal(); - cy.intercept(`/api/jobs/${jobID}/annotations?**action=create**`).as('createAnnotationsRequest'); - cy.get('.cvat-action-runner-save-job-recommendation').should('exist').and('be.visible').click(); - cy.wait('@createAnnotationsRequest').its('response.statusCode').should('equal', 200); - cy.get('.cvat-action-runner-save-job-recommendation').should('not.exist'); - - cy.closeAnnotationsActionsModal(); - }); - - it('Recommendation to disable automatic saving appears in modal if automatic saving is enabled', () => { - cy.openSettings(); - cy.contains('Workspace').click(); - cy.get('.cvat-workspace-settings-auto-save').within(() => { - cy.get('[type="checkbox"]').check(); - }); - cy.closeSettings(); - - cy.openAnnotationsActionsModal(); - cy.get('.cvat-action-runner-disable-autosave-recommendation').should('exist').and('be.visible').click(); - cy.get('.cvat-action-runner-disable-autosave-recommendation').should('not.exist'); - cy.closeAnnotationsActionsModal(); - - cy.openSettings(); - cy.contains('Workspace').click(); - cy.get('.cvat-workspace-settings-auto-save').within(() => { - cy.get('[type="checkbox"]').should('not.be.checked'); - }); - cy.closeSettings(); - }); }); describe('Test action: "Remove filtered shapes"', () => { @@ -374,7 +333,7 @@ context('Testing annotations actions workflow', () => { cy.goCheckFrameNumber(latestFrameNumber); cy.get('.cvat_canvas_shape').should('have.length', 1); - cy.saveJob('PUT', 200, 'saveJob'); + cy.saveJob('PATCH', 200, 'saveJob'); const exportAnnotation = { as: 'exportAnnotations', type: 'annotations', diff --git a/tests/cypress/e2e/features/ground_truth_jobs.js b/tests/cypress/e2e/features/ground_truth_jobs.js index 0753d59839cc..482a940c3d68 100644 --- a/tests/cypress/e2e/features/ground_truth_jobs.js +++ b/tests/cypress/e2e/features/ground_truth_jobs.js @@ -4,12 +4,11 @@ /// +import { defaultTaskSpec } from '../../support/default-specs'; + context('Ground truth jobs', () => { - const caseId = 'Ground truth jobs'; const labelName = 'car'; - const taskName = `Annotation task for Case ${caseId}`; - const attrName = `Attr for Case ${caseId}`; - const textDefaultValue = 'Some default value for type Text'; + const taskName = 'Annotation task for Ground truth jobs'; const jobOptions = { jobType: 'Ground truth', @@ -17,6 +16,12 @@ context('Ground truth jobs', () => { fromTaskPage: true, }; + const defaultValidationParams = { + frameCount: 3, + mode: 'gt', + frameSelectionMethod: 'random_uniform', + }; + const groundTruthRectangles = [ { id: 1, @@ -64,8 +69,8 @@ context('Ground truth jobs', () => { let jobID = null; let taskID = null; - // With seed = 1, frameCount = 4, totalFrames = 10 - predifined ground truth frames are: - const groundTruthFrames = [0, 1, 5, 6]; + // With seed = 1, frameCount = 4, totalFrames = 100 - predifined ground truth frames are: + const groundTruthFrames = [10, 23, 71, 87]; function checkRectangleAndObjectMenu(rectangle, isGroundTruthJob = false) { if (isGroundTruthJob) { @@ -97,36 +102,33 @@ context('Ground truth jobs', () => { cy.get('.cvat-quality-control-management-tab').should('exist').and('be.visible'); } + function createAndOpenTask(serverFiles, validationParams = null) { + const { taskSpec, dataSpec, extras } = defaultTaskSpec({ + taskName, serverFiles, labelName, validationParams, + }); + return cy.headlessCreateTask(taskSpec, dataSpec, extras).then((taskResponse) => { + taskID = taskResponse.taskID; + if (validationParams) { + [groundTruthJobID, jobID] = taskResponse.jobIDs; + } else { + [jobID] = taskResponse.jobIDs; + } + }).then(() => { + cy.visit(`/tasks/${taskID}`); + cy.get('.cvat-task-details').should('exist').and('be.visible'); + }); + } + before(() => { cy.visit('auth/login'); cy.login(); }); describe('Testing ground truth basics', () => { - const imagesCount = 10; - const imageFileName = 'ground_truth_1'; - const width = 800; - const height = 800; - const posX = 10; - const posY = 10; - const color = 'gray'; - const archiveName = `${imageFileName}.zip`; - const archivePath = `cypress/fixtures/${archiveName}`; - const imagesFolder = `cypress/fixtures/${imageFileName}`; - const directoryToArchive = imagesFolder; + const serverFiles = ['bigArchive.zip']; before(() => { - cy.visit('/tasks'); - cy.imageGenerator(imagesFolder, imageFileName, width, height, color, posX, posY, labelName, imagesCount); - cy.createZipArchive(directoryToArchive, archivePath); - cy.createAnnotationTask(taskName, labelName, attrName, textDefaultValue, archiveName); - cy.openTask(taskName); - cy.url().then((url) => { - taskID = Number(url.split('/').slice(-1)[0].split('?')[0]); - }); - cy.get('.cvat-job-item').first().invoke('attr', 'data-row-id').then((val) => { - jobID = val; - }); + createAndOpenTask(serverFiles); }); after(() => { @@ -196,35 +198,80 @@ context('Ground truth jobs', () => { }); }); + describe('Testing creating task with quality params', () => { + const imagesCount = 3; + const imageFileName = `image_${taskName.replace(' ', '_').toLowerCase()}`; + const width = 800; + const height = 800; + const posX = 10; + const posY = 10; + const color = 'gray'; + const archiveName = `${imageFileName}.zip`; + const archivePath = `cypress/fixtures/${archiveName}`; + const imagesFolder = `cypress/fixtures/${imageFileName}`; + const directoryToArchive = imagesFolder; + const attrName = 'gt_attr'; + const defaultAttrValue = 'GT attr'; + const multiAttrParams = false; + const forProject = false; + const attachToProject = false; + const projectName = null; + const expectedResult = 'success'; + const projectSubsetFieldValue = null; + const advancedConfigurationParams = false; + + before(() => { + cy.contains('.cvat-header-button', 'Tasks').should('be.visible').click(); + cy.url().should('include', '/tasks'); + cy.imageGenerator(imagesFolder, imageFileName, width, height, color, posX, posY, labelName, imagesCount); + cy.createZipArchive(directoryToArchive, archivePath); + }); + + afterEach(() => { + cy.goToTaskList(); + cy.deleteTask(taskName); + }); + + function createTaskWithQualityParams(qualityParams) { + cy.createAnnotationTask( + taskName, + labelName, + attrName, + defaultAttrValue, + archiveName, + multiAttrParams, + advancedConfigurationParams, + forProject, + attachToProject, + projectName, + expectedResult, + projectSubsetFieldValue, + qualityParams, + ); + cy.openTask(taskName); + cy.get('.cvat-job-item').first() + .find('.ant-tag') + .should('have.text', 'Ground truth'); + } + + it('Create task with ground truth job', () => { + createTaskWithQualityParams({ + validationMode: 'Ground Truth', + }); + }); + + it('Create task with honeypots', () => { + createTaskWithQualityParams({ + validationMode: 'Honeypots', + }); + }); + }); + describe('Testing ground truth management basics', () => { const serverFiles = ['images/image_1.jpg', 'images/image_2.jpg', 'images/image_3.jpg']; before(() => { - cy.headlessCreateTask({ - labels: [{ name: labelName, attributes: [], type: 'any' }], - name: taskName, - project_id: null, - source_storage: { location: 'local' }, - target_storage: { location: 'local' }, - }, { - server_files: serverFiles, - image_quality: 70, - use_zip_chunks: true, - use_cache: true, - sorting_method: 'lexicographical', - }).then((taskResponse) => { - taskID = taskResponse.taskID; - [jobID] = taskResponse.jobIDs; - }).then(() => ( - cy.headlessCreateJob({ - task_id: taskID, - frame_count: 3, - type: 'ground_truth', - frame_selection_method: 'random_uniform', - }) - )).then((jobResponse) => { - groundTruthJobID = jobResponse.jobID; - }).then(() => { + createAndOpenTask(serverFiles, defaultValidationParams).then(() => { cy.visit(`/tasks/${taskID}/quality-control#management`); cy.get('.cvat-quality-control-management-tab').should('exist').and('be.visible'); cy.get('.cvat-annotations-quality-allocation-table-summary').should('exist').and('be.visible'); @@ -312,35 +359,10 @@ context('Ground truth jobs', () => { }); describe('Regression tests', () => { - const imagesCount = 20; - const imageFileName = 'ground_truth_2'; - const width = 100; - const height = 100; - const posX = 10; - const posY = 10; - const color = 'gray'; - const archiveName = `${imageFileName}.zip`; - const archivePath = `cypress/fixtures/${archiveName}`; - const imagesFolder = `cypress/fixtures/${imageFileName}`; - const directoryToArchive = imagesFolder; + const serverFiles = ['bigArchive.zip']; - before(() => { - cy.visit('/tasks'); - cy.imageGenerator(imagesFolder, imageFileName, width, height, color, posX, posY, labelName, imagesCount); - cy.createZipArchive(directoryToArchive, archivePath); - cy.createAnnotationTask( - taskName, - labelName, - attrName, - textDefaultValue, - archiveName, - false, - { multiJobs: true, segmentSize: 1 }, - ); - cy.openTask(taskName); - cy.url().then((url) => { - taskID = Number(url.split('/').slice(-1)[0].split('?')[0]); - }); + beforeEach(() => { + createAndOpenTask(serverFiles); }); afterEach(() => { @@ -378,5 +400,51 @@ context('Ground truth jobs', () => { jobID = Number(url.split('/').slice(-1)[0].split('?')[0]); }).should('match', /\/tasks\/\d+\/jobs\/\d+/); }); + + it('Check GT annotations can not be shown in standard annotation view', () => { + cy.headlessCreateJob({ + task_id: taskID, + frame_count: 4, + type: 'ground_truth', + frame_selection_method: 'random_uniform', + seed: 1, + }).then((jobResponse) => { + groundTruthJobID = jobResponse.jobID; + return cy.headlessCreateObjects(groundTruthFrames.map((frame, index) => { + const gtRect = groundTruthRectangles[index]; + return { + labelName, + objectType: 'shape', + shapeType: 'rectangle', + occluded: false, + frame, + points: [gtRect.firstX, gtRect.firstY, gtRect.secondX, gtRect.secondY], + }; + }), groundTruthJobID); + }).then(() => { + cy.visit(`/tasks/${taskID}/jobs/${jobID}`); + cy.get('.cvat-canvas-container').should('exist'); + + cy.changeWorkspace('Review'); + cy.get('.cvat-objects-sidebar-show-ground-truth').click(); + cy.get('.cvat-objects-sidebar-show-ground-truth').should( + 'have.class', 'cvat-objects-sidebar-show-ground-truth-active', + ); + groundTruthFrames.forEach((frame, index) => { + cy.goCheckFrameNumber(frame); + checkRectangleAndObjectMenu(groundTruthRectangles[index]); + }); + + cy.interactMenu('Open the task'); + cy.get('.cvat-task-job-list').within(() => { + cy.contains('a', `Job #${jobID}`).click(); + }); + groundTruthFrames.forEach((frame) => { + cy.goCheckFrameNumber(frame); + cy.get('.cvat_canvas_shape').should('not.exist'); + cy.get('.cvat-objects-sidebar-state-item').should('not.exist'); + }); + }); + }); }); }); diff --git a/tests/cypress/support/commands.js b/tests/cypress/support/commands.js index d3988c2e56a0..f0f085260cbf 100644 --- a/tests/cypress/support/commands.js +++ b/tests/cypress/support/commands.js @@ -178,6 +178,7 @@ Cypress.Commands.add( projectName = '', expectedResult = 'success', projectSubsetFieldValue = 'Test', + qualityConfigurationParams = null, ) => { cy.url().then(() => { cy.get('.cvat-create-task-dropdown').click(); @@ -215,6 +216,9 @@ Cypress.Commands.add( if (advancedConfigurationParams) { cy.advancedConfiguration(advancedConfigurationParams); } + if (qualityConfigurationParams) { + cy.configureTaskQualityMode(qualityConfigurationParams); + } cy.get('.cvat-submit-continue-task-button').scrollIntoView(); cy.get('.cvat-submit-continue-task-button').click(); if (expectedResult === 'success') { @@ -291,7 +295,7 @@ Cypress.Commands.add('headlessCreateObjects', (objects, jobID) => { }); }); -Cypress.Commands.add('headlessCreateTask', (taskSpec, dataSpec) => { +Cypress.Commands.add('headlessCreateTask', (taskSpec, dataSpec, extras) => { cy.window().then(async ($win) => { const task = new $win.cvat.classes.Task({ ...taskSpec, @@ -310,7 +314,7 @@ Cypress.Commands.add('headlessCreateTask', (taskSpec, dataSpec) => { task.remoteFiles = dataSpec.remote_files; } - const result = await task.save(); + const result = await task.save(extras || {}); return cy.wrap({ taskID: result.id, jobIDs: result.jobs.map((job) => job.id) }); }); }); @@ -897,6 +901,15 @@ Cypress.Commands.add('advancedConfiguration', (advancedConfigurationParams) => { } }); +Cypress.Commands.add('configureTaskQualityMode', (qualityConfigurationParams) => { + cy.contains('Quality').click(); + if (qualityConfigurationParams.validationMode) { + cy.get('#validationMode').within(() => { + cy.contains(qualityConfigurationParams.validationMode).click(); + }); + } +}); + Cypress.Commands.add('removeAnnotations', () => { cy.contains('.cvat-annotation-header-button', 'Menu').click(); cy.get('.cvat-annotation-menu').within(() => { diff --git a/tests/cypress/support/default-specs.js b/tests/cypress/support/default-specs.js new file mode 100644 index 000000000000..ea07bab747b2 --- /dev/null +++ b/tests/cypress/support/default-specs.js @@ -0,0 +1,63 @@ +// Copyright (C) 2024 CVAT.ai Corporation +// +// SPDX-License-Identifier: MIT + +function defaultTaskSpec({ + labelName, + taskName, + serverFiles, + validationParams, +}) { + const taskSpec = { + labels: [ + { name: labelName, attributes: [], type: 'any' }, + ], + name: taskName, + project_id: null, + source_storage: { location: 'local' }, + target_storage: { location: 'local' }, + }; + + const dataSpec = { + server_files: serverFiles, + image_quality: 70, + use_zip_chunks: true, + use_cache: true, + sorting_method: (validationParams && validationParams.mode === 'gt_pool') ? 'random' : 'lexicographical', + }; + + const extras = {}; + if (validationParams) { + const convertedParams = {}; + if (validationParams.frames) { + convertedParams.frames = validationParams.frames; + } + if (validationParams.frameSelectionMethod) { + convertedParams.frame_selection_method = validationParams.frameSelectionMethod; + } + if (validationParams.frameCount) { + convertedParams.frame_count = validationParams.frameCount; + } + if (validationParams.framesPerJobCount) { + convertedParams.frames_per_job_count = validationParams.framesPerJobCount; + } + if (validationParams.mode) { + convertedParams.mode = validationParams.mode; + } + if (validationParams.randomSeed) { + convertedParams.random_seed = validationParams.randomSeed; + } + + extras.validation_params = convertedParams; + } + + return { + taskSpec, + dataSpec, + extras, + }; +} + +module.exports = { + defaultTaskSpec, +}; diff --git a/tests/docker-compose.file_share.yml b/tests/docker-compose.file_share.yml index 3ceeb355f687..bca485ad48c8 100644 --- a/tests/docker-compose.file_share.yml +++ b/tests/docker-compose.file_share.yml @@ -5,3 +5,6 @@ services: cvat_server: volumes: - ./tests/mounted_file_share:/home/django/share:rw + cvat_worker_chunks: + volumes: + - ./tests/mounted_file_share:/home/django/share:rw diff --git a/tests/docker-compose.minio.yml b/tests/docker-compose.minio.yml index 6f82aadd1806..6089aa69f8bf 100644 --- a/tests/docker-compose.minio.yml +++ b/tests/docker-compose.minio.yml @@ -8,6 +8,7 @@ services: cvat_server: *allow-minio cvat_worker_export: *allow-minio cvat_worker_import: *allow-minio + cvat_worker_chunks: *allow-minio minio: image: quay.io/minio/minio:RELEASE.2022-09-17T00-09-45Z diff --git a/tests/python/cli/cmtp_function.py b/tests/python/cli/cmtp_function.py new file mode 100644 index 000000000000..2ae5cb26f663 --- /dev/null +++ b/tests/python/cli/cmtp_function.py @@ -0,0 +1,22 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.models as models +import PIL.Image + +spec = cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec("car", 0), + ], +) + + +def detect( + context: cvataa.DetectionFunctionContext, image: PIL.Image.Image +) -> list[models.LabeledShapeRequest]: + if context.conv_mask_to_poly: + return [cvataa.polygon(0, [0, 0, 0, 1, 1, 1])] + else: + return [cvataa.mask(0, [1, 0, 0, 0, 0])] diff --git a/tests/python/cli/conf_threshold_function.py b/tests/python/cli/conf_threshold_function.py new file mode 100644 index 000000000000..bcb1add2d660 --- /dev/null +++ b/tests/python/cli/conf_threshold_function.py @@ -0,0 +1,21 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.models as models +import PIL.Image + +spec = cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec("car", 0), + ], +) + + +def detect( + context: cvataa.DetectionFunctionContext, image: PIL.Image.Image +) -> list[models.LabeledShapeRequest]: + return [ + cvataa.rectangle(0, [context.conf_threshold, 1, 1, 1]), + ] diff --git a/tests/python/cli/test_cli.py b/tests/python/cli/test_cli.py index d6b19cfe0a3c..8008f44270ab 100644 --- a/tests/python/cli/test_cli.py +++ b/tests/python/cli/test_cli.py @@ -27,6 +27,8 @@ class TestCLI: def setup( self, restore_db_per_function, # force fixture call order to allow DB setup + restore_redis_inmem_per_function, + restore_redis_ondisk_per_function, fxt_stdout: io.StringIO, tmp_path: Path, admin_user: str, @@ -347,3 +349,39 @@ def test_auto_annotate_with_parameters(self, fxt_new_task: Task): annotations = fxt_new_task.get_annotations() assert annotations.shapes + + def test_auto_annotate_with_threshold(self, fxt_new_task: Task): + annotations = fxt_new_task.get_annotations() + assert not annotations.shapes + + self.run_cli( + "auto-annotate", + str(fxt_new_task.id), + f"--function-module={__package__}.conf_threshold_function", + "--conf-threshold=0.75", + ) + + annotations = fxt_new_task.get_annotations() + assert annotations.shapes[0].points[0] == 0.75 + + def test_auto_annotate_with_cmtp(self, fxt_new_task: Task): + self.run_cli( + "auto-annotate", + str(fxt_new_task.id), + f"--function-module={__package__}.cmtp_function", + "--clear-existing", + ) + + annotations = fxt_new_task.get_annotations() + assert annotations.shapes[0].type.value == "mask" + + self.run_cli( + "auto-annotate", + str(fxt_new_task.id), + f"--function-module={__package__}.cmtp_function", + "--clear-existing", + "--conv-mask-to-poly", + ) + + annotations = fxt_new_task.get_annotations() + assert annotations.shapes[0].type.value == "polygon" diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt index 6ef44c0f5edb..5dfad3d6f7fb 100644 --- a/tests/python/requirements.txt +++ b/tests/python/requirements.txt @@ -4,9 +4,9 @@ pytest-cases==3.6.13 pytest-timeout==2.1.0 pytest-cov==4.1.0 requests==2.32.2 -deepdiff==5.6.0 +deepdiff==7.0.1 boto3==1.17.61 Pillow==10.3.0 python-dateutil==2.8.2 pyyaml==6.0.0 -numpy==1.22.0 \ No newline at end of file +numpy==2.0.0 diff --git a/tests/python/rest_api/test_jobs.py b/tests/python/rest_api/test_jobs.py index 5057f652030c..e7b405dce9e9 100644 --- a/tests/python/rest_api/test_jobs.py +++ b/tests/python/rest_api/test_jobs.py @@ -691,6 +691,7 @@ def test_get_gt_job_in_org_task( @pytest.mark.usefixtures("restore_db_per_class") @pytest.mark.usefixtures("restore_redis_ondisk_per_class") +@pytest.mark.usefixtures("restore_redis_inmem_per_class") class TestGetGtJobData: def _delete_gt_job(self, user, gt_job_id): with make_api_client(user) as api_client: diff --git a/tests/python/rest_api/test_projects.py b/tests/python/rest_api/test_projects.py index abfccd5f6b03..d3d807d68088 100644 --- a/tests/python/rest_api/test_projects.py +++ b/tests/python/rest_api/test_projects.py @@ -19,7 +19,7 @@ from typing import Optional, Union import pytest -from cvat_sdk.api_client import ApiClient, Configuration, models +from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models from cvat_sdk.api_client.api_client import Endpoint from cvat_sdk.api_client.exceptions import ForbiddenException from cvat_sdk.core.helpers import get_paginated_collection @@ -37,8 +37,10 @@ from shared.utils.helpers import generate_image_files from .utils import ( + DATUMARO_FORMAT_FOR_DIMENSION, CollectionSimpleFilterTestBase, create_task, + export_dataset, export_project_backup, export_project_dataset, ) @@ -991,6 +993,68 @@ def test_can_export_and_import_dataset_after_deleting_related_storage( self._test_import_project(admin_user, project_id, "CVAT 1.1", import_data) + @pytest.mark.parametrize( + "dimension, format_name", + [ + *DATUMARO_FORMAT_FOR_DIMENSION.items(), + ("2d", "CVAT 1.1"), + ("3d", "CVAT 1.1"), + ("2d", "COCO 1.0"), + ], + ) + def test_cant_import_annotations_as_project(self, admin_user, tasks, format_name, dimension): + task = next(t for t in tasks if t.get("size") if t["dimension"] == dimension) + + def _export_task(task_id: int, format_name: str) -> io.BytesIO: + with make_api_client(admin_user) as api_client: + return io.BytesIO( + export_dataset( + api_client.tasks_api, + api_version=2, + id=task_id, + format=format_name, + save_images=False, + ) + ) + + if format_name in list(DATUMARO_FORMAT_FOR_DIMENSION.values()): + with zipfile.ZipFile(_export_task(task["id"], format_name)) as zip_file: + annotations = zip_file.read("annotations/default.json") + + dataset_file = io.BytesIO(annotations) + dataset_file.name = "annotations.json" + elif format_name == "CVAT 1.1": + with zipfile.ZipFile(_export_task(task["id"], "CVAT for images 1.1")) as zip_file: + annotations = zip_file.read("annotations.xml") + + dataset_file = io.BytesIO(annotations) + dataset_file.name = "annotations.xml" + elif format_name == "COCO 1.0": + with zipfile.ZipFile(_export_task(task["id"], format_name)) as zip_file: + annotations = zip_file.read("annotations/instances_default.json") + + dataset_file = io.BytesIO(annotations) + dataset_file.name = "annotations.json" + else: + assert False + + with make_api_client(admin_user) as api_client: + project, _ = api_client.projects_api.create( + project_write_request=models.ProjectWriteRequest( + name=f"test_annotations_import_as_project {format_name}" + ) + ) + + import_data = {"dataset_file": dataset_file} + + with pytest.raises(exceptions.ApiException, match="Dataset file should be zip archive"): + self._test_import_project( + admin_user, + project.id, + format_name=format_name, + data=import_data, + ) + @pytest.mark.parametrize( "export_format, subset_path_template", [ @@ -1045,10 +1109,7 @@ def test_creates_subfolders_for_subsets_on_export( len([f for f in zip_file.namelist() if f.startswith(folder_prefix)]) > 0 ), f"No {folder_prefix} in {zip_file.namelist()}" - def test_export_project_with_honeypots( - self, - admin_user: str, - ): + def test_export_project_with_honeypots(self, admin_user: str): project_spec = { "name": "Project with honeypots", "labels": [{"name": "cat"}], diff --git a/tests/python/rest_api/test_tasks.py b/tests/python/rest_api/test_tasks.py index a61683d981e2..be49c9d43ca1 100644 --- a/tests/python/rest_api/test_tasks.py +++ b/tests/python/rest_api/test_tasks.py @@ -64,9 +64,11 @@ ) from .utils import ( + DATUMARO_FORMAT_FOR_DIMENSION, CollectionSimpleFilterTestBase, compare_annotations, create_task, + export_dataset, export_task_backup, export_task_dataset, parse_frame_step, @@ -888,6 +890,7 @@ def test_can_export_task_to_coco_format(self, admin_user: str, tid: int, api_ver @pytest.mark.parametrize("api_version", (1, 2)) @pytest.mark.usefixtures("restore_db_per_function") + @pytest.mark.usefixtures("restore_redis_ondisk_per_function") def test_can_download_task_with_special_chars_in_name(self, admin_user: str, api_version: int): # Control characters in filenames may conflict with the Content-Disposition header # value restrictions, as it needs to include the downloaded file name. @@ -969,11 +972,52 @@ def test_uses_subset_name( subset_path in path for path in zip_file.namelist() ), f"No {subset_path} in {zip_file.namelist()}" + @pytest.mark.parametrize( + "dimension, mode", [("2d", "annotation"), ("2d", "interpolation"), ("3d", "annotation")] + ) + def test_datumaro_export_without_annotations_includes_image_info( + self, admin_user, tasks, mode, dimension + ): + task = next( + t for t in tasks if t.get("size") if t["mode"] == mode if t["dimension"] == dimension + ) + + with make_api_client(admin_user) as api_client: + dataset_file = io.BytesIO( + export_dataset( + api_client.tasks_api, + api_version=2, + id=task["id"], + format=DATUMARO_FORMAT_FOR_DIMENSION[dimension], + save_images=False, + ) + ) + + with zipfile.ZipFile(dataset_file) as zip_file: + annotations = json.loads(zip_file.read("annotations/default.json")) + + assert annotations["items"] + for item in annotations["items"]: + assert "media" not in item + + if dimension == "2d": + assert osp.splitext(item["image"]["path"])[0] == item["id"] + assert not Path(item["image"]["path"]).is_absolute() + assert tuple(item["image"]["size"]) > (0, 0) + elif dimension == "3d": + assert osp.splitext(osp.basename(item["point_cloud"]["path"]))[0] == item["id"] + assert not Path(item["point_cloud"]["path"]).is_absolute() + for related_image in item["related_images"]: + assert not Path(related_image["path"]).is_absolute() + if "size" in related_image: + assert tuple(related_image["size"]) > (0, 0) + @pytest.mark.usefixtures("restore_db_per_function") @pytest.mark.usefixtures("restore_cvat_data_per_function") @pytest.mark.usefixtures("restore_redis_ondisk_per_function") @pytest.mark.usefixtures("restore_redis_ondisk_after_class") +@pytest.mark.usefixtures("restore_redis_inmem_per_function") class TestPostTaskData: _USERNAME = "admin1" @@ -2683,8 +2727,9 @@ def read_frame(self, i: int) -> Image.Image: @pytest.mark.usefixtures("restore_db_per_class") @pytest.mark.usefixtures("restore_cvat_data_per_class") -@pytest.mark.usefixtures("restore_redis_ondisk_per_class") +@pytest.mark.usefixtures("restore_redis_ondisk_per_function") @pytest.mark.usefixtures("restore_redis_ondisk_after_class") +@pytest.mark.usefixtures("restore_redis_inmem_per_function") class TestTaskData: _USERNAME = "admin1" @@ -3774,6 +3819,7 @@ def test_admin_can_add_skeleton(self, tasks, admin_user): @pytest.mark.usefixtures("restore_db_per_function") @pytest.mark.usefixtures("restore_cvat_data_per_function") @pytest.mark.usefixtures("restore_redis_ondisk_per_function") +@pytest.mark.usefixtures("restore_redis_inmem_per_function") class TestWorkWithTask: _USERNAME = "admin1" @@ -4654,7 +4700,7 @@ def test_task_unassigned_cannot_see_task_preview( self._test_assigned_users_cannot_see_task_preview(tasks, users, is_task_staff) -@pytest.mark.usefixtures("restore_redis_ondisk_per_class") +@pytest.mark.usefixtures("restore_redis_ondisk_per_function") @pytest.mark.usefixtures("restore_redis_ondisk_after_class") class TestUnequalJobs: @pytest.fixture(autouse=True) @@ -5181,6 +5227,47 @@ def test_import_annotations_after_deleting_related_cloud_storage( task.import_annotations(self.import_format, file_path) self._check_annotations(task_id) + @pytest.mark.parametrize("dimension", ["2d", "3d"]) + def test_can_import_datumaro_json(self, admin_user, tasks, dimension): + task = next( + t + for t in tasks + if t.get("size") + if t["dimension"] == dimension and t.get("validation_mode") != "gt_pool" + ) + + with make_api_client(admin_user) as api_client: + original_annotations = json.loads( + api_client.tasks_api.retrieve_annotations(task["id"])[1].data + ) + + dataset_archive = io.BytesIO( + export_dataset( + api_client.tasks_api, + api_version=2, + id=task["id"], + format=DATUMARO_FORMAT_FOR_DIMENSION[dimension], + save_images=False, + ) + ) + + with zipfile.ZipFile(dataset_archive) as zip_file: + annotations = zip_file.read("annotations/default.json") + + with TemporaryDirectory() as tempdir: + annotations_path = Path(tempdir) / "annotations.json" + annotations_path.write_bytes(annotations) + self.client.tasks.retrieve(task["id"]).import_annotations( + DATUMARO_FORMAT_FOR_DIMENSION[dimension], annotations_path + ) + + with make_api_client(admin_user) as api_client: + updated_annotations = json.loads( + api_client.tasks_api.retrieve_annotations(task["id"])[1].data + ) + + assert compare_annotations(original_annotations, updated_annotations) == {} + @pytest.mark.parametrize( "format_name", [ diff --git a/tests/python/rest_api/utils.py b/tests/python/rest_api/utils.py index 434c3705ddc3..aa747d169e9d 100644 --- a/tests/python/rest_api/utils.py +++ b/tests/python/rest_api/utils.py @@ -573,7 +573,7 @@ def create_task(username, spec, data, content_type="application/json", **kwargs) return task.id, response_.headers.get("X-Request-Id") -def compare_annotations(a, b): +def compare_annotations(a: dict, b: dict) -> dict: def _exclude_cb(obj, path): return path.endswith("['elements']") and not obj @@ -593,5 +593,11 @@ def _exclude_cb(obj, path): ) +DATUMARO_FORMAT_FOR_DIMENSION = { + "2d": "Datumaro 1.0", + "3d": "Datumaro 3D 1.0", +} + + def parse_frame_step(frame_filter: str) -> int: return int((frame_filter or "step=1").split("=")[1]) diff --git a/tests/python/sdk/test_api_wrappers.py b/tests/python/sdk/test_api_wrappers.py index 84ec919c9ba2..f324637b78e9 100644 --- a/tests/python/sdk/test_api_wrappers.py +++ b/tests/python/sdk/test_api_wrappers.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: MIT +import pickle from copy import deepcopy from cvat_sdk import models @@ -112,3 +113,12 @@ def test_models_do_not_return_internal_collections(): model_data2 = model.to_dict() assert DeepDiff(model_data1_original, model_data2) == {} + + +def test_models_are_pickleable(): + model = models.PatchedLabelRequest(id=5, name="person") + pickled_model = pickle.dumps(model) + unpickled_model = pickle.loads(pickled_model) + + assert unpickled_model.id == model.id + assert unpickled_model.name == model.name diff --git a/tests/python/sdk/test_auto_annotation.py b/tests/python/sdk/test_auto_annotation.py index ae4a0d711774..0d22100cfb15 100644 --- a/tests/python/sdk/test_auto_annotation.py +++ b/tests/python/sdk/test_auto_annotation.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: MIT import io +import math from logging import Logger from pathlib import Path from types import SimpleNamespace as namespace @@ -29,6 +30,7 @@ def _common_setup( fxt_login: tuple[Client, str], fxt_logger: tuple[Logger, io.StringIO], restore_redis_ondisk_per_function, + restore_redis_inmem_per_function, ): logger = fxt_logger[0] client = fxt_login[0] @@ -269,6 +271,77 @@ def detect(context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]: assert shapes[i].points == [5, 6, 7, 8] assert shapes[i].rotation == 10 + def test_conf_threshold(self): + spec = cvataa.DetectionFunctionSpec(labels=[]) + + received_threshold = None + + def detect( + context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> list[models.LabeledShapeRequest]: + nonlocal received_threshold + received_threshold = context.conf_threshold + return [] + + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + conf_threshold=0.75, + ) + + assert received_threshold == 0.75 + + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + ) + + assert received_threshold is None + + for bad_threshold in [-0.1, 1.1]: + with pytest.raises(ValueError): + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + conf_threshold=bad_threshold, + ) + + def test_conv_mask_to_poly(self): + spec = cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec("car", 123), + ], + ) + + received_cmtp = None + + def detect(context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]: + nonlocal received_cmtp + received_cmtp = context.conv_mask_to_poly + return [cvataa.mask(123, [1, 0, 0, 0, 0])] + + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + conv_mask_to_poly=False, + ) + + assert received_cmtp is False + + with pytest.raises(cvataa.BadFunctionError, match=".*conv_mask_to_poly.*"): + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + conv_mask_to_poly=True, + ) + + assert received_cmtp is True + def _test_bad_function_spec(self, spec: cvataa.DetectionFunctionSpec, exc_match: str) -> None: def detect(context, image): assert False @@ -575,8 +648,9 @@ def forward(self, images: list[torch.Tensor]) -> list[dict]: return [ { - "boxes": torch.tensor([[1, 2, 3, 4]]), - "labels": torch.tensor([self._label_id]), + "boxes": torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]), + "labels": torch.tensor([self._label_id, self._label_id]), + "scores": torch.tensor([0.75, 0.74]), } ] @@ -587,6 +661,60 @@ def fake_get_detection_model(name: str, weights, test_param): return FakeTorchvisionDetector(label_id=car_label_id) + class FakeTorchvisionInstanceSegmenter(nn.Module): + def __init__(self, label_id: int) -> None: + super().__init__() + self._label_id = label_id + + def forward(self, images: list[torch.Tensor]) -> list[dict]: + assert isinstance(images, list) + assert all(isinstance(t, torch.Tensor) for t in images) + + def make_box(im, a1, a2): + return [im.shape[2] * a1, im.shape[1] * a1, im.shape[2] * a2, im.shape[1] * a2] + + def make_mask(im, a1, a2): + # creates a rectangular mask with a hole + mask = torch.full((1, im.shape[1], im.shape[2]), 0.49) + mask[ + 0, + math.ceil(im.shape[1] * a1) : math.floor(im.shape[1] * a2), + math.ceil(im.shape[2] * a1) : math.floor(im.shape[2] * a2), + ] = 0.5 + mask[ + 0, + math.ceil(im.shape[1] * a1) + 3 : math.floor(im.shape[1] * a2) - 3, + math.ceil(im.shape[2] * a1) + 3 : math.floor(im.shape[2] * a2) - 3, + ] = 0.49 + return mask + + return [ + { + "labels": torch.tensor([self._label_id, self._label_id]), + "boxes": torch.tensor( + [ + make_box(im, 1 / 6, 1 / 3), + make_box(im, 2 / 3, 5 / 6), + ] + ), + "masks": torch.stack( + [ + make_mask(im, 1 / 6, 1 / 3), + make_mask(im, 2 / 3, 5 / 6), + ] + ), + "scores": torch.tensor([0.75, 0.74]), + } + for im in images + ] + + def fake_get_instance_segmentation_model(name: str, weights, test_param): + assert test_param == "expected_value" + + car_label_id = weights.meta["categories"].index("car") + + return FakeTorchvisionInstanceSegmenter(label_id=car_label_id) + class FakeTorchvisionKeypointDetector(nn.Module): def __init__(self, label_id: int, keypoint_names: list[str]) -> None: super().__init__() @@ -599,15 +727,17 @@ def forward(self, images: list[torch.Tensor]) -> list[dict]: return [ { - "labels": torch.tensor([self._label_id]), + "labels": torch.tensor([self._label_id, self._label_id]), "keypoints": torch.tensor( [ [ [hash(name) % 100, 0, 1 if name.startswith("right_") else 0] for i, name in enumerate(self._keypoint_names) - ] + ], + [[0, 0, 1] for i, name in enumerate(self._keypoint_names)], ] ), + "scores": torch.tensor([0.75, 0.74]), } ] @@ -672,6 +802,7 @@ def test_torchvision_detection(self, monkeypatch: pytest.MonkeyPatch): self.task.id, td.create("fasterrcnn_resnet50_fpn_v2", "COCO_V1", test_param="expected_value"), allow_unmatched_labels=True, + conf_threshold=0.75, ) annotations = self.task.get_annotations() @@ -681,6 +812,54 @@ def test_torchvision_detection(self, monkeypatch: pytest.MonkeyPatch): assert annotations.shapes[0].type.value == "rectangle" assert annotations.shapes[0].points == [1, 2, 3, 4] + def test_torchvision_instance_segmentation(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(torchvision_models, "get_model", fake_get_instance_segmentation_model) + + import cvat_sdk.auto_annotation.functions.torchvision_instance_segmentation as tis + from cvat_sdk.masks import encode_mask + + cvataa.annotate_task( + self.client, + self.task.id, + tis.create("maskrcnn_resnet50_fpn_v2", "COCO_V1", test_param="expected_value"), + allow_unmatched_labels=True, + conf_threshold=0.75, + ) + + annotations = self.task.get_annotations() + + assert len(annotations.shapes) == 1 + assert self.task_labels_by_id[annotations.shapes[0].label_id].name == "car" + + expected_bitmap = torch.zeros((100, 100), dtype=torch.bool) + expected_bitmap[17:33, 17:33] = True + expected_bitmap[20:30, 20:30] = False + + assert annotations.shapes[0].type.value == "mask" + assert annotations.shapes[0].points == encode_mask(expected_bitmap, [16, 16, 34, 34]) + + cvataa.annotate_task( + self.client, + self.task.id, + tis.create("maskrcnn_resnet50_fpn_v2", "COCO_V1", test_param="expected_value"), + allow_unmatched_labels=True, + conf_threshold=0.75, + conv_mask_to_poly=True, + clear_existing=True, + ) + + annotations = self.task.get_annotations() + + assert len(annotations.shapes) == 1 + assert self.task_labels_by_id[annotations.shapes[0].label_id].name == "car" + assert annotations.shapes[0].type.value == "polygon" + + # We shouldn't rely on the exact result of polygon conversion, + # since it depends on a 3rd-party library. Instead, we'll just + # check that all points are within the expected area. + for x, y in zip(*[iter(annotations.shapes[0].points)] * 2): + assert expected_bitmap[round(y), round(x)] + def test_torchvision_keypoint_detection(self, monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr(torchvision_models, "get_model", fake_get_keypoint_detection_model) @@ -691,6 +870,7 @@ def test_torchvision_keypoint_detection(self, monkeypatch: pytest.MonkeyPatch): self.task.id, tkd.create("keypointrcnn_resnet50_fpn", "COCO_V1", test_param="expected_value"), allow_unmatched_labels=True, + conf_threshold=0.75, ) annotations = self.task.get_annotations() diff --git a/tests/python/sdk/test_datasets.py b/tests/python/sdk/test_datasets.py index 525082d0eae3..7f13e75ea92f 100644 --- a/tests/python/sdk/test_datasets.py +++ b/tests/python/sdk/test_datasets.py @@ -23,6 +23,7 @@ def _common_setup( fxt_login: tuple[Client, str], fxt_logger: tuple[Logger, io.StringIO], restore_redis_ondisk_per_function, + restore_redis_inmem_per_function, ): logger = fxt_logger[0] client = fxt_login[0] diff --git a/tests/python/sdk/test_masks.py b/tests/python/sdk/test_masks.py new file mode 100644 index 000000000000..46e8b9f214cc --- /dev/null +++ b/tests/python/sdk/test_masks.py @@ -0,0 +1,71 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import pytest + +try: + import numpy as np + from cvat_sdk.masks import encode_mask + +except ModuleNotFoundError as e: + if e.name.split(".")[0] != "numpy": + raise + + encode_mask = None + + +@pytest.mark.skipif(encode_mask is None, reason="NumPy is not installed") +class TestMasks: + def test_encode_mask(self): + bitmap = np.array( + [ + np.fromstring("0 0 1 1 1 0", sep=" "), + np.fromstring("0 1 1 0 0 0", sep=" "), + ], + dtype=np.bool_, + ) + bbox = [2.9, 0.9, 4.1, 1.1] # will get rounded to [2, 0, 5, 2] + + # There's slightly different logic for when the cropped mask starts with + # 0 and 1, so test both. + # This one starts with 1: + # 111 + # 100 + + assert encode_mask(bitmap, bbox) == [0, 4, 2, 2, 0, 4, 1] + + bbox = [1, 0, 5, 2] + + # This one starts with 0: + # 0111 + # 1100 + + assert encode_mask(bitmap, bbox) == [1, 5, 2, 1, 0, 4, 1] + + # Edge case: full image + bbox = [0, 0, 6, 2] + assert encode_mask(bitmap, bbox) == [2, 3, 2, 2, 3, 0, 0, 5, 1] + + def test_encode_mask_invalid_dim(self): + with pytest.raises(ValueError, match="bitmap must have 2 dimensions"): + encode_mask([True], [0, 0, 1, 1]) + + def test_encode_mask_invalid_dtype(self): + with pytest.raises(ValueError, match="bitmap must have boolean items"): + encode_mask([[1]], [0, 0, 1, 1]) + + @pytest.mark.parametrize( + "bbox", + [ + [-0.1, 0, 1, 1], + [0, -0.1, 1, 1], + [0, 0, 1.1, 1], + [0, 0, 1, 1.1], + [1, 0, 0, 1], + [0, 1, 1, 0], + ], + ) + def test_encode_mask_invalid_bbox(self, bbox): + with pytest.raises(ValueError, match="bbox has invalid coordinates"): + encode_mask([[True]], bbox) diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 8e6918abf301..1427a070d46b 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -36,6 +36,7 @@ def _common_setup( fxt_login: tuple[Client, str], fxt_logger: tuple[Logger, io.StringIO], restore_redis_ondisk_per_function, + restore_redis_inmem_per_function, ): logger = fxt_logger[0] client = fxt_login[0]