Skip to content

Commit

Permalink
feat: run OCR on every new image
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Oct 30, 2024
1 parent 2f82256 commit e17c351
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/container-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ jobs:
echo "POSTGRES_USER=postgres" >> .env
echo "POSTGRES_PASSWORD=${{ secrets.POSTGRES_PASSWORD }}" >> .env
echo "ENVIRONMENT=${{ env.ENVIRONMENT }}" >> .env
echo "GOOGLE_CLOUD_VISION_API_KEY=${{ secrets.GOOGLE_CLOUD_VISION_API_KEY }}" >> .env
- name: Create Docker volumes
uses: appleboy/ssh-action@master
Expand Down
6 changes: 6 additions & 0 deletions config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,9 @@
OAUTH2_SERVER_URL = os.getenv("OAUTH2_SERVER_URL")
SESSION_COOKIE_NAME = "opsession"
OFF_USER_AGENT = "open-prices/0.1.0"


# Google Cloud Vision API
# ------------------------------------------------------------------------------

GOOGLE_CLOUD_VISION_API_KEY = os.getenv("GOOGLE_CLOUD_VISION_API_KEY")
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ x-api-common: &api-common
- OAUTH2_SERVER_URL
- SENTRY_DSN
- LOG_LEVEL
- GOOGLE_CLOUD_VISION_API_KEY
networks:
- default

Expand Down
6 changes: 6 additions & 0 deletions open_prices/api/proofs/views.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from django.conf import settings
from django_filters.rest_framework import DjangoFilterBackend
from django_q.tasks import async_task
from drf_spectacular.utils import extend_schema
from rest_framework import filters, mixins, status, viewsets
from rest_framework.decorators import action
Expand Down Expand Up @@ -75,6 +77,10 @@ def upload(self, request: Request) -> Response:
status=status.HTTP_400_BAD_REQUEST,
)
file_path, mimetype, image_thumb_path = store_file(request.data.get("file"))
async_task(
"open_prices.proofs.utils.run_ocr_task",
f"{settings.IMAGES_DIR}/{file_path}",
)
proof_create_data = {
"file_path": file_path,
"mimetype": mimetype,
Expand Down
37 changes: 37 additions & 0 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import gzip
import json
import tempfile
import unittest
from decimal import Decimal
from pathlib import Path

from django.core.exceptions import ValidationError
from django.test import TestCase
Expand All @@ -9,6 +14,7 @@
from open_prices.proofs import constants as proof_constants
from open_prices.proofs.factories import ProofFactory
from open_prices.proofs.models import Proof
from open_prices.proofs.utils import run_ocr_task

LOCATION_OSM_NODE_652825274 = {
"type": location_constants.TYPE_OSM,
Expand Down Expand Up @@ -302,3 +308,34 @@ def test_proof_update(self):
self.assertEqual(
self.proof_price_tag.prices.first().location, self.location_osm_2
)


class RunOCRTaskTest(TestCase):
def test_run_ocr_task_success(self):
response_data = {"responses": [{"textAnnotations": [{"description": "test"}]}]}
with self.settings(GOOGLE_CLOUD_VISION_API_KEY="test_api_key"):
# mock call to run_ocr_on_image
with unittest.mock.patch(
"open_prices.proofs.utils.run_ocr_on_image",
return_value=response_data,
) as mock_run_ocr_on_image:
with tempfile.TemporaryDirectory() as tmpdirname:
image_path = Path(f"{tmpdirname}/test.jpg")
with image_path.open("w") as f:
f.write("test")
run_ocr_task(image_path)
mock_run_ocr_on_image.assert_called_once_with(
image_path, "test_api_key"
)
ocr_path = image_path.with_suffix(".json.gz")
self.assertTrue(ocr_path.is_file())

with gzip.open(ocr_path, "rt") as f:
actual_data = json.loads(f.read())
self.assertEqual(
set(actual_data.keys()), {"responses", "created_at"}
)
self.assertIsInstance(actual_data["created_at"], int)
self.assertEqual(
actual_data["responses"], response_data["responses"]
)
79 changes: 79 additions & 0 deletions open_prices/proofs/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import base64
import gzip
import json
import logging
import random
import string
import time
from mimetypes import guess_extension
from pathlib import Path
from typing import Any

from django.conf import settings
from django.core.files.uploadedfile import InMemoryUploadedFile, TemporaryUploadedFile
from openfoodfacts.utils import http_session
from PIL import Image, ImageOps

logger = logging.getLogger(__name__)


def get_file_extension_and_mimetype(
file: InMemoryUploadedFile | TemporaryUploadedFile,
Expand Down Expand Up @@ -124,3 +133,73 @@ def store_file(
# Build file_path
file_path = generate_relative_path(current_dir_id_str, file_stem, extension)
return (file_path, mimetype, image_thumb_path)


def run_ocr_on_image(image_path: Path | str, api_key: str) -> dict[str, Any] | None:
"""Run Google Cloud Vision OCR on the image stored at the given path.
:param image_path: the path to the image
:param api_key: the Google Cloud Vision API key
:return: the OCR data as a dict or None if an error occurred
"""
with open(image_path, "rb") as f:
image_bytes = f.read()

base64_content = base64.b64encode(image_bytes).decode("utf-8")
url = f"https://vision.googleapis.com/v1/images:annotate?key={api_key}"
r = http_session.post(
url,
json={
"requests": [
{
"features": [{"type": "TEXT_DETECTION"}],
"image": {"content": base64_content},
}
]
},
)

if not r.ok:
logger.debug(
"Error running OCR on image %s, HTTP %s\n%s",
image_path,
r.status_code,
r.text,
)
return r.json()


def run_ocr_task(image_path: Path | str, override: bool = False) -> None:
"""Run OCR on the image stored at the given path and save the result to a
JSON file.
The JSON file will be saved in the same directory as the image, with the
same name but a `.json` extension.
:param image_path: the path to the image
:param override: whether to override existing OCR data, default to False
"""
image_path = Path(image_path)
api_key = settings.GOOGLE_CLOUD_VISION_API_KEY

if api_key is None:
logger.error("No Google Cloud Vision API key found")
return

ocr_json_path = image_path.with_suffix(".json.gz")

if ocr_json_path.exists() and not override:
logger.info("OCR data already exists for %s", image_path)
return

data = run_ocr_on_image(image_path, api_key)

if data is None:
return

data["created_at"] = int(time.time())

with gzip.open(ocr_json_path, "wt") as f:
f.write(json.dumps(data))

logger.debug("OCR data saved to %s", ocr_json_path)

0 comments on commit e17c351

Please sign in to comment.