From 15fb19494d38b9e0108d35642e064228a433d8f5 Mon Sep 17 00:00:00 2001 From: ddilbaz Date: Mon, 2 Dec 2024 22:19:51 +0000 Subject: [PATCH] Add mamba tests [#75] Add tests for mamba and mamba2. Tests are marked xfail because they yield 'AssertionError: Attempt to trace forbidden callable ' error. However, the tests generate a graph. --- .github/workflows/run-model-tests.yml | 2 +- tests/models/mamba/test_mamba.py | 59 +++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 tests/models/mamba/test_mamba.py diff --git a/.github/workflows/run-model-tests.yml b/.github/workflows/run-model-tests.yml index ecb17405..e944cc93 100644 --- a/.github/workflows/run-model-tests.yml +++ b/.github/workflows/run-model-tests.yml @@ -15,7 +15,7 @@ jobs: matrix: build: [ {runs-on: n150, name: "run1", test_names: "stable_diffusion, Qwen, MobileNetV2, clip, flan_t5, mlpmixer, resnet, vilt, albert, codegen, glpn_kitti, mnist, resnet50, RMBG, unet_carvana, mgp-str-base, musicgen_small, segformer, torchvision, yolos"}, - {runs-on: n150, name: "run2", test_names: "t5, whisper, autoencoder_conv, deit, gpt2, mobilenet_ssd, roberta, timm, xglm, autoencoder_linear, detr, beit, distilbert, hand_landmark, openpose, segment_anything, unet, yolov3, bert, dpr, hardnet, opt, speecht5_tts, unet_brain, yolov5, bloom, falcon, llama, perceiver_io, squeeze_bert, gpt_neo"}, + {runs-on: n150, name: "run2", test_names: "t5, whisper, autoencoder_conv, deit, gpt2, mobilenet_ssd, roberta, timm, xglm, autoencoder_linear, detr, beit, distilbert, hand_landmark, openpose, segment_anything, unet, yolov3, bert, dpr, hardnet, opt, speecht5_tts, unet_brain, yolov5, bloom, falcon, llama, perceiver_io, squeeze_bert, gpt_neo, mamba"}, ] runs-on: - ${{ matrix.build.runs-on }} diff --git a/tests/models/mamba/test_mamba.py b/tests/models/mamba/test_mamba.py new file mode 100644 index 00000000..70fa7fc8 --- /dev/null +++ b/tests/models/mamba/test_mamba.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +# Reference: https://huggingface.co/state-spaces/mamba-2.8b-hf + +from transformers import MambaForCausalLM, AutoTokenizer, GenerationConfig +import pytest +from tests.utils import ModelTester +import torch + + +class ThisTester(ModelTester): + def _load_model(self): + model = MambaForCausalLM.from_pretrained( + self.model_name, torch_dtype=torch.bfloat16 + ) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, torch_dtype=torch.bfloat16 + ) + return model.generate + + def _load_inputs(self): + prompt = "Hey how are you doing?" + input_ids = self.tokenizer(prompt, return_tensors="pt")["input_ids"] + generation_config = GenerationConfig(max_new_tokens=10) + arguments = {"input_ids": input_ids, "generation_config": generation_config} + return arguments + + def set_model_eval(self, model): + return model + + +@pytest.mark.parametrize( + "mode", + ["eval"], +) +@pytest.mark.parametrize( + "model_name", + [ + "state-spaces/mamba-790m-hf", + "state-spaces/mamba-2.8b-hf", + "state-spaces/mamba-1.4b-hf", + "state-spaces/mamba-370m-hf", + ], +) +@pytest.mark.xfail( + reason="Fails due to 'Attempt to trace forbidden callable', but we can still generate a graph" +) +def test_mamba(record_property, mode, model_name): + record_property("model_name", model_name) + record_property("mode", mode) + + tester = ThisTester(model_name, mode) + results = tester.test_model() + if mode == "eval": + gen_text = tester.tokenizer.batch_decode(results) + print("Generated text: ", gen_text) + + record_property("torch_ttnn", (tester, results))