From 28b68d3f945f3ca141ec0482cd30cafe851e7222 Mon Sep 17 00:00:00 2001 From: Raphael Odini Date: Wed, 15 Jan 2025 11:52:25 +0100 Subject: [PATCH] Simplify --- .../proofs/management/commands/run_ml_model.py | 2 +- open_prices/proofs/ml.py | 7 +------ open_prices/proofs/models.py | 2 +- open_prices/proofs/tests.py | 15 ++------------- 4 files changed, 5 insertions(+), 21 deletions(-) diff --git a/open_prices/proofs/management/commands/run_ml_model.py b/open_prices/proofs/management/commands/run_ml_model.py index a3a82bd4..c0500bc1 100644 --- a/open_prices/proofs/management/commands/run_ml_model.py +++ b/open_prices/proofs/management/commands/run_ml_model.py @@ -93,7 +93,7 @@ def handle_proof_prediction_job(self, types: list[str], limit: int) -> None: for proof in tqdm.tqdm(proofs): self.stdout.write(f"Processing proof {proof.id}...") - run_and_save_proof_prediction(proof.id) + run_and_save_proof_prediction(proof) self.stdout.write("Done.") def handle_price_tag_extraction_job(self, limit: int) -> None: diff --git a/open_prices/proofs/ml.py b/open_prices/proofs/ml.py index bd93cf54..dff8fb88 100644 --- a/open_prices/proofs/ml.py +++ b/open_prices/proofs/ml.py @@ -656,7 +656,7 @@ def run_and_save_proof_type_prediction( def run_and_save_proof_prediction( - proof_id: int, run_price_tag_extraction: bool = True + proof: Proof, run_price_tag_extraction: bool = True ) -> None: """Run all ML models on a specific proof, and save the predictions in DB. @@ -669,11 +669,6 @@ def run_and_save_proof_prediction( :param run_price_tag_extraction: whether to run the price tag extraction model on the detected price tags, defaults to True """ - proof = Proof.objects.filter(id=proof_id).first() - if not proof: - logger.error("Proof with id %s not found", proof_id) - return - file_path_full = proof.file_path_full if file_path_full is None or not Path(file_path_full).exists(): diff --git a/open_prices/proofs/models.py b/open_prices/proofs/models.py index efaa9c33..e5b88e4f 100644 --- a/open_prices/proofs/models.py +++ b/open_prices/proofs/models.py @@ -337,7 +337,7 @@ def proof_post_save_run_ml_models(sender, instance, created, **kwargs): if created: async_task( "open_prices.proofs.ml.run_and_save_proof_prediction", - instance.id, + instance, ) diff --git a/open_prices/proofs/tests.py b/open_prices/proofs/tests.py index 3771b44c..e6a9e1cd 100644 --- a/open_prices/proofs/tests.py +++ b/open_prices/proofs/tests.py @@ -378,20 +378,11 @@ def test_fetch_and_save_ocr_data_invalid_extension(self): class MLModelTest(TestCase): - def test_run_and_save_proof_prediction_proof_does_not_exist(self): - # check that we emit an error log - with self.assertLogs("open_prices.proofs.ml", level="ERROR") as cm: - self.assertIsNone(run_and_save_proof_prediction(1)) - self.assertEqual( - cm.output, - ["ERROR:open_prices.proofs.ml:Proof with id 1 not found"], - ) - def test_run_and_save_proof_prediction_proof_file_not_found(self): proof = ProofFactory() # check that we emit an error log with self.assertLogs("open_prices.proofs.ml", level="ERROR") as cm: - self.assertIsNone(run_and_save_proof_prediction(proof.id)) + self.assertIsNone(run_and_save_proof_prediction(proof)) self.assertEqual( cm.output, [ @@ -437,9 +428,7 @@ def test_run_and_save_proof_prediction_proof(self): return_value=detect_price_tags_response, ) as mock_detect_price_tags, ): - run_and_save_proof_prediction( - proof.id, run_price_tag_extraction=False - ) + run_and_save_proof_prediction(proof, run_price_tag_extraction=False) mock_predict_proof_type.assert_called_once() mock_detect_price_tags.assert_called_once()