Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
raphodn committed Jan 15, 2025
1 parent 396aed8 commit 28b68d3
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 21 deletions.
2 changes: 1 addition & 1 deletion open_prices/proofs/management/commands/run_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def handle_proof_prediction_job(self, types: list[str], limit: int) -> None:

for proof in tqdm.tqdm(proofs):
self.stdout.write(f"Processing proof {proof.id}...")
run_and_save_proof_prediction(proof.id)
run_and_save_proof_prediction(proof)
self.stdout.write("Done.")

def handle_price_tag_extraction_job(self, limit: int) -> None:
Expand Down
7 changes: 1 addition & 6 deletions open_prices/proofs/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def run_and_save_proof_type_prediction(


def run_and_save_proof_prediction(
proof_id: int, run_price_tag_extraction: bool = True
proof: Proof, run_price_tag_extraction: bool = True
) -> None:
"""Run all ML models on a specific proof, and save the predictions in DB.
Expand All @@ -669,11 +669,6 @@ def run_and_save_proof_prediction(
:param run_price_tag_extraction: whether to run the price tag extraction
model on the detected price tags, defaults to True
"""
proof = Proof.objects.filter(id=proof_id).first()
if not proof:
logger.error("Proof with id %s not found", proof_id)
return

file_path_full = proof.file_path_full

if file_path_full is None or not Path(file_path_full).exists():
Expand Down
2 changes: 1 addition & 1 deletion open_prices/proofs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def proof_post_save_run_ml_models(sender, instance, created, **kwargs):
if created:
async_task(
"open_prices.proofs.ml.run_and_save_proof_prediction",
instance.id,
instance,
)


Expand Down
15 changes: 2 additions & 13 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,20 +378,11 @@ def test_fetch_and_save_ocr_data_invalid_extension(self):


class MLModelTest(TestCase):
def test_run_and_save_proof_prediction_proof_does_not_exist(self):
# check that we emit an error log
with self.assertLogs("open_prices.proofs.ml", level="ERROR") as cm:
self.assertIsNone(run_and_save_proof_prediction(1))
self.assertEqual(
cm.output,
["ERROR:open_prices.proofs.ml:Proof with id 1 not found"],
)

def test_run_and_save_proof_prediction_proof_file_not_found(self):
proof = ProofFactory()
# check that we emit an error log
with self.assertLogs("open_prices.proofs.ml", level="ERROR") as cm:
self.assertIsNone(run_and_save_proof_prediction(proof.id))
self.assertIsNone(run_and_save_proof_prediction(proof))
self.assertEqual(
cm.output,
[
Expand Down Expand Up @@ -437,9 +428,7 @@ def test_run_and_save_proof_prediction_proof(self):
return_value=detect_price_tags_response,
) as mock_detect_price_tags,
):
run_and_save_proof_prediction(
proof.id, run_price_tag_extraction=False
)
run_and_save_proof_prediction(proof, run_price_tag_extraction=False)
mock_predict_proof_type.assert_called_once()
mock_detect_price_tags.assert_called_once()

Expand Down

0 comments on commit 28b68d3

Please sign in to comment.