diff --git a/open_prices/api/proofs/views.py b/open_prices/api/proofs/views.py index d6575841..94a5b98e 100644 --- a/open_prices/api/proofs/views.py +++ b/open_prices/api/proofs/views.py @@ -25,7 +25,7 @@ from open_prices.api.utils import get_source_from_request from open_prices.common.authentication import CustomAuthentication from open_prices.common.constants import PriceTagStatus -from open_prices.proofs.ml import extract_from_price_tags +from open_prices.proofs.ml import extract_from_price_tag from open_prices.proofs.models import PriceTag, Proof from open_prices.proofs.utils import store_file @@ -126,8 +126,8 @@ def upload(self, request: Request) -> Response: def process_with_gemini(self, request: Request) -> Response: files = request.FILES.getlist("files") sample_files = [PIL.Image.open(file.file) for file in files] - res = extract_from_price_tags(sample_files) - return Response(res, status=status.HTTP_200_OK) + labels = [extract_from_price_tag(sample_file) for sample_file in sample_files] + return Response({"labels": labels}, status=status.HTTP_200_OK) class PriceTagViewSet( diff --git a/open_prices/proofs/ml.py b/open_prices/proofs/ml.py index 969825e8..2ed002fe 100644 --- a/open_prices/proofs/ml.py +++ b/open_prices/proofs/ml.py @@ -74,6 +74,7 @@ class Products(enum.Enum): CUCUMBERS = "en:cucumbers" DATES = "en:dates" ENDIVES = "en:endives" + FENNEL_BULBS = "en:fennel-bulbs" FIGS = "en:figs" GARLIC = "en:garlic" GINGER = "en:ginger" @@ -157,40 +158,35 @@ class Label(typing.TypedDict): product_name: str -class Labels(typing.TypedDict): - labels: list[Label] +def extract_from_price_tag(image: Image.Image) -> Label: + """Extract price tag information from an image. - -def extract_from_price_tags(images: Image.Image) -> Labels: - """Extract price tag information from a list of images.""" + :param image: the input Pillow image + :return: the extracted information as a dictionary + """ # Gemini model max payload size is 20MB # To prevent the payload from being too large, we resize the images before # upload - resized_images = [] max_size = 1024 - for image in images: - if image.width > max_size or image.height > max_size: - resized_image = image.copy() - resized_image.thumbnail((max_size, max_size)) - resized_images.append(resized_image) - else: - resized_images.append(image) + if image.width > max_size or image.height > max_size: + image = image.copy() + image.thumbnail((max_size, max_size)) response = model.generate_content( [ ( - f"Here are {len(resized_images)} pictures containing a label. " - "For each picture of a label, please extract all the following attributes: " + "Here is one picture containing a label. " + "Please extract all the following attributes: " "the product category matching product name, the origin category matching country of origin, the price, " "is the product organic, the unit (per KILOGRAM or per UNIT) and the barcode (valid EAN-13 usually). " - f"I expect a list of {len(resized_images)} labels in your reply, no more, no less. " - "If you cannot decode an attribute, set it to an empty string" - ) - ] - + resized_images, + "I expect a single JSON in your reply, no more, no less. " + "If you cannot decode an attribute, set it to an empty string." + ), + image, + ], generation_config=genai.GenerationConfig( - response_mime_type="application/json", response_schema=Labels + response_mime_type="application/json", response_schema=Label ), ) return json.loads(response.text) @@ -285,7 +281,7 @@ def run_and_save_price_tag_extraction( logger.error("Proof file not found: %s", proof.file_path_full) return [] - cropped_images = [] + predictions = [] for price_tag in price_tags: y_min, x_min, y_max, x_max = price_tag.bounding_box image = Image.open(proof.file_path_full) @@ -296,12 +292,7 @@ def run_and_save_price_tag_extraction( y_max * image.height, ) cropped_image = image.crop((left, top, right, bottom)) - cropped_images.append(cropped_image) - - labels = extract_from_price_tags(cropped_images) - - predictions = [] - for price_tag, label in zip(price_tags, labels["labels"]): + label = extract_from_price_tag(cropped_image) prediction = PriceTagPrediction.objects.create( price_tag=price_tag, type=constants.PRICE_TAG_EXTRACTION_TYPE, @@ -351,8 +342,8 @@ def update_price_tag_extraction(price_tag_id: int) -> PriceTagPrediction: y_max * image.height, ) cropped_image = image.crop((left, top, right, bottom)) - gemini_output = extract_from_price_tags([cropped_image]) - price_tag_prediction.data = gemini_output["labels"][0] + gemini_output = extract_from_price_tag(cropped_image) + price_tag_prediction.data = gemini_output price_tag_prediction.model_name = GEMINI_MODEL_NAME price_tag_prediction.model_version = GEMINI_MODEL_VERSION price_tag_prediction.save()