Skip to content

Commit

Permalink
ref worker
Browse files Browse the repository at this point in the history
  • Loading branch information
githubering182 committed Dec 12, 2024
1 parent 0a2c57e commit 59b6c73
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
8 changes: 5 additions & 3 deletions storage-app/src/shared/worker_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,19 @@ def __init__(
file_ids: list[str],
task: Task
) -> None:
self.object_set = Bucket(bucket_name).get_download_objects(file_ids)
self.file_ids = file_ids
self.bucket_name = bucket_name
self._task = task

self._get_annotation(bucket_name, file_ids)

async def archive_objects(self) -> Optional[bool]:
json_data: Any = ("annotation.json", dumps(self.annotation, indent=4).encode("utf-8"))
object_set = Bucket(self.bucket_name).get_download_objects(self.file_ids)

queue = Queue()

producer = FileProducer(self.object_set, queue, MAX_CONCURENT)
producer = FileProducer(object_set, queue, MAX_CONCURENT)
writer = ZipWriter(f"{self.bucket_name}_dataset")
consumer = ZipConsumer(queue, [json_data], writer)

Expand All @@ -68,11 +69,12 @@ async def archive_objects(self) -> Optional[bool]:
wait_list = wait_list.next
continue

print(f"ZIP WORK STALL, {producer.iter_count}")
self._task.update_state(state="PROGRESS")
await async_stall_for(5)

await producer_task
await self.object_set.close()
await object_set.close()
consumer.join()
writer.join()

Expand Down
4 changes: 2 additions & 2 deletions storage-app/src/worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from celery import Celery
from shared.settings import BROKER_URL, RESULT_URL, CELERY_CONFIG
from shared.worker_services import Zipper, Hasher, EmbeddingStatus
from asyncio import get_event_loop
from asyncio import run
from typing import Optional, Any
from json import JSONEncoder, loads, dumps
from kombu.serialization import register
Expand Down Expand Up @@ -35,7 +35,7 @@ def default(self, o) -> Any: return getattr(o, "__json__", super().default)(o)
)
def produce_download_task(self, bucket_name: str, file_ids: list[str]) -> str | None:
task = Zipper(bucket_name, file_ids, self)
get_event_loop().run_until_complete(task.archive_objects())
run(task.archive_objects())
return task.archive_id


Expand Down

0 comments on commit 59b6c73

Please sign in to comment.