Skip to content

Commit

Permalink
Add test for mgp-str-base [#57] (#86)
Browse files Browse the repository at this point in the history
* Add test for mgp-str-base [#57]

Multi-Granularity Prediction for Scene Text Recognition from https://huggingface.co/alibaba-damo/mgp-str-base

* Add to nightly tests
  • Loading branch information
ddilbazTT authored Dec 4, 2024
1 parent 34e925e commit 83731a2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/run-model-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
build: [
{runs-on: n150, n300, name: "run1", test_names: "stable_diffusion"},
{runs-on: n150, n300, name: "run2", test_names: "MobileNetV2, clip, flan_t5, mlpmixer, resnet, vilt, albert, codegen, glpn_kitti, mnist, resnet50, t5, whisper, autoencoder_conv, deit, gpt2, mobilenet_ssd, roberta, timm, xglm, autoencoder_linear, detr, gpt_neo, musicgen_small, segformer, torchvision, yolos"},
{runs-on: n300, n150, name: "run3", test_names: "beit, distilbert, hand_landmark, openpose, segment_anything, unet, yolov3, bert, dpr, hardnet, opt, speecht5_tts, unet_brain, yolov5, bloom, falcon, llama, perceiver_io, squeeze_bert, unet_carvana"},
{runs-on: n300, n150, name: "run3", test_names: "beit, distilbert, hand_landmark, openpose, segment_anything, unet, yolov3, bert, dpr, hardnet, opt, speecht5_tts, unet_brain, yolov5, bloom, falcon, llama, perceiver_io, squeeze_bert, unet_carvana, mgp-str-base"},
]
runs-on:
- ${{ matrix.build.runs-on }}
Expand Down
53 changes: 53 additions & 0 deletions tests/models/mgp-str-base/test_mgp_str_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
# From: https://huggingface.co/alibaba-damo/mgp-str-base

from PIL import Image
import requests
import torch
from transformers import MgpstrProcessor, MgpstrForSceneTextRecognition
import pytest
from tests.utils import ModelTester


class ThisTester(ModelTester):
def _load_model(self):
model = MgpstrForSceneTextRecognition.from_pretrained(
"alibaba-damo/mgp-str-base", torch_dtype=torch.bfloat16
)
self.processor = MgpstrProcessor.from_pretrained(
"alibaba-damo/mgp-str-base", torch_dtype=torch.bfloat16
)
return model

def _load_inputs(self):
url = "https://i.postimg.cc/ZKwLg2Gw/367-14.png" # generated_text = "ticket"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
inputs = self.processor(
images=image,
return_tensors="pt",
)
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
return inputs


@pytest.mark.parametrize(
"mode",
["train", "eval"],
)
def test_mgp_str_base(record_property, mode):
model_name = "alibaba-damo/mgp-str-base"
record_property("model_name", model_name)
record_property("mode", mode)

tester = ThisTester(model_name, mode)
results = tester.test_model()

if mode == "eval":
logits = results.logits
generated_text = tester.processor.batch_decode(logits)["generated_text"]
print(f"Generated text: '{generated_text}'")
assert generated_text[0] == "ticket"

record_property("torch_ttnn", (tester, results))

0 comments on commit 83731a2

Please sign in to comment.