Skip to content

Commit

Permalink
refactor arch task logic
Browse files Browse the repository at this point in the history
  • Loading branch information
githubering182 committed Dec 11, 2024
1 parent 3669a1b commit b7ca813
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 14 deletions.
6 changes: 3 additions & 3 deletions storage-app/src/shared/archive_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,13 @@ def __init__(
self.queue = queue
self.max_concurrent = max_concurrent
self._done = False
self.iter_count = 0

@property
def ready(self) -> bool: return self._done

async def produce(self):
tasks = []
iter_count = 0

while await self.object_set.fetch_next:
file = self.object_set.next_object()
Expand All @@ -254,8 +254,8 @@ async def produce(self):
await wait(tasks, return_when=FIRST_COMPLETED)
tasks = [task for task in tasks if not task.done()]

iter_count += 1
if not iter_count % GC_FREQ: gc_collect()
self.iter_count += 1
if not self.iter_count % GC_FREQ: gc_collect()

await gather(*tasks)
self.queue.put(None)
Expand Down
4 changes: 2 additions & 2 deletions storage-app/src/shared/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

assert STORAGE_PORT

ASYNC_PRODUCER_MAX_CONCURRENT: int = 1_000
ASYNC_PRODUCER_GC_FREQ: int = 100
ASYNC_PRODUCER_MAX_CONCURRENT: int = 256
ASYNC_PRODUCER_GC_FREQ: int = 256
APP_BACKEND_URL: str = "http://" + getenv("APP_BACKEND_URL", "127.0.0.1")
SECRET_KEY: str = getenv("SECRET_KEY", "")
SECRET_ALGO: str = getenv("SECRET_ALGO", "HS256")
Expand Down
16 changes: 12 additions & 4 deletions storage-app/src/shared/worker_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .hasher import VHash, IHash
from queue import Queue
from .archive_helpers import FileProducer, ZipConsumer, ZipWriter
from celery import Task


class EmbeddingStatus(Enum):
Expand All @@ -34,13 +35,18 @@ class Zipper:
written: bool = False
archive_extension: str = "zip"

def __init__(self, bucket_name: str, file_ids: list[str]) -> None:
def __init__(
self,
bucket_name: str,
file_ids: list[str],
task: Task
) -> None:
self.object_set = Bucket(bucket_name).get_download_objects(file_ids)
self.bucket_name = bucket_name
self._task = task

self._get_annotation(bucket_name, file_ids)

self.bucket_name = bucket_name

async def archive_objects(self) -> Optional[bool]:
json_data: Any = ("annotation.json", dumps(self.annotation, indent=4).encode("utf-8"))

Expand All @@ -61,7 +67,9 @@ async def archive_objects(self) -> Optional[bool]:
if wait_list.task.ready:
wait_list = wait_list.next
continue
await async_stall_for(1)

self._task.update_state(state="PROGRESS")
await async_stall_for(5)

await producer_task
await self.object_set.close()
Expand Down
10 changes: 5 additions & 5 deletions storage-app/src/worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from celery import Celery
from shared.settings import BROKER_URL, RESULT_URL, CELERY_CONFIG
from shared.worker_services import Zipper, Hasher, EmbeddingStatus
from asyncio import get_event_loop
from asyncio import run
from typing import Optional, Any
from json import JSONEncoder, loads, dumps
from kombu.serialization import register
Expand All @@ -26,10 +26,10 @@ def default(self, o) -> Any: return getattr(o, "__json__", super().default)(o)
)


@worker.task(name="produce_download_task")
def produce_download_task(bucket_name: str, file_ids: list[str]) -> str | None:
task = Zipper(bucket_name, file_ids)
get_event_loop().run_until_complete(task.archive_objects())
@worker.task(bind=True, name="produce_download_task")
def produce_download_task(self, bucket_name: str, file_ids: list[str]) -> str | None:
task = Zipper(bucket_name, file_ids, self)
run(task.archive_objects())
return task.archive_id


Expand Down

0 comments on commit b7ca813

Please sign in to comment.