Skip to content

Commit

Permalink
Move OCR script to ml file
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn committed Jan 15, 2025
1 parent b68e5fc commit 0101a4b
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 96 deletions.
2 changes: 1 addition & 1 deletion open_prices/proofs/management/commands/run_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.conf import settings
from django.core.management.base import BaseCommand

from open_prices.proofs.utils import fetch_and_save_ocr_data
from open_prices.proofs.ml import fetch_and_save_ocr_data


class Command(BaseCommand):
Expand Down
91 changes: 91 additions & 0 deletions open_prices/proofs/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@
- extract data from PriceTags with Gemini
"""

import base64
import enum
import gzip
import json
import logging
import time
from pathlib import Path
from typing import Any

import google.generativeai as genai
import typing_extensions as typing
from django.conf import settings
from openfoodfacts.ml.image_classification import ImageClassifier
from openfoodfacts.ml.object_detection import ObjectDetectionRawResult, ObjectDetector
from openfoodfacts.utils import http_session
from PIL import Image

from . import constants
Expand All @@ -23,6 +28,14 @@
logger = logging.getLogger(__name__)


GOOGLE_CLOUD_VISION_OCR_API_URL = "https://vision.googleapis.com/v1/images:annotate"
GOOGLE_CLOUD_VISION_OCR_FEATURES = [
"TEXT_DETECTION",
"LOGO_DETECTION",
"LABEL_DETECTION",
"SAFE_SEARCH_DETECTION",
"FACE_DETECTION",
]
PROOF_CLASSIFICATION_LABEL_NAMES = [
"OTHER",
"PRICE_TAG",
Expand Down Expand Up @@ -301,6 +314,84 @@ def detect_price_tags(
)


def run_ocr_on_image(image_path: Path | str, api_key: str) -> dict[str, Any] | None:
"""Run Google Cloud Vision OCR on the image stored at the given path.
:param image_path: the path to the image
:param api_key: the Google Cloud Vision API key
:return: the OCR data as a dict or None if an error occurred
This is similar to the run_ocr.py script in openfoodfacts-server:
https://github.com/openfoodfacts/openfoodfacts-server/blob/main/scripts/run_ocr.py
"""
with open(image_path, "rb") as f:
image_bytes = f.read()

base64_content = base64.b64encode(image_bytes).decode("utf-8")
url = f"{GOOGLE_CLOUD_VISION_OCR_API_URL}?key={api_key}"
data = {
"requests": [
{
"features": [
{"type": feature} for feature in GOOGLE_CLOUD_VISION_OCR_FEATURES
],
"image": {"content": base64_content},
}
]
}
response = http_session.post(url, json=data)

if not response.ok:
logger.debug(
"Error running OCR on image %s, HTTP %s\n%s",
image_path,
response.status_code,
response.text,
)
return response.json()


def fetch_and_save_ocr_data(image_path: Path | str, override: bool = False) -> bool:
"""Run OCR on the image stored at the given path and save the result to a
JSON file.
The JSON file will be saved in the same directory as the image, with the
same name but a `.json` extension.
:param image_path: the path to the image
:param override: whether to override existing OCR data, default to False
:return: True if the OCR data was saved, False otherwise
"""
image_path = Path(image_path)

if image_path.suffix not in (".jpg", ".jpeg", ".png", ".webp"):
logger.debug("Skipping %s, not a supported image type", image_path)
return False

if not settings.GOOGLE_CLOUD_VISION_API_KEY:
logger.error("No Google Cloud Vision API key found")
return False

ocr_json_path = image_path.with_suffix(".json.gz")

if ocr_json_path.exists() and not override:
logger.info("OCR data already exists for %s", image_path)
return False

data = run_ocr_on_image(image_path, settings.GOOGLE_CLOUD_VISION_API_KEY)

if data is None:
return False

data["created_at"] = int(time.time())

with gzip.open(ocr_json_path, "wt") as f:
f.write(json.dumps(data))

logger.debug("OCR data saved to %s", ocr_json_path)
return True


def run_and_save_price_tag_extraction_from_id(price_tag_id: int) -> None:
"""Extract information from a single price tag using the Gemini model and
save the predictions in the database.
Expand Down
2 changes: 1 addition & 1 deletion open_prices/proofs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def proof_post_save_run_ocr(sender, instance, created, **kwargs):
if not settings.TESTING:
if created:
async_task(
"open_prices.proofs.utils.fetch_and_save_ocr_data",
"open_prices.proofs.ml.fetch_and_save_ocr_data",
f"{settings.IMAGES_DIR}/{instance.file_path}",
)

