Skip to content

Commit

Permalink
Auto-gen embeddings, migrate thumbs to redis, use rabbitmq for task q…
Browse files Browse the repository at this point in the history
…ueue
  • Loading branch information
PrivateGER committed Mar 22, 2024
1 parent dbc092a commit 6920ed7
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 34 deletions.
14 changes: 13 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ services:
ports:
- "5555:5555"
environment:
CELERY_BROKER_URL: redis://redis:6379/0
CELERY_BROKER_URL: "amqp://user:password@localhost:5672"
FLOWER_PORT: 5555
depends_on:
- redis
Expand All @@ -51,8 +51,20 @@ services:
ports:
- "8080:8080"

rabbitmq:
image: rabbitmq:3-management
ports:
- "5672:5672"
- "15672:15672"
environment:
RABBITMQ_DEFAULT_USER: user
RABBITMQ_DEFAULT_PASS: password
volumes:
- rabbitmq-data:/var/lib/rabbitmq

volumes:
qdrant_data:
rabbitmq-data:

configs:
qdrant_config:
Expand Down
47 changes: 24 additions & 23 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import requests
from fastapi import FastAPI, Body
from starlette.requests import Request
from starlette.responses import StreamingResponse
from starlette.responses import StreamingResponse, Response

import config
import model_worker
import redis_conn
import s3
import tasks
from database import artwork_collection, translation_collection
Expand Down Expand Up @@ -123,41 +124,41 @@ async def imgproxy(image_id: str):
artwork = await artwork_collection.find_one({"_id": image_id})
encoded_url = base64.b64encode(f"{s3.base_url}{artwork['s3_object_name']}".encode("utf-8")).decode("utf-8")

# Check if file is cached in tempdir
# If not, fetch from imgproxy
if os.path.exists(f"/tmp/{artwork.id}.webp"):
return StreamingResponse(open(f"/tmp/{artwork.id}.webp", "rb"), media_type="image/webp")
redis = redis_conn.redis_client

# Check if file is cached in redis
if redis.exists(f"{artwork['_id']}:thumbnail"):
return Response(content=bytes(redis.get(f"{artwork['_id']}:thumbnail")), media_type="image/webp")

# return a streaming response
url = httpx.URL(f"{config.IMGPROXY_THUMBNAIL_BASE_URL}{encoded_url}.webp")
async with httpx.AsyncClient() as client:
response = await client.get(url)
# save to tempdir
with open(f"/tmp/{artwork.id}.webp", "wb") as f:
f.write(response.content)
# save to redis
redis.setex(f"{artwork['_id']}:thumbnail", 60 * 60 * 24 * 7, response.content)

return StreamingResponse(BytesIO(response.content), media_type="image/webp")
return Response(content=response.content, media_type="image/webp")


@app.get("/imgproxy/optimized/{image_id}")
async def imgproxy(image_id: str):
async def optimized_artwork(image_id: str):
artwork = await artwork_collection.find_one({"_id": image_id})
encoded_url = base64.b64encode(f"{s3.base_url}{artwork['s3_object_name']}".encode("utf-8")).decode("utf-8")

# Check if file is cached in tempdir
# If not, fetch from imgproxy
if os.path.exists(f"/tmp/{artwork.id}_orig.webp"):
return StreamingResponse(open(f"/tmp/{artwork.id}_orig.webp", "rb"), media_type="image/webp")
redis = redis_conn.redis_client

# Check if file is cached in redis
if redis.exists(f"{artwork['_id']}:optimized"):
return Response(content=redis.get(f"{artwork['_id']}:optimized"), media_type="image/webp")

# return a streaming response
url = httpx.URL(f"{config.IMGPROXY_OPTIMIZED_BASE_URL}{encoded_url}.webp")
async with httpx.AsyncClient() as client:
response = await client.get(url)
# save to tempdir
with open(f"/tmp/{artwork.id}_orig.webp", "wb") as f:
f.write(response.content)
# save to redis
redis.setex(f"{artwork['_id']}:optimized", 60 * 60 * 24, response.content)

return StreamingResponse(BytesIO(response.content), media_type="image/webp")
return Response(content=response.content, media_type="image/webp")


