Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Llama 3B: Include tests with proper setups for main builing bocks + minor fix. #123 #124 #125 #126 #165 #167

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pybuda/pybuda/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ def get_parameters(self) -> List[Parameter]:
for name, param in itertools.chain(*all_params):
if name in recorded_names:
continue
if param == None:
nvukobratTT marked this conversation as resolved.
Show resolved Hide resolved
continue
pybuda_param = Parameter(
param.cpu(),
requires_grad = param.requires_grad,
Expand Down
19 changes: 5 additions & 14 deletions pybuda/test/mlir/llama/test_llama_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,19 @@

# SPDX-License-Identifier: Apache-2.0

import pytest

from test.mlir.llama.utils.utils import load_model
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer

import pybuda


@pytest.mark.xfail(reason="Tile broadcast op is not supported on MLIR.")
def test_llama_inference():
# Compiler configurations
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.enable_tvm_cpu_fallback = False

# Load Llama 3B model and tokenizer
model_path = "openlm-research/open_llama_3b"
config = LlamaConfig()
config.hidden_size = 3200
config.intermediate_size = 8640
config.num_hidden_layers = 26
config.pad_token_id = 0
config.return_dict = False
framework_model = LlamaForCausalLM.from_pretrained(
model_path, device_map="auto", config=config
)
framework_model.eval()
framework_model = load_model(model_path)
tokenizer = LlamaTokenizer.from_pretrained(model_path)

prompt = "Q: What is the largest animal?\nA:"
Expand Down
36 changes: 36 additions & 0 deletions pybuda/test/mlir/llama/tests/test_llama_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
import torch
import pytest

import pybuda
from test.mlir.llama.utils.utils import load_model
from pybuda.op.eval.common import compare_with_golden_pcc


@pytest.mark.xfail(reason="Embedding op is not supported on MLIR.")
def test_llama_embedding():
# Load Llama 3B model and tokenizer
framework_model = load_model()
vocab_size = framework_model.config.vocab_size
framework_model = framework_model.model.embed_tokens

# Input samples
inputs = [
torch.randint(0, vocab_size, (1, 12)), # Input token IDs
]

# Sanity run
golden_output = framework_model(*inputs)

# Compile the model
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)

# Run on TT device
tt_out = compiled_model(*inputs)
tt_out = [out.to("cpu") for out in tt_out]

# Validate results
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99)

34 changes: 34 additions & 0 deletions pybuda/test/mlir/llama/tests/test_llama_lm_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
import torch
import pytest

import pybuda
from test.mlir.llama.utils.utils import load_model
from pybuda.op.eval.common import compare_with_golden_pcc


@pytest.mark.xfail(reason="Squeeze op is not supported on MLIR.")
def test_llama_lm_head():
# Load Llama 3B model and tokenizer
framework_model = load_model()
framework_model = framework_model.lm_head

# Input samples
inputs = [
torch.rand((1, 12, 3200)), # Hidden states
]

# Sanity run
golden_output = framework_model(*inputs)

# Compile the model
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)

# Run on TT device
tt_out = compiled_model(*inputs)
tt_out = [out.to("cpu") for out in tt_out]

# Validate results
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99)
34 changes: 34 additions & 0 deletions pybuda/test/mlir/llama/tests/test_llama_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
import torch
import pytest

import pybuda
from test.mlir.llama.utils.utils import load_model
from pybuda.op.eval.common import compare_with_golden_pcc


@pytest.mark.xfail(reason="Squeeze op is not supported on MLIR.")
def test_llama_mlp():
# Load Llama 3B model and tokenizer
framework_model = load_model()
framework_model = framework_model.model.layers[0].mlp

# Input samples
inputs = [
torch.rand((1, 12, 3200)), # Hidden states
]

# Sanity run
golden_output = framework_model(*inputs)

# Compile the model
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)

# Run on TT device
tt_out = compiled_model(*inputs)
tt_out = [out.to("cpu") for out in tt_out]

# Validate results
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99)
34 changes: 34 additions & 0 deletions pybuda/test/mlir/llama/tests/test_llama_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
import torch
import pytest

