From 5574431378fefc96c7edd8ee3dbe9c7bc21100f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Mon, 16 Dec 2024 18:17:20 +0100 Subject: [PATCH] feat: add `price_tag` table and routes Linked issue: #611 --- open_prices/api/prices/serializers.py | 2 +- open_prices/api/prices/tests.py | 2 +- open_prices/api/proofs/filters.py | 9 +- open_prices/api/proofs/serializers.py | 43 ++- open_prices/api/proofs/tests.py | 340 +++++++++++++++++- open_prices/api/proofs/views.py | 79 +++- open_prices/api/urls.py | 3 +- open_prices/common/constants.py | 12 + open_prices/common/tasks.py | 2 +- open_prices/prices/models.py | 4 +- open_prices/prices/tests.py | 2 +- open_prices/proofs/admin.py | 2 +- open_prices/proofs/factories.py | 21 +- .../management/commands/run_ml_model.py | 2 +- .../proofs/migrations/0007_pricetag.py | 117 ++++++ open_prices/proofs/models/__init__.py | 0 open_prices/proofs/models/price_tag.py | 150 ++++++++ .../proofs/{models.py => models/proof.py} | 3 +- open_prices/proofs/tests.py | 55 ++- open_prices/stats/models.py | 2 +- 20 files changed, 830 insertions(+), 20 deletions(-) create mode 100644 open_prices/proofs/migrations/0007_pricetag.py create mode 100644 open_prices/proofs/models/__init__.py create mode 100644 open_prices/proofs/models/price_tag.py rename open_prices/proofs/{models.py => models/proof.py} (99%) diff --git a/open_prices/api/prices/serializers.py b/open_prices/api/prices/serializers.py index ba8dea2f..aeedac4b 100644 --- a/open_prices/api/prices/serializers.py +++ b/open_prices/api/prices/serializers.py @@ -6,7 +6,7 @@ from open_prices.locations.models import Location from open_prices.prices.models import Price from open_prices.products.models import Product -from open_prices.proofs.models import Proof +from open_prices.proofs.models.proof import Proof class PriceSerializer(serializers.ModelSerializer): diff --git a/open_prices/api/prices/tests.py b/open_prices/api/prices/tests.py index 18320260..45250e55 100644 --- a/open_prices/api/prices/tests.py +++ b/open_prices/api/prices/tests.py @@ -13,7 +13,7 @@ from open_prices.products.models import Product from open_prices.proofs import constants as proof_constants from open_prices.proofs.factories import ProofFactory -from open_prices.proofs.models import Proof +from open_prices.proofs.models.proof import Proof from open_prices.users.factories import SessionFactory PRICE_8001505005707 = { diff --git a/open_prices/api/proofs/filters.py b/open_prices/api/proofs/filters.py index ff136247..e9217fed 100644 --- a/open_prices/api/proofs/filters.py +++ b/open_prices/api/proofs/filters.py @@ -1,6 +1,6 @@ import django_filters -from open_prices.proofs.models import Proof +from open_prices.proofs.models.proof import Proof class ProofFilter(django_filters.FilterSet): @@ -38,3 +38,10 @@ class Meta: "owner", "price_count", ] + + +class PriceTagFilter(django_filters.FilterSet): + status__isnull = django_filters.BooleanFilter( + field_name="status", lookup_expr="isnull" + ) + status = django_filters.CharFilter(field_name="status", lookup_expr="exact") diff --git a/open_prices/api/proofs/serializers.py b/open_prices/api/proofs/serializers.py index 197bae15..7398f976 100644 --- a/open_prices/api/proofs/serializers.py +++ b/open_prices/api/proofs/serializers.py @@ -2,7 +2,9 @@ from open_prices.api.locations.serializers import LocationSerializer from open_prices.locations.models import Location -from open_prices.proofs.models import Proof, ProofPrediction +from open_prices.prices.models import Price +from open_prices.proofs.models.price_tag import PriceTag +from open_prices.proofs.models.proof import Proof, ProofPrediction class ProofPredictionSerializer(serializers.ModelSerializer): @@ -81,3 +83,42 @@ class ProofProcessWithGeminiSerializer(serializers.Serializer): mode = ( serializers.CharField() ) # TODO: this mode param should be used to select the prompt to execute, unimplemented for now + + +class PriceTagSerializer(serializers.ModelSerializer): + price_id = serializers.PrimaryKeyRelatedField(read_only=True) + + class Meta: + model = PriceTag + exclude = ["price", "proof"] + + +class PriceTagFullSerializer(PriceTagSerializer): + proof = ProofSerializer() + + class Meta: + model = PriceTag + exclude = ["price"] + + +class PriceTagCreateSerializer(serializers.ModelSerializer): + proof_id = serializers.PrimaryKeyRelatedField( + queryset=Proof.objects.all(), source="proof" + ) + price_id = serializers.PrimaryKeyRelatedField( + queryset=Price.objects.all(), source="price", required=False + ) + + class Meta: + model = PriceTag + fields = PriceTag.CREATE_FIELDS + + +class PriceTagUpdateSerializer(serializers.ModelSerializer): + price_id = serializers.PrimaryKeyRelatedField( + queryset=Price.objects.all(), source="price" + ) + + class Meta: + model = PriceTag + fields = PriceTag.UPDATE_FIELDS diff --git a/open_prices/api/proofs/tests.py b/open_prices/api/proofs/tests.py index 5a30a6a6..d0b0657d 100644 --- a/open_prices/api/proofs/tests.py +++ b/open_prices/api/proofs/tests.py @@ -6,13 +6,19 @@ from django.urls import reverse from PIL import Image +from open_prices.common.constants import PriceTagStatus from open_prices.locations import constants as location_constants from open_prices.locations.factories import LocationFactory from open_prices.prices.factories import PriceFactory from open_prices.prices.models import Price from open_prices.proofs import constants as proof_constants -from open_prices.proofs.factories import ProofFactory, ProofPredictionFactory -from open_prices.proofs.models import Proof +from open_prices.proofs.factories import ( + PriceTagFactory, + ProofFactory, + ProofPredictionFactory, +) +from open_prices.proofs.models.price_tag import PriceTag +from open_prices.proofs.models.proof import Proof from open_prices.users.factories import SessionFactory LOCATION_OSM_NODE_652825274 = { @@ -389,3 +395,333 @@ def test_proof_delete(self): self.assertEqual( Proof.objects.filter(owner=self.user_session_1.user.user_id).count(), 0 ) + + +class PriceTagListApiTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.url = reverse("api:price-tags-list") + cls.proof = ProofFactory(type=proof_constants.TYPE_PRICE_TAG) + cls.proof_2 = ProofFactory(type=proof_constants.TYPE_PRICE_TAG) + cls.price = PriceFactory(proof=cls.proof) + cls.price_tag_1 = PriceTagFactory( + proof=cls.proof, + price=cls.price, + status=PriceTagStatus.linked_to_price.value, + ) + cls.price_tag_2 = PriceTagFactory(proof=cls.proof) + cls.price_tag_3 = PriceTagFactory( + proof=cls.proof_2, status=PriceTagStatus.deleted + ) + + def test_price_tag_list(self): + # Check that we can access price tags anonymously + # We only have 2 queries: + # - 1 to count the number of price tags + # - 1 to get the price tags and their associated proof + with self.assertNumQueries(2): + response = self.client.get(self.url) + self.assertEqual(response.status_code, 200) + data = response.data + + self.assertEqual(data["total"], 3) + self.assertEqual(len(data["items"]), 3) + item = data["items"][0] + self.assertEqual(item["id"], self.price_tag_1.id) # default order: created + self.assertNotIn("price", item) # not returned in "list" + self.assertEqual(item["price_id"], self.price.id) + item_2 = data["items"][1] + self.assertEqual(item_2["id"], self.price_tag_2.id) + self.assertIsNone(item_2["price_id"]) + + def test_price_tag_list_filter_with_status(self): + url = self.url + "?status=0" # deleted + response = self.client.get(url) + self.assertEqual(response.data["total"], 1) + self.assertEqual(len(response.data["items"]), 1) + self.assertEqual(response.data["items"][0]["id"], self.price_tag_3.id) + + def test_price_tag_list_filter_with_status_is_null(self): + url = self.url + "?status__isnull=True" + response = self.client.get(url) + print(response.data) + self.assertEqual(response.data["total"], 1) + self.assertEqual(len(response.data["items"]), 1) + # Price tag 1 is linked to price, price tag 3 is deleted, so only + # price tag 2 is returned + self.assertEqual(response.data["items"][0]["id"], self.price_tag_2.id) + + +class PriceTagDetailApiTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.proof = ProofFactory(type=proof_constants.TYPE_PRICE_TAG) + cls.price = PriceFactory(proof=cls.proof) + cls.price_tag_1 = PriceTagFactory(proof=cls.proof) + cls.price_tag_2 = PriceTagFactory(proof=cls.proof, price=cls.price) + cls.url = reverse("api:price-tags-detail", args=[cls.price_tag_2.id]) + + def test_price_tag_detail(self): + # Check that we can retrieve a single price tags anonymously + # We only have 1 query to get the price tags and their associated proof + with self.assertNumQueries(1): + response = self.client.get(self.url) + self.assertEqual(response.status_code, 200) + data = response.data + self.assertEqual(data["id"], self.price_tag_2.id) + self.assertIn("proof", data.keys()) + proof = data["proof"] + self.assertEquals(proof["id"], self.proof.id) + self.assertEqual(proof["type"], proof_constants.TYPE_PRICE_TAG) + self.assertNotIn("price", data) # not returned in "detail" + self.assertEqual(data["price_id"], self.price.id) + + +class PriceTagCreateApiTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.url = reverse("api:price-tags-list") + cls.user_session = SessionFactory() + cls.proof = ProofFactory(type=proof_constants.TYPE_PRICE_TAG) + cls.price = PriceFactory(proof=cls.proof) + cls.default_bounding_box = [0.1, 0.2, 0.3, 0.4] + + def test_price_tag_create_unauthenticated(self): + response = self.client.post( + self.url, + data={"bounding_box": self.default_bounding_box, "proof_id": self.proof.id}, + ) + self.assertEqual(response.status_code, 403) + self.assertEqual( + response.data, {"detail": "Authentication credentials were not provided."} + ) + + def test_price_tag_create_missing_proof_id(self): + response = self.client.post( + self.url, + data={"bounding_box": self.default_bounding_box}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 400) + self.assertDictEqual(response.data, {"proof_id": ["This field is required."]}) + + def test_price_tag_create_proof_not_found(self): + response = self.client.post( + self.url, + data={"bounding_box": self.default_bounding_box, "proof_id": 999}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 400) + self.assertDictEqual( + response.data, {"proof_id": ['Invalid pk "999" - object does not exist.']} + ) + + def test_price_tag_create_price_not_found(self): + response = self.client.post( + self.url, + data={ + "bounding_box": self.default_bounding_box, + "proof_id": self.proof.id, + "price_id": 998, + }, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 400) + self.assertDictEqual( + response.data, {"price_id": ['Invalid pk "998" - object does not exist.']} + ) + + def test_price_tag_create(self): + response = self.client.post( + self.url, + data={"bounding_box": self.default_bounding_box, "proof_id": self.proof.id}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 201) + self.assertEqual(response.data["created_by"], self.user_session.user.user_id) + self.assertEqual(response.data["updated_by"], self.user_session.user.user_id) + self.assertEqual(response.data["status"], None) + self.assertEqual(response.data["bounding_box"], self.default_bounding_box) + self.assertEqual(response.data["price_id"], None) + + def test_price_tag_create_with_price(self): + response = self.client.post( + self.url, + data={ + "bounding_box": self.default_bounding_box, + "proof_id": self.proof.id, + "price_id": self.price.id, + }, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 201) + self.assertEqual(response.data["created_by"], self.user_session.user.user_id) + self.assertEqual(response.data["updated_by"], self.user_session.user.user_id) + self.assertEqual(response.data["price_id"], self.price.id) + self.assertEqual(response.data["status"], PriceTagStatus.linked_to_price.value) + + +class PriceTagUpdateApiTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.user_session = SessionFactory() + cls.proof = ProofFactory(type=proof_constants.TYPE_PRICE_TAG) + cls.proof_2 = ProofFactory(type=proof_constants.TYPE_PRICE_TAG) + cls.price = PriceFactory(proof=cls.proof) + cls.price_2 = PriceFactory(proof=cls.proof_2) + cls.price_tag = PriceTagFactory( + proof=cls.proof, model_version="object-detector" + ) + cls.url = reverse("api:price-tags-detail", args=[cls.price_tag.id]) + cls.new_bounding_box = [0.2, 0.3, 0.4, 0.5] + + def test_price_tag_create_unauthenticated(self): + response = self.client.patch( + self.url, data={"bounding_box": self.new_bounding_box} + ) + self.assertEqual(response.status_code, 403) + self.assertEqual( + response.data, {"detail": "Authentication credentials were not provided."} + ) + + def test_price_tag_create_update_read_only_fields(self): + self.assertEqual(self.price_tag.model_version, "object-detector") + self.assertNotEqual(self.price_tag.bounding_box, self.new_bounding_box) + response = self.client.patch( + self.url, + content_type="application/json", + data={ + "bounding_box": self.new_bounding_box, + "proof_id": self.proof_2.id, + "model_version": "test", + }, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 200) + # Proof ID didn't change + self.assertEqual(response.data["proof"]["id"], self.proof.id) + # Model version didn't change + self.assertEqual(response.data["model_version"], "object-detector") + # New bounding box was set + self.assertEqual(response.data["bounding_box"], self.new_bounding_box) + + def test_price_tag_set_price_id(self): + self.assertEqual(self.price_tag.price_id, None) + self.assertEqual(self.price_tag.status, None) + response = self.client.patch( + self.url, + content_type="application/json", + data={"price_id": self.price.id}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 200) + # Price ID was set to the new value + self.assertEqual(response.data["price_id"], self.price.id) + # Status was automatically set to linked_to_price + self.assertEqual(response.data["status"], PriceTagStatus.linked_to_price.value) + + def test_price_tag_set_invalid_price_id(self): + self.assertEqual(self.price_tag.price_id, None) + response = self.client.patch( + self.url, + content_type="application/json", + # Price associated with another proof + data={"price_id": self.price_2.id}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.data, {"price": ["Price should belong to the same proof."]} + ) + + def test_price_tag_set_status(self): + self.assertEqual(self.price_tag.status, None) + response = self.client.patch( + self.url, + content_type="application/json", + # Price associated with another proof + data={"status": PriceTagStatus.not_readable.value}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["status"], PriceTagStatus.not_readable.value) + + def test_price_tag_invalid_status(self): + self.assertEqual(self.price_tag.status, None) + response = self.client.patch( + self.url, + content_type="application/json", + # Invalid status value + data={"status": 999}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 400) + self.assertEqual(response.data, {"status": ['"999" is not a valid choice.']}) + + def test_price_tag_set_new_bounding_box(self): + self.assertNotEqual(self.price_tag.bounding_box, self.new_bounding_box) + response = self.client.patch( + self.url, + content_type="application/json", + data={"bounding_box": self.new_bounding_box}, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["bounding_box"], self.new_bounding_box) + + def test_price_tag_invalid_bounding_box(self): + response = self.client.patch( + self.url, + content_type="application/json", + data={"bounding_box": [0.1, 0.2, 0.3]}, # only 3 values + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 400) + self.assertEqual( + response.data, {"bounding_box": ["Bounding box should have 4 values."]} + ) + + +class PriceTagDeleteApiTest(TestCase): + @classmethod + def setUpTestData(cls): + cls.user_session = SessionFactory() + cls.proof = ProofFactory(type=proof_constants.TYPE_PRICE_TAG) + cls.price = PriceFactory(proof=cls.proof) + cls.price_tag = PriceTagFactory(proof=cls.proof) + cls.price_tag_with_associated_price = PriceTagFactory( + proof=cls.proof, price=cls.price + ) + cls.url = reverse("api:price-tags-detail", args=[cls.price_tag.id]) + cls.url_with_associated_price = reverse( + "api:price-tags-detail", args=[cls.price_tag_with_associated_price.id] + ) + + def test_price_tag_delete_unauthenticated(self): + response = self.client.delete(self.url) + self.assertEqual(response.status_code, 403) + self.assertEqual( + response.data, {"detail": "Authentication credentials were not provided."} + ) + + def test_price_tag_delete_with_associated_price(self): + response = self.client.delete( + self.url_with_associated_price, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 403) + self.assertEqual( + response.data, {"detail": "Cannot delete price tag with associated prices."} + ) + + def test_price_tag_delete(self): + response = self.client.delete( + self.url, + headers={"Authorization": f"Bearer {self.user_session.token}"}, + ) + self.assertEqual(response.status_code, 204) + self.assertEqual(response.data, None) + self.assertEqual(PriceTag.objects.filter(id=self.price_tag.id).count(), 1) + self.assertEqual( + PriceTag.objects.get(id=self.price_tag.id).status, PriceTagStatus.deleted + ) diff --git a/open_prices/api/proofs/views.py b/open_prices/api/proofs/views.py index 6d4da5e8..c4df9458 100644 --- a/open_prices/api/proofs/views.py +++ b/open_prices/api/proofs/views.py @@ -9,8 +9,11 @@ from rest_framework.request import Request from rest_framework.response import Response -from open_prices.api.proofs.filters import ProofFilter +from open_prices.api.proofs.filters import PriceTagFilter, ProofFilter from open_prices.api.proofs.serializers import ( + PriceTagCreateSerializer, + PriceTagFullSerializer, + PriceTagUpdateSerializer, ProofCreateSerializer, ProofFullSerializer, ProofHalfFullSerializer, @@ -20,8 +23,10 @@ ) from open_prices.api.utils import get_source_from_request from open_prices.common.authentication import CustomAuthentication +from open_prices.common.constants import PriceTagStatus from open_prices.common.gemini import handle_bulk_labels -from open_prices.proofs.models import Proof +from open_prices.proofs.models.price_tag import PriceTag +from open_prices.proofs.models.proof import Proof from open_prices.proofs.utils import store_file @@ -123,3 +128,73 @@ def process_with_gemini(self, request: Request) -> Response: sample_files = [PIL.Image.open(file.file) for file in files] res = handle_bulk_labels(sample_files) return Response(res, status=status.HTTP_200_OK) + + +class PriceTagViewSet( + mixins.ListModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + viewsets.GenericViewSet, +): + authentication_classes = [CustomAuthentication] + permission_classes = [IsAuthenticatedOrReadOnly] + http_method_names = ["get", "post", "patch", "delete"] # disable "put" + queryset = PriceTag.objects.select_related("proof").all() + serializer_class = PriceTagFullSerializer + filter_backends = [DjangoFilterBackend, filters.OrderingFilter] + filterset_class = PriceTagFilter + ordering_fields = ["created"] + ordering = ["created"] + + def get_queryset(self): + if self.action in ("create", "update"): + # We need to prefetch the price object if it exists to validate the + # price_id field, and the proof object to validate the proof_id + # field + return ( + PriceTag.objects.select_related("proof").select_related("price").all() + ) + return super().get_queryset() + + def get_serializer_class(self): + if self.request.method == "POST": + return PriceTagCreateSerializer + elif self.request.method == "PATCH": + return PriceTagUpdateSerializer + return self.serializer_class + + def destroy(self, request: Request, *args, **kwargs) -> Response: + price_tag = self.get_object() + if price_tag.price_id is not None: + return Response( + {"detail": "Cannot delete price tag with associated prices."}, + status=status.HTTP_403_FORBIDDEN, + ) + price_tag.status = PriceTagStatus.deleted + price_tag.save() + return Response(status=status.HTTP_204_NO_CONTENT) + + def create(self, request: Request, *args, **kwargs): + # validate + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + # save + + user_id = self.request.user.user_id + price = serializer.save(updated_by=user_id, created_by=user_id) + # return full price + return Response( + self.serializer_class(price).data, status=status.HTTP_201_CREATED + ) + + def update(self, request: Request, *args, **kwargs): + # validate + serializer = self.get_serializer( + self.get_object(), data=request.data, partial=True + ) + serializer.is_valid(raise_exception=True) + # save + price = serializer.save(updated_by=self.request.user.user_id) + # return full price + return Response(self.serializer_class(price).data) diff --git a/open_prices/api/urls.py b/open_prices/api/urls.py index b9224a0f..511f3967 100644 --- a/open_prices/api/urls.py +++ b/open_prices/api/urls.py @@ -10,7 +10,7 @@ from open_prices.api.locations.views import LocationViewSet from open_prices.api.prices.views import PriceViewSet from open_prices.api.products.views import ProductViewSet -from open_prices.api.proofs.views import ProofViewSet +from open_prices.api.proofs.views import PriceTagViewSet, ProofViewSet from open_prices.api.stats.views import StatsView from open_prices.api.users.views import UserViewSet from open_prices.api.views import StatusView @@ -23,6 +23,7 @@ router.register(r"v1/products", ProductViewSet, basename="products") router.register(r"v1/proofs", ProofViewSet, basename="proofs") router.register(r"v1/prices", PriceViewSet, basename="prices") +router.register(r"v1/price_tags", PriceTagViewSet, basename="price-tags") urlpatterns = [ # auth urls diff --git a/open_prices/common/constants.py b/open_prices/common/constants.py index 87326065..29c3f1c9 100644 --- a/open_prices/common/constants.py +++ b/open_prices/common/constants.py @@ -1,4 +1,16 @@ +import enum + from babel.numbers import list_currencies CURRENCY_LIST = sorted(list_currencies()) CURRENCY_CHOICES = [(key, key) for key in CURRENCY_LIST] + + +class PriceTagStatus(enum.IntEnum): + deleted = 0 + linked_to_price = 1 + not_readable = 2 + not_price_tag = 3 + + +PRICE_TAG_STATUS_CHOICES = [(item.value, item.name) for item in PriceTagStatus] diff --git a/open_prices/common/tasks.py b/open_prices/common/tasks.py index 907705ad..bdeefd2e 100644 --- a/open_prices/common/tasks.py +++ b/open_prices/common/tasks.py @@ -14,7 +14,7 @@ from open_prices.locations.models import Location from open_prices.prices.models import Price from open_prices.products.models import Product -from open_prices.proofs.models import Proof +from open_prices.proofs.models.proof import Proof from open_prices.stats.models import TotalStats from open_prices.users.models import User diff --git a/open_prices/prices/models.py b/open_prices/prices/models.py index be7ba285..cd20d479 100644 --- a/open_prices/prices/models.py +++ b/open_prices/prices/models.py @@ -19,7 +19,7 @@ from open_prices.prices import constants as price_constants from open_prices.products.models import Product from open_prices.proofs import constants as proof_constants -from open_prices.proofs.models import Proof +from open_prices.proofs.models.proof import Proof from open_prices.users.models import User # Taxonomy mapping generation takes ~200ms, so we cache it to avoid @@ -451,7 +451,7 @@ def clean(self, *args, **kwargs): # - receipt_quantity can only be set for receipts (default to 1) if self.proof_id: proof = None - from open_prices.proofs.models import Proof + from open_prices.proofs.models.proof import Proof try: proof = Proof.objects.get(id=self.proof_id) diff --git a/open_prices/prices/tests.py b/open_prices/prices/tests.py index b4d08d81..aad96169 100644 --- a/open_prices/prices/tests.py +++ b/open_prices/prices/tests.py @@ -13,7 +13,7 @@ from open_prices.products.models import Product from open_prices.proofs import constants as proof_constants from open_prices.proofs.factories import ProofFactory -from open_prices.proofs.models import Proof +from open_prices.proofs.models.proof import Proof from open_prices.users.factories import SessionFactory from open_prices.users.models import User diff --git a/open_prices/proofs/admin.py b/open_prices/proofs/admin.py index 10e7781d..c92f3a18 100644 --- a/open_prices/proofs/admin.py +++ b/open_prices/proofs/admin.py @@ -2,7 +2,7 @@ from django.urls import reverse from django.utils.html import format_html -from open_prices.proofs.models import Proof +from open_prices.proofs.models.proof import Proof @admin.register(Proof) diff --git a/open_prices/proofs/factories.py b/open_prices/proofs/factories.py index 10f71a20..2f9e4b04 100644 --- a/open_prices/proofs/factories.py +++ b/open_prices/proofs/factories.py @@ -5,7 +5,8 @@ from factory.django import DjangoModelFactory from open_prices.proofs import constants as proof_constants -from open_prices.proofs.models import Proof, ProofPrediction +from open_prices.proofs.models.price_tag import PriceTag +from open_prices.proofs.models.proof import Proof, ProofPrediction class ProofFactory(DjangoModelFactory): @@ -44,3 +45,21 @@ class Meta: } value = "SHELF" max_confidence = 0.98 + + +class PriceTagFactory(DjangoModelFactory): + class Meta: + model = PriceTag + + proof = factory.SubFactory(ProofFactory) + model_version = "price_tag_detection-1.0" + created = factory.LazyFunction( + lambda: datetime.datetime.now(tz=datetime.timezone.utc) + ) + updated = factory.LazyFunction( + lambda: datetime.datetime.now(tz=datetime.timezone.utc) + ) + bounding_box = [0.1, 0.2, 0.3, 0.4] + status = None + created_by = None + updated_by = None diff --git a/open_prices/proofs/management/commands/run_ml_model.py b/open_prices/proofs/management/commands/run_ml_model.py index 533049fe..7e09a66f 100644 --- a/open_prices/proofs/management/commands/run_ml_model.py +++ b/open_prices/proofs/management/commands/run_ml_model.py @@ -10,7 +10,7 @@ PROOF_CLASSIFICATION_MODEL_NAME, run_and_save_proof_prediction, ) -from open_prices.proofs.models import Proof +from open_prices.proofs.models.proof import Proof # Initializing root logger get_logger() diff --git a/open_prices/proofs/migrations/0007_pricetag.py b/open_prices/proofs/migrations/0007_pricetag.py new file mode 100644 index 00000000..5da2bd04 --- /dev/null +++ b/open_prices/proofs/migrations/0007_pricetag.py @@ -0,0 +1,117 @@ +# Generated by Django 5.1.4 on 2024-12-16 14:04 + +import django.contrib.postgres.fields +import django.db.models.deletion +import django.utils.timezone +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("prices", "0004_price_type"), + ("proofs", "0006_add_proof_prediction_table"), + ] + + operations = [ + migrations.CreateModel( + name="PriceTag", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "created", + models.DateTimeField( + default=django.utils.timezone.now, + verbose_name="When the tag was created in DB", + ), + ), + ( + "updated", + models.DateTimeField( + auto_now=True, verbose_name="When the tag was last updated" + ), + ), + ( + "bounding_box", + django.contrib.postgres.fields.ArrayField( + base_field=models.FloatField(), + size=None, + verbose_name="Coordinates of the bounding box, in the format [y_min, x_min, y_max, x_max]", + ), + ), + ( + "status", + models.IntegerField( + blank=True, + choices=[ + (0, "deleted"), + (1, "linked_to_price"), + (2, "not_readable"), + (3, "not_price_tag"), + ], + null=True, + verbose_name="The annotation status. Possible values are: - null: not annotated yet- 0 (the price tag was deleted by a user)- 1 (the price tag is linked to a price)- 2 (the price tag barcode or price cannot be read)- 3 (the object is not a price tag)", + ), + ), + ( + "model_version", + models.CharField( + max_length=30, + verbose_name="The version of the object detector model that generated the prediction", + null=True, + blank=True, + ), + ), + ( + "created_by", + models.CharField( + blank=True, + max_length=100, + null=True, + verbose_name="The name of the user who created this price tag. This field is null if the tag was created by a model.", + ), + ), + ( + "updated_by", + models.CharField( + blank=True, + max_length=100, + null=True, + verbose_name="The name of the user who last updated this price tag bounding boxes. If the price tag bounding boxes were never updated, this field is null.", + ), + ), + ( + "price", + models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="price_tags", + to="prices.price", + verbose_name="The price linked to this tag", + ), + ), + ( + "proof", + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="price_tags", + to="proofs.proof", + verbose_name="The proof this price tag belongs to", + ), + ), + ], + options={ + "verbose_name": "Price Tag", + "verbose_name_plural": "Price Tags", + "db_table": "price_tags", + }, + ), + ] diff --git a/open_prices/proofs/models/__init__.py b/open_prices/proofs/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/open_prices/proofs/models/price_tag.py b/open_prices/proofs/models/price_tag.py new file mode 100644 index 00000000..6fce720a --- /dev/null +++ b/open_prices/proofs/models/price_tag.py @@ -0,0 +1,150 @@ +from django.contrib.postgres.fields import ArrayField +from django.core.exceptions import ValidationError +from django.db import models +from django.utils import timezone + +from open_prices.common import constants, utils +from open_prices.prices.models import Price +from open_prices.proofs import constants as proof_constants + +from .proof import Proof + + +class PriceTag(models.Model): + """A single price tag in a proof.""" + + UPDATE_FIELDS = ["bounding_box", "status", "price_id"] + CREATE_FIELDS = UPDATE_FIELDS + ["proof_id"] + + proof = models.ForeignKey( + Proof, + on_delete=models.CASCADE, + related_name="price_tags", + verbose_name="The proof this price tag belongs to", + ) + price = models.ForeignKey( + Price, + on_delete=models.SET_NULL, + related_name="price_tags", + null=True, + blank=True, + verbose_name="The price linked to this tag", + ) + created = models.DateTimeField( + default=timezone.now, verbose_name="When the tag was created in DB" + ) + updated = models.DateTimeField( + auto_now=True, verbose_name="When the tag was last updated" + ) + bounding_box = ArrayField( + base_field=models.FloatField(), + verbose_name="Coordinates of the bounding box, in the format [y_min, x_min, y_max, x_max]", + ) + status = models.IntegerField( + choices=constants.PRICE_TAG_STATUS_CHOICES, + null=True, + blank=True, + verbose_name="The annotation status. Possible values are: " + "- null: not annotated yet" + "- 0 (the price tag was deleted by a user)" + "- 1 (the price tag is linked to a price)" + "- 2 (the price tag barcode or price cannot be read)" + "- 3 (the object is not a price tag)", + ) + model_version = models.CharField( + max_length=30, + verbose_name="The version of the object detector model that generated the prediction", + blank=True, + null=True, + ) + created_by = models.CharField( + max_length=100, + verbose_name="The name of the user who created this price tag. This field is null if " + "the tag was created by a model.", + null=True, + blank=True, + ) + updated_by = models.CharField( + max_length=100, + verbose_name="The name of the user who last updated this price tag bounding boxes. " + "If the price tag bounding boxes were never updated, this field is null.", + null=True, + blank=True, + ) + + class Meta: + db_table = "price_tags" + verbose_name = "Price Tag" + verbose_name_plural = "Price Tags" + + def __str__(self): + return f"{self.proof} - {self.status}" + + def clean(self, *args, **kwargs): + validation_errors = dict() + if self.bounding_box is not None: + if len(self.bounding_box) != 4: + utils.add_validation_error( + validation_errors, + "bounding_box", + "Bounding box should have 4 values.", + ) + else: + if not all(isinstance(value, float) for value in self.bounding_box): + utils.add_validation_error( + validation_errors, + "bounding_box", + "Bounding box values should be floats.", + ) + elif not all(value >= 0 and value <= 1 for value in self.bounding_box): + utils.add_validation_error( + validation_errors, + "bounding_box", + "Bounding box values should be between 0 and 1.", + ) + else: + y_min, x_min, y_max, x_max = self.bounding_box + if y_min >= y_max or x_min >= x_max: + utils.add_validation_error( + validation_errors, + "bounding_box", + "Bounding box values should be in the format [y_min, x_min, y_max, x_max].", + ) + + # self.proof and self.price is fetched with select_related in the view + # when the action is "create" or "update" + # We therefore only check the validity of the relationship if the user + # tries to update the price tag + if self.proof: + if self.proof.type != proof_constants.TYPE_PRICE_TAG: + utils.add_validation_error( + validation_errors, + "proof", + "Proof should have type PRICE_TAG.", + ) + + if self.price: + if self.proof and self.price.proof_id != self.proof.id: + utils.add_validation_error( + validation_errors, + "price", + "Price should belong to the same proof.", + ) + + if self.status is None: + self.status = constants.PriceTagStatus.linked_to_price.value + elif self.status != constants.PriceTagStatus.linked_to_price.value: + utils.add_validation_error( + validation_errors, + "status", + "Status should be `linked_to_price` when price_id is set.", + ) + + if bool(validation_errors): + raise ValidationError(validation_errors) + + super().clean(*args, **kwargs) + + def save(self, *args, **kwargs): + self.full_clean() + super().save(*args, **kwargs) diff --git a/open_prices/proofs/models.py b/open_prices/proofs/models/proof.py similarity index 99% rename from open_prices/proofs/models.py rename to open_prices/proofs/models/proof.py index 6a3be976..25d93ee8 100644 --- a/open_prices/proofs/models.py +++ b/open_prices/proofs/models/proof.py @@ -1,7 +1,8 @@ import decimal from django.conf import settings -from django.core.validators import MinValueValidator, ValidationError +from django.core.exceptions import ValidationError +from django.core.validators import MinValueValidator from django.db import models from django.db.models import Count, signals from django.dispatch import receiver diff --git a/open_prices/proofs/tests.py b/open_prices/proofs/tests.py index d822333f..c72fc45a 100644 --- a/open_prices/proofs/tests.py +++ b/open_prices/proofs/tests.py @@ -14,7 +14,11 @@ from open_prices.locations.factories import LocationFactory from open_prices.prices.factories import PriceFactory from open_prices.proofs import constants as proof_constants -from open_prices.proofs.factories import ProofFactory, ProofPredictionFactory +from open_prices.proofs.factories import ( + PriceTagFactory, + ProofFactory, + ProofPredictionFactory, +) from open_prices.proofs.ml import ( PRICE_TAG_DETECTOR_MODEL_NAME, PRICE_TAG_DETECTOR_MODEL_VERSION, @@ -25,7 +29,7 @@ run_and_save_proof_prediction, run_and_save_proof_type_prediction, ) -from open_prices.proofs.models import Proof +from open_prices.proofs.models.proof import Proof from open_prices.proofs.utils import fetch_and_save_ocr_data, select_proof_image_dir LOCATION_OSM_NODE_652825274 = { @@ -539,3 +543,50 @@ def test_select_proof_image_dir_existing_dir_create_new_dir(self): (images_dir / "0001" / "0001.jpg").touch() selected_dir = select_proof_image_dir(images_dir, max_images_per_dir=1) self.assertEqual(selected_dir, images_dir / "0002") + + +class TestPriceTagCreation(TestCase): + def test_create_price_tag_invalid_bounding_box_length(self): + with self.assertRaises(ValidationError) as cm: + PriceTagFactory(bounding_box=[0.1, 0.2]) + self.assertEqual( + str(cm.exception), + "{'bounding_box': ['Bounding box should have 4 values.']}", + ) + + def test_create_price_tag_invalid_bounding_box_value(self): + with self.assertRaises(ValidationError) as cm: + PriceTagFactory(bounding_box=None) + self.assertEqual( + str(cm.exception), + "{'bounding_box': ['This field cannot be null.']}", + ) + + with self.assertRaises(ValidationError) as cm: + PriceTagFactory(bounding_box=["st", 0.2, 0.3, 0.4]) + self.assertEqual( + str(cm.exception), + "{'bounding_box': ['Bounding box values should be floats.']}", + ) + + with self.assertRaises(ValidationError) as cm: + PriceTagFactory(bounding_box=[0.1, 1.2, 0.3, 0.4]) + self.assertEqual( + str(cm.exception), + "{'bounding_box': ['Bounding box values should be between 0 and 1.']}", + ) + + with self.assertRaises(ValidationError) as cm: + PriceTagFactory(bounding_box=[0.5, 0.1, 0.4, 0.4]) + self.assertEqual( + str(cm.exception), + "{'bounding_box': ['Bounding box values should be in the format [y_min, x_min, y_max, x_max].']}", + ) + + def test_create_price_tag_invalid_proof_type(self): + with self.assertRaises(ValidationError) as cm: + PriceTagFactory(bounding_box=None, proof__type=proof_constants.TYPE_RECEIPT) + self.assertEqual( + str(cm.exception), + "{'bounding_box': ['Proof should have type PRICE_TAG.']}", + ) diff --git a/open_prices/stats/models.py b/open_prices/stats/models.py index a6391323..56b4efca 100644 --- a/open_prices/stats/models.py +++ b/open_prices/stats/models.py @@ -88,7 +88,7 @@ def update_location_stats(self): self.save(update_fields=self.LOCATION_COUNT_FIELDS + ["updated"]) def update_proof_stats(self): - from open_prices.proofs.models import Proof + from open_prices.proofs.models.proof import Proof self.proof_count = Proof.objects.count() self.proof_with_price_count = Proof.objects.has_prices().count()