Skip to content

Commit

Permalink
Add default batch size in bulk create
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max committed Jan 24, 2025
1 parent a14ddba commit 67a6df3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
20 changes: 14 additions & 6 deletions cvat/apps/dataset_manager/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import re
import tempfile
import zipfile
from collections.abc import Generator, Sequence
from collections.abc import Generator, Iterable, Sequence
from contextlib import contextmanager
from copy import deepcopy
from datetime import timedelta
from enum import Enum
from threading import Lock
from typing import Any
from typing import Any, TypeVar

import attrs
import django_rq
Expand All @@ -38,18 +38,26 @@ def make_zip_archive(src_path, dst_path):
archive.write(path, osp.relpath(path, src_path))


def bulk_create(db_model, objects, flt_param):
_ModelT = TypeVar("_ModelT", bound=models.Model)

def bulk_create(
db_model: type[_ModelT],
objects: Iterable[_ModelT],
*,
flt_param: dict[str, Any] | None = None,
batch_size: int | None = 10000
) -> list[_ModelT]:
if objects:
if flt_param:
if "postgresql" in settings.DATABASES["default"]["ENGINE"]:
return db_model.objects.bulk_create(objects)
return db_model.objects.bulk_create(objects, batch_size=batch_size)
else:
ids = list(db_model.objects.filter(**flt_param).values_list('id', flat=True))
db_model.objects.bulk_create(objects)
db_model.objects.bulk_create(objects, batch_size=batch_size)

return list(db_model.objects.exclude(id__in=ids).filter(**flt_param))
else:
return db_model.objects.bulk_create(objects)
return db_model.objects.bulk_create(objects, batch_size=batch_size)

return []

Expand Down
10 changes: 3 additions & 7 deletions cvat/apps/quality_control/quality_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2537,9 +2537,7 @@ def _save_reports(self, *, task_report: dict, job_reports: list[dict]) -> models
)
db_job_reports.append(db_job_report)

db_job_reports = bulk_create(
db_model=models.QualityReport, objects=db_job_reports, flt_param={}
)
db_job_reports = bulk_create(db_model=models.QualityReport, objects=db_job_reports)

db_conflicts = []
db_report_iter = itertools.chain([db_task_report], db_job_reports)
Expand All @@ -2554,9 +2552,7 @@ def _save_reports(self, *, task_report: dict, job_reports: list[dict]) -> models
)
db_conflicts.append(db_conflict)

db_conflicts = bulk_create(
db_model=models.AnnotationConflict, objects=db_conflicts, flt_param={}
)
db_conflicts = bulk_create(db_model=models.AnnotationConflict, objects=db_conflicts)

db_ann_ids = []
db_conflicts_iter = iter(db_conflicts)
Expand All @@ -2572,7 +2568,7 @@ def _save_reports(self, *, task_report: dict, job_reports: list[dict]) -> models
)
db_ann_ids.append(db_ann_id)

db_ann_ids = bulk_create(db_model=models.AnnotationId, objects=db_ann_ids, flt_param={})
db_ann_ids = bulk_create(db_model=models.AnnotationId, objects=db_ann_ids)

return db_task_report

Expand Down

0 comments on commit 67a6df3

Please sign in to comment.