From 83731a21954aecfd2469d6b857f5f65f5d41e3cc Mon Sep 17 00:00:00 2001 From: Defne Dilbaz Date: Wed, 4 Dec 2024 14:30:51 -0500 Subject: [PATCH] Add test for mgp-str-base [#57] (#86) * Add test for mgp-str-base [#57] Multi-Granularity Prediction for Scene Text Recognition from https://huggingface.co/alibaba-damo/mgp-str-base * Add to nightly tests --- .github/workflows/run-model-tests.yml | 2 +- .../models/mgp-str-base/test_mgp_str_base.py | 53 +++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 tests/models/mgp-str-base/test_mgp_str_base.py diff --git a/.github/workflows/run-model-tests.yml b/.github/workflows/run-model-tests.yml index b3505378..777e150e 100644 --- a/.github/workflows/run-model-tests.yml +++ b/.github/workflows/run-model-tests.yml @@ -16,7 +16,7 @@ jobs: build: [ {runs-on: n150, n300, name: "run1", test_names: "stable_diffusion"}, {runs-on: n150, n300, name: "run2", test_names: "MobileNetV2, clip, flan_t5, mlpmixer, resnet, vilt, albert, codegen, glpn_kitti, mnist, resnet50, t5, whisper, autoencoder_conv, deit, gpt2, mobilenet_ssd, roberta, timm, xglm, autoencoder_linear, detr, gpt_neo, musicgen_small, segformer, torchvision, yolos"}, - {runs-on: n300, n150, name: "run3", test_names: "beit, distilbert, hand_landmark, openpose, segment_anything, unet, yolov3, bert, dpr, hardnet, opt, speecht5_tts, unet_brain, yolov5, bloom, falcon, llama, perceiver_io, squeeze_bert, unet_carvana"}, + {runs-on: n300, n150, name: "run3", test_names: "beit, distilbert, hand_landmark, openpose, segment_anything, unet, yolov3, bert, dpr, hardnet, opt, speecht5_tts, unet_brain, yolov5, bloom, falcon, llama, perceiver_io, squeeze_bert, unet_carvana, mgp-str-base"}, ] runs-on: - ${{ matrix.build.runs-on }} diff --git a/tests/models/mgp-str-base/test_mgp_str_base.py b/tests/models/mgp-str-base/test_mgp_str_base.py new file mode 100644 index 00000000..14e4c724 --- /dev/null +++ b/tests/models/mgp-str-base/test_mgp_str_base.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +# From: https://huggingface.co/alibaba-damo/mgp-str-base + +from PIL import Image +import requests +import torch +from transformers import MgpstrProcessor, MgpstrForSceneTextRecognition +import pytest +from tests.utils import ModelTester + + +class ThisTester(ModelTester): + def _load_model(self): + model = MgpstrForSceneTextRecognition.from_pretrained( + "alibaba-damo/mgp-str-base", torch_dtype=torch.bfloat16 + ) + self.processor = MgpstrProcessor.from_pretrained( + "alibaba-damo/mgp-str-base", torch_dtype=torch.bfloat16 + ) + return model + + def _load_inputs(self): + url = "https://i.postimg.cc/ZKwLg2Gw/367-14.png" # generated_text = "ticket" + image = Image.open(requests.get(url, stream=True).raw).convert("RGB") + inputs = self.processor( + images=image, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16) + return inputs + + +@pytest.mark.parametrize( + "mode", + ["train", "eval"], +) +def test_mgp_str_base(record_property, mode): + model_name = "alibaba-damo/mgp-str-base" + record_property("model_name", model_name) + record_property("mode", mode) + + tester = ThisTester(model_name, mode) + results = tester.test_model() + + if mode == "eval": + logits = results.logits + generated_text = tester.processor.batch_decode(logits)["generated_text"] + print(f"Generated text: '{generated_text}'") + assert generated_text[0] == "ticket" + + record_property("torch_ttnn", (tester, results))