Skip to content

Commit

Permalink
chore: Add unit test for proof type detection (#597)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 authored Dec 5, 2024
1 parent 8d049aa commit 0386cff
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/container-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ jobs:
echo "GOOGLE_CLOUD_VISION_API_KEY=${{ secrets.GOOGLE_CLOUD_VISION_API_KEY }}" >> .env
echo "GOOGLE_GEMINI_API_KEY=${{ secrets.GOOGLE_GEMINI_API_KEY }}" >> .env
echo "TRITON_URI=${{ env.TRITON_URI }}" >> .env
ECHO "ENABLE_ML_PREDICTIONS=True" >> .env
echo "ENABLE_ML_PREDICTIONS=True" >> .env
- name: Create Docker volumes
uses: appleboy/ssh-action@master
Expand Down
63 changes: 63 additions & 0 deletions open_prices/proofs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

from django.core.exceptions import ValidationError
from django.test import TestCase
from PIL import Image

from open_prices.locations import constants as location_constants
from open_prices.locations.factories import LocationFactory
from open_prices.prices.factories import PriceFactory
from open_prices.proofs import constants as proof_constants
from open_prices.proofs.factories import ProofFactory
from open_prices.proofs.ml.image_classifier import run_and_save_proof_prediction
from open_prices.proofs.models import Proof
from open_prices.proofs.utils import fetch_and_save_ocr_data

Expand Down Expand Up @@ -349,3 +351,64 @@ def test_fetch_and_save_ocr_data_invalid_extension(self):
f.write("test")
output = fetch_and_save_ocr_data(image_path)
self.assertFalse(output)


class ImageClassifierTest(TestCase):
def test_run_and_save_proof_prediction_proof_does_not_exist(self):
self.assertIsNone(run_and_save_proof_prediction(1))

def test_run_and_save_proof_prediction_proof_file_not_found(self):
proof = ProofFactory()
self.assertIsNone(run_and_save_proof_prediction(proof.id))

def test_run_and_save_proof_prediction_proof(self):
# Create a white blank image with Pillow
image = Image.new("RGB", (100, 100), "white")
predict_proof_type_response = [
("SHELF", 0.9786477088928223),
("PRICE_TAG", 0.021345501765608788),
]

# We save the image to a temporary file
with tempfile.TemporaryDirectory() as tmpdirname:
NEW_IMAGE_DIR = Path(tmpdirname)
file_path = NEW_IMAGE_DIR / "1.jpg"
image.save(file_path)

# change temporarily settings.IMAGE_DIR
with self.settings(IMAGE_DIR=NEW_IMAGE_DIR):
proof = ProofFactory(file_path=file_path)

# Patch predict_proof_type to return a fixed response
with unittest.mock.patch(
"open_prices.proofs.ml.image_classifier.predict_proof_type",
return_value=predict_proof_type_response,
) as mock_predict_proof_type:
run_and_save_proof_prediction(proof.id)
mock_predict_proof_type.assert_called_once()
proof_prediction = proof.predictions.first()
self.assertIsNotNone(proof_prediction)
self.assertEqual(
proof_prediction.type,
proof_constants.PROOF_PREDICTION_CLASSIFICATION_TYPE,
)

self.assertEqual(
proof_prediction.model_name, "price_proof_classification"
)
self.assertEqual(
proof_prediction.model_version, "price_proof_classification-1.0"
)
self.assertEqual(proof_prediction.value, "SHELF")
self.assertEqual(proof_prediction.max_confidence, 0.9786477088928223)
self.assertEqual(
proof_prediction.data,
{
"prediction": [
{"label": "SHELF", "score": 0.9786477088928223},
{"label": "PRICE_TAG", "score": 0.021345501765608788},
]
},
)
proof_prediction.delete()
proof.delete()

0 comments on commit 0386cff

Please sign in to comment.