-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Llama 3B: Include tests with proper setups for main builing b…
…locks + minor fix. #123 #124 #125 #126 #165 (#167) Main Llama blocks: - Embeddings - LM Head - MLP - RMS Norm - Rotary Embeddings - Self Attention
- Loading branch information
1 parent
b61a0ad
commit 8d88845
Showing
10 changed files
with
259 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
import pytest | ||
|
||
import pybuda | ||
from test.mlir.llama.utils.utils import load_model | ||
from pybuda.op.eval.common import compare_with_golden_pcc | ||
|
||
|
||
@pytest.mark.xfail(reason="Embedding op is not supported on MLIR.") | ||
def test_llama_embedding(): | ||
# Load Llama 3B model and tokenizer | ||
framework_model = load_model() | ||
vocab_size = framework_model.config.vocab_size | ||
framework_model = framework_model.model.embed_tokens | ||
|
||
# Input samples | ||
inputs = [ | ||
torch.randint(0, vocab_size, (1, 12)), # Input token IDs | ||
] | ||
|
||
# Sanity run | ||
golden_output = framework_model(*inputs) | ||
|
||
# Compile the model | ||
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) | ||
|
||
# Run on TT device | ||
tt_out = compiled_model(*inputs) | ||
tt_out = [out.to("cpu") for out in tt_out] | ||
|
||
# Validate results | ||
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
import pytest | ||
|
||
import pybuda | ||
from test.mlir.llama.utils.utils import load_model | ||
from pybuda.op.eval.common import compare_with_golden_pcc | ||
|
||
|
||
@pytest.mark.xfail(reason="Squeeze op is not supported on MLIR.") | ||
def test_llama_lm_head(): | ||
# Load Llama 3B model and tokenizer | ||
framework_model = load_model() | ||
framework_model = framework_model.lm_head | ||
|
||
# Input samples | ||
inputs = [ | ||
torch.rand((1, 12, 3200)), # Hidden states | ||
] | ||
|
||
# Sanity run | ||
golden_output = framework_model(*inputs) | ||
|
||
# Compile the model | ||
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) | ||
|
||
# Run on TT device | ||
tt_out = compiled_model(*inputs) | ||
tt_out = [out.to("cpu") for out in tt_out] | ||
|
||
# Validate results | ||
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
import pytest | ||
|
||
import pybuda | ||
from test.mlir.llama.utils.utils import load_model | ||
from pybuda.op.eval.common import compare_with_golden_pcc | ||
|
||
|
||
@pytest.mark.xfail(reason="Squeeze op is not supported on MLIR.") | ||
def test_llama_mlp(): | ||
# Load Llama 3B model and tokenizer | ||
framework_model = load_model() | ||
framework_model = framework_model.model.layers[0].mlp | ||
|
||
# Input samples | ||
inputs = [ | ||
torch.rand((1, 12, 3200)), # Hidden states | ||
] | ||
|
||
# Sanity run | ||
golden_output = framework_model(*inputs) | ||
|
||
# Compile the model | ||
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) | ||
|
||
# Run on TT device | ||
tt_out = compiled_model(*inputs) | ||
tt_out = [out.to("cpu") for out in tt_out] | ||
|
||
# Validate results | ||
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
import pytest | ||
|
||
import pybuda | ||
from test.mlir.llama.utils.utils import load_model | ||
from pybuda.op.eval.common import compare_with_golden_pcc | ||
|
||
|
||
@pytest.mark.xfail(reason="Tile broadcast op is not supported on MLIR.") | ||
def test_llama_lm_head(): | ||
# Load Llama 3B model and tokenizer | ||
framework_model = load_model() | ||
framework_model = framework_model.model.norm | ||
|
||
# Input samples | ||
inputs = [ | ||
torch.rand((1, 12, 3200)), # Hidden states | ||
] | ||
|
||
# Sanity run | ||
golden_output = framework_model(*inputs) | ||
|
||
# Compile the model | ||
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) | ||
|
||
# Run on TT device | ||
tt_out = compiled_model(*inputs) | ||
tt_out = [out.to("cpu") for out in tt_out] | ||
|
||
# Validate results | ||
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
import pytest | ||
|
||
import pybuda | ||
from test.mlir.llama.utils.utils import load_model | ||
from pybuda.op.eval.common import compare_with_golden_pcc | ||
|
||
|
||
@pytest.mark.xfail(reason="Dynamic shapes..") | ||
def test_llama_rotary_emb(): | ||
# Load Llama 3B model and tokenizer | ||
framework_model = load_model() | ||
kv_seq_len = 12 | ||
framework_model = framework_model.model.layers[0].self_attn.rotary_emb | ||
|
||
# Input samples | ||
inputs = [ | ||
torch.rand((1, 32, kv_seq_len, 100)), # Value states | ||
torch.unsqueeze(torch.tensor(kv_seq_len), 0) , # Sequence length | ||
] | ||
|
||
# Sanity run | ||
golden_output = framework_model(*inputs) | ||
|
||
# Compile the model | ||
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) | ||
|
||
# Run on TT device | ||
tt_out = compiled_model(*inputs) | ||
tt_out = [out.to("cpu") for out in tt_out] | ||
|
||
# Validate results | ||
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
import torch | ||
import pytest | ||
|
||
import pybuda | ||
from test.mlir.llama.utils.utils import load_model | ||
from pybuda.op.eval.common import compare_with_golden_pcc | ||
|
||
|
||
@pytest.mark.xfail(reason="Squeeze op is not supported on MLIR.") | ||
def test_llama_self_attn(): | ||
# Define wrapper function | ||
class Wrapper(torch.nn.Module): | ||
def __init__(self, model): | ||
super().__init__() | ||
self.model = model | ||
|
||
def forward(self, *inputs): | ||
hidden_states, _, _ = self.model(*inputs) | ||
|
||
return hidden_states | ||
|
||
# Load Llama 3B model and tokenizer | ||
framework_model = load_model() | ||
framework_model = Wrapper(framework_model.model.layers[0].self_attn) | ||
|
||
# Input samples | ||
inputs = [ | ||
torch.rand((1, 12, 3200)), # Hidden states | ||
torch.ones((1, 1, 12, 12)), # Attention mask | ||
torch.arange(12).unsqueeze(0), # Position IDs | ||
] | ||
|
||
# Sanity run | ||
golden_output = framework_model(*inputs) | ||
|
||
# Compile the model | ||
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) | ||
|
||
# Run on TT device | ||
tt_out = compiled_model(*inputs) | ||
tt_out = [out.to("cpu") for out in tt_out] | ||
|
||
# Validate results | ||
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer | ||
|
||
import pybuda | ||
|
||
def load_model(model_path="openlm-research/open_llama_3b"): | ||
# Compiler configurations | ||
compiler_cfg = pybuda.config._get_global_compiler_config() | ||
compiler_cfg.enable_tvm_cpu_fallback = False | ||
|
||
# Load Llama 3B model | ||
config = LlamaConfig() | ||
config.hidden_size = 3200 | ||
config.intermediate_size = 8640 | ||
config.num_hidden_layers = 26 | ||
config.pad_token_id = 0 | ||
config.return_dict = False | ||
framework_model = LlamaForCausalLM.from_pretrained( | ||
model_path, device_map="auto", config=config | ||
) | ||
framework_model.eval() | ||
|
||
return framework_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters