diff --git a/.github/workflows/full.yml b/.github/workflows/full.yml
index e587e26aa1b8..e42380de5ead 100644
--- a/.github/workflows/full.yml
+++ b/.github/workflows/full.yml
@@ -156,7 +156,7 @@ jobs:
- name: Install SDK
run: |
pip3 install -r ./tests/python/requirements.txt \
- -e './cvat-sdk[pytorch]' -e ./cvat-cli \
+ -e './cvat-sdk[masks,pytorch]' -e ./cvat-cli \
--extra-index-url https://download.pytorch.org/whl/cpu
- name: Running REST API and SDK tests
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index f4e3f11d1052..becca0218f94 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -166,7 +166,7 @@ jobs:
- name: Install SDK
run: |
pip3 install -r ./tests/python/requirements.txt \
- -e './cvat-sdk[pytorch]' -e ./cvat-cli \
+ -e './cvat-sdk[masks,pytorch]' -e ./cvat-cli \
--extra-index-url https://download.pytorch.org/whl/cpu
- name: Run REST API and SDK tests
diff --git a/.vscode/launch.json b/.vscode/launch.json
index af93ae24c007..78f24c96ca83 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -4,6 +4,7 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
+
{
"name": "REST API tests: Attach to server",
"type": "debugpy",
@@ -168,7 +169,7 @@
"CVAT_SERVERLESS": "1",
"ALLOWED_HOSTS": "*",
"DJANGO_LOG_SERVER_HOST": "localhost",
- "DJANGO_LOG_SERVER_PORT": "8282"
+ "DJANGO_LOG_SERVER_PORT": "8282",
},
"args": [
"runserver",
@@ -178,7 +179,7 @@
],
"django": true,
"cwd": "${workspaceFolder}",
- "console": "internalConsole"
+ "console": "internalConsole",
},
{
"name": "server: chrome",
@@ -360,6 +361,28 @@
},
"console": "internalConsole"
},
+ {
+ "name": "server: RQ - chunks",
+ "type": "debugpy",
+ "request": "launch",
+ "stopOnEntry": false,
+ "justMyCode": false,
+ "python": "${command:python.interpreterPath}",
+ "program": "${workspaceFolder}/manage.py",
+ "args": [
+ "rqworker",
+ "chunks",
+ "--worker-class",
+ "cvat.rqworker.SimpleWorker"
+ ],
+ "django": true,
+ "cwd": "${workspaceFolder}",
+ "env": {
+ "DJANGO_LOG_SERVER_HOST": "localhost",
+ "DJANGO_LOG_SERVER_PORT": "8282"
+ },
+ "console": "internalConsole"
+ },
{
"name": "server: migrate",
"type": "debugpy",
@@ -553,7 +576,8 @@
"server: RQ - scheduler",
"server: RQ - quality reports",
"server: RQ - analytics reports",
- "server: RQ - cleaning"
+ "server: RQ - cleaning",
+ "server: RQ - chunks",
]
}
]
diff --git a/.vscode/settings.json b/.vscode/settings.json
index a0caaf036765..baf7dc5b3879 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -29,6 +29,15 @@
"database": "${workspaceFolder:cvat}/db.sqlite3"
}
],
+ "python.analysis.exclude": [
+ // VS Code defaults
+ "**/node_modules",
+ "**/__pycache__",
+ ".git",
+
+ "cvat-cli/build",
+ "cvat-sdk/build",
+ ],
"python.defaultInterpreterPath": "${workspaceFolder}/.env/",
"python.testing.pytestArgs": [
"--rootdir","${workspaceFolder}/tests/"
diff --git a/CHANGELOG.md b/CHANGELOG.md
index dd8854040e02..a9143f436f05 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -16,6 +16,112 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
+
+## \[2.23.0\] - 2024-11-29
+
+### Added
+
+- Support for direct .json file import in Datumaro format
+ ()
+
+- \[SDK, CLI\] Added a `conf_threshold` parameter to
+ `cvat_sdk.auto_annotation.annotate_task`, which is passed as-is to the AA
+ function object via the context. The CLI equivalent is `auto-annotate
+ --conf-threshold`. This makes it easier to write and use AA functions that
+ support object filtering based on confidence levels
+ ()
+
+- \[SDK\] Built-in auto-annotation functions now support object filtering by
+ confidence level
+ ()
+
+- New events (create|update|delete):(membership|webhook) and (create|delete):invitation
+ ()
+
+- \[SDK\] Added new auto-annotation helpers (`mask`, `polygon`, `encode_mask`)
+ to support AA functions that return masks or polygons
+ ()
+
+- \[SDK\] Added a new built-in auto-annotation function,
+ `torchvision_instance_segmentation`
+ ()
+
+- \[SDK, CLI\] Added a new auto-annotation parameter, `conv_mask_to_poly`
+ (`--conv-mask-to-poly` in the CLI)
+ ()
+
+- A user may undo or redo changes, made by an annotations actions using general approach (e.g. Ctrl+Z, Ctrl+Y)
+ ()
+
+- Basically, annotations actions now support any kinds of objects (shapes, tracks, tags)
+ ()
+
+- A user may run annotations actions on a certain object (added corresponding object menu item)
+ ()
+
+- A shortcut to open annotations actions modal for a currently selected object
+ ()
+
+- A default role if IAM_TYPE='LDAP' and if the user is not a member of any group in 'DJANGO_AUTH_LDAP_GROUPS' ()
+
+- The `POST /api/lambda/requests` endpoint now has a `conv_mask_to_poly`
+ parameter with the same semantics as the old `convMaskToPoly` parameter
+ ()
+
+- \[SDK\] Model instances can now be pickled
+ ()
+
+### Changed
+
+- Chunks are now prepared in a separate worker process
+ ()
+
+- \[Helm\] Traefik sticky sessions for the backend service are disabled
+ ()
+
+- Payload for events (create|update|delete):(shapes|tags|tracks) does not include frame and attributes anymore
+ ()
+
+### Deprecated
+
+- The `convMaskToPoly` parameter of the `POST /api/lambda/requests` endpoint
+ is deprecated; use `conv_mask_to_poly` instead
+ ()
+
+### Removed
+
+- It it no longer possible to run lambda functions on compressed images;
+ original images will always be used
+ ()
+
+### Fixed
+
+- Export without images in Datumaro format should include image info
+ ()
+
+- Inconsistent zOrder behavior on job open
+ ()
+
+- Ground truth annotations can be shown in standard mode
+ ()
+
+- Keybinds in UI allow drawing disabled shape types
+ ()
+
+- Style issues on the Quality page when browser zoom is applied
+ ()
+- Flickering of masks in review mode, even when no conflicts are highlighted
+ ()
+
+- Fixed security header duplication in HTTP responses from the backend
+ ()
+
+- The error occurs when trying to copy/paste a mask on a video after opening the job
+ ()
+
+- Attributes do not get copied when copy/paste a mask
+ ()
+
## \[2.22.0\] - 2024-11-11
@@ -87,7 +193,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
()
- Tags in ground truth job couldn't be deleted via `x` button
- ()
+ ()
- Exception 'Canvas is busy' when change frame during drag/resize a track
()
@@ -377,7 +483,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Deprecated
- Client events `upload:annotations`, `lock:object`, `change:attribute`, `change:label`
- ()
+ ()
### Removed
@@ -404,7 +510,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
()
- Sometimes it is not possible to switch workspace because active control broken after
-trying to create a tag with a shortcut ()
+ trying to create a tag with a shortcut
+ ()
## \[2.16.3\] - 2024-08-13
@@ -445,13 +552,14 @@ trying to create a tag with a shortcut ()
+ **Asset is already related to another guide**
+ ()
- Undo can't be done when a shape is rotated
()
- Exporting a skeleton track in a format defined for shapes raises error
-`operands could not be broadcast together with shapes (X, ) (Y, )`
+ `operands could not be broadcast together with shapes (X, ) (Y, )`
()
- Delete label modal window does not have cancellation button
@@ -470,10 +578,11 @@ trying to create a tag with a shortcut ()
- API call to run automatic annotations fails on a model with attributes
- when mapping not provided in the request ()
+ when mapping not provided in the request
+ ()
- Fixed a label collision issue where labels with similar prefixes
-and numeric suffixes could conflict, causing error on export.
+ and numeric suffixes could conflict, causing error on export.
()
@@ -510,9 +619,9 @@ and numeric suffixes could conflict, causing error on export.
### Added
- Set of features to track background activities: importing/exporting datasets, annotations or backups, creating tasks.
-Now you may find these processes on Requests page, it allows a user to understand current status of these activities
-and enhances user experience, not losing progress when the browser tab is closed
-()
+ Now you may find these processes on Requests page, it allows a user to understand current status of these activities
+ and enhances user experience, not losing progress when the browser tab is closed
+ ()
- User now may update a job state from the corresponding task page
()
@@ -523,7 +632,8 @@ and enhances user experience, not losing progress when the browser tab is closed
### Changed
- "Finish the job" button on annotation view now only sets state to 'completed'.
- The job stage keeps unchanged ()
+ The job stage keeps unchanged
+ ()
- Log files for individual backend processes are now stored in ephemeral
storage of each backend container rather than in the `cvat_logs` volume
@@ -535,7 +645,7 @@ and enhances user experience, not losing progress when the browser tab is closed
### Removed
- Renew the job button in annotation menu was removed
- ()
+ ()
### Fixed
@@ -583,10 +693,12 @@ and enhances user experience, not losing progress when the browser tab is closed
()
- Exception 'this.el.node.getScreenCTM() is null' occuring in Firefox when
-a user resizes window during skeleton dragging/resizing ()
+ a user resizes window during skeleton dragging/resizing
+ ()
- Exception 'Edge's nodeFrom M or nodeTo N do not to refer to any node'
-occuring when a user resizes window during skeleton dragging/resizing ()
+ occuring when a user resizes window during skeleton dragging/resizing
+ ()
- Slightly broken layout when running attributed face detection model
()
@@ -644,7 +756,8 @@ occuring when a user resizes window during skeleton dragging/resizing ()
- When use route `/auth/login-with-token/` without `next` query parameter
-the page reloads infinitely ()
+ the page reloads infinitely
+ ()
- Fixed kvrocks port naming for istio
()
@@ -815,7 +928,7 @@ the page reloads infinitely ()
- Opening update CS page sends infinite requests when CS id does not exist
()
-Uploading files with TUS immediately failed when one of the requests failed
+- Uploading files with TUS immediately failed when one of the requests failed
()
- Longer analytics report calculation because of inefficient requests to analytics db
@@ -985,7 +1098,7 @@ Uploading files with TUS immediately failed when one of the requests failed
()
- 90 deg-rotated video was added with "Prefer Zip Chunks" disabled
-was warped, fixed using the static cropImage function.
+ was warped, fixed using the static cropImage function.
()
@@ -1023,7 +1136,7 @@ was warped, fixed using the static cropImage function.
### Added
- Single shape annotation mode allowing to easily annotate scenarious where a user
-only needs to draw one object on one image ()
+ only needs to draw one object on one image ()
### Fixed
@@ -1151,7 +1264,7 @@ only needs to draw one object on one image ()
- \[Compose, Helm\] Updated Clickhouse to version 23.11.*
@@ -1200,11 +1313,11 @@ longer accepted automatically. Instead, the invitee can now review the invitatio
()
- Error message `Edge's nodeFrom ${dataNodeFrom} or nodeTo ${dataNodeTo} do not to refer to any node`
- when upload a file with some abscent skeleton nodes ()
+ when upload a file with some abscent skeleton nodes ()
- Wrong context menu position in skeleton configurator (Firefox only)
- ()
+ ()
- Fixed console error `(Error: attribute width: A negative value is not valid`
- appearing when skeleton with all outside elements is created ()
+ appearing when skeleton with all outside elements is created ()
- Updating cloud storage attached to CVAT using Azure connection string
()
@@ -1215,7 +1328,7 @@ longer accepted automatically. Instead, the invitee can now review the invitatio
### Added
- Introduced CVAT actions. Actions allow performing different
- predefined scenarios on annotations automatically (e.g. shape converters)
+ predefined scenarios on annotations automatically (e.g. shape converters)
()
- The UI will now retry requests that were rejected due to rate limiting
diff --git a/Dockerfile.ui b/Dockerfile.ui
index 170ee1a76633..da9c36d38960 100644
--- a/Dockerfile.ui
+++ b/Dockerfile.ui
@@ -1,11 +1,5 @@
FROM node:lts-slim AS cvat-ui
-ARG WA_PAGE_VIEW_HIT
-ARG UI_APP_CONFIG
-ARG CLIENT_PLUGINS
-ARG DISABLE_SOURCE_MAPS
-ARG SOURCE_MAPS_TOKEN
-
ENV TERM=xterm \
LANG='C.UTF-8' \
LC_ALL='C.UTF-8'
@@ -29,6 +23,13 @@ COPY cvat-core/ /tmp/cvat-core/
COPY cvat-canvas3d/ /tmp/cvat-canvas3d/
COPY cvat-canvas/ /tmp/cvat-canvas/
COPY cvat-ui/ /tmp/cvat-ui/
+
+ARG WA_PAGE_VIEW_HIT
+ARG UI_APP_CONFIG
+ARG CLIENT_PLUGINS
+ARG DISABLE_SOURCE_MAPS
+ARG SOURCE_MAPS_TOKEN
+
RUN CLIENT_PLUGINS="${CLIENT_PLUGINS}" \
DISABLE_SOURCE_MAPS="${DISABLE_SOURCE_MAPS}" \
UI_APP_CONFIG="${UI_APP_CONFIG}" \
diff --git a/cvat-canvas/src/typescript/canvasView.ts b/cvat-canvas/src/typescript/canvasView.ts
index 4c346b4d6735..d1bab2369521 100644
--- a/cvat-canvas/src/typescript/canvasView.ts
+++ b/cvat-canvas/src/typescript/canvasView.ts
@@ -2877,6 +2877,9 @@ export class CanvasViewImpl implements CanvasView, Listener {
const shapeView = window.document.getElementById(`cvat_canvas_shape_${clientID}`);
if (shapeView) shapeView.classList.remove(this.getHighlightClassname());
});
+ const redrawMasks = (highlightedElements.elementsIDs.length !== 0 ||
+ this.highlightedElements.elementsIDs.length !== 0);
+
if (highlightedElements.elementsIDs.length) {
this.highlightedElements = { ...highlightedElements };
this.canvas.classList.add('cvat-canvas-highlight-enabled');
@@ -2891,9 +2894,11 @@ export class CanvasViewImpl implements CanvasView, Listener {
};
this.canvas.classList.remove('cvat-canvas-highlight-enabled');
}
- const masks = Object.values(this.drawnStates).filter((state) => state.shapeType === 'mask');
- this.deleteObjects(masks);
- this.addObjects(masks);
+ if (redrawMasks) {
+ const masks = Object.values(this.drawnStates).filter((state) => state.shapeType === 'mask');
+ this.deleteObjects(masks);
+ this.addObjects(masks);
+ }
if (this.highlightedElements.elementsIDs.length) {
this.deactivate();
const clientID = this.highlightedElements.elementsIDs[0];
diff --git a/cvat-canvas/src/typescript/masksHandler.ts b/cvat-canvas/src/typescript/masksHandler.ts
index ca6e5e469a63..7f6a4e313fb3 100644
--- a/cvat-canvas/src/typescript/masksHandler.ts
+++ b/cvat-canvas/src/typescript/masksHandler.ts
@@ -404,6 +404,10 @@ export class MasksHandlerImpl implements MasksHandler {
rle.push(wrappingBbox.left, wrappingBbox.top, wrappingBbox.right, wrappingBbox.bottom);
this.onDrawDone({
+ occluded: this.drawData.initialState.occluded,
+ attributes: { ...this.drawData.initialState.attributes },
+ color: this.drawData.initialState.color,
+ objectType: this.drawData.initialState.objectType,
shapeType: this.drawData.shapeType,
points: rle,
label: this.drawData.initialState.label,
diff --git a/cvat-cli/requirements/base.txt b/cvat-cli/requirements/base.txt
index e9be53974d91..5f27832efdb7 100644
--- a/cvat-cli/requirements/base.txt
+++ b/cvat-cli/requirements/base.txt
@@ -1,3 +1,3 @@
-cvat-sdk~=2.22.0
+cvat-sdk~=2.23.0
Pillow>=10.3.0
setuptools>=70.0.0 # not directly required, pinned by Snyk to avoid a vulnerability
diff --git a/cvat-cli/src/cvat_cli/_internal/commands.py b/cvat-cli/src/cvat_cli/_internal/commands.py
index e86ef3b6350f..324d427a64b8 100644
--- a/cvat-cli/src/cvat_cli/_internal/commands.py
+++ b/cvat-cli/src/cvat_cli/_internal/commands.py
@@ -20,7 +20,13 @@
from cvat_sdk.core.proxies.tasks import ResourceType
from .command_base import CommandGroup
-from .parsers import BuildDictAction, parse_function_parameter, parse_label_arg, parse_resource_type
+from .parsers import (
+ BuildDictAction,
+ parse_function_parameter,
+ parse_label_arg,
+ parse_resource_type,
+ parse_threshold,
+)
COMMANDS = CommandGroup(description="Perform common operations related to CVAT tasks.")
@@ -463,6 +469,19 @@ def configure_parser(self, parser: argparse.ArgumentParser) -> None:
help="Allow the function to declare labels not configured in the task",
)
+ parser.add_argument(
+ "--conf-threshold",
+ type=parse_threshold,
+ help="Confidence threshold for filtering detections",
+ default=None,
+ )
+
+ parser.add_argument(
+ "--conv-mask-to-poly",
+ action="store_true",
+ help="Convert mask shapes to polygon shapes",
+ )
+
def execute(
self,
client: Client,
@@ -473,6 +492,8 @@ def execute(
function_parameters: dict[str, Any],
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
+ conf_threshold: Optional[float],
+ conv_mask_to_poly: bool,
) -> None:
if function_module is not None:
function = importlib.import_module(function_module)
@@ -497,4 +518,6 @@ def execute(
pbar=DeferredTqdmProgressReporter(),
clear_existing=clear_existing,
allow_unmatched_labels=allow_unmatched_labels,
+ conf_threshold=conf_threshold,
+ conv_mask_to_poly=conv_mask_to_poly,
)
diff --git a/cvat-cli/src/cvat_cli/_internal/parsers.py b/cvat-cli/src/cvat_cli/_internal/parsers.py
index a66710a09f47..97dcb5b2668a 100644
--- a/cvat-cli/src/cvat_cli/_internal/parsers.py
+++ b/cvat-cli/src/cvat_cli/_internal/parsers.py
@@ -53,6 +53,17 @@ def parse_function_parameter(s: str) -> tuple[str, Any]:
return (key, value)
+def parse_threshold(s: str) -> float:
+ try:
+ value = float(s)
+ except ValueError as e:
+ raise argparse.ArgumentTypeError("must be a number") from e
+
+ if not 0 <= value <= 1:
+ raise argparse.ArgumentTypeError("must be between 0 and 1")
+ return value
+
+
class BuildDictAction(argparse.Action):
def __init__(self, option_strings, dest, default=None, **kwargs):
super().__init__(option_strings, dest, default=default or {}, **kwargs)
diff --git a/cvat-cli/src/cvat_cli/version.py b/cvat-cli/src/cvat_cli/version.py
index b2829a54b105..9b4fa879ca12 100644
--- a/cvat-cli/src/cvat_cli/version.py
+++ b/cvat-cli/src/cvat_cli/version.py
@@ -1 +1 @@
-VERSION = "2.22.0"
+VERSION = "2.23.0"
diff --git a/cvat-core/package.json b/cvat-core/package.json
index a769b74bf78c..6b9039673812 100644
--- a/cvat-core/package.json
+++ b/cvat-core/package.json
@@ -1,6 +1,6 @@
{
"name": "cvat-core",
- "version": "15.2.1",
+ "version": "15.3.0",
"type": "module",
"description": "Part of Computer Vision Tool which presents an interface for client-side integration",
"main": "src/api.ts",
diff --git a/cvat-core/src/annotations-actions.ts b/cvat-core/src/annotations-actions.ts
deleted file mode 100644
index 43d3ef29a910..000000000000
--- a/cvat-core/src/annotations-actions.ts
+++ /dev/null
@@ -1,320 +0,0 @@
-// Copyright (C) 2023-2024 CVAT.ai Corporation
-//
-// SPDX-License-Identifier: MIT
-
-import { omit, range, throttle } from 'lodash';
-import { ArgumentError } from './exceptions';
-import { SerializedCollection, SerializedShape } from './server-response-types';
-import { Job, Task } from './session';
-import { EventScope, ObjectType } from './enums';
-import ObjectState from './object-state';
-import { getAnnotations, getCollection } from './annotations';
-import { propagateShapes } from './object-utils';
-
-export interface SingleFrameActionInput {
- collection: Omit;
- frameData: {
- width: number;
- height: number;
- number: number;
- };
-}
-
-export interface SingleFrameActionOutput {
- collection: Omit;
-}
-
-export enum ActionParameterType {
- SELECT = 'select',
- NUMBER = 'number',
-}
-
-// For SELECT values should be a list of possible options
-// For NUMBER values should be a list with [min, max, step],
-// or a callback ({ instance }: { instance: Job | Task }) => [min, max, step]
-type ActionParameters = Record string[]);
- defaultValue: string | (({ instance }: { instance: Job | Task }) => string);
-}>;
-
-export enum FrameSelectionType {
- SEGMENT = 'segment',
- CURRENT_FRAME = 'current_frame',
-}
-
-export default class BaseSingleFrameAction {
- /* eslint-disable @typescript-eslint/no-unused-vars */
- public async init(
- sessionInstance: Job | Task,
- parameters: Record,
- ): Promise {
- throw new Error('Method not implemented');
- }
-
- public async destroy(): Promise {
- throw new Error('Method not implemented');
- }
-
- public async run(sessionInstance: Job | Task, input: SingleFrameActionInput): Promise {
- throw new Error('Method not implemented');
- }
-
- public get name(): string {
- throw new Error('Method not implemented');
- }
-
- public get parameters(): ActionParameters | null {
- throw new Error('Method not implemented');
- }
-
- public get frameSelection(): FrameSelectionType {
- return FrameSelectionType.SEGMENT;
- }
-}
-
-class RemoveFilteredShapes extends BaseSingleFrameAction {
- public async init(): Promise {
- // nothing to init
- }
-
- public async destroy(): Promise {
- // nothing to destroy
- }
-
- public async run(): Promise {
- return { collection: { shapes: [] } };
- }
-
- public get name(): string {
- return 'Remove filtered shapes';
- }
-
- public get parameters(): ActionParameters | null {
- return null;
- }
-}
-
-class PropagateShapes extends BaseSingleFrameAction {
- #targetFrame: number;
-
- public async init(instance, parameters): Promise {
- this.#targetFrame = parameters['Target frame'];
- }
-
- public async destroy(): Promise {
- // nothing to destroy
- }
-
- public async run(
- instance: Job | Task,
- { collection: { shapes }, frameData: { number } },
- ): Promise {
- if (number === this.#targetFrame) {
- return { collection: { shapes } };
- }
-
- const frameNumbers = instance instanceof Job ? await instance.frames.frameNumbers() : range(0, instance.size);
- const propagatedShapes = propagateShapes(shapes, number, this.#targetFrame, frameNumbers);
- return { collection: { shapes: [...shapes, ...propagatedShapes] } };
- }
-
- public get name(): string {
- return 'Propagate shapes';
- }
-
- public get parameters(): ActionParameters | null {
- return {
- 'Target frame': {
- type: ActionParameterType.NUMBER,
- values: ({ instance }) => {
- if (instance instanceof Job) {
- return [instance.startFrame, instance.stopFrame, 1].map((val) => val.toString());
- }
- return [0, instance.size - 1, 1].map((val) => val.toString());
- },
- defaultValue: ({ instance }) => {
- if (instance instanceof Job) {
- return instance.stopFrame.toString();
- }
- return (instance.size - 1).toString();
- },
- },
- };
- }
-
- public get frameSelection(): FrameSelectionType {
- return FrameSelectionType.CURRENT_FRAME;
- }
-}
-
-const registeredActions: BaseSingleFrameAction[] = [];
-
-export async function listActions(): Promise {
- return [...registeredActions];
-}
-
-export async function registerAction(action: BaseSingleFrameAction): Promise {
- if (!(action instanceof BaseSingleFrameAction)) {
- throw new ArgumentError('Provided action is not instance of BaseSingleFrameAction');
- }
-
- const { name } = action;
- if (registeredActions.map((_action) => _action.name).includes(name)) {
- throw new ArgumentError(`Action name must be unique. Name "${name}" is already exists`);
- }
-
- registeredActions.push(action);
-}
-
-registerAction(new RemoveFilteredShapes());
-registerAction(new PropagateShapes());
-
-async function runSingleFrameChain(
- instance: Job | Task,
- actionsChain: BaseSingleFrameAction[],
- actionParameters: Record[],
- frameFrom: number,
- frameTo: number,
- filters: string[],
- onProgress: (message: string, progress: number) => void,
- cancelled: () => boolean,
-): Promise {
- type IDsToHandle = { shapes: number[] };
- const event = await instance.logger.log(EventScope.annotationsAction, {
- from: frameFrom,
- to: frameTo,
- chain: actionsChain.map((action) => action.name).join(' => '),
- }, true);
-
- // if called too fast, it will freeze UI, so, add throttling here
- const wrappedOnProgress = throttle(onProgress, 100, { leading: true, trailing: true });
- const showMessageWithPause = async (message: string, progress: number, duration: number): Promise => {
- // wrapper that gives a chance to abort action
- wrappedOnProgress(message, progress);
- await new Promise((resolve) => setTimeout(resolve, duration));
- };
-
- try {
- await showMessageWithPause('Actions initialization', 0, 500);
- if (cancelled()) {
- return;
- }
-
- await Promise.all(actionsChain.map((action, idx) => {
- const declaredParameters = action.parameters;
- if (!declaredParameters) {
- return action.init(instance, {});
- }
-
- const setupValues = actionParameters[idx];
- const parameters = Object.entries(declaredParameters).reduce((acc, [name, { type, defaultValue }]) => {
- if (type === ActionParameterType.NUMBER) {
- acc[name] = +(Object.hasOwn(setupValues, name) ? setupValues[name] : defaultValue);
- } else {
- acc[name] = (Object.hasOwn(setupValues, name) ? setupValues[name] : defaultValue);
- }
- return acc;
- }, {} as Record);
-
- return action.init(instance, parameters);
- }));
-
- const exportedCollection = getCollection(instance).export();
- const handledCollection: SingleFrameActionInput['collection'] = { shapes: [] };
- const modifiedCollectionIDs: IDsToHandle = { shapes: [] };
-
- // Iterate over frames
- const totalFrames = frameTo - frameFrom + 1;
- for (let frame = frameFrom; frame <= frameTo; frame++) {
- const frameData = await Object.getPrototypeOf(instance).frames
- .get.implementation.call(instance, frame);
-
- // Ignore deleted frames
- if (!frameData.deleted) {
- // Get annotations according to filter
- const states: ObjectState[] = await getAnnotations(instance, frame, false, filters);
- const frameCollectionIDs = states.reduce((acc, val) => {
- if (val.objectType === ObjectType.SHAPE) {
- acc.shapes.push(val.clientID as number);
- }
- return acc;
- }, { shapes: [] });
-
- // Pick frame collection according to filtered IDs
- let frameCollection = {
- shapes: exportedCollection.shapes.filter((shape) => frameCollectionIDs
- .shapes.includes(shape.clientID as number)),
- };
-
- // Iterate over actions on each not deleted frame
- for await (const action of actionsChain) {
- ({ collection: frameCollection } = await action.run(instance, {
- collection: frameCollection,
- frameData: {
- width: frameData.width,
- height: frameData.height,
- number: frameData.number,
- },
- }));
- }
-
- const progress = Math.ceil(+(((frame - frameFrom) / totalFrames) * 100));
- wrappedOnProgress('Actions are running', progress);
- if (cancelled()) {
- return;
- }
-
- handledCollection.shapes.push(...frameCollection.shapes.map((shape) => omit(shape, 'id')));
- modifiedCollectionIDs.shapes.push(...frameCollectionIDs.shapes);
- }
- }
-
- await showMessageWithPause('Commiting handled objects', 100, 1500);
- if (cancelled()) {
- return;
- }
-
- exportedCollection.shapes.forEach((shape) => {
- if (Number.isInteger(shape.clientID) && !modifiedCollectionIDs.shapes.includes(shape.clientID as number)) {
- handledCollection.shapes.push(shape);
- }
- });
-
- await instance.annotations.clear();
- await instance.actions.clear();
- await instance.annotations.import({
- ...handledCollection,
- tracks: exportedCollection.tracks,
- tags: exportedCollection.tags,
- });
-
- event.close();
- } finally {
- wrappedOnProgress('Finalizing', 100);
- await Promise.all(actionsChain.map((action) => action.destroy()));
- }
-}
-
-export async function runActions(
- instance: Job | Task,
- actionsChain: BaseSingleFrameAction[],
- actionParameters: Record[],
- frameFrom: number,
- frameTo: number,
- filters: string[],
- onProgress: (message: string, progress: number) => void,
- cancelled: () => boolean,
-): Promise {
- // there will be another function for MultiFrameChains (actions handling tracks)
- return runSingleFrameChain(
- instance,
- actionsChain,
- actionParameters,
- frameFrom,
- frameTo,
- filters,
- onProgress,
- cancelled,
- );
-}
diff --git a/cvat-core/src/annotations-actions/annotations-actions.ts b/cvat-core/src/annotations-actions/annotations-actions.ts
new file mode 100644
index 000000000000..172b8cd88e3d
--- /dev/null
+++ b/cvat-core/src/annotations-actions/annotations-actions.ts
@@ -0,0 +1,113 @@
+// Copyright (C) 2024 CVAT.ai Corporation
+//
+// SPDX-License-Identifier: MIT
+
+import ObjectState from '../object-state';
+import { ArgumentError } from '../exceptions';
+import { Job, Task } from '../session';
+import { BaseAction } from './base-action';
+import {
+ BaseShapesAction, run as runShapesAction, call as callShapesAction,
+} from './base-shapes-action';
+import {
+ BaseCollectionAction, run as runCollectionAction, call as callCollectionAction,
+} from './base-collection-action';
+
+import { RemoveFilteredShapes } from './remove-filtered-shapes';
+import { PropagateShapes } from './propagate-shapes';
+
+const registeredActions: BaseAction[] = [];
+
+export async function listActions(): Promise {
+ return [...registeredActions];
+}
+
+export async function registerAction(action: BaseAction): Promise {
+ if (!(action instanceof BaseAction)) {
+ throw new ArgumentError('Provided action must inherit one of base classes');
+ }
+
+ const { name } = action;
+ if (registeredActions.map((_action) => _action.name).includes(name)) {
+ throw new ArgumentError(`Action name must be unique. Name "${name}" is already exists`);
+ }
+
+ registeredActions.push(action);
+}
+
+registerAction(new RemoveFilteredShapes());
+registerAction(new PropagateShapes());
+
+export async function runAction(
+ instance: Job | Task,
+ action: BaseAction,
+ actionParameters: Record,
+ frameFrom: number,
+ frameTo: number,
+ filters: object[],
+ onProgress: (message: string, progress: number) => void,
+ cancelled: () => boolean,
+): Promise {
+ if (action instanceof BaseShapesAction) {
+ return runShapesAction(
+ instance,
+ action,
+ actionParameters,
+ frameFrom,
+ frameTo,
+ filters,
+ onProgress,
+ cancelled,
+ );
+ }
+
+ if (action instanceof BaseCollectionAction) {
+ return runCollectionAction(
+ instance,
+ action,
+ actionParameters,
+ frameFrom,
+ filters,
+ onProgress,
+ cancelled,
+ );
+ }
+
+ return Promise.resolve();
+}
+
+export async function callAction(
+ instance: Job | Task,
+ action: BaseAction,
+ actionParameters: Record,
+ frame: number,
+ states: ObjectState[],
+ onProgress: (message: string, progress: number) => void,
+ cancelled: () => boolean,
+): Promise {
+ if (action instanceof BaseShapesAction) {
+ return callShapesAction(
+ instance,
+ action,
+ actionParameters,
+ frame,
+ states,
+ onProgress,
+ cancelled,
+ );
+ }
+
+ if (action instanceof BaseCollectionAction) {
+ return callCollectionAction(
+ instance,
+ action,
+ actionParameters,
+ frame,
+ states,
+ onProgress,
+ cancelled,
+ );
+ }
+
+ return Promise.resolve();
+}
diff --git a/cvat-core/src/annotations-actions/base-action.ts b/cvat-core/src/annotations-actions/base-action.ts
new file mode 100644
index 000000000000..3246261d2c9a
--- /dev/null
+++ b/cvat-core/src/annotations-actions/base-action.ts
@@ -0,0 +1,60 @@
+// Copyright (C) 2024 CVAT.ai Corporation
+//
+// SPDX-License-Identifier: MIT
+
+import { SerializedCollection } from 'server-response-types';
+import ObjectState from '../object-state';
+import { Job, Task } from '../session';
+
+export enum ActionParameterType {
+ SELECT = 'select',
+ NUMBER = 'number',
+}
+
+// For SELECT values should be a list of possible options
+// For NUMBER values should be a list with [min, max, step],
+// or a callback ({ instance }: { instance: Job | Task }) => [min, max, step]
+export type ActionParameters = Record string[]);
+ defaultValue: string | (({ instance }: { instance: Job | Task }) => string);
+}>;
+
+export abstract class BaseAction {
+ public abstract init(sessionInstance: Job | Task, parameters: Record): Promise;
+ public abstract destroy(): Promise;
+ public abstract run(input: unknown): Promise;
+ public abstract applyFilter(input: unknown): unknown;
+ public abstract isApplicableForObject(objectState: ObjectState): boolean;
+
+ public abstract get name(): string;
+ public abstract get parameters(): ActionParameters | null;
+}
+
+export function prepareActionParameters(declared: ActionParameters, defined: object): Record {
+ if (!declared) {
+ return {};
+ }
+
+ return Object.entries(declared).reduce((acc, [name, { type, defaultValue }]) => {
+ if (type === ActionParameterType.NUMBER) {
+ acc[name] = +(Object.hasOwn(defined, name) ? defined[name] : defaultValue);
+ } else {
+ acc[name] = (Object.hasOwn(defined, name) ? defined[name] : defaultValue);
+ }
+ return acc;
+ }, {} as Record);
+}
+
+export function validateClientIDs(collection: Partial) {
+ [].concat(
+ collection.shapes ?? [],
+ collection.tracks ?? [],
+ collection.tags ?? [],
+ ).forEach((object) => {
+ // clientID is required to correct collection filtering and commiting in annotations actions logic
+ if (typeof object.clientID !== 'number') {
+ throw new Error('ClientID is undefined when running annotations action, but required');
+ }
+ });
+}
diff --git a/cvat-core/src/annotations-actions/base-collection-action.ts b/cvat-core/src/annotations-actions/base-collection-action.ts
new file mode 100644
index 000000000000..c48135694566
--- /dev/null
+++ b/cvat-core/src/annotations-actions/base-collection-action.ts
@@ -0,0 +1,178 @@
+// Copyright (C) 2024 CVAT.ai Corporation
+//
+// SPDX-License-Identifier: MIT
+
+import { throttle } from 'lodash';
+
+import ObjectState from '../object-state';
+import AnnotationsFilter from '../annotations-filter';
+import { Job, Task } from '../session';
+import {
+ SerializedCollection, SerializedShape,
+ SerializedTag, SerializedTrack,
+} from '../server-response-types';
+import { EventScope, ObjectType } from '../enums';
+import { getCollection } from '../annotations';
+import { BaseAction, prepareActionParameters, validateClientIDs } from './base-action';
+
+export interface CollectionActionInput {
+ onProgress(message: string, percent: number): void;
+ cancelled(): boolean;
+ collection: Pick;
+ frameData: {
+ width: number;
+ height: number;
+ number: number;
+ };
+}
+
+export interface CollectionActionOutput {
+ created: CollectionActionInput['collection'];
+ deleted: CollectionActionInput['collection'];
+}
+
+export abstract class BaseCollectionAction extends BaseAction {
+ public abstract run(input: CollectionActionInput): Promise;
+ public abstract applyFilter(
+ input: Pick,
+ ): CollectionActionInput['collection'];
+}
+
+export async function run(
+ instance: Job | Task,
+ action: BaseCollectionAction,
+ actionParameters: Record,
+ frame: number,
+ filters: object[],
+ onProgress: (message: string, progress: number) => void,
+ cancelled: () => boolean,
+): Promise {
+ const event = await instance.logger.log(EventScope.annotationsAction, {
+ from: frame,
+ to: frame,
+ name: action.name,
+ }, true);
+
+ const wrappedOnProgress = throttle(onProgress, 100, { leading: true, trailing: true });
+ const showMessageWithPause = async (message: string, progress: number, duration: number): Promise => {
+ // wrapper that gives a chance to abort action
+ wrappedOnProgress(message, progress);
+ await new Promise((resolve) => setTimeout(resolve, duration));
+ };
+
+ try {
+ await showMessageWithPause('Action initialization', 0, 500);
+ if (cancelled()) {
+ return;
+ }
+
+ await action.init(instance, prepareActionParameters(action.parameters, actionParameters));
+
+ const frameData = await Object.getPrototypeOf(instance).frames
+ .get.implementation.call(instance, frame);
+ const exportedCollection = getCollection(instance).export();
+
+ // Apply action filter first
+ const filteredByAction = action.applyFilter({ collection: exportedCollection, frameData });
+ validateClientIDs(filteredByAction);
+
+ let mapID2Obj = [].concat(filteredByAction.shapes, filteredByAction.tags, filteredByAction.tracks)
+ .reduce((acc, object) => {
+ acc[object.clientID as number] = object;
+ return acc;
+ }, {});
+
+ // Then apply user filter
+ const annotationsFilter = new AnnotationsFilter();
+ const filteredCollectionIDs = annotationsFilter
+ .filterSerializedCollection(filteredByAction, instance.labels, filters);
+ const filteredByUser = {
+ shapes: filteredCollectionIDs.shapes.map((clientID) => mapID2Obj[clientID]),
+ tags: filteredCollectionIDs.tags.map((clientID) => mapID2Obj[clientID]),
+ tracks: filteredCollectionIDs.tracks.map((clientID) => mapID2Obj[clientID]),
+ };
+ mapID2Obj = [].concat(filteredByUser.shapes, filteredByUser.tags, filteredByUser.tracks)
+ .reduce((acc, object) => {
+ acc[object.clientID as number] = object;
+ return acc;
+ }, {});
+
+ const { created, deleted } = await action.run({
+ collection: filteredByUser,
+ frameData: {
+ width: frameData.width,
+ height: frameData.height,
+ number: frameData.number,
+ },
+ onProgress: wrappedOnProgress,
+ cancelled,
+ });
+
+ await instance.annotations.commit(created, deleted, frame);
+ event.close();
+ } finally {
+ wrappedOnProgress('Finalizing', 100);
+ await action.destroy();
+ }
+}
+
+export async function call(
+ instance: Job | Task,
+ action: BaseCollectionAction,
+ actionParameters: Record,
+ frame: number,
+ states: ObjectState[],
+ onProgress: (message: string, progress: number) => void,
+ cancelled: () => boolean,
+): Promise {
+ const event = await instance.logger.log(EventScope.annotationsAction, {
+ from: frame,
+ to: frame,
+ name: action.name,
+ }, true);
+
+ const throttledOnProgress = throttle(onProgress, 100, { leading: true, trailing: true });
+ try {
+ await action.init(instance, prepareActionParameters(action.parameters, actionParameters));
+ const exportedStates = await Promise.all(states.map((state) => state.export()));
+ const exportedCollection = exportedStates.reduce((acc, value, idx) => {
+ if (states[idx].objectType === ObjectType.SHAPE) {
+ acc.shapes.push(value as SerializedShape);
+ }
+
+ if (states[idx].objectType === ObjectType.TAG) {
+ acc.tags.push(value as SerializedTag);
+ }
+
+ if (states[idx].objectType === ObjectType.TRACK) {
+ acc.tracks.push(value as SerializedTrack);
+ }
+
+ return acc;
+ }, { shapes: [], tags: [], tracks: [] });
+
+ const frameData = await Object.getPrototypeOf(instance).frames.get.implementation.call(instance, frame);
+ const filteredByAction = action.applyFilter({ collection: exportedCollection, frameData });
+ validateClientIDs(filteredByAction);
+
+ const processedCollection = await action.run({
+ onProgress: throttledOnProgress,
+ cancelled,
+ collection: filteredByAction,
+ frameData: {
+ width: frameData.width,
+ height: frameData.height,
+ number: frameData.number,
+ },
+ });
+
+ await instance.annotations.commit(
+ processedCollection.created,
+ processedCollection.deleted,
+ frame,
+ );
+ event.close();
+ } finally {
+ await action.destroy();
+ }
+}
diff --git a/cvat-core/src/annotations-actions/base-shapes-action.ts b/cvat-core/src/annotations-actions/base-shapes-action.ts
new file mode 100644
index 000000000000..80d2b4fee78b
--- /dev/null
+++ b/cvat-core/src/annotations-actions/base-shapes-action.ts
@@ -0,0 +1,196 @@
+// Copyright (C) 2024 CVAT.ai Corporation
+//
+// SPDX-License-Identifier: MIT
+
+import { throttle } from 'lodash';
+
+import ObjectState from '../object-state';
+import AnnotationsFilter from '../annotations-filter';
+import { Job, Task } from '../session';
+import { SerializedCollection, SerializedShape } from '../server-response-types';
+import { EventScope, ObjectType } from '../enums';
+import { getCollection } from '../annotations';
+import { BaseAction, prepareActionParameters, validateClientIDs } from './base-action';
+
+export interface ShapesActionInput {
+ onProgress(message: string, percent: number): void;
+ cancelled(): boolean;
+ collection: Pick;
+ frameData: {
+ width: number;
+ height: number;
+ number: number;
+ };
+}
+
+export interface ShapesActionOutput {
+ created: ShapesActionInput['collection'];
+ deleted: ShapesActionInput['collection'];
+}
+
+export abstract class BaseShapesAction extends BaseAction {
+ public abstract run(input: ShapesActionInput): Promise;
+ public abstract applyFilter(
+ input: Pick
+ ): ShapesActionInput['collection'];
+}
+
+export async function run(
+ instance: Job | Task,
+ action: BaseShapesAction,
+ actionParameters: Record,
+ frameFrom: number,
+ frameTo: number,
+ filters: object[],
+ onProgress: (message: string, progress: number) => void,
+ cancelled: () => boolean,
+): Promise {
+ const event = await instance.logger.log(EventScope.annotationsAction, {
+ from: frameFrom,
+ to: frameTo,
+ name: action.name,
+ }, true);
+
+ const throttledOnProgress = throttle(onProgress, 100, { leading: true, trailing: true });
+ const showMessageWithPause = async (message: string, progress: number, duration: number): Promise => {
+ // wrapper that gives a chance to abort action
+ throttledOnProgress(message, progress);
+ await new Promise((resolve) => setTimeout(resolve, duration));
+ };
+
+ try {
+ await showMessageWithPause('Actions initialization', 0, 500);
+ if (cancelled()) {
+ return;
+ }
+
+ await action.init(instance, prepareActionParameters(action.parameters, actionParameters));
+
+ const exportedCollection = getCollection(instance).export();
+ validateClientIDs(exportedCollection);
+
+ const annotationsFilter = new AnnotationsFilter();
+ const filteredShapeIDs = annotationsFilter.filterSerializedCollection({
+ shapes: exportedCollection.shapes,
+ tags: [],
+ tracks: [],
+ }, instance.labels, filters).shapes;
+
+ const filteredShapesByFrame = exportedCollection.shapes.reduce((acc, shape) => {
+ if (shape.frame >= frameFrom && shape.frame <= frameTo && filteredShapeIDs.includes(shape.clientID)) {
+ acc[shape.frame] = acc[shape.frame] ?? [];
+ acc[shape.frame].push(shape);
+ }
+ return acc;
+ }, {} as Record);
+
+ const totalUpdates = { created: { shapes: [] }, deleted: { shapes: [] } };
+ // Iterate over frames
+ const totalFrames = frameTo - frameFrom + 1;
+ for (let frame = frameFrom; frame <= frameTo; frame++) {
+ const frameData = await Object.getPrototypeOf(instance).frames
+ .get.implementation.call(instance, frame);
+
+ // Ignore deleted frames
+ if (!frameData.deleted) {
+ const frameShapes = filteredShapesByFrame[frame] ?? [];
+ if (!frameShapes.length) {
+ continue;
+ }
+
+ // finally apply the own filter of the action
+ const filteredByAction = action.applyFilter({
+ collection: {
+ shapes: frameShapes,
+ },
+ frameData,
+ });
+ validateClientIDs(filteredByAction);
+
+ const { created, deleted } = await action.run({
+ onProgress: throttledOnProgress,
+ cancelled,
+ collection: { shapes: filteredByAction.shapes },
+ frameData: {
+ width: frameData.width,
+ height: frameData.height,
+ number: frameData.number,
+ },
+ });
+
+ Array.prototype.push.apply(totalUpdates.created.shapes, created.shapes);
+ Array.prototype.push.apply(totalUpdates.deleted.shapes, deleted.shapes);
+
+ const progress = Math.ceil(+(((frame - frameFrom) / totalFrames) * 100));
+ throttledOnProgress('Actions are running', progress);
+ if (cancelled()) {
+ return;
+ }
+ }
+ }
+
+ await showMessageWithPause('Commiting handled objects', 100, 1500);
+ if (cancelled()) {
+ return;
+ }
+
+ await instance.annotations.commit(
+ { shapes: totalUpdates.created.shapes, tags: [], tracks: [] },
+ { shapes: totalUpdates.deleted.shapes, tags: [], tracks: [] },
+ frameFrom,
+ );
+
+ event.close();
+ } finally {
+ throttledOnProgress('Finalizing', 100);
+ await action.destroy();
+ }
+}
+
+export async function call(
+ instance: Job | Task,
+ action: BaseShapesAction,
+ actionParameters: Record,
+ frame: number,
+ states: ObjectState[],
+ onProgress: (message: string, progress: number) => void,
+ cancelled: () => boolean,
+): Promise {
+ const event = await instance.logger.log(EventScope.annotationsAction, {
+ from: frame,
+ to: frame,
+ name: action.name,
+ }, true);
+
+ const throttledOnProgress = throttle(onProgress, 100, { leading: true, trailing: true });
+ try {
+ await action.init(instance, prepareActionParameters(action.parameters, actionParameters));
+
+ const exported = await Promise.all(states.filter((state) => state.objectType === ObjectType.SHAPE)
+ .map((state) => state.export())) as SerializedShape[];
+ const frameData = await Object.getPrototypeOf(instance).frames.get.implementation.call(instance, frame);
+ const filteredByAction = action.applyFilter({ collection: { shapes: exported }, frameData });
+ validateClientIDs(filteredByAction);
+
+ const processedCollection = await action.run({
+ onProgress: throttledOnProgress,
+ cancelled,
+ collection: { shapes: filteredByAction.shapes },
+ frameData: {
+ width: frameData.width,
+ height: frameData.height,
+ number: frameData.number,
+ },
+ });
+
+ await instance.annotations.commit(
+ { shapes: processedCollection.created.shapes, tags: [], tracks: [] },
+ { shapes: processedCollection.deleted.shapes, tags: [], tracks: [] },
+ frame,
+ );
+
+ event.close();
+ } finally {
+ await action.destroy();
+ }
+}
diff --git a/cvat-core/src/annotations-actions/propagate-shapes.ts b/cvat-core/src/annotations-actions/propagate-shapes.ts
new file mode 100644
index 000000000000..ee68b9600f4f
--- /dev/null
+++ b/cvat-core/src/annotations-actions/propagate-shapes.ts
@@ -0,0 +1,85 @@
+// Copyright (C) 2024 CVAT.ai Corporation
+//
+// SPDX-License-Identifier: MIT
+
+import { range } from 'lodash';
+
+import ObjectState from '../object-state';
+import { Job, Task } from '../session';
+import { SerializedShape } from '../server-response-types';
+import { propagateShapes } from '../object-utils';
+import { ObjectType } from '../enums';
+
+import { ActionParameterType, ActionParameters } from './base-action';
+import { BaseCollectionAction, CollectionActionInput, CollectionActionOutput } from './base-collection-action';
+
+export class PropagateShapes extends BaseCollectionAction {
+ #instance: Task | Job;
+ #targetFrame: number;
+
+ public async init(instance: Job | Task, parameters): Promise {
+ this.#instance = instance;
+ this.#targetFrame = parameters['Target frame'];
+ }
+
+ public async destroy(): Promise {
+ // nothing to destroy
+ }
+
+ public async run(input: CollectionActionInput): Promise {
+ const { collection, frameData: { number } } = input;
+ if (number === this.#targetFrame) {
+ return {
+ created: { shapes: [], tags: [], tracks: [] },
+ deleted: { shapes: [], tags: [], tracks: [] },
+ };
+ }
+
+ const frameNumbers = this.#instance instanceof Job ?
+ await this.#instance.frames.frameNumbers() : range(0, this.#instance.size);
+ const propagatedShapes = propagateShapes(
+ collection.shapes, number, this.#targetFrame, frameNumbers,
+ );
+
+ return {
+ created: { shapes: propagatedShapes, tags: [], tracks: [] },
+ deleted: { shapes: [], tags: [], tracks: [] },
+ };
+ }
+
+ public applyFilter(input: CollectionActionInput): CollectionActionInput['collection'] {
+ return {
+ shapes: input.collection.shapes.filter((shape) => shape.frame === input.frameData.number),
+ tags: [],
+ tracks: [],
+ };
+ }
+
+ public isApplicableForObject(objectState: ObjectState): boolean {
+ return objectState.objectType === ObjectType.SHAPE;
+ }
+
+ public get name(): string {
+ return 'Propagate shapes';
+ }
+
+ public get parameters(): ActionParameters | null {
+ return {
+ 'Target frame': {
+ type: ActionParameterType.NUMBER,
+ values: ({ instance }) => {
+ if (instance instanceof Job) {
+ return [instance.startFrame, instance.stopFrame, 1].map((val) => val.toString());
+ }
+ return [0, instance.size - 1, 1].map((val) => val.toString());
+ },
+ defaultValue: ({ instance }) => {
+ if (instance instanceof Job) {
+ return instance.stopFrame.toString();
+ }
+ return (instance.size - 1).toString();
+ },
+ },
+ };
+ }
+}
diff --git a/cvat-core/src/annotations-actions/remove-filtered-shapes.ts b/cvat-core/src/annotations-actions/remove-filtered-shapes.ts
new file mode 100644
index 000000000000..ab2a30964fad
--- /dev/null
+++ b/cvat-core/src/annotations-actions/remove-filtered-shapes.ts
@@ -0,0 +1,41 @@
+// Copyright (C) 2024 CVAT.ai Corporation
+//
+// SPDX-License-Identifier: MIT
+
+import { BaseShapesAction, ShapesActionInput, ShapesActionOutput } from './base-shapes-action';
+import { ActionParameters } from './base-action';
+
+export class RemoveFilteredShapes extends BaseShapesAction {
+ public async init(): Promise {
+ // nothing to init
+ }
+
+ public async destroy(): Promise {
+ // nothing to destroy
+ }
+
+ public async run(input: ShapesActionInput): Promise {
+ return {
+ created: { shapes: [] },
+ deleted: input.collection,
+ };
+ }
+
+ public applyFilter(input: ShapesActionInput): ShapesActionInput['collection'] {
+ const { collection } = input;
+ return collection;
+ }
+
+ public isApplicableForObject(): boolean {
+ // remove action does not make sense when running on one object
+ return false;
+ }
+
+ public get name(): string {
+ return 'Remove filtered shapes';
+ }
+
+ public get parameters(): ActionParameters | null {
+ return null;
+ }
+}
diff --git a/cvat-core/src/annotations-collection.ts b/cvat-core/src/annotations-collection.ts
index 14879e86bcd8..25496dfe69a7 100644
--- a/cvat-core/src/annotations-collection.ts
+++ b/cvat-core/src/annotations-collection.ts
@@ -157,9 +157,68 @@ export default class Collection {
return result;
}
- public export(): Omit {
+ public commit(
+ appended: Omit,
+ removed: Omit,
+ frame: number,
+ ): { tags: Tag[]; shapes: Shape[]; tracks: Track[]; } {
+ const isCollectionConsistent = [].concat(removed.shapes, removed.tags, removed.tracks)
+ .every((object) => typeof object.clientID === 'number' &&
+ Object.prototype.hasOwnProperty.call(this.objects, object.clientID));
+
+ if (!isCollectionConsistent) {
+ throw new ArgumentError('Objects required to be deleted were not found in the collection');
+ }
+
+ const removedCollection: (Shape | Tag | Track)[] = [].concat(removed.shapes, removed.tags, removed.tracks)
+ .map((object) => this.objects[object.clientID as number]);
+
+ const imported = this.import(appended);
+ const appendedCollection = ([] as (Shape | Tag | Track)[])
+ .concat(imported.shapes, imported.tags, imported.tracks);
+ if (!(appendedCollection.length > 0 || removedCollection.length > 0)) {
+ // nothing to commit
+ return;
+ }
+
+ let prevRemoved = [];
+ removedCollection.forEach((collectionObject) => {
+ prevRemoved.push(collectionObject.removed);
+ collectionObject.removed = true;
+ });
+
+ this.history.do(
+ HistoryActions.COMMIT_ANNOTATIONS,
+ () => {
+ removedCollection.forEach((collectionObject, idx) => {
+ collectionObject.removed = prevRemoved[idx];
+ });
+ prevRemoved = [];
+ appendedCollection.forEach((collectionObject) => {
+ collectionObject.removed = true;
+ });
+ },
+ () => {
+ removedCollection.forEach((collectionObject) => {
+ prevRemoved.push(collectionObject.removed);
+ collectionObject.removed = true;
+ });
+ appendedCollection.forEach((collectionObject) => {
+ collectionObject.removed = false;
+ });
+ },
+ [].concat(
+ removedCollection.map((object) => object.clientID),
+ appendedCollection.map((object) => object.clientID),
+ ),
+ frame,
+ );
+ }
+
+ public export(): Pick {
const data = {
- tracks: this.tracks.filter((track) => !track.removed).map((track) => track.toJSON() as SerializedTrack),
+ tracks: this.tracks.filter((track) => !track.removed)
+ .map((track) => track.toJSON() as SerializedTrack),
shapes: Object.values(this.shapes)
.reduce((accumulator, frameShapes) => {
accumulator.push(...frameShapes);
@@ -201,7 +260,7 @@ export default class Collection {
}
const objectStates = [];
- const filtered = this.annotationsFilter.filter(visible, filters);
+ const filtered = this.annotationsFilter.filterSerializedObjectStates(visible, filters);
visible.forEach((stateData) => {
if (!filters.length || filtered.includes(stateData.clientID)) {
@@ -1338,7 +1397,7 @@ export default class Collection {
statesData.push(...tracks.map((track) => track.get(frame)).filter((state) => !state.outside));
// Filtering
- const filtered = this.annotationsFilter.filter(statesData, annotationsFilters);
+ const filtered = this.annotationsFilter.filterSerializedObjectStates(statesData, annotationsFilters);
if (filtered.length) {
return frame;
}
diff --git a/cvat-core/src/annotations-filter.ts b/cvat-core/src/annotations-filter.ts
index 58c9e82a63e5..fa7b8e739f5a 100644
--- a/cvat-core/src/annotations-filter.ts
+++ b/cvat-core/src/annotations-filter.ts
@@ -6,15 +6,74 @@
import jsonLogic from 'json-logic-js';
import { SerializedData } from './object-state';
import { AttributeType, ObjectType, ShapeType } from './enums';
+import { SerializedCollection } from './server-response-types';
+import { Attribute, Label } from './labels';
function adjustName(name): string {
return name.replace(/\./g, '\u2219');
}
+function getDimensions(points: number[], shapeType: ShapeType): {
+ width: number | null;
+ height: number | null;
+} {
+ let [width, height]: (number | null)[] = [null, null];
+ if (shapeType === ShapeType.MASK) {
+ const [xtl, ytl, xbr, ybr] = points.slice(-4);
+ [width, height] = [xbr - xtl + 1, ybr - ytl + 1];
+ } else if (shapeType === ShapeType.ELLIPSE) {
+ const [cx, cy, rightX, topY] = points;
+ width = Math.abs(rightX - cx) * 2;
+ height = Math.abs(cy - topY) * 2;
+ } else {
+ let xtl = Number.MAX_SAFE_INTEGER;
+ let xbr = Number.MIN_SAFE_INTEGER;
+ let ytl = Number.MAX_SAFE_INTEGER;
+ let ybr = Number.MIN_SAFE_INTEGER;
+
+ points.forEach((coord, idx) => {
+ if (idx % 2) {
+ // y
+ ytl = Math.min(ytl, coord);
+ ybr = Math.max(ybr, coord);
+ } else {
+ // x
+ xtl = Math.min(xtl, coord);
+ xbr = Math.max(xbr, coord);
+ }
+ });
+ [width, height] = [xbr - xtl, ybr - ytl];
+ }
+
+ return {
+ width,
+ height,
+ };
+}
+
+function convertAttribute(id: number, value: string, attributesSpec: Record): [
+ string,
+ number | boolean | string,
+] {
+ const spec = attributesSpec[id];
+ const name = adjustName(spec.name);
+ if (spec.inputType === AttributeType.NUMBER) {
+ return [name, +value];
+ }
+
+ if (spec.inputType === AttributeType.CHECKBOX) {
+ return [name, value === 'true'];
+ }
+
+ return [name, value];
+}
+
+type ConvertedAttributes = Record;
+
interface ConvertedObjectData {
width: number | null;
height: number | null;
- attr: Record>;
+ attr: Record;
label: string;
serverID: number;
objectID: number;
@@ -24,7 +83,7 @@ interface ConvertedObjectData {
}
export default class AnnotationsFilter {
- _convertObjects(statesData: SerializedData[]): ConvertedObjectData[] {
+ private _convertSerializedObjectStates(statesData: SerializedData[]): ConvertedObjectData[] {
const objects = statesData.map((state) => {
const labelAttributes = state.label.attributes.reduce((acc, attr) => {
acc[attr.id] = attr;
@@ -33,50 +92,26 @@ export default class AnnotationsFilter {
let [width, height]: (number | null)[] = [null, null];
if (state.objectType !== ObjectType.TAG) {
- if (state.shapeType === ShapeType.MASK) {
- const [xtl, ytl, xbr, ybr] = state.points.slice(-4);
- [width, height] = [xbr - xtl + 1, ybr - ytl + 1];
- } else {
- let xtl = Number.MAX_SAFE_INTEGER;
- let xbr = Number.MIN_SAFE_INTEGER;
- let ytl = Number.MAX_SAFE_INTEGER;
- let ybr = Number.MIN_SAFE_INTEGER;
-
- const points = state.points || state.elements.reduce((acc, val) => {
- acc.push(val.points);
- return acc;
- }, []).flat();
- points.forEach((coord, idx) => {
- if (idx % 2) {
- // y
- ytl = Math.min(ytl, coord);
- ybr = Math.max(ybr, coord);
- } else {
- // x
- xtl = Math.min(xtl, coord);
- xbr = Math.max(xbr, coord);
- }
- });
- [width, height] = [xbr - xtl, ybr - ytl];
- }
+ const points = state.shapeType === ShapeType.SKELETON ? state.elements.reduce((acc, val) => {
+ acc.push(val.points);
+ return acc;
+ }, []).flat() : state.points;
+
+ ({ width, height } = getDimensions(points, state.shapeType as ShapeType));
}
- const attributes = Object.keys(state.attributes).reduce>((acc, key) => {
- const attr = labelAttributes[key];
- let value = state.attributes[key];
- if (attr.inputType === AttributeType.NUMBER) {
- value = +value;
- } else if (attr.inputType === AttributeType.CHECKBOX) {
- value = value === 'true';
- }
- acc[adjustName(attr.name)] = value;
+ const attributes = Object.keys(state.attributes).reduce((acc, key) => {
+ const [name, value] = convertAttribute(+key, state.attributes[key], labelAttributes);
+ acc[name] = value;
return acc;
- }, {});
+ }, {} as Record);
return {
width,
height,
- attr: Object.fromEntries([[adjustName(state.label.name), attributes]]),
+ attr: {
+ [adjustName(state.label.name)]: attributes,
+ },
label: state.label.name,
serverID: state.serverID,
objectID: state.clientID,
@@ -89,11 +124,119 @@ export default class AnnotationsFilter {
return objects;
}
- filter(statesData: SerializedData[], filters: object[]): number[] {
- if (!filters.length) return statesData.map((stateData): number => stateData.clientID);
- const converted = this._convertObjects(statesData);
+ private _convertSerializedCollection(
+ collection: Omit,
+ labelsSpec: Label[],
+ ): { shapes: ConvertedObjectData[]; tags: ConvertedObjectData[]; tracks: ConvertedObjectData[]; } {
+ const labelByID = labelsSpec.reduce>((acc, label) => ({
+ [label.id]: label,
+ ...acc,
+ }), {});
+
+ const attributeById = labelsSpec.map((label) => label.attributes).flat().reduce((acc, attribute) => ({
+ ...acc,
+ [attribute.id]: attribute,
+ }), {} as Record);
+
+ const convertAttributes = (
+ attributes: SerializedCollection['shapes'][0]['attributes'],
+ ): ConvertedAttributes => attributes.reduce((acc, { spec_id, value }) => {
+ const [name, adjustedValue] = convertAttribute(spec_id, value, attributeById);
+ acc[name] = adjustedValue;
+ return acc;
+ }, {} as Record);
+
+ return {
+ shapes: collection.shapes.map((shape) => {
+ const label = labelByID[shape.label_id];
+ const points = shape.type === ShapeType.SKELETON ?
+ shape.elements.map((el) => el.points).flat() : shape.points;
+ let [width, height]: (number | null)[] = [null, null];
+ ({ width, height } = getDimensions(points, shape.type));
+
+ return {
+ width,
+ height,
+ attr: {
+ [adjustName(label.name)]: convertAttributes(shape.attributes),
+ },
+ label: label.name,
+ serverID: shape.id ?? null,
+ type: ObjectType.SHAPE,
+ shape: shape.type,
+ occluded: shape.occluded,
+ objectID: shape.clientID ?? null,
+ };
+ }),
+ tags: collection.tags.map((tag) => {
+ const label = labelByID[tag.label_id];
+
+ return {
+ width: null,
+ height: null,
+ attr: {
+ [adjustName(label.name)]: convertAttributes(tag.attributes),
+ },
+ label: labelByID[tag.label_id]?.name ?? null,
+ serverID: tag.id ?? null,
+ type: ObjectType.SHAPE,
+ shape: null,
+ occluded: false,
+ objectID: tag.clientID ?? null,
+ };
+ }),
+ tracks: collection.tracks.map((track) => {
+ const label = labelByID[track.label_id];
+
+ return {
+ width: null,
+ height: null,
+ attr: {
+ [adjustName(label.name)]: convertAttributes(track.attributes),
+ },
+ label: labelByID[track.label_id]?.name ?? null,
+ serverID: track.id,
+ type: ObjectType.TRACK,
+ shape: track.shapes[0]?.type ?? null,
+ occluded: null,
+ objectID: track.clientID ?? null,
+ };
+ }),
+ };
+ }
+
+ public filterSerializedObjectStates(statesData: SerializedData[], filters: object[]): number[] {
+ if (!filters.length) {
+ return statesData.map((stateData): number => stateData.clientID);
+ }
+
+ const converted = this._convertSerializedObjectStates(statesData);
return converted
.map((state) => state.objectID)
.filter((_, index) => jsonLogic.apply(filters[0], converted[index]));
}
+
+ public filterSerializedCollection(
+ collection: Omit,
+ labelsSpec: Label[],
+ filters: object[],
+ ): { shapes: number[]; tags: number[]; tracks: number[]; } {
+ if (!filters.length) {
+ return {
+ shapes: collection.shapes.map((shape) => shape.clientID),
+ tags: collection.tags.map((tag) => tag.clientID),
+ tracks: collection.tracks.map((track) => track.clientID),
+ };
+ }
+
+ const converted = this._convertSerializedCollection(collection, labelsSpec);
+ return {
+ shapes: converted.shapes.map((shape) => shape.objectID)
+ .filter((_, index) => jsonLogic.apply(filters[0], converted.shapes[index])),
+ tags: converted.tags.map((shape) => shape.objectID)
+ .filter((_, index) => jsonLogic.apply(filters[0], converted.tags[index])),
+ tracks: converted.tracks.map((shape) => shape.objectID)
+ .filter((_, index) => jsonLogic.apply(filters[0], converted.tracks[index])),
+ };
+ }
}
diff --git a/cvat-core/src/annotations-history.ts b/cvat-core/src/annotations-history.ts
index 748d55bcf93d..2e59db96ea1f 100644
--- a/cvat-core/src/annotations-history.ts
+++ b/cvat-core/src/annotations-history.ts
@@ -5,7 +5,7 @@
import { HistoryActions } from './enums';
-const MAX_HISTORY_LENGTH = 128;
+const MAX_HISTORY_LENGTH = 32;
interface ActionItem {
action: HistoryActions;
diff --git a/cvat-core/src/annotations-objects.ts b/cvat-core/src/annotations-objects.ts
index defcf7dbbada..ab7e32de9784 100644
--- a/cvat-core/src/annotations-objects.ts
+++ b/cvat-core/src/annotations-objects.ts
@@ -150,17 +150,12 @@ class Annotation {
injection.groups.max = Math.max(injection.groups.max, this.group);
}
- protected withContext(frame: number): {
- __internal: {
- save: (data: ObjectState) => ObjectState;
- delete: Annotation['delete'];
- };
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
+ protected withContext(_: number): {
+ delete: Annotation['delete'];
} {
return {
- __internal: {
- save: (this as any).save.bind(this, frame),
- delete: this.delete.bind(this),
- },
+ delete: this.delete.bind(this),
};
}
@@ -530,6 +525,17 @@ export class Shape extends Drawn {
this.zOrder = data.z_order;
}
+ protected withContext(frame: number): ReturnType & {
+ save: (data: ObjectState) => ObjectState;
+ export: () => SerializedShape;
+ } {
+ return {
+ ...super.withContext(frame),
+ save: this.save.bind(this, frame),
+ export: this.toJSON.bind(this) as () => SerializedShape,
+ };
+ }
+
// Method is used to export data to the server
public toJSON(): SerializedShape | SerializedShape['elements'][0] {
const result: SerializedShape = {
@@ -592,7 +598,7 @@ export class Shape extends Drawn {
pinned: this.pinned,
frame,
source: this.source,
- ...this.withContext(frame),
+ __internal: this.withContext(frame),
};
if (typeof this.outside !== 'undefined') {
@@ -838,6 +844,17 @@ export class Track extends Drawn {
}, {});
}
+ protected withContext(frame: number): ReturnType & {
+ save: (data: ObjectState) => ObjectState;
+ export: () => SerializedTrack;
+ } {
+ return {
+ ...super.withContext(frame),
+ save: this.save.bind(this, frame),
+ export: this.toJSON.bind(this) as () => SerializedTrack,
+ };
+ }
+
// Method is used to export data to the server
public toJSON(): SerializedTrack | SerializedTrack['elements'][0] {
const labelAttributes = attrsAsAnObject(this.label.attributes);
@@ -931,7 +948,7 @@ export class Track extends Drawn {
},
frame,
source: this.source,
- ...this.withContext(frame),
+ __internal: this.withContext(frame),
};
}
@@ -1405,6 +1422,17 @@ export class Track extends Drawn {
}
export class Tag extends Annotation {
+ protected withContext(frame: number): ReturnType & {
+ save: (data: ObjectState) => ObjectState;
+ export: () => SerializedTag;
+ } {
+ return {
+ ...super.withContext(frame),
+ save: this.save.bind(this, frame),
+ export: this.toJSON.bind(this) as () => SerializedTag,
+ };
+ }
+
// Method is used to export data to the server
public toJSON(): SerializedTag {
const result: SerializedTag = {
@@ -1451,7 +1479,7 @@ export class Tag extends Annotation {
updated: this.updated,
frame,
source: this.source,
- ...this.withContext(frame),
+ __internal: this.withContext(frame),
};
}
@@ -2022,7 +2050,7 @@ export class SkeletonShape extends Shape {
hidden: elements.every((el) => el.hidden),
frame,
source: this.source,
- ...this.withContext(frame),
+ __internal: this.withContext(frame),
};
}
@@ -3064,7 +3092,7 @@ export class SkeletonTrack extends Track {
occluded: elements.every((el) => el.occluded),
lock: elements.every((el) => el.lock),
hidden: elements.every((el) => el.hidden),
- ...this.withContext(frame),
+ __internal: this.withContext(frame),
};
}
diff --git a/cvat-core/src/api-implementation.ts b/cvat-core/src/api-implementation.ts
index 0e9f400ad499..c9e53a2e1e0d 100644
--- a/cvat-core/src/api-implementation.ts
+++ b/cvat-core/src/api-implementation.ts
@@ -39,7 +39,9 @@ import QualityConflict, { ConflictSeverity } from './quality-conflict';
import QualitySettings from './quality-settings';
import { getFramesMeta } from './frames';
import AnalyticsReport from './analytics-report';
-import { listActions, registerAction, runActions } from './annotations-actions';
+import {
+ callAction, listActions, registerAction, runAction,
+} from './annotations-actions/annotations-actions';
import { convertDescriptions, getServerAPISchema } from './server-schema';
import { JobType } from './enums';
import { PaginatedResource } from './core-types';
@@ -54,7 +56,8 @@ export default function implementAPI(cvat: CVATCore): CVATCore {
implementationMixin(cvat.plugins.register, PluginRegistry.register.bind(cvat));
implementationMixin(cvat.actions.list, listActions);
implementationMixin(cvat.actions.register, registerAction);
- implementationMixin(cvat.actions.run, runActions);
+ implementationMixin(cvat.actions.run, runAction);
+ implementationMixin(cvat.actions.call, callAction);
implementationMixin(cvat.lambda.list, lambdaManager.list.bind(lambdaManager));
implementationMixin(cvat.lambda.run, lambdaManager.run.bind(lambdaManager));
diff --git a/cvat-core/src/api.ts b/cvat-core/src/api.ts
index ca33f431c43e..f4eb5d8b23fd 100644
--- a/cvat-core/src/api.ts
+++ b/cvat-core/src/api.ts
@@ -21,7 +21,9 @@ import CloudStorage from './cloud-storage';
import Organization from './organization';
import Webhook from './webhook';
import AnnotationGuide from './guide';
-import BaseSingleFrameAction from './annotations-actions';
+import { BaseAction } from './annotations-actions/base-action';
+import { BaseCollectionAction } from './annotations-actions/base-collection-action';
+import { BaseShapesAction } from './annotations-actions/base-shapes-action';
import QualityReport from './quality-report';
import QualityConflict from './quality-conflict';
import QualitySettings from './quality-settings';
@@ -191,14 +193,14 @@ function build(): CVATCore {
const result = await PluginRegistry.apiWrapper(cvat.actions.list);
return result;
},
- async register(action: BaseSingleFrameAction) {
+ async register(action: BaseAction) {
const result = await PluginRegistry.apiWrapper(cvat.actions.register, action);
return result;
},
async run(
instance: Job | Task,
- actionsChain: BaseSingleFrameAction[],
- actionsParameters: Record[],
+ actions: BaseAction,
+ actionsParameters: Record,
frameFrom: number,
frameTo: number,
filters: string[],
@@ -211,7 +213,7 @@ function build(): CVATCore {
const result = await PluginRegistry.apiWrapper(
cvat.actions.run,
instance,
- actionsChain,
+ actions,
actionsParameters,
frameFrom,
frameTo,
@@ -221,6 +223,30 @@ function build(): CVATCore {
);
return result;
},
+ async call(
+ instance: Job | Task,
+ actions: BaseAction,
+ actionsParameters: Record,
+ frame: number,
+ states: ObjectState[],
+ onProgress: (
+ message: string,
+ progress: number,
+ ) => void,
+ cancelled: () => boolean,
+ ) {
+ const result = await PluginRegistry.apiWrapper(
+ cvat.actions.call,
+ instance,
+ actions,
+ actionsParameters,
+ frame,
+ states,
+ onProgress,
+ cancelled,
+ );
+ return result;
+ },
},
lambda: {
async list() {
@@ -420,7 +446,8 @@ function build(): CVATCore {
Organization,
Webhook,
AnnotationGuide,
- BaseSingleFrameAction,
+ BaseShapesAction,
+ BaseCollectionAction,
QualitySettings,
AnalyticsReport,
QualityConflict,
diff --git a/cvat-core/src/enums.ts b/cvat-core/src/enums.ts
index 1b291662d213..25fdf815fa20 100644
--- a/cvat-core/src/enums.ts
+++ b/cvat-core/src/enums.ts
@@ -148,6 +148,7 @@ export enum HistoryActions {
REMOVED_OBJECT = 'Removed object',
REMOVED_FRAME = 'Removed frame',
RESTORED_FRAME = 'Restored frame',
+ COMMIT_ANNOTATIONS = 'Commit annotations',
}
export enum ModelKind {
diff --git a/cvat-core/src/index.ts b/cvat-core/src/index.ts
index 8a4c9e8bfb53..79ce8b305a9f 100644
--- a/cvat-core/src/index.ts
+++ b/cvat-core/src/index.ts
@@ -34,7 +34,14 @@ import AnalyticsReport from './analytics-report';
import AnnotationGuide from './guide';
import { JobValidationLayout, TaskValidationLayout } from './validation-layout';
import { Request } from './request';
-import BaseSingleFrameAction, { listActions, registerAction, runActions } from './annotations-actions';
+import {
+ runAction,
+ callAction,
+ listActions,
+ registerAction,
+} from './annotations-actions/annotations-actions';
+import { BaseCollectionAction } from './annotations-actions/base-collection-action';
+import { BaseShapesAction } from './annotations-actions/base-shapes-action';
import {
ArgumentError, DataError, Exception, ScriptingError, ServerError,
} from './exceptions';
@@ -165,7 +172,8 @@ export default interface CVATCore {
actions: {
list: typeof listActions;
register: typeof registerAction;
- run: typeof runActions;
+ run: typeof runAction;
+ call: typeof callAction;
};
logger: typeof logger;
config: {
@@ -209,7 +217,8 @@ export default interface CVATCore {
Organization: typeof Organization;
Webhook: typeof Webhook;
AnnotationGuide: typeof AnnotationGuide;
- BaseSingleFrameAction: typeof BaseSingleFrameAction;
+ BaseShapesAction: typeof BaseShapesAction;
+ BaseCollectionAction: typeof BaseCollectionAction;
QualityReport: typeof QualityReport;
QualityConflict: typeof QualityConflict;
QualitySettings: typeof QualitySettings;
diff --git a/cvat-core/src/object-state.ts b/cvat-core/src/object-state.ts
index 9b35736a08a1..28993a0d114c 100644
--- a/cvat-core/src/object-state.ts
+++ b/cvat-core/src/object-state.ts
@@ -1,5 +1,5 @@
// Copyright (C) 2019-2022 Intel Corporation
-// Copyright (C) 2022-2023 CVAT.ai Corporation
+// Copyright (C) 2022-2024 CVAT.ai Corporation
//
// SPDX-License-Identifier: MIT
@@ -8,6 +8,7 @@ import PluginRegistry from './plugins';
import { ArgumentError } from './exceptions';
import { Label } from './labels';
import { isEnum } from './common';
+import { SerializedShape, SerializedTag, SerializedTrack } from './server-response-types';
interface UpdateFlags {
label: boolean;
@@ -516,10 +517,15 @@ export default class ObjectState {
const result = await PluginRegistry.apiWrapper.call(this, ObjectState.prototype.delete, frame, force);
return result;
}
+
+ async export(): Promise {
+ const result = await PluginRegistry.apiWrapper.call(this, ObjectState.prototype.export);
+ return result;
+ }
}
Object.defineProperty(ObjectState.prototype.save, 'implementation', {
- value: function save(): ObjectState {
+ value: function saveImplementation(): ObjectState {
if (this.__internal && this.__internal.save) {
return this.__internal.save(this);
}
@@ -529,8 +535,19 @@ Object.defineProperty(ObjectState.prototype.save, 'implementation', {
writable: false,
});
+Object.defineProperty(ObjectState.prototype.export, 'implementation', {
+ value: function exportImplementation(): ObjectState {
+ if (this.__internal && this.__internal.export) {
+ return this.__internal.export(this);
+ }
+
+ return this;
+ },
+ writable: false,
+});
+
Object.defineProperty(ObjectState.prototype.delete, 'implementation', {
- value: function remove(frame: number, force: boolean): boolean {
+ value: function deleteImplementation(frame: number, force: boolean): boolean {
if (this.__internal && this.__internal.delete) {
if (!Number.isInteger(+frame) || +frame < 0) {
throw new ArgumentError('Frame argument must be a non negative integer');
diff --git a/cvat-core/src/session-implementation.ts b/cvat-core/src/session-implementation.ts
index 904899831abf..7ea9e326fb8b 100644
--- a/cvat-core/src/session-implementation.ts
+++ b/cvat-core/src/session-implementation.ts
@@ -519,6 +519,18 @@ export function implementJob(Job: typeof JobClass): typeof JobClass {
},
});
+ Object.defineProperty(Job.prototype.annotations.commit, 'implementation', {
+ value: function commitAnnotationsImplementation(
+ this: JobClass,
+ added: Parameters[0],
+ removed: Parameters[1],
+ frame: Parameters[2],
+ ): ReturnType {
+ getCollection(this).commit(added, removed, frame);
+ return Promise.resolve();
+ },
+ });
+
Object.defineProperty(Job.prototype.annotations.upload, 'implementation', {
value: async function uploadAnnotationsImplementation(
this: JobClass,
@@ -1208,6 +1220,18 @@ export function implementTask(Task: typeof TaskClass): typeof TaskClass {
},
});
+ Object.defineProperty(Task.prototype.annotations.commit, 'implementation', {
+ value: function commitAnnotationsImplementation(
+ this: TaskClass,
+ added: Parameters[0],
+ removed: Parameters[1],
+ frame: Parameters[2],
+ ): ReturnType {
+ getCollection(this).commit(added, removed, frame);
+ return Promise.resolve();
+ },
+ });
+
Object.defineProperty(Task.prototype.annotations.exportDataset, 'implementation', {
value: async function exportDatasetImplementation(
this: TaskClass,
diff --git a/cvat-core/src/session.ts b/cvat-core/src/session.ts
index a2bc2008aef0..b3269ee78076 100644
--- a/cvat-core/src/session.ts
+++ b/cvat-core/src/session.ts
@@ -172,6 +172,17 @@ function buildDuplicatedAPI(prototype) {
return result;
},
+ async commit(added, removed, frame) {
+ const result = await PluginRegistry.apiWrapper.call(
+ this,
+ prototype.annotations.commit,
+ added,
+ removed,
+ frame,
+ );
+ return result;
+ },
+
async exportDataset(
format: string,
saveImages: boolean,
@@ -332,7 +343,7 @@ export class Session {
delTrackKeyframesOnly?: boolean;
}) => Promise;
save: (
- onUpdate ?: (message: string) => void,
+ onUpdate?: (message: string) => void,
) => Promise;
search: (
frameFrom: number,
@@ -361,6 +372,11 @@ export class Session {
}>;
import: (data: Omit) => Promise;
export: () => Promise>;
+ commit: (
+ added: Omit,
+ removed: Omit,
+ frame: number,
+ ) => Promise;
statistics: () => Promise;
hasUnsavedChanges: () => boolean;
exportDataset: (
@@ -431,6 +447,7 @@ export class Session {
select: Object.getPrototypeOf(this).annotations.select.bind(this),
import: Object.getPrototypeOf(this).annotations.import.bind(this),
export: Object.getPrototypeOf(this).annotations.export.bind(this),
+ commit: Object.getPrototypeOf(this).annotations.commit.bind(this),
statistics: Object.getPrototypeOf(this).annotations.statistics.bind(this),
hasUnsavedChanges: Object.getPrototypeOf(this).annotations.hasUnsavedChanges.bind(this),
exportDataset: Object.getPrototypeOf(this).annotations.exportDataset.bind(this),
diff --git a/cvat-sdk/README.md b/cvat-sdk/README.md
index fa68c0e5d40d..89702c02abd4 100644
--- a/cvat-sdk/README.md
+++ b/cvat-sdk/README.md
@@ -20,7 +20,14 @@ To install a prebuilt package, run the following command in the terminal:
pip install cvat-sdk
```
-To use the PyTorch adapter, request the `pytorch` extra:
+To use the `cvat_sdk.masks` module, request the `masks` extra:
+
+```bash
+pip install "cvat-sdk[masks]"
+```
+
+To use the PyTorch adapter or the built-in PyTorch-based auto-annotation functions,
+request the `pytorch` extra:
```bash
pip install "cvat-sdk[pytorch]"
diff --git a/cvat-sdk/cvat_sdk/auto_annotation/__init__.py b/cvat-sdk/cvat_sdk/auto_annotation/__init__.py
index e5dbdf9fcc42..adbb6007e125 100644
--- a/cvat-sdk/cvat_sdk/auto_annotation/__init__.py
+++ b/cvat-sdk/cvat_sdk/auto_annotation/__init__.py
@@ -10,8 +10,27 @@
keypoint,
keypoint_spec,
label_spec,
+ mask,
+ polygon,
rectangle,
shape,
skeleton,
skeleton_label_spec,
)
+
+__all__ = [
+ "annotate_task",
+ "BadFunctionError",
+ "DetectionFunction",
+ "DetectionFunctionContext",
+ "DetectionFunctionSpec",
+ "keypoint_spec",
+ "keypoint",
+ "label_spec",
+ "mask",
+ "polygon",
+ "rectangle",
+ "shape",
+ "skeleton_label_spec",
+ "skeleton",
+]
diff --git a/cvat-sdk/cvat_sdk/auto_annotation/driver.py b/cvat-sdk/cvat_sdk/auto_annotation/driver.py
index 0f3d82ea32ea..5ffdb36f5bee 100644
--- a/cvat-sdk/cvat_sdk/auto_annotation/driver.py
+++ b/cvat-sdk/cvat_sdk/auto_annotation/driver.py
@@ -99,9 +99,11 @@ def __init__(
ds_labels: Sequence[models.ILabel],
*,
allow_unmatched_labels: bool,
+ conv_mask_to_poly: bool,
) -> None:
self._logger = logger
self._allow_unmatched_labels = allow_unmatched_labels
+ self._conv_mask_to_poly = conv_mask_to_poly
ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels}
@@ -217,12 +219,19 @@ def validate_and_remap(self, shapes: list[models.LabeledShapeRequest], ds_frame:
if getattr(shape, "elements", None):
raise BadFunctionError("function output non-skeleton shape with elements")
+ if shape.type.value == "mask" and self._conv_mask_to_poly:
+ raise BadFunctionError(
+ "function output mask shape despite conv_mask_to_poly=True"
+ )
+
shapes[:] = new_shapes
-@attrs.frozen
+@attrs.frozen(kw_only=True)
class _DetectionFunctionContextImpl(DetectionFunctionContext):
frame_name: str
+ conf_threshold: Optional[float] = None
+ conv_mask_to_poly: bool = False
def annotate_task(
@@ -233,6 +242,8 @@ def annotate_task(
pbar: Optional[ProgressReporter] = None,
clear_existing: bool = False,
allow_unmatched_labels: bool = False,
+ conf_threshold: Optional[float] = None,
+ conv_mask_to_poly: bool = False,
) -> None:
"""
Downloads data for the task with the given ID, applies the given function to it
@@ -264,11 +275,21 @@ def annotate_task(
function declares a label in its spec that has no corresponding label in the task.
If it's set to true, then such labels are allowed, and any annotations returned by the
function that refer to this label are ignored. Otherwise, BadFunctionError is raised.
+
+ The conf_threshold parameter must be None or a number between 0 and 1. It will be passed
+ to the AA function as the conf_threshold attribute of the context object.
+
+ The conv_mask_to_poly parameter will be passed to the AA function as the conv_mask_to_poly
+ attribute of the context object. If it's true, and the AA function returns any mask shapes,
+ BadFunctionError will be raised.
"""
if pbar is None:
pbar = NullProgressReporter()
+ if conf_threshold is not None and not 0 <= conf_threshold <= 1:
+ raise ValueError("conf_threshold must be None or a number between 0 and 1")
+
dataset = TaskDataset(client, task_id, load_annotations=False)
assert isinstance(function.spec, DetectionFunctionSpec)
@@ -278,6 +299,7 @@ def annotate_task(
function.spec.labels,
dataset.labels,
allow_unmatched_labels=allow_unmatched_labels,
+ conv_mask_to_poly=conv_mask_to_poly,
)
shapes = []
@@ -285,12 +307,17 @@ def annotate_task(
with pbar.task(total=len(dataset.samples), unit="samples"):
for sample in pbar.iter(dataset.samples):
frame_shapes = function.detect(
- _DetectionFunctionContextImpl(sample.frame_name), sample.media.load_image()
+ _DetectionFunctionContextImpl(
+ frame_name=sample.frame_name,
+ conf_threshold=conf_threshold,
+ conv_mask_to_poly=conv_mask_to_poly,
+ ),
+ sample.media.load_image(),
)
mapper.validate_and_remap(frame_shapes, sample.frame_index)
shapes.extend(frame_shapes)
- client.logger.info("Uploading annotations to task %d", task_id)
+ client.logger.info("Uploading annotations to task %d...", task_id)
if clear_existing:
client.tasks.api.update_annotations(
@@ -302,3 +329,5 @@ def annotate_task(
task_id,
patched_labeled_data_request=models.PatchedLabeledDataRequest(shapes=shapes),
)
+
+ client.logger.info("Upload complete")
diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/_torchvision.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/_torchvision.py
new file mode 100644
index 000000000000..9fa88e0a7c07
--- /dev/null
+++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/_torchvision.py
@@ -0,0 +1,26 @@
+# Copyright (C) 2024 CVAT.ai Corporation
+#
+# SPDX-License-Identifier: MIT
+
+from functools import cached_property
+
+import torchvision.models
+
+import cvat_sdk.auto_annotation as cvataa
+
+
+class TorchvisionFunction:
+ def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
+ weights_enum = torchvision.models.get_model_weights(model_name)
+ self._weights = weights_enum[weights_name]
+ self._transforms = self._weights.transforms()
+ self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs)
+ self._model.eval()
+
+ @cached_property
+ def spec(self) -> cvataa.DetectionFunctionSpec:
+ return cvataa.DetectionFunctionSpec(
+ labels=[
+ cvataa.label_spec(cat, i) for i, cat in enumerate(self._weights.meta["categories"])
+ ]
+ )
diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py
index d257cb7ec889..b16e4d8874ae 100644
--- a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py
+++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py
@@ -2,38 +2,26 @@
#
# SPDX-License-Identifier: MIT
-from functools import cached_property
-
import PIL.Image
-import torchvision.models
import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models
+from ._torchvision import TorchvisionFunction
-class _TorchvisionDetectionFunction:
- def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
- weights_enum = torchvision.models.get_model_weights(model_name)
- self._weights = weights_enum[weights_name]
- self._transforms = self._weights.transforms()
- self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs)
- self._model.eval()
-
- @cached_property
- def spec(self) -> cvataa.DetectionFunctionSpec:
- return cvataa.DetectionFunctionSpec(
- labels=[
- cvataa.label_spec(cat, i) for i, cat in enumerate(self._weights.meta["categories"])
- ]
- )
- def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
+class _TorchvisionDetectionFunction(TorchvisionFunction):
+ def detect(
+ self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
+ ) -> list[models.LabeledShapeRequest]:
+ conf_threshold = context.conf_threshold or 0
results = self._model([self._transforms(image)])
return [
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
- for box, label in zip(result["boxes"], result["labels"])
+ for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
+ if score >= conf_threshold
]
diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_instance_segmentation.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_instance_segmentation.py
new file mode 100644
index 000000000000..6aa891811f5b
--- /dev/null
+++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_instance_segmentation.py
@@ -0,0 +1,70 @@
+# Copyright (C) 2024 CVAT.ai Corporation
+#
+# SPDX-License-Identifier: MIT
+
+import math
+from collections.abc import Iterator
+
+import numpy as np
+import PIL.Image
+from skimage import measure
+from torch import Tensor
+
+import cvat_sdk.auto_annotation as cvataa
+import cvat_sdk.models as models
+from cvat_sdk.masks import encode_mask
+
+from ._torchvision import TorchvisionFunction
+
+
+def _is_positively_oriented(contour: np.ndarray) -> bool:
+ ys, xs = contour.T
+
+ # This is the shoelace formula, except we only need the sign of the result,
+ # so we compare instead of subtracting. Compared to the typical formula,
+ # the sign is inverted, because the Y axis points downwards.
+ return np.sum(xs * np.roll(ys, -1)) < np.sum(ys * np.roll(xs, -1))
+
+
+def _generate_shapes(
+ context: cvataa.DetectionFunctionContext, box: Tensor, mask: Tensor, label: Tensor
+) -> Iterator[models.LabeledShapeRequest]:
+ LEVEL = 0.5
+
+ if context.conv_mask_to_poly:
+ # Since we treat mask values of exactly LEVEL as true, we'd like them
+ # to also be considered high by find_contours. And for that, the level
+ # parameter must be slightly less than LEVEL.
+ contours = measure.find_contours(mask[0].detach().numpy(), level=math.nextafter(LEVEL, 0))
+
+ for contour in contours:
+ if len(contour) < 3 or _is_positively_oriented(contour):
+ continue
+
+ contour = measure.approximate_polygon(contour, tolerance=2.5)
+
+ yield cvataa.polygon(label.item(), contour[:, ::-1].ravel().tolist())
+
+ else:
+ yield cvataa.mask(label.item(), encode_mask(mask[0] >= LEVEL, box.tolist()))
+
+
+class _TorchvisionInstanceSegmentationFunction(TorchvisionFunction):
+ def detect(
+ self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
+ ) -> list[models.LabeledShapeRequest]:
+ conf_threshold = context.conf_threshold or 0
+ results = self._model([self._transforms(image)])
+
+ return [
+ shape
+ for result in results
+ for box, mask, label, score in zip(
+ result["boxes"], result["masks"], result["labels"], result["scores"]
+ )
+ if score >= conf_threshold
+ for shape in _generate_shapes(context, box, mask, label)
+ ]
+
+
+create = _TorchvisionInstanceSegmentationFunction
diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py
index c7199b67738b..4d2250d61c35 100644
--- a/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py
+++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_keypoint_detection.py
@@ -5,20 +5,14 @@
from functools import cached_property
import PIL.Image
-import torchvision.models
import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models
+from ._torchvision import TorchvisionFunction
-class _TorchvisionKeypointDetectionFunction:
- def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
- weights_enum = torchvision.models.get_model_weights(model_name)
- self._weights = weights_enum[weights_name]
- self._transforms = self._weights.transforms()
- self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs)
- self._model.eval()
+class _TorchvisionKeypointDetectionFunction(TorchvisionFunction):
@cached_property
def spec(self) -> cvataa.DetectionFunctionSpec:
return cvataa.DetectionFunctionSpec(
@@ -35,7 +29,10 @@ def spec(self) -> cvataa.DetectionFunctionSpec:
]
)
- def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeRequest]:
+ def detect(
+ self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
+ ) -> list[models.LabeledShapeRequest]:
+ conf_threshold = context.conf_threshold or 0
results = self._model([self._transforms(image)])
return [
@@ -51,7 +48,10 @@ def detect(self, context, image: PIL.Image.Image) -> list[models.LabeledShapeReq
],
)
for result in results
- for keypoints, label in zip(result["keypoints"], result["labels"])
+ for keypoints, label, score in zip(
+ result["keypoints"], result["labels"], result["scores"]
+ )
+ if score >= conf_threshold
]
diff --git a/cvat-sdk/cvat_sdk/auto_annotation/interface.py b/cvat-sdk/cvat_sdk/auto_annotation/interface.py
index 20a21fe4a5cf..f95cb50b4f2d 100644
--- a/cvat-sdk/cvat_sdk/auto_annotation/interface.py
+++ b/cvat-sdk/cvat_sdk/auto_annotation/interface.py
@@ -4,7 +4,7 @@
import abc
from collections.abc import Sequence
-from typing import Protocol
+from typing import Optional, Protocol
import attrs
import PIL.Image
@@ -50,7 +50,33 @@ def frame_name(self) -> str:
The file name of the frame that the current image corresponds to in
the dataset.
"""
- ...
+
+ @property
+ @abc.abstractmethod
+ def conf_threshold(self) -> Optional[float]:
+ """
+ The confidence threshold that the function should use for filtering
+ detections.
+
+ If the function is able to estimate confidence levels, then:
+
+ * If this value is None, the function may apply a default threshold at its discretion.
+
+ * Otherwise, it will be a number between 0 and 1. The function must only return
+ objects with confidence levels greater than or equal to this value.
+
+ If the function is not able to estimate confidence levels, it can ignore this value.
+ """
+
+ @property
+ @abc.abstractmethod
+ def conv_mask_to_poly(self) -> bool:
+ """
+ If this is true, the function must convert any mask shapes to polygon shapes
+ before returning them.
+
+ If the function does not return any mask shapes, then it can ignore this value.
+ """
class DetectionFunction(Protocol):
@@ -152,6 +178,21 @@ def rectangle(label_id: int, points: Sequence[float], **kwargs) -> models.Labele
return shape(label_id, type="rectangle", points=points, **kwargs)
+def polygon(label_id: int, points: Sequence[float], **kwargs) -> models.LabeledShapeRequest:
+ """Helper factory function for LabeledShapeRequest with frame=0 and type="polygon"."""
+ return shape(label_id, type="polygon", points=points, **kwargs)
+
+
+def mask(label_id: int, points: Sequence[float], **kwargs) -> models.LabeledShapeRequest:
+ """
+ Helper factory function for LabeledShapeRequest with frame=0 and type="mask".
+
+ It's recommended to use the cvat.masks.encode_mask function to build the
+ points argument.
+ """
+ return shape(label_id, type="mask", points=points, **kwargs)
+
+
def skeleton(
label_id: int, elements: Sequence[models.SubLabeledShapeRequest], **kwargs
) -> models.LabeledShapeRequest:
diff --git a/cvat-sdk/cvat_sdk/masks.py b/cvat-sdk/cvat_sdk/masks.py
new file mode 100644
index 000000000000..f623aec7d043
--- /dev/null
+++ b/cvat-sdk/cvat_sdk/masks.py
@@ -0,0 +1,44 @@
+# Copyright (C) 2024 CVAT.ai Corporation
+#
+# SPDX-License-Identifier: MIT
+
+import math
+from collections.abc import Sequence
+
+import numpy as np
+from numpy.typing import ArrayLike
+
+
+def encode_mask(bitmap: ArrayLike, /, bbox: Sequence[float]) -> list[float]:
+ """
+ Encodes an image mask into an array of numbers suitable for the "points"
+ attribute of a LabeledShapeRequest object of type "mask".
+
+ bitmap must be a boolean array of shape (H, W), where H is the height and
+ W is the width of the image that the mask applies to.
+
+ bbox must have the form [x1, y1, x2, y2], where (0, 0) <= (x1, y1) < (x2, y2) <= (W, H).
+ The mask will be limited to points between (x1, y1) and (x2, y2).
+ """
+
+ bitmap = np.asanyarray(bitmap)
+ if bitmap.ndim != 2:
+ raise ValueError("bitmap must have 2 dimensions")
+ if bitmap.dtype != np.bool_:
+ raise ValueError("bitmap must have boolean items")
+
+ x1, y1 = map(math.floor, bbox[0:2])
+ x2, y2 = map(math.ceil, bbox[2:4])
+
+ if not (0 <= x1 < x2 <= bitmap.shape[1] and 0 <= y1 < y2 <= bitmap.shape[0]):
+ raise ValueError("bbox has invalid coordinates")
+
+ flat = bitmap[y1:y2, x1:x2].ravel()
+
+ (run_indices,) = np.diff(flat, prepend=[not flat[0]], append=[not flat[-1]]).nonzero()
+ if flat[0]:
+ run_lengths = np.diff(run_indices, prepend=[0])
+ else:
+ run_lengths = np.diff(run_indices)
+
+ return run_lengths.tolist() + [x1, y1, x2 - 1, y2 - 1]
diff --git a/cvat-sdk/gen/generate.sh b/cvat-sdk/gen/generate.sh
index ca9a08be98fe..60875c499496 100755
--- a/cvat-sdk/gen/generate.sh
+++ b/cvat-sdk/gen/generate.sh
@@ -8,7 +8,7 @@ set -e
GENERATOR_VERSION="v6.0.1"
-VERSION="2.22.0"
+VERSION="2.23.0"
LIB_NAME="cvat_sdk"
LAYER1_LIB_NAME="${LIB_NAME}/api_client"
DST_DIR="$(cd "$(dirname -- "$0")/.." && pwd)"
diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_composed.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_composed.mustache
index 97d3cb930c27..e56437b401ee 100644
--- a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_composed.mustache
+++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_composed.mustache
@@ -22,7 +22,6 @@
{{name}} ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}} [optional]{{#defaultValue}} if omitted the server will use the default value of {{{.}}}{{/defaultValue}} # noqa: E501
{{/optionalVars}}
"""
- from {{packageName}}.configuration import Configuration
{{#requiredVars}}
{{#defaultValue}}
@@ -32,7 +31,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
_path_to_item = kwargs.pop('_path_to_item', ())
- _configuration = kwargs.pop('_configuration', Configuration())
+ _configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())
self = super(OpenApiModel, cls).__new__(cls)
diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_shared.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_shared.mustache
index 4c149f22ce88..12dbba9ac641 100644
--- a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_shared.mustache
+++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_shared.mustache
@@ -27,7 +27,6 @@
{{/optionalVars}}
{{> model_templates/docstring_init_required_kwargs }}
"""
- from {{packageName}}.configuration import Configuration
{{#requiredVars}}
{{#defaultValue}}
@@ -37,7 +36,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', True)
_path_to_item = kwargs.pop('_path_to_item', ())
- _configuration = kwargs.pop('_configuration', Configuration())
+ _configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())
self = super(OpenApiModel, cls).__new__(cls)
diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_simple.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_simple.mustache
index 853532e9f5ca..e8daa85e829c 100644
--- a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_simple.mustache
+++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_from_openapi_data_simple.mustache
@@ -12,7 +12,6 @@
value ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}}{{#defaultValue}} if omitted defaults to {{{.}}}{{/defaultValue}}{{#allowableValues}}, must be one of [{{#enumVars}}{{{value}}}, {{/enumVars}}]{{/allowableValues}} # noqa: E501
{{> model_templates/docstring_init_required_kwargs }}
"""
- from {{packageName}}.configuration import Configuration
# required up here when default value is not given
_path_to_item = kwargs.pop('_path_to_item', ())
@@ -39,7 +38,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
- _configuration = kwargs.pop('_configuration', Configuration())
+ _configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())
{{> model_templates/invalid_pos_args }}
diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_shared.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_shared.mustache
index 998b4841b7e7..c7d402a6cc52 100644
--- a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_shared.mustache
+++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_shared.mustache
@@ -30,7 +30,6 @@
{{/optionalVars}}
{{> model_templates/docstring_init_required_kwargs }}
"""
- from {{packageName}}.configuration import Configuration
{{#requiredVars}}
{{^isReadOnly}}
@@ -42,7 +41,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
_path_to_item = kwargs.pop('_path_to_item', ())
- _configuration = kwargs.pop('_configuration', Configuration())
+ _configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())
{{> model_templates/invalid_pos_args }}
diff --git a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_simple.mustache b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_simple.mustache
index 8c8b42ce1f49..424b1d439c62 100644
--- a/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_simple.mustache
+++ b/cvat-sdk/gen/templates/openapi-generator/model_templates/method_init_simple.mustache
@@ -20,7 +20,6 @@
value ({{{dataType}}}):{{#description}} {{{.}}}.{{/description}}{{#defaultValue}} if omitted defaults to {{{.}}}{{/defaultValue}}{{#allowableValues}}, must be one of [{{#enumVars}}{{{value}}}, {{/enumVars}}]{{/allowableValues}} # noqa: E501
{{> model_templates/docstring_init_required_kwargs }}
"""
- from {{packageName}}.configuration import Configuration
# required up here when default value is not given
_path_to_item = kwargs.pop('_path_to_item', ())
@@ -45,7 +44,7 @@
_check_type = kwargs.pop('_check_type', True)
_spec_property_naming = kwargs.pop('_spec_property_naming', False)
- _configuration = kwargs.pop('_configuration', Configuration())
+ _configuration = kwargs.pop('_configuration', None)
_visited_composed_classes = kwargs.pop('_visited_composed_classes', ())
{{> model_templates/invalid_pos_args }}
diff --git a/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache b/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache
index c9e2b70d77bb..cc3c03dbce77 100644
--- a/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache
+++ b/cvat-sdk/gen/templates/openapi-generator/model_utils.mustache
@@ -354,6 +354,13 @@ class OpenApiModel(object):
new_inst = new_cls._new_from_openapi_data(*args, **kwargs)
return new_inst
+ def __setstate__(self, state):
+ # This is the same as the default implementation. We override it,
+ # because unpickling attempts to access `obj.__setstate__` on an uninitialized
+ # object, and if this method is not defined, it results in a call to `__getattr__`.
+ # This fails, because `__getattr__` relies on `self._data_store`, which doesn't
+ # exist in an uninitialized object.
+ self.__dict__.update(state)
class ModelSimple(OpenApiModel):
"""the parent class of models whose type != object in their
@@ -1084,7 +1091,7 @@ def deserialize_file(response_data, configuration, content_disposition=None):
(file_type): the deserialized file which is open
The user is responsible for closing and reading the file
"""
- fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path)
+ fd, path = tempfile.mkstemp(dir=configuration.temp_folder_path if configuration else None)
os.close(fd)
os.remove(path)
@@ -1263,27 +1270,21 @@ def validate_and_convert_types(input_value, required_types_mixed, path_to_item,
input_class_simple = get_simple_class(input_value)
valid_type = is_valid_type(input_class_simple, valid_classes)
if not valid_type:
- if (configuration
- or (input_class_simple == dict
- and dict not in valid_classes)):
- # if input_value is not valid_type try to convert it
- converted_instance = attempt_convert_item(
- input_value,
- valid_classes,
- path_to_item,
- configuration,
- spec_property_naming,
- key_type=False,
- must_convert=True,
- check_type=_check_type
- )
- return converted_instance
- else:
- raise get_type_error(input_value, path_to_item, valid_classes,
- key_type=False)
+ # if input_value is not valid_type try to convert it
+ converted_instance = attempt_convert_item(
+ input_value,
+ valid_classes,
+ path_to_item,
+ configuration,
+ spec_property_naming,
+ key_type=False,
+ must_convert=True,
+ check_type=_check_type
+ )
+ return converted_instance
# input_value's type is in valid_classes
- if len(valid_classes) > 1 and configuration:
+ if len(valid_classes) > 1:
# there are valid classes which are not the current class
valid_classes_coercible = remove_uncoercible(
valid_classes, input_value, spec_property_naming, must_convert=False)
diff --git a/cvat-sdk/gen/templates/openapi-generator/setup.mustache b/cvat-sdk/gen/templates/openapi-generator/setup.mustache
index eb89f5d20554..e0379cabd06e 100644
--- a/cvat-sdk/gen/templates/openapi-generator/setup.mustache
+++ b/cvat-sdk/gen/templates/openapi-generator/setup.mustache
@@ -77,7 +77,8 @@ setup(
python_requires="{{{generatorLanguageVersion}}}",
install_requires=BASE_REQUIREMENTS,
extras_require={
- "pytorch": ['torch', 'torchvision'],
+ "masks": ["numpy>=2"],
+ "pytorch": ['torch', 'torchvision', 'scikit-image>=0.24', 'cvat_sdk[masks]'],
},
package_dir={"": "."},
packages=find_packages(include=["cvat_sdk*"]),
diff --git a/cvat-ui/package.json b/cvat-ui/package.json
index a74485fa107d..703718121cd1 100644
--- a/cvat-ui/package.json
+++ b/cvat-ui/package.json
@@ -1,6 +1,6 @@
{
"name": "cvat-ui",
- "version": "1.66.4",
+ "version": "1.67.0",
"description": "CVAT single-page application",
"main": "src/index.tsx",
"scripts": {
diff --git a/cvat-ui/src/actions/annotation-actions.ts b/cvat-ui/src/actions/annotation-actions.ts
index 0cc8f3052bc0..115470429990 100644
--- a/cvat-ui/src/actions/annotation-actions.ts
+++ b/cvat-ui/src/actions/annotation-actions.ts
@@ -1081,7 +1081,7 @@ export function finishCurrentJobAsync(): ThunkAction {
export function rememberObject(createParams: {
activeObjectType?: ObjectType;
activeLabelID?: number;
- activeShapeType?: ShapeType;
+ activeShapeType?: ShapeType | null;
activeNumOfPoints?: number;
activeRectDrawingMethod?: RectDrawingMethod;
activeCuboidDrawingMethod?: CuboidDrawingMethod;
diff --git a/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx b/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx
index 27898da9fa2a..f33dd9bf231a 100644
--- a/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx
+++ b/cvat-ui/src/components/annotation-page/annotations-actions/annotations-actions-modal.tsx
@@ -7,6 +7,7 @@ import './styles.scss';
import React, {
useEffect, useReducer, useRef, useState,
} from 'react';
+import { createRoot } from 'react-dom/client';
import Button from 'antd/lib/button';
import { Col, Row } from 'antd/lib/grid';
import Progress from 'antd/lib/progress';
@@ -22,28 +23,27 @@ import { useIsMounted } from 'utils/hooks';
import { createAction, ActionUnion } from 'utils/redux';
import { getCVATStore } from 'cvat-store';
import {
- BaseSingleFrameAction, FrameSelectionType, Job, getCore,
+ BaseCollectionAction, BaseAction, Job, getCore,
+ ObjectState,
} from 'cvat-core-wrapper';
import { Canvas } from 'cvat-canvas-wrapper';
-import { fetchAnnotationsAsync, saveAnnotationsAsync } from 'actions/annotation-actions';
-import { switchAutoSave } from 'actions/settings-actions';
+import { fetchAnnotationsAsync } from 'actions/annotation-actions';
import { clamp } from 'utils/math';
const core = getCore();
interface State {
- actions: BaseSingleFrameAction[];
- activeAction: BaseSingleFrameAction | null;
+ actions: BaseAction[];
+ activeAction: BaseAction | null;
fetching: boolean;
progress: number | null;
progressMessage: string | null;
cancelled: boolean;
- autoSaveEnabled: boolean;
- jobHasBeenSaved: boolean;
frameFrom: number;
frameTo: number;
actionParameters: Record;
modalVisible: boolean;
+ targetObjectState?: ObjectState | null;
}
enum ReducerActionType {
@@ -53,8 +53,6 @@ enum ReducerActionType {
RESET_BEFORE_RUN = 'RESET_BEFORE_RUN',
RESET_AFTER_RUN = 'RESET_AFTER_RUN',
CANCEL_ACTION = 'CANCEL_ACTION',
- SET_AUTOSAVE_DISABLED_FLAG = 'SET_AUTOSAVE_DISABLED_FLAG',
- SET_JOB_WAS_SAVED_FLAG = 'SET_JOB_WAS_SAVED_FLAG',
UPDATE_FRAME_FROM = 'UPDATE_FRAME_FROM',
UPDATE_FRAME_TO = 'UPDATE_FRAME_TO',
UPDATE_ACTION_PARAMETER = 'UPDATE_ACTION_PARAMETER',
@@ -62,10 +60,10 @@ enum ReducerActionType {
}
export const reducerActions = {
- setAnnotationsActions: (actions: BaseSingleFrameAction[]) => (
+ setAnnotationsActions: (actions: BaseAction[]) => (
createAction(ReducerActionType.SET_ANNOTATIONS_ACTIONS, { actions })
),
- setActiveAnnotationsAction: (activeAction: BaseSingleFrameAction) => (
+ setActiveAnnotationsAction: (activeAction: BaseAction) => (
createAction(ReducerActionType.SET_ACTIVE_ANNOTATIONS_ACTION, { activeAction })
),
updateProgress: (progress: number | null, progressMessage: string | null) => (
@@ -80,12 +78,6 @@ export const reducerActions = {
cancelAction: () => (
createAction(ReducerActionType.CANCEL_ACTION)
),
- setAutoSaveDisabledFlag: () => (
- createAction(ReducerActionType.SET_AUTOSAVE_DISABLED_FLAG)
- ),
- setJobSavedFlag: (jobHasBeenSaved: boolean) => (
- createAction(ReducerActionType.SET_JOB_WAS_SAVED_FLAG, { jobHasBeenSaved })
- ),
updateFrameFrom: (frameFrom: number) => (
createAction(ReducerActionType.UPDATE_FRAME_FROM, { frameFrom })
),
@@ -100,19 +92,54 @@ export const reducerActions = {
),
};
+const KEEP_LATEST = 5;
+let lastSelectedActions: [string, Record][] = [];
+function updateLatestActions(name: string, parameters: Record = {}): void {
+ const idx = lastSelectedActions.findIndex((el) => el[0] === name);
+ if (idx === -1) {
+ lastSelectedActions = [[name, parameters], ...lastSelectedActions];
+ } else {
+ lastSelectedActions = [
+ [name, parameters],
+ ...lastSelectedActions.slice(0, idx),
+ ...lastSelectedActions.slice(idx + 1),
+ ];
+ }
+
+ lastSelectedActions = lastSelectedActions.slice(-KEEP_LATEST);
+}
+
const reducer = (state: State, action: ActionUnion): State => {
if (action.type === ReducerActionType.SET_ANNOTATIONS_ACTIONS) {
+ const { actions } = action.payload;
+ const list = state.targetObjectState ? actions
+ .filter((_action) => _action.isApplicableForObject(state.targetObjectState as ObjectState)) : actions;
+
+ let activeAction = null;
+ let activeActionParameters = {};
+ for (const item of lastSelectedActions) {
+ const [actionName, actionParameters] = item;
+ const candidate = list.find((el) => el.name === actionName);
+ if (candidate) {
+ activeAction = candidate;
+ activeActionParameters = actionParameters;
+ break;
+ }
+ }
+
return {
...state,
- actions: action.payload.actions,
- activeAction: action.payload.actions[0] || null,
- actionParameters: {},
+ actions: list,
+ activeAction: activeAction ?? list[0] ?? null,
+ actionParameters: activeActionParameters,
};
}
if (action.type === ReducerActionType.SET_ACTIVE_ANNOTATIONS_ACTION) {
- const { frameSelection } = action.payload.activeAction;
- if (frameSelection === FrameSelectionType.CURRENT_FRAME) {
+ const { activeAction } = action.payload;
+ updateLatestActions(activeAction.name, {});
+
+ if (action.payload.activeAction instanceof BaseCollectionAction) {
const storage = getCVATStore();
const currentFrame = storage.getState().annotation.player.frame.number;
return {
@@ -123,6 +150,7 @@ const reducer = (state: State, action: ActionUnion): Stat
actionParameters: {},
};
}
+
return {
...state,
activeAction: action.payload.activeAction,
@@ -163,20 +191,6 @@ const reducer = (state: State, action: ActionUnion): Stat
};
}
- if (action.type === ReducerActionType.SET_AUTOSAVE_DISABLED_FLAG) {
- return {
- ...state,
- autoSaveEnabled: false,
- };
- }
-
- if (action.type === ReducerActionType.SET_JOB_WAS_SAVED_FLAG) {
- return {
- ...state,
- jobHasBeenSaved: action.payload.jobHasBeenSaved,
- };
- }
-
if (action.type === ReducerActionType.UPDATE_FRAME_FROM) {
return {
...state,
@@ -194,12 +208,16 @@ const reducer = (state: State, action: ActionUnion): Stat
}
if (action.type === ReducerActionType.UPDATE_ACTION_PARAMETER) {
+ const updatedActionParameters = {
+ ...state.actionParameters,
+ [action.payload.name]: action.payload.value,
+ };
+
+ updateLatestActions((state.activeAction as BaseAction).name, updatedActionParameters);
+
return {
...state,
- actionParameters: {
- ...state.actionParameters,
- [action.payload.name]: action.payload.value,
- },
+ actionParameters: updatedActionParameters,
};
}
@@ -213,9 +231,9 @@ const reducer = (state: State, action: ActionUnion): Stat
return state;
};
-type Props = NonNullable[keyof BaseSingleFrameAction['parameters']];
+type ActionParameterProps = NonNullable[keyof BaseAction['parameters']];
-function ActionParameterComponent(props: Props & { onChange: (value: string) => void }): JSX.Element {
+function ActionParameterComponent(props: ActionParameterProps & { onChange: (value: string) => void }): JSX.Element {
const {
defaultValue, type, values, onChange,
} = props;
@@ -262,8 +280,13 @@ function ActionParameterComponent(props: Props & { onChange: (value: string) =>
);
}
-function AnnotationsActionsModalContent(props: { onClose: () => void; }): JSX.Element {
- const { onClose } = props;
+interface Props {
+ onClose: () => void;
+ targetObjectState?: ObjectState;
+}
+
+function AnnotationsActionsModalContent(props: Props): JSX.Element {
+ const { onClose, targetObjectState: defaultTargetObjectState } = props;
const isMounted = useIsMounted();
const storage = getCVATStore();
const cancellationRef = useRef(false);
@@ -276,29 +299,27 @@ function AnnotationsActionsModalContent(props: { onClose: () => void; }): JSX.El
progress: null,
progressMessage: null,
cancelled: false,
- autoSaveEnabled: storage.getState().settings.workspace.autoSave,
- jobHasBeenSaved: true,
frameFrom: jobInstance.startFrame,
frameTo: jobInstance.stopFrame,
actionParameters: {},
modalVisible: true,
+ targetObjectState: defaultTargetObjectState ?? null,
});
useEffect(() => {
- core.actions.list().then((list: BaseSingleFrameAction[]) => {
+ core.actions.list().then((list: BaseAction[]) => {
if (isMounted()) {
- dispatch(reducerActions.setJobSavedFlag(!jobInstance.annotations.hasUnsavedChanges()));
dispatch(reducerActions.setAnnotationsActions(list));
}
});
}, []);
const {
- actions, activeAction, fetching, autoSaveEnabled, jobHasBeenSaved,
+ actions, activeAction, fetching, targetObjectState,
progress, progressMessage, frameFrom, frameTo, actionParameters, modalVisible,
} = state;
- const currentFrameAction = activeAction?.frameSelection === FrameSelectionType.CURRENT_FRAME;
+ const currentFrameAction = activeAction instanceof BaseCollectionAction || targetObjectState !== null;
return (
void; }): JSX.El
- Actions allow executing certain algorithms on
-
-
- filtered
-
-
- annotations.
- It affects only the local browser state.
- Once an action has finished,
- it cannot be reverted.
- You may reload the page to get annotations from the server.
- It is strongly recommended to review the changes
- before saving annotations to the server.
-
+ targetObjectState ? (
+ Selected action will be applied to the current object
+ ) : (
+
+ Actions allow executing certain algorithms on
+
+
+ filtered
+
+
+ annotations.
+