diff --git a/.env b/.env index b3e66921..d7ce1b3c 100644 --- a/.env +++ b/.env @@ -46,10 +46,10 @@ GUNICORN_WORKERS=1 # It works because we added the special `host.docker.internal:host-gateway` # host in dev.yml for all services # Triton is the ML inference server used at Open Food Facts -TRITON_URI=host.docker.internal:5004 +TRITON_URI=host.docker.internal:5504 # By default, don't enable ML predictions, as we don't necessarily have a Triton # server running. # During local development, to enable ML predictions, set this to True and make sure -# you have Triton running on port 5004. +# you have Triton running on port 5504. ENABLE_ML_PREDICTIONS=False diff --git a/config/settings.py b/config/settings.py index a4ef7da1..80a46a86 100644 --- a/config/settings.py +++ b/config/settings.py @@ -293,5 +293,5 @@ # Triton Inference Server (ML) # ------------------------------------------------------------------------------ -TRITON_URI = os.getenv("TRITON_URI", "localhost:5004") +TRITON_URI = os.getenv("TRITON_URI", "localhost:5504") ENABLE_ML_PREDICTIONS = os.getenv("ENABLE_ML_PREDICTIONS") == "True" diff --git a/open_prices/proofs/management/commands/run_ml_model.py b/open_prices/proofs/management/commands/run_ml_model.py new file mode 100644 index 00000000..6267ee87 --- /dev/null +++ b/open_prices/proofs/management/commands/run_ml_model.py @@ -0,0 +1,53 @@ +import argparse + +import tqdm +from django.core.management.base import BaseCommand + +from open_prices.proofs.ml.image_classifier import ( + PROOF_CLASSIFICATION_MODEL_NAME, + run_and_save_proof_prediction, +) +from open_prices.proofs.models import Proof + + +class Command(BaseCommand): + help = """Run ML models on images with proof predictions, and save the predictions + in DB.""" + _allowed_types = ["proof_classification"] + + def add_arguments(self, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--limit", type=int, help="Limit the number of proofs to process." + ) + parser.add_argument("type", type=str, help="Type of model to run.", nargs="+") + + def handle(self, *args, **options) -> None: # type: ignore + self.stdout.write( + "Running ML models on images without proof predictions for this model..." + ) + limit = options["limit"] + types = options["type"] + + if not all(t in self._allowed_types for t in types): + raise ValueError( + f"Invalid type(s) provided: {types}, allowed: {self._allowed_types}" + ) + + if "proof_classification" in types: + # Get proofs that don't have a proof prediction with + # model_name = PROOF_CLASSIFICATION_MODEL_NAME by performing an + # outer join on the Proof and Prediction tables. + proofs = ( + Proof.objects.filter(predictions__model_name__isnull=True) + | Proof.objects.exclude( + predictions__model_name=PROOF_CLASSIFICATION_MODEL_NAME + ) + ).distinct() + + if limit: + proofs = proofs[:limit] + + for proof in tqdm.tqdm(proofs): + self.stdout.write(f"Processing proof {proof.id}...") + run_and_save_proof_prediction(proof.id) + self.stdout.write("Done.")