Expand Down
5 changes: 3 additions & 2 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
PROOF_CLASSIFICATION_MODEL_VERSION,
ObjectDetectionRawResult,
create_price_tags_from_proof_prediction,
fetch_and_save_ocr_data,
run_and_save_price_tag_detection,
run_and_save_proof_prediction,
run_and_save_proof_type_prediction,
)
from open_prices.proofs.models import PriceTag, Proof
from open_prices.proofs.utils import fetch_and_save_ocr_data, select_proof_image_dir
from open_prices.proofs.utils import select_proof_image_dir

LOCATION_OSM_NODE_652825274 = {
"type": location_constants.TYPE_OSM,
Expand Down Expand Up @@ -341,7 +342,7 @@ def test_fetch_and_save_ocr_data_success(self):
with self.settings(GOOGLE_CLOUD_VISION_API_KEY="test_api_key"):
# mock call to run_ocr_on_image
with unittest.mock.patch(
"open_prices.proofs.utils.run_ocr_on_image",
"open_prices.proofs.ml.run_ocr_on_image",
return_value=response_data,
) as mock_run_ocr_on_image:
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down
92 changes: 0 additions & 92 deletions open_prices/proofs/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
import base64
import gzip
import json
import logging
import random
import string
import time
from decimal import Decimal
from mimetypes import guess_extension
from pathlib import Path
from typing import Any

from django.conf import settings
from django.core.files.uploadedfile import InMemoryUploadedFile, TemporaryUploadedFile
from openfoodfacts.utils import http_session
from PIL import Image, ImageOps

from open_prices.prices.constants import TYPE_CATEGORY, TYPE_PRODUCT
Expand Down Expand Up @@ -154,92 +148,6 @@ def select_proof_image_dir(images_dir: Path, max_images_per_dir: int = 1_000) ->
return current_dir


def run_ocr_on_image(image_path: Path | str, api_key: str) -> dict[str, Any] | None:
"""Run Google Cloud Vision OCR on the image stored at the given path.
:param image_path: the path to the image
:param api_key: the Google Cloud Vision API key
:return: the OCR data as a dict or None if an error occurred
This is similar to the run_ocr.py script in openfoodfacts-server:
https://github.com/openfoodfacts/openfoodfacts-server/blob/main/scripts/run_ocr.py
"""
with open(image_path, "rb") as f:
image_bytes = f.read()

base64_content = base64.b64encode(image_bytes).decode("utf-8")
url = f"https://vision.googleapis.com/v1/images:annotate?key={api_key}"
r = http_session.post(
url,
json={
"requests": [
{
"features": [
{"type": "TEXT_DETECTION"},
{"type": "LOGO_DETECTION"},
{"type": "LABEL_DETECTION"},
{"type": "SAFE_SEARCH_DETECTION"},
{"type": "FACE_DETECTION"},
],
"image": {"content": base64_content},
}
]
},
)

if not r.ok:
logger.debug(
"Error running OCR on image %s, HTTP %s\n%s",
image_path,
r.status_code,
r.text,
)
return r.json()


def fetch_and_save_ocr_data(image_path: Path | str, override: bool = False) -> bool:
"""Run OCR on the image stored at the given path and save the result to a
JSON file.
The JSON file will be saved in the same directory as the image, with the
same name but a `.json` extension.
:param image_path: the path to the image
:param override: whether to override existing OCR data, default to False
:return: True if the OCR data was saved, False otherwise
"""
image_path = Path(image_path)

if image_path.suffix not in (".jpg", ".jpeg", ".png", ".webp"):
logger.debug("Skipping %s, not a supported image type", image_path)
return False

api_key = settings.GOOGLE_CLOUD_VISION_API_KEY

if not api_key:
logger.error("No Google Cloud Vision API key found")
return False

ocr_json_path = image_path.with_suffix(".json.gz")

if ocr_json_path.exists() and not override:
logger.info("OCR data already exists for %s", image_path)
return False

data = run_ocr_on_image(image_path, api_key)

if data is None:
return False

data["created_at"] = int(time.time())

with gzip.open(ocr_json_path, "wt") as f:
f.write(json.dumps(data))

logger.debug("OCR data saved to %s", ocr_json_path)
return True


def match_decimal_with_float(price_decimal: Decimal, price_float: float) -> bool:
return float(price_decimal) == price_float

Expand Down

0 comments on commit 0101a4b

Please sign in to comment.