diff --git a/storage-app/src/shared/archive_helpers.py b/storage-app/src/shared/archive_helpers.py index 2c9d42c..b686ae8 100644 --- a/storage-app/src/shared/archive_helpers.py +++ b/storage-app/src/shared/archive_helpers.py @@ -238,13 +238,13 @@ def __init__( self.queue = queue self.max_concurrent = max_concurrent self._done = False + self.iter_count = 0 @property def ready(self) -> bool: return self._done async def produce(self): tasks = [] - iter_count = 0 while await self.object_set.fetch_next: file = self.object_set.next_object() @@ -254,8 +254,8 @@ async def produce(self): await wait(tasks, return_when=FIRST_COMPLETED) tasks = [task for task in tasks if not task.done()] - iter_count += 1 - if not iter_count % GC_FREQ: gc_collect() + self.iter_count += 1 + if not self.iter_count % GC_FREQ: gc_collect() await gather(*tasks) self.queue.put(None) diff --git a/storage-app/src/shared/settings.py b/storage-app/src/shared/settings.py index 699327b..66dd30a 100644 --- a/storage-app/src/shared/settings.py +++ b/storage-app/src/shared/settings.py @@ -19,8 +19,8 @@ assert STORAGE_PORT -ASYNC_PRODUCER_MAX_CONCURRENT: int = 1_000 -ASYNC_PRODUCER_GC_FREQ: int = 100 +ASYNC_PRODUCER_MAX_CONCURRENT: int = 256 +ASYNC_PRODUCER_GC_FREQ: int = 256 APP_BACKEND_URL: str = "http://" + getenv("APP_BACKEND_URL", "127.0.0.1") SECRET_KEY: str = getenv("SECRET_KEY", "") SECRET_ALGO: str = getenv("SECRET_ALGO", "HS256") diff --git a/storage-app/src/shared/worker_services.py b/storage-app/src/shared/worker_services.py index 2aab557..231ed4c 100644 --- a/storage-app/src/shared/worker_services.py +++ b/storage-app/src/shared/worker_services.py @@ -15,6 +15,7 @@ from .hasher import VHash, IHash from queue import Queue from .archive_helpers import FileProducer, ZipConsumer, ZipWriter +from celery import Task class EmbeddingStatus(Enum): @@ -34,13 +35,18 @@ class Zipper: written: bool = False archive_extension: str = "zip" - def __init__(self, bucket_name: str, file_ids: list[str]) -> None: + def __init__( + self, + bucket_name: str, + file_ids: list[str], + task: Task + ) -> None: self.object_set = Bucket(bucket_name).get_download_objects(file_ids) + self.bucket_name = bucket_name + self._task = task self._get_annotation(bucket_name, file_ids) - self.bucket_name = bucket_name - async def archive_objects(self) -> Optional[bool]: json_data: Any = ("annotation.json", dumps(self.annotation, indent=4).encode("utf-8")) @@ -61,7 +67,9 @@ async def archive_objects(self) -> Optional[bool]: if wait_list.task.ready: wait_list = wait_list.next continue - await async_stall_for(1) + + self._task.update_state(state="PROGRESS") + await async_stall_for(5) await producer_task await self.object_set.close() diff --git a/storage-app/src/worker.py b/storage-app/src/worker.py index dc60567..72ca3d6 100644 --- a/storage-app/src/worker.py +++ b/storage-app/src/worker.py @@ -1,7 +1,7 @@ from celery import Celery from shared.settings import BROKER_URL, RESULT_URL, CELERY_CONFIG from shared.worker_services import Zipper, Hasher, EmbeddingStatus -from asyncio import get_event_loop +from asyncio import run from typing import Optional, Any from json import JSONEncoder, loads, dumps from kombu.serialization import register @@ -26,10 +26,10 @@ def default(self, o) -> Any: return getattr(o, "__json__", super().default)(o) ) -@worker.task(name="produce_download_task") -def produce_download_task(bucket_name: str, file_ids: list[str]) -> str | None: - task = Zipper(bucket_name, file_ids) - get_event_loop().run_until_complete(task.archive_objects()) +@worker.task(bind=True, name="produce_download_task") +def produce_download_task(self, bucket_name: str, file_ids: list[str]) -> str | None: + task = Zipper(bucket_name, file_ids, self) + run(task.archive_objects()) return task.archive_id