import pybuda
from test.mlir.llama.utils.utils import load_model
from pybuda.op.eval.common import compare_with_golden_pcc


@pytest.mark.xfail(reason="Tile broadcast op is not supported on MLIR.")
def test_llama_lm_head():
# Load Llama 3B model and tokenizer
framework_model = load_model()
framework_model = framework_model.model.norm

# Input samples
inputs = [
torch.rand((1, 12, 3200)), # Hidden states
]

# Sanity run
golden_output = framework_model(*inputs)

# Compile the model
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)

# Run on TT device
tt_out = compiled_model(*inputs)
tt_out = [out.to("cpu") for out in tt_out]

# Validate results
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99)
36 changes: 36 additions & 0 deletions pybuda/test/mlir/llama/tests/test_llama_rotary_emb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
import torch
import pytest

import pybuda
from test.mlir.llama.utils.utils import load_model
from pybuda.op.eval.common import compare_with_golden_pcc


@pytest.mark.xfail(reason="Dynamic shapes..")
def test_llama_rotary_emb():
# Load Llama 3B model and tokenizer
framework_model = load_model()
kv_seq_len = 12
framework_model = framework_model.model.layers[0].self_attn.rotary_emb

# Input samples
inputs = [
torch.rand((1, 32, kv_seq_len, 100)), # Value states
torch.unsqueeze(torch.tensor(kv_seq_len), 0) , # Sequence length
]

# Sanity run
golden_output = framework_model(*inputs)

# Compile the model
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)

# Run on TT device
tt_out = compiled_model(*inputs)
tt_out = [out.to("cpu") for out in tt_out]

# Validate results
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99)
47 changes: 47 additions & 0 deletions pybuda/test/mlir/llama/tests/test_llama_self_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
import torch
import pytest

import pybuda
from test.mlir.llama.utils.utils import load_model
from pybuda.op.eval.common import compare_with_golden_pcc


@pytest.mark.xfail(reason="Squeeze op is not supported on MLIR.")
def test_llama_self_attn():
# Define wrapper function
class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model

def forward(self, *inputs):
hidden_states, _, _ = self.model(*inputs)

return hidden_states

# Load Llama 3B model and tokenizer
framework_model = load_model()
framework_model = Wrapper(framework_model.model.layers[0].self_attn)

# Input samples
inputs = [
torch.rand((1, 12, 3200)), # Hidden states
torch.ones((1, 1, 12, 12)), # Attention mask
torch.arange(12).unsqueeze(0), # Position IDs
]

# Sanity run
golden_output = framework_model(*inputs)

# Compile the model
compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)

# Run on TT device
tt_out = compiled_model(*inputs)
tt_out = [out.to("cpu") for out in tt_out]

# Validate results
assert compare_with_golden_pcc(golden=golden_output, calculated=tt_out[0], pcc=0.99)
22 changes: 22 additions & 0 deletions pybuda/test/mlir/llama/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer

import pybuda

def load_model(model_path="openlm-research/open_llama_3b"):
# Compiler configurations
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.enable_tvm_cpu_fallback = False

# Load Llama 3B model
config = LlamaConfig()
config.hidden_size = 3200
config.intermediate_size = 8640
config.num_hidden_layers = 26
config.pad_token_id = 0
config.return_dict = False
framework_model = LlamaForCausalLM.from_pretrained(
model_path, device_map="auto", config=config
)
framework_model.eval()

return framework_model
9 changes: 9 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,19 @@ addopts = -svv --junit-xml=reports/report.xml

# Where pytest should look for tests
testpaths =
# Ops
pybuda/test/mlir/test_ops.py

# API
pybuda/test/test_api.py

# MNIST Linear
pybuda/test/mlir/mnist/test_inference.py
pybuda/test/mlir/test_training.py

# Llama 3B
pybuda/test/mlir/llama/test_llama_inference.py
pybuda/test/mlir/llama/tests

filterwarnings =
ignore::DeprecationWarning
Loading