diff --git a/pybuda/pybuda/module.py b/pybuda/pybuda/module.py index 49cb5adb4..cd5051d05 100644 --- a/pybuda/pybuda/module.py +++ b/pybuda/pybuda/module.py @@ -240,6 +240,8 @@ def get_parameters(self) -> List[Parameter]: for name, param in itertools.chain(*all_params): if name in recorded_names: continue + if param == None: + continue pybuda_param = Parameter( param.cpu(), requires_grad = param.requires_grad, diff --git a/pybuda/test/mlir/llama/test_llama_inference.py b/pybuda/test/mlir/llama/test_llama_inference.py index 5db01a63b..a24c52465 100644 --- a/pybuda/test/mlir/llama/test_llama_inference.py +++ b/pybuda/test/mlir/llama/test_llama_inference.py @@ -2,28 +2,19 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest + +from test.mlir.llama.utils.utils import load_model from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import pybuda +@pytest.mark.xfail(reason="Tile broadcast op is not supported on MLIR.") def test_llama_inference(): - # Compiler configurations - compiler_cfg = pybuda.config._get_global_compiler_config() - compiler_cfg.enable_tvm_cpu_fallback = False - # Load Llama 3B model and tokenizer model_path = "openlm-research/open_llama_3b" - config = LlamaConfig() - config.hidden_size = 3200 - config.intermediate_size = 8640 - config.num_hidden_layers = 26 - config.pad_token_id = 0 - config.return_dict = False - framework_model = LlamaForCausalLM.from_pretrained( - model_path, device_map="auto", config=config - ) - framework_model.eval() + framework_model = load_model(model_path) tokenizer = LlamaTokenizer.from_pretrained(model_path) prompt = "Q: What is the largest animal?\nA:" diff --git a/pybuda/test/mlir/llama/tests/test_llama_embedding.py b/pybuda/test/mlir/llama/tests/test_llama_embedding.py new file mode 100644 index 000000000..180f746ec --- /dev/null +++ b/pybuda/test/mlir/llama/tests/test_llama_embedding.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import torch +import pytest + +import pybuda +from test.mlir.llama.utils.utils import load_model +from pybuda.op.eval.common import compare_with_golden_pcc + + +@pytest.mark.xfail(reason="Embedding op is not supported on MLIR.") +def test_llama_embedding(): + # Load Llama 3B model and tokenizer + framework_model = load_model() + vocab_size = framework_model.config.vocab_size + framework_model = framework_model.model.embed_tokens + + # Input samples + inputs = [ + torch.randint(0, vocab_size, (1, 12)), # Input token IDs + ] + + # Sanity run + golden_output = framework_model(*inputs) + + # Compile the model + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) + + # Run on TT device + tt_out = compiled_model(*inputs) + tt_out = [out.to("cpu") for out in tt_out] + + # Validate results + assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) + diff --git a/pybuda/test/mlir/llama/tests/test_llama_lm_head.py b/pybuda/test/mlir/llama/tests/test_llama_lm_head.py new file mode 100644 index 000000000..1f08801fb --- /dev/null +++ b/pybuda/test/mlir/llama/tests/test_llama_lm_head.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import torch +import pytest + +import pybuda +from test.mlir.llama.utils.utils import load_model +from pybuda.op.eval.common import compare_with_golden_pcc + + +@pytest.mark.xfail(reason="Squeeze op is not supported on MLIR.") +def test_llama_lm_head(): + # Load Llama 3B model and tokenizer + framework_model = load_model() + framework_model = framework_model.lm_head + + # Input samples + inputs = [ + torch.rand((1, 12, 3200)), # Hidden states + ] + + # Sanity run + golden_output = framework_model(*inputs) + + # Compile the model + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) + + # Run on TT device + tt_out = compiled_model(*inputs) + tt_out = [out.to("cpu") for out in tt_out] + + # Validate results + assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) diff --git a/pybuda/test/mlir/llama/tests/test_llama_mlp.py b/pybuda/test/mlir/llama/tests/test_llama_mlp.py new file mode 100644 index 000000000..c40cb6d8b --- /dev/null +++ b/pybuda/test/mlir/llama/tests/test_llama_mlp.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import torch +import pytest + +import pybuda +from test.mlir.llama.utils.utils import load_model +from pybuda.op.eval.common import compare_with_golden_pcc + + +@pytest.mark.xfail(reason="Squeeze op is not supported on MLIR.") +def test_llama_mlp(): + # Load Llama 3B model and tokenizer + framework_model = load_model() + framework_model = framework_model.model.layers[0].mlp + + # Input samples + inputs = [ + torch.rand((1, 12, 3200)), # Hidden states + ] + + # Sanity run + golden_output = framework_model(*inputs) + + # Compile the model + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) + + # Run on TT device + tt_out = compiled_model(*inputs) + tt_out = [out.to("cpu") for out in tt_out] + + # Validate results + assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) diff --git a/pybuda/test/mlir/llama/tests/test_llama_rms_norm.py b/pybuda/test/mlir/llama/tests/test_llama_rms_norm.py new file mode 100644 index 000000000..dab328f8e --- /dev/null +++ b/pybuda/test/mlir/llama/tests/test_llama_rms_norm.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import torch +import pytest + +import pybuda +from test.mlir.llama.utils.utils import load_model +from pybuda.op.eval.common import compare_with_golden_pcc + + +@pytest.mark.xfail(reason="Tile broadcast op is not supported on MLIR.") +def test_llama_lm_head(): + # Load Llama 3B model and tokenizer + framework_model = load_model() + framework_model = framework_model.model.norm + + # Input samples + inputs = [ + torch.rand((1, 12, 3200)), # Hidden states + ] + + # Sanity run + golden_output = framework_model(*inputs) + + # Compile the model + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) + + # Run on TT device + tt_out = compiled_model(*inputs) + tt_out = [out.to("cpu") for out in tt_out] + + # Validate results + assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) diff --git a/pybuda/test/mlir/llama/tests/test_llama_rotary_emb.py b/pybuda/test/mlir/llama/tests/test_llama_rotary_emb.py new file mode 100644 index 000000000..fca3c8346 --- /dev/null +++ b/pybuda/test/mlir/llama/tests/test_llama_rotary_emb.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import torch +import pytest + +import pybuda +from test.mlir.llama.utils.utils import load_model +from pybuda.op.eval.common import compare_with_golden_pcc + + +@pytest.mark.xfail(reason="Dynamic shapes..") +def test_llama_rotary_emb(): + # Load Llama 3B model and tokenizer + framework_model = load_model() + kv_seq_len = 12 + framework_model = framework_model.model.layers[0].self_attn.rotary_emb + + # Input samples + inputs = [ + torch.rand((1, 32, kv_seq_len, 100)), # Value states + torch.unsqueeze(torch.tensor(kv_seq_len), 0) , # Sequence length + ] + + # Sanity run + golden_output = framework_model(*inputs) + + # Compile the model + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) + + # Run on TT device + tt_out = compiled_model(*inputs) + tt_out = [out.to("cpu") for out in tt_out] + + # Validate results + assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) diff --git a/pybuda/test/mlir/llama/tests/test_llama_self_attn.py b/pybuda/test/mlir/llama/tests/test_llama_self_attn.py new file mode 100644 index 000000000..e86c661f4 --- /dev/null +++ b/pybuda/test/mlir/llama/tests/test_llama_self_attn.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 +import torch +import pytest + +import pybuda +from test.mlir.llama.utils.utils import load_model +from pybuda.op.eval.common import compare_with_golden_pcc + + +@pytest.mark.xfail(reason="Squeeze op is not supported on MLIR.") +def test_llama_self_attn(): + # Define wrapper function + class Wrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, *inputs): + hidden_states, _, _ = self.model(*inputs) + + return hidden_states + + # Load Llama 3B model and tokenizer + framework_model = load_model() + framework_model = Wrapper(framework_model.model.layers[0].self_attn) + + # Input samples + inputs = [ + torch.rand((1, 12, 3200)), # Hidden states + torch.ones((1, 1, 12, 12)), # Attention mask + torch.arange(12).unsqueeze(0), # Position IDs + ] + + # Sanity run + golden_output = framework_model(*inputs) + + # Compile the model + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) + + # Run on TT device + tt_out = compiled_model(*inputs) + tt_out = [out.to("cpu") for out in tt_out] + + # Validate results + assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) diff --git a/pybuda/test/mlir/llama/utils/utils.py b/pybuda/test/mlir/llama/utils/utils.py new file mode 100644 index 000000000..03751ccb0 --- /dev/null +++ b/pybuda/test/mlir/llama/utils/utils.py @@ -0,0 +1,22 @@ +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer + +import pybuda + +def load_model(model_path="openlm-research/open_llama_3b"): + # Compiler configurations + compiler_cfg = pybuda.config._get_global_compiler_config() + compiler_cfg.enable_tvm_cpu_fallback = False + + # Load Llama 3B model + config = LlamaConfig() + config.hidden_size = 3200 + config.intermediate_size = 8640 + config.num_hidden_layers = 26 + config.pad_token_id = 0 + config.return_dict = False + framework_model = LlamaForCausalLM.from_pretrained( + model_path, device_map="auto", config=config + ) + framework_model.eval() + + return framework_model diff --git a/pytest.ini b/pytest.ini index bd4293fd7..8eebb2c29 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,10 +6,19 @@ addopts = -svv --junit-xml=reports/report.xml # Where pytest should look for tests testpaths = + # Ops pybuda/test/mlir/test_ops.py + + # API pybuda/test/test_api.py + + # MNIST Linear pybuda/test/mlir/mnist/test_inference.py pybuda/test/mlir/test_training.py + # Llama 3B + pybuda/test/mlir/llama/test_llama_inference.py + pybuda/test/mlir/llama/tests + filterwarnings = ignore::DeprecationWarning