@app.get("/search")
Expand All @@ -174,11 +175,12 @@ async def search(request: Request, query: str = None, page: int = 1, neural: boo
split_tags = query.split(" ")
results = []
if neural:
search_task = model_worker.neural_search.delay(query, page=page, limit=25)
search_task = model_worker.neural_search.apply_async(args=[query], kwargs={"page": page, "limit": 25},
priority=10)
search_task.wait()
results = search_task.get()
else:
search_task = model_worker.tag_search.delay(split_tags, page, 25, group_sets=group_sets)
search_task = model_worker.tag_search.apply_async(args=[split_tags], kwargs={"page": page, "limit": 25, "group_sets": group_sets}, priority=10)
search_task.wait()
results = search_task.get()

Expand Down Expand Up @@ -272,7 +274,6 @@ async def delete_videos():

@app.post("/api/clearcache")
async def clear_cache():
for file in os.listdir("/tmp"):
if file.endswith(".webp"):
os.remove(f"/tmp/{file}")
redis = redis_conn.redis_client
redis.flushdb()
return {"message": "Cleared cache."}
10 changes: 9 additions & 1 deletion model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,17 @@
from PIL import Image
import open_clip
from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, FieldCondition, MatchText, IsNullCondition, MatchValue, PayloadField
from kombu import Exchange, Queue

app = Celery('model_worker', broker=config.CELERY_REDIS_BROKER, backend=config.CELERY_REDIS_BROKER, result_expires=60*60*24)

app = Celery('model_worker', broker=config.CELERY_RABBITMQ_URL, backend=config.REDIS_URL, result_expires=60 * 60 * 24)
app.conf.task_default_queue = 'model_worker'
app.conf.task_queues = [
Queue('model_worker', routing_key='model_tasks.#', queue_arguments={'x-max-priority': 10}, max_priority=10),
]
app.conf.update(
worker_prefetch_multiplier=1
)

model, preprocess, tokenizer, qdrant_client = None, None, None, None
loadingMutex = Lock()
Expand Down
6 changes: 6 additions & 0 deletions redis_conn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import redis

import config

redis_client = redis.Redis.from_url(config.REDIS_URL)

17 changes: 12 additions & 5 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,17 @@
from database import CONNECTION_STRING, Artwork
from s3 import upload_file
from schema import PixivDownloadBatch
from kombu import Exchange, Queue

app = Celery('tasks', broker=config.CELERY_RABBITMQ_URL, backend=config.REDIS_URL, result_expires=60 * 60 * 24)
app.conf.task_default_queue = 'tasks'
app.conf.task_queues = [
Queue('tasks', routing_key='tasks.#', queue_arguments={'x-max-priority': 10, 'celeryd_prefetch_multiplier': 1}, max_priority=10),
]
app.conf.update(
worker_prefetch_multiplier=1
)

app = Celery('tasks', broker='redis://localhost', backend='redis://localhost', result_expires=60*60*24)
gelbooru = Gelbooru()

@app.task
Expand Down Expand Up @@ -84,8 +93,7 @@ def downloadPixivImage(pixivImage: PixivDownloadBatch):
else:
artwork.s3_object_name = upload_file(temp, sha256, os.path.splitext(pixivImage.url)[1][1:])
artwork_collection.insert_one(artwork.model_dump(by_alias=True))
sleep(1)
model_worker.generate_embeddings.delay([sha256])
model_worker.generate_embeddings.apply_async([sha256], countdown=10)


@app.task
Expand Down Expand Up @@ -129,8 +137,7 @@ def downloadGelbooru(gelbooruImage):
else:
artwork.s3_object_name = upload_file(temp, sha256, os.path.splitext(gelbooruImage["url"])[1][1:])
artwork_collection.insert_one(artwork.model_dump(by_alias=True))
sleep(1)
model_worker.generate_embeddings.delay([sha256])
model_worker.generate_embeddings.apply_async([sha256], countdown=10)


@async_task(app, bind=True)
Expand Down
8 changes: 4 additions & 4 deletions templates/gallery.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ <h1 class="text-2xl text-slate-700">Recently added sets</h1>
<div class="flex flex-row flex-wrap gap-4 mt-4 justify-center items-center">
{% for image in images %}
<div class="flex-initial p-0 m-0 last:grow-0 max-w-32 lg:max-w-64 border-2 border-black relative drop-shadow-xl">
<a href="/image/{{ image.id }}">
<a {% if image.vitEmbedding %}href="/image/{{ image.id }}"{% else %}href="#"{% endif %}>
<img src="/imgproxy/thumbnail/{{ image.id }}" class="object-scale-down" alt="{{ image.title }}">
</a>
{% if not image.vitEmbedding %}
<div class="absolute z-10 top-0 left-0 right-0 bottom-0 flex justify-center items-center">
<div class="absolute z-10 top-0 left-0 right-0 bottom-0 flex justify-center items-center p-4 m-4">
<p class="text-white bg-black bg-opacity-75 text-xl text-center">
<i class="fa-solid fa-circle-exclamation"></i><br />
No embeddings available
<i class="fa-solid fa-hourglass-start"></i><br />
Not analyzed yet
</p>
</div>
{% endif %}
Expand Down

0 comments on commit 6920ed7

Please sign in to comment.