Skip to content

Commit

Permalink
Merge pull request #87 from tenstorrent/nvukobrat/llama_placeholder
Browse files Browse the repository at this point in the history
Placeholder for Llama 3B + minor fix
  • Loading branch information
nvukobratTT authored Aug 15, 2024
2 parents 0b5aa21 + 8316409 commit 7af981e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ tt_debug
build
net2pipe_output/
third_party/llvm
venv/

/llk_out/

Expand Down
11 changes: 7 additions & 4 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ class MLIRGenerator
return builder_.create<mlir::tensor::EmptyOp>(
get_tt_forge_operation_location(graph, node),
shape_vec,
get_float_type(node));
get_data_type(node));
}

/// Emit the return operation for the function.
Expand All @@ -364,8 +364,8 @@ class MLIRGenerator
mlir::ValueRange(returnValues));
}

/// Get the MLIR float type type for a TTForge node.
mlir::FloatType get_float_type(graphlib::Node *node)
/// Get the MLIR data type for a TTForge node.
mlir::Type get_data_type(graphlib::Node *node)
{
switch (node->output_df())
{
Expand All @@ -375,7 +375,10 @@ class MLIRGenerator
return builder_.getBF16Type();
case tt::DataFormat::Float16:
return builder_.getF16Type();
case tt::DataFormat::Int8:
return builder_.getI8Type();
default:
log_error("Unsupported data format during lowering from TTForge to TTIR: {}", node->output_df());
TT_ASSERT(false);
}
// TODO add all supported types in switch
Expand All @@ -390,7 +393,7 @@ class MLIRGenerator
{
shape_vec.push_back((int64_t)dim);
}
return mlir::RankedTensorType::get(shape_vec, get_float_type(node));
return mlir::RankedTensorType::get(shape_vec, get_data_type(node));
}

/// Get the location for a module.
Expand Down
37 changes: 37 additions & 0 deletions pybuda/test/mlir/llama/test_llama_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer

import pybuda


def test_llama_inference():
# Compiler configurations
compiler_cfg = pybuda.config._get_global_compiler_config()
compiler_cfg.enable_tvm_cpu_fallback = False

# Load Llama 3B model and tokenizer
model_path = "openlm-research/open_llama_3b"
config = LlamaConfig()
config.hidden_size = 3200
config.intermediate_size = 8640
config.num_hidden_layers = 26
config.pad_token_id = 0
config.return_dict = False
framework_model = LlamaForCausalLM.from_pretrained(
model_path, device_map="auto", config=config
)
framework_model.eval()
tokenizer = LlamaTokenizer.from_pretrained(model_path)

prompt = "Q: What is the largest animal?\nA:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

# Sanity run
generation_output = framework_model.generate(input_ids=input_ids, max_new_tokens=32)
print(tokenizer.decode(generation_output[0]))

# Compile the model
compiled_model = pybuda.compile(framework_model, input_ids)

0 comments on commit 7af981e

Please sign in to comment.