diff --git a/docker-compose.yml b/docker-compose.yml index 4f98f7b..ecc1e7e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,7 +24,7 @@ services: ports: - "5555:5555" environment: - CELERY_BROKER_URL: redis://redis:6379/0 + CELERY_BROKER_URL: "amqp://user:password@localhost:5672" FLOWER_PORT: 5555 depends_on: - redis @@ -51,8 +51,20 @@ services: ports: - "8080:8080" + rabbitmq: + image: rabbitmq:3-management + ports: + - "5672:5672" + - "15672:15672" + environment: + RABBITMQ_DEFAULT_USER: user + RABBITMQ_DEFAULT_PASS: password + volumes: + - rabbitmq-data:/var/lib/rabbitmq + volumes: qdrant_data: + rabbitmq-data: configs: qdrant_config: diff --git a/main.py b/main.py index 0a735b5..c9ec1ce 100644 --- a/main.py +++ b/main.py @@ -8,10 +8,11 @@ import requests from fastapi import FastAPI, Body from starlette.requests import Request -from starlette.responses import StreamingResponse +from starlette.responses import StreamingResponse, Response import config import model_worker +import redis_conn import s3 import tasks from database import artwork_collection, translation_collection @@ -123,41 +124,41 @@ async def imgproxy(image_id: str): artwork = await artwork_collection.find_one({"_id": image_id}) encoded_url = base64.b64encode(f"{s3.base_url}{artwork['s3_object_name']}".encode("utf-8")).decode("utf-8") - # Check if file is cached in tempdir - # If not, fetch from imgproxy - if os.path.exists(f"/tmp/{artwork.id}.webp"): - return StreamingResponse(open(f"/tmp/{artwork.id}.webp", "rb"), media_type="image/webp") + redis = redis_conn.redis_client + + # Check if file is cached in redis + if redis.exists(f"{artwork['_id']}:thumbnail"): + return Response(content=bytes(redis.get(f"{artwork['_id']}:thumbnail")), media_type="image/webp") # return a streaming response url = httpx.URL(f"{config.IMGPROXY_THUMBNAIL_BASE_URL}{encoded_url}.webp") async with httpx.AsyncClient() as client: response = await client.get(url) - # save to tempdir - with open(f"/tmp/{artwork.id}.webp", "wb") as f: - f.write(response.content) + # save to redis + redis.setex(f"{artwork['_id']}:thumbnail", 60 * 60 * 24 * 7, response.content) - return StreamingResponse(BytesIO(response.content), media_type="image/webp") + return Response(content=response.content, media_type="image/webp") @app.get("/imgproxy/optimized/{image_id}") -async def imgproxy(image_id: str): +async def optimized_artwork(image_id: str): artwork = await artwork_collection.find_one({"_id": image_id}) encoded_url = base64.b64encode(f"{s3.base_url}{artwork['s3_object_name']}".encode("utf-8")).decode("utf-8") - # Check if file is cached in tempdir - # If not, fetch from imgproxy - if os.path.exists(f"/tmp/{artwork.id}_orig.webp"): - return StreamingResponse(open(f"/tmp/{artwork.id}_orig.webp", "rb"), media_type="image/webp") + redis = redis_conn.redis_client + + # Check if file is cached in redis + if redis.exists(f"{artwork['_id']}:optimized"): + return Response(content=redis.get(f"{artwork['_id']}:optimized"), media_type="image/webp") # return a streaming response url = httpx.URL(f"{config.IMGPROXY_OPTIMIZED_BASE_URL}{encoded_url}.webp") async with httpx.AsyncClient() as client: response = await client.get(url) - # save to tempdir - with open(f"/tmp/{artwork.id}_orig.webp", "wb") as f: - f.write(response.content) + # save to redis + redis.setex(f"{artwork['_id']}:optimized", 60 * 60 * 24, response.content) - return StreamingResponse(BytesIO(response.content), media_type="image/webp") + return Response(content=response.content, media_type="image/webp") @app.get("/search") @@ -174,11 +175,12 @@ async def search(request: Request, query: str = None, page: int = 1, neural: boo split_tags = query.split(" ") results = [] if neural: - search_task = model_worker.neural_search.delay(query, page=page, limit=25) + search_task = model_worker.neural_search.apply_async(args=[query], kwargs={"page": page, "limit": 25}, + priority=10) search_task.wait() results = search_task.get() else: - search_task = model_worker.tag_search.delay(split_tags, page, 25, group_sets=group_sets) + search_task = model_worker.tag_search.apply_async(args=[split_tags], kwargs={"page": page, "limit": 25, "group_sets": group_sets}, priority=10) search_task.wait() results = search_task.get() @@ -272,7 +274,6 @@ async def delete_videos(): @app.post("/api/clearcache") async def clear_cache(): - for file in os.listdir("/tmp"): - if file.endswith(".webp"): - os.remove(f"/tmp/{file}") + redis = redis_conn.redis_client + redis.flushdb() return {"message": "Cleared cache."} diff --git a/model_worker.py b/model_worker.py index a814ff8..dbf2d32 100644 --- a/model_worker.py +++ b/model_worker.py @@ -18,9 +18,17 @@ from PIL import Image import open_clip from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchText, IsNullCondition, MatchValue, PayloadField +from kombu import Exchange, Queue -app = Celery('model_worker', broker=config.CELERY_REDIS_BROKER, backend=config.CELERY_REDIS_BROKER, result_expires=60*60*24) + +app = Celery('model_worker', broker=config.CELERY_RABBITMQ_URL, backend=config.REDIS_URL, result_expires=60 * 60 * 24) app.conf.task_default_queue = 'model_worker' +app.conf.task_queues = [ + Queue('model_worker', routing_key='model_tasks.#', queue_arguments={'x-max-priority': 10}, max_priority=10), +] +app.conf.update( + worker_prefetch_multiplier=1 +) model, preprocess, tokenizer, qdrant_client = None, None, None, None loadingMutex = Lock() diff --git a/redis_conn.py b/redis_conn.py new file mode 100644 index 0000000..a5d544c --- /dev/null +++ b/redis_conn.py @@ -0,0 +1,6 @@ +import redis + +import config + +redis_client = redis.Redis.from_url(config.REDIS_URL) + diff --git a/tasks.py b/tasks.py index 7d245b3..cc08b51 100644 --- a/tasks.py +++ b/tasks.py @@ -18,8 +18,17 @@ from database import CONNECTION_STRING, Artwork from s3 import upload_file from schema import PixivDownloadBatch +from kombu import Exchange, Queue + +app = Celery('tasks', broker=config.CELERY_RABBITMQ_URL, backend=config.REDIS_URL, result_expires=60 * 60 * 24) +app.conf.task_default_queue = 'tasks' +app.conf.task_queues = [ + Queue('tasks', routing_key='tasks.#', queue_arguments={'x-max-priority': 10, 'celeryd_prefetch_multiplier': 1}, max_priority=10), +] +app.conf.update( + worker_prefetch_multiplier=1 +) -app = Celery('tasks', broker='redis://localhost', backend='redis://localhost', result_expires=60*60*24) gelbooru = Gelbooru() @app.task @@ -84,8 +93,7 @@ def downloadPixivImage(pixivImage: PixivDownloadBatch): else: artwork.s3_object_name = upload_file(temp, sha256, os.path.splitext(pixivImage.url)[1][1:]) artwork_collection.insert_one(artwork.model_dump(by_alias=True)) - sleep(1) - model_worker.generate_embeddings.delay([sha256]) + model_worker.generate_embeddings.apply_async([sha256], countdown=10) @app.task @@ -129,8 +137,7 @@ def downloadGelbooru(gelbooruImage): else: artwork.s3_object_name = upload_file(temp, sha256, os.path.splitext(gelbooruImage["url"])[1][1:]) artwork_collection.insert_one(artwork.model_dump(by_alias=True)) - sleep(1) - model_worker.generate_embeddings.delay([sha256]) + model_worker.generate_embeddings.apply_async([sha256], countdown=10) @async_task(app, bind=True) diff --git a/templates/gallery.html b/templates/gallery.html index 3dafe78..f18ec85 100644 --- a/templates/gallery.html +++ b/templates/gallery.html @@ -5,14 +5,14 @@

Recently added sets

{% for image in images %}
- + {{ image.title }} {% if not image.vitEmbedding %} -
+

-
- No embeddings available +
+ Not analyzed yet

{% endif %}