Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(Proofs): move OCR script in the ML file #679

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.db.models import Q
from openfoodfacts.utils import get_logger

from open_prices.proofs.constants import PRICE_TAG_EXTRACTION_TYPE, TYPE_PRICE_TAG
from open_prices.proofs import constants as proof_constants
from open_prices.proofs.ml import (
PRICE_TAG_DETECTOR_MODEL_NAME,
PROOF_CLASSIFICATION_MODEL_NAME,
Expand Down Expand Up @@ -93,19 +93,21 @@ def handle_proof_prediction_job(self, types: list[str], limit: int) -> None:

for proof in tqdm.tqdm(proofs):
self.stdout.write(f"Processing proof {proof.id}...")
run_and_save_proof_prediction(proof.id)
run_and_save_proof_prediction(proof)
self.stdout.write("Done.")

def handle_price_tag_extraction_job(self, limit: int) -> None:
# Get all proofs of type PRICE_TAG
proofs = Proof.objects.filter(type=TYPE_PRICE_TAG).order_by("-id")
proofs = Proof.objects.filter(type=proof_constants.TYPE_PRICE_TAG).order_by(
"-id"
)

added = 0
for proof in tqdm.tqdm(proofs):
for price_tag in proof.price_tags.all():
# Check if the price tag already has a prediction
if not PriceTagPrediction.objects.filter(
type=PRICE_TAG_EXTRACTION_TYPE, price_tag=price_tag
type=proof_constants.PRICE_TAG_EXTRACTION_TYPE, price_tag=price_tag
).exists():
self.stdout.write(
f"Processing price tag {price_tag.id} (proof {proof.id})..."
Expand Down
2 changes: 1 addition & 1 deletion open_prices/proofs/management/commands/run_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.conf import settings
from django.core.management.base import BaseCommand

from open_prices.proofs.utils import fetch_and_save_ocr_data
from open_prices.proofs.ml import fetch_and_save_ocr_data


class Command(BaseCommand):
Expand Down
119 changes: 106 additions & 13 deletions open_prices/proofs/ml.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,41 @@
"""
Proof ML/AI
- predict Proof type with triton
- detect Proof's PriceTags with triton
- extract data from PriceTags with Gemini
"""

import base64
import enum
import gzip
import json
import logging
import time
from pathlib import Path
from typing import Any

import google.generativeai as genai
import typing_extensions as typing
from django.conf import settings
from openfoodfacts.ml.image_classification import ImageClassifier
from openfoodfacts.ml.object_detection import ObjectDetectionRawResult, ObjectDetector
from openfoodfacts.utils import http_session
from PIL import Image

from . import constants
from . import constants as proof_constants
from .models import PriceTag, PriceTagPrediction, Proof, ProofPrediction

logger = logging.getLogger(__name__)


GOOGLE_CLOUD_VISION_OCR_API_URL = "https://vision.googleapis.com/v1/images:annotate"
GOOGLE_CLOUD_VISION_OCR_FEATURES = [
"TEXT_DETECTION",
"LOGO_DETECTION",
"LABEL_DETECTION",
"SAFE_SEARCH_DETECTION",
"FACE_DETECTION",
]
PROOF_CLASSIFICATION_LABEL_NAMES = [
"OTHER",
"PRICE_TAG",
Expand Down Expand Up @@ -294,6 +314,84 @@ def detect_price_tags(
)


def run_ocr_on_image(image_path: Path | str, api_key: str) -> dict[str, Any] | None:
"""Run Google Cloud Vision OCR on the image stored at the given path.

:param image_path: the path to the image
:param api_key: the Google Cloud Vision API key
:return: the OCR data as a dict or None if an error occurred

This is similar to the run_ocr.py script in openfoodfacts-server:
https://github.com/openfoodfacts/openfoodfacts-server/blob/main/scripts/run_ocr.py
"""
with open(image_path, "rb") as f:
image_bytes = f.read()

base64_content = base64.b64encode(image_bytes).decode("utf-8")
url = f"{GOOGLE_CLOUD_VISION_OCR_API_URL}?key={api_key}"
data = {
"requests": [
{
"features": [
{"type": feature} for feature in GOOGLE_CLOUD_VISION_OCR_FEATURES
],
"image": {"content": base64_content},
}
]
}
response = http_session.post(url, json=data)

if not response.ok:
logger.debug(
"Error running OCR on image %s, HTTP %s\n%s",
image_path,
response.status_code,
response.text,
)
return response.json()


def fetch_and_save_ocr_data(image_path: Path | str, override: bool = False) -> bool:
"""Run OCR on the image stored at the given path and save the result to a
JSON file.

The JSON file will be saved in the same directory as the image, with the
same name but a `.json` extension.

:param image_path: the path to the image
:param override: whether to override existing OCR data, default to False
:return: True if the OCR data was saved, False otherwise
"""
image_path = Path(image_path)

if image_path.suffix not in (".jpg", ".jpeg", ".png", ".webp"):
logger.debug("Skipping %s, not a supported image type", image_path)
return False

if not settings.GOOGLE_CLOUD_VISION_API_KEY:
logger.error("No Google Cloud Vision API key found")
return False

ocr_json_path = image_path.with_suffix(".json.gz")

if ocr_json_path.exists() and not override:
logger.info("OCR data already exists for %s", image_path)
return False

data = run_ocr_on_image(image_path, settings.GOOGLE_CLOUD_VISION_API_KEY)

if data is None:
return False

data["created_at"] = int(time.time())

with gzip.open(ocr_json_path, "wt") as f:
f.write(json.dumps(data))

logger.debug("OCR data saved to %s", ocr_json_path)
return True


def run_and_save_price_tag_extraction_from_id(price_tag_id: int) -> None:
"""Extract information from a single price tag using the Gemini model and
save the predictions in the database.
Expand Down Expand Up @@ -341,7 +439,7 @@ def run_and_save_price_tag_extraction(
label = extract_from_price_tag(cropped_image)
prediction = PriceTagPrediction.objects.create(
price_tag=price_tag,
type=constants.PRICE_TAG_EXTRACTION_TYPE,
type=proof_constants.PRICE_TAG_EXTRACTION_TYPE,
model_name=GEMINI_MODEL_NAME,
model_version=GEMINI_MODEL_VERSION,
data=label,
Expand Down Expand Up @@ -369,7 +467,7 @@ def update_price_tag_extraction(price_tag_id: int) -> PriceTagPrediction:
return []

price_tag_prediction = PriceTagPrediction.objects.filter(
price_tag=price_tag, type=constants.PRICE_TAG_EXTRACTION_TYPE
price_tag=price_tag, type=proof_constants.PRICE_TAG_EXTRACTION_TYPE
).first()

if not price_tag_prediction:
Expand Down Expand Up @@ -473,7 +571,7 @@ def run_and_save_price_tag_detection(
PRICE_TAG_DETECTOR_MODEL_NAME,
)
if (
proof.type == constants.TYPE_PRICE_TAG
proof.type == proof_constants.TYPE_PRICE_TAG
and not PriceTag.objects.filter(proof=proof).exists()
):
logger.debug(
Expand All @@ -494,14 +592,14 @@ def run_and_save_price_tag_detection(

proof_prediction = ProofPrediction.objects.create(
proof=proof,
type=constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
type=proof_constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
model_name=PRICE_TAG_DETECTOR_MODEL_NAME,
model_version=PRICE_TAG_DETECTOR_MODEL_VERSION,
data={"objects": detections},
value=None,
max_confidence=max_confidence,
)
if proof.type == constants.TYPE_PRICE_TAG:
if proof.type == proof_constants.TYPE_PRICE_TAG:
create_price_tags_from_proof_prediction(
proof, proof_prediction, run_extraction=run_extraction
)
Expand Down Expand Up @@ -543,7 +641,7 @@ def run_and_save_proof_type_prediction(
proof_type = max(prediction, key=lambda x: x[1])[0]
return ProofPrediction.objects.create(
proof=proof,
type=constants.PROOF_PREDICTION_CLASSIFICATION_TYPE,
type=proof_constants.PROOF_PREDICTION_CLASSIFICATION_TYPE,
model_name=PROOF_CLASSIFICATION_MODEL_NAME,
model_version=PROOF_CLASSIFICATION_MODEL_VERSION,
data={
Expand All @@ -558,7 +656,7 @@ def run_and_save_proof_type_prediction(


def run_and_save_proof_prediction(
proof_id: int, run_price_tag_extraction: bool = True
proof: Proof, run_price_tag_extraction: bool = True
) -> None:
"""Run all ML models on a specific proof, and save the predictions in DB.

Expand All @@ -571,11 +669,6 @@ def run_and_save_proof_prediction(
:param run_price_tag_extraction: whether to run the price tag extraction
model on the detected price tags, defaults to True
"""
proof = Proof.objects.filter(id=proof_id).first()
if not proof:
logger.error("Proof with id %s not found", proof_id)
return

file_path_full = proof.file_path_full

if file_path_full is None or not Path(file_path_full).exists():
Expand Down
31 changes: 16 additions & 15 deletions open_prices/proofs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,26 @@ def proof_post_save_run_ocr(sender, instance, created, **kwargs):
if not settings.TESTING:
if created:
async_task(
"open_prices.proofs.utils.fetch_and_save_ocr_data",
"open_prices.proofs.ml.fetch_and_save_ocr_data",
f"{settings.IMAGES_DIR}/{instance.file_path}",
)


@receiver(signals.post_save, sender=Proof)
def proof_post_save_run_ml_models(sender, instance, created, **kwargs):
"""
After saving a proof in DB, run ML models on it.
- type prediction
- price tags extraction
"""
if not settings.TESTING and settings.ENABLE_ML_PREDICTIONS:
if created:
async_task(
"open_prices.proofs.ml.run_and_save_proof_prediction",
instance,
)


@receiver(signals.post_save, sender=Proof)
def proof_post_save_update_prices(sender, instance, created, **kwargs):
if not created:
Expand Down Expand Up @@ -404,20 +419,6 @@ def __str__(self):
return f"{self.proof} - {self.model_name} - {self.model_version}"


@receiver(signals.post_save, sender=Proof)
def proof_post_save_run_ml_models(sender, instance, created, **kwargs):
"""After saving a proof in DB, run ML models on it.

Currently, only the proof classification model is run.
"""
if not settings.TESTING and settings.ENABLE_ML_PREDICTIONS:
if created:
async_task(
"open_prices.proofs.ml.run_and_save_proof_prediction",
instance.id,
)


class PriceTagQuerySet(models.QuerySet):
def status_unknown(self):
return self.filter(status=None)
Expand Down
20 changes: 5 additions & 15 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@
PROOF_CLASSIFICATION_MODEL_VERSION,
ObjectDetectionRawResult,
create_price_tags_from_proof_prediction,
fetch_and_save_ocr_data,
run_and_save_price_tag_detection,
run_and_save_proof_prediction,
run_and_save_proof_type_prediction,
)
from open_prices.proofs.models import PriceTag, Proof
from open_prices.proofs.utils import fetch_and_save_ocr_data, select_proof_image_dir
from open_prices.proofs.utils import select_proof_image_dir

LOCATION_OSM_NODE_652825274 = {
"type": location_constants.TYPE_OSM,
Expand Down Expand Up @@ -341,7 +342,7 @@ def test_fetch_and_save_ocr_data_success(self):
with self.settings(GOOGLE_CLOUD_VISION_API_KEY="test_api_key"):
# mock call to run_ocr_on_image
with unittest.mock.patch(
"open_prices.proofs.utils.run_ocr_on_image",
"open_prices.proofs.ml.run_ocr_on_image",
return_value=response_data,
) as mock_run_ocr_on_image:
with tempfile.TemporaryDirectory() as tmpdirname:
Expand Down Expand Up @@ -377,20 +378,11 @@ def test_fetch_and_save_ocr_data_invalid_extension(self):


class MLModelTest(TestCase):
def test_run_and_save_proof_prediction_proof_does_not_exist(self):
# check that we emit an error log
with self.assertLogs("open_prices.proofs.ml", level="ERROR") as cm:
self.assertIsNone(run_and_save_proof_prediction(1))
self.assertEqual(
cm.output,
["ERROR:open_prices.proofs.ml:Proof with id 1 not found"],
)

def test_run_and_save_proof_prediction_proof_file_not_found(self):
proof = ProofFactory()
# check that we emit an error log
with self.assertLogs("open_prices.proofs.ml", level="ERROR") as cm:
self.assertIsNone(run_and_save_proof_prediction(proof.id))
self.assertIsNone(run_and_save_proof_prediction(proof))
self.assertEqual(
cm.output,
[
Expand Down Expand Up @@ -436,9 +428,7 @@ def test_run_and_save_proof_prediction_proof(self):
return_value=detect_price_tags_response,
) as mock_detect_price_tags,
):
run_and_save_proof_prediction(
proof.id, run_price_tag_extraction=False
)
run_and_save_proof_prediction(proof, run_price_tag_extraction=False)
mock_predict_proof_type.assert_called_once()
mock_detect_price_tags.assert_called_once()

Expand Down
Loading
Loading