Skip to content

Commit

Permalink
Additional cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn committed Jan 15, 2025
1 parent 0101a4b commit 396aed8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
8 changes: 5 additions & 3 deletions open_prices/proofs/management/commands/run_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.db.models import Q
from openfoodfacts.utils import get_logger

from open_prices.proofs.constants import PRICE_TAG_EXTRACTION_TYPE, TYPE_PRICE_TAG
from open_prices.proofs import constants as proof_constants
from open_prices.proofs.ml import (
PRICE_TAG_DETECTOR_MODEL_NAME,
PROOF_CLASSIFICATION_MODEL_NAME,
Expand Down Expand Up @@ -98,14 +98,16 @@ def handle_proof_prediction_job(self, types: list[str], limit: int) -> None:

def handle_price_tag_extraction_job(self, limit: int) -> None:
# Get all proofs of type PRICE_TAG
proofs = Proof.objects.filter(type=TYPE_PRICE_TAG).order_by("-id")
proofs = Proof.objects.filter(type=proof_constants.TYPE_PRICE_TAG).order_by(
"-id"
)

added = 0
for proof in tqdm.tqdm(proofs):
for price_tag in proof.price_tags.all():
# Check if the price tag already has a prediction
if not PriceTagPrediction.objects.filter(
type=PRICE_TAG_EXTRACTION_TYPE, price_tag=price_tag
type=proof_constants.PRICE_TAG_EXTRACTION_TYPE, price_tag=price_tag
).exists():
self.stdout.write(
f"Processing price tag {price_tag.id} (proof {proof.id})..."
Expand Down
14 changes: 7 additions & 7 deletions open_prices/proofs/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from openfoodfacts.utils import http_session
from PIL import Image

from . import constants
from . import constants as proof_constants
from .models import PriceTag, PriceTagPrediction, Proof, ProofPrediction

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -439,7 +439,7 @@ def run_and_save_price_tag_extraction(
label = extract_from_price_tag(cropped_image)
prediction = PriceTagPrediction.objects.create(
price_tag=price_tag,
type=constants.PRICE_TAG_EXTRACTION_TYPE,
type=proof_constants.PRICE_TAG_EXTRACTION_TYPE,
model_name=GEMINI_MODEL_NAME,
model_version=GEMINI_MODEL_VERSION,
data=label,
Expand Down Expand Up @@ -467,7 +467,7 @@ def update_price_tag_extraction(price_tag_id: int) -> PriceTagPrediction:
return []

price_tag_prediction = PriceTagPrediction.objects.filter(
price_tag=price_tag, type=constants.PRICE_TAG_EXTRACTION_TYPE
price_tag=price_tag, type=proof_constants.PRICE_TAG_EXTRACTION_TYPE
).first()

if not price_tag_prediction:
Expand Down Expand Up @@ -571,7 +571,7 @@ def run_and_save_price_tag_detection(
PRICE_TAG_DETECTOR_MODEL_NAME,
)
if (
proof.type == constants.TYPE_PRICE_TAG
proof.type == proof_constants.TYPE_PRICE_TAG
and not PriceTag.objects.filter(proof=proof).exists()
):
logger.debug(
Expand All @@ -592,14 +592,14 @@ def run_and_save_price_tag_detection(

proof_prediction = ProofPrediction.objects.create(
proof=proof,
type=constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
type=proof_constants.PROOF_PREDICTION_OBJECT_DETECTION_TYPE,
model_name=PRICE_TAG_DETECTOR_MODEL_NAME,
model_version=PRICE_TAG_DETECTOR_MODEL_VERSION,
data={"objects": detections},
value=None,
max_confidence=max_confidence,
)
if proof.type == constants.TYPE_PRICE_TAG:
if proof.type == proof_constants.TYPE_PRICE_TAG:
create_price_tags_from_proof_prediction(
proof, proof_prediction, run_extraction=run_extraction
)
Expand Down Expand Up @@ -641,7 +641,7 @@ def run_and_save_proof_type_prediction(
proof_type = max(prediction, key=lambda x: x[1])[0]
return ProofPrediction.objects.create(
proof=proof,
type=constants.PROOF_PREDICTION_CLASSIFICATION_TYPE,
type=proof_constants.PROOF_PREDICTION_CLASSIFICATION_TYPE,
model_name=PROOF_CLASSIFICATION_MODEL_NAME,
model_version=PROOF_CLASSIFICATION_MODEL_VERSION,
data={
Expand Down

0 comments on commit 396aed8

Please sign in to comment.