Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run ops on device #25

Merged
merged 4 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion env/activate
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ else
fi
export TTTORCH_ENV_ACTIVATED=1
export TTMLIR_ENV_ACTIVATED=1
export PATH=$TTMLIR_TOOLCHAIN_DIR/bin:$PATH
export PATH=$TT_TORCH_HOME/third_party/tt-mlir/src/tt-mlir-build/bin:$TTMLIR_TOOLCHAIN_DIR/bin:$PATH
export TOKENIZERS_PARALLELISM=false
if [ -n "$PROJECT_ROOT" ]; then
export TT_METAL_HOME="$PROJECT_ROOT/third_party/tt-mlir/src/tt-mlir/third_party/tt-metal/src/tt-metal"
Expand Down
99 changes: 94 additions & 5 deletions results/parse_op_by_op_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import xlsxwriter
from mdutils.mdutils import MdUtils

import subprocess

# Script to parse the results of the unique ops json files and combine them into a spreadsheet
# This script parses models compiled into stable hlo / TTIR op by op
def find_json_files(directory="results"):
Expand Down Expand Up @@ -92,6 +94,10 @@ def process_json_files():
"Status",
"Ops",
"Raw SHLO",
"Raw TTIR",
"Raw TTNNIR",
"Compile Error",
"Trace dump",
)
worksheet.write_row(row, 0, header, bold)
row += 1
Expand All @@ -112,6 +118,9 @@ def process_json_files():
"status": value["compilation_status"],
"stable_hlo_graph": value["stable_hlo_graph"],
"ops": value["stable_hlo_ops"],
"ttir_graph": value["ttir_graph"],
"ttnn_graph": value["ttnn_graph"],
"key": key,
}
)
ops_per_model[model_name] = list(torch_ops.keys())
Expand All @@ -124,18 +133,74 @@ def process_json_files():
for torch_name, torch_op in sorted(torch_ops.items()):
stable_hlo_ops_per_torch_op[torch_name] = set()
name = torch_name
test_num = 0
for op in torch_op:
num_ops = op["num_ops"]
input_shapes = extract_shape(op["input_shapes"])
output_shapes = extract_shape(op["output_shapes"])
status = op["status"]
raw_shlo = op["stable_hlo_graph"]
ops = op["ops"]
row_data = [name, input_shapes, output_shapes, num_ops, status]
error = ""
trace_dump = ""
if status == 5 or status == 4:
if status == 5:
# Does not compile to TTNNIR, create unit test
test_name = f"{torch_name}_{test_num}.mlir"
test_num += 1
with open(f"results/mlir_tests/ttir/{test_name}", "w") as f:
f.write(op["ttir_graph"])

result = subprocess.run(
[
"ttmlir-opt",
"--ttir-to-ttnn-backend-pipeline",
f"results/mlir_tests/ttir/{test_name}",
],
capture_output=True,
text=True,
)
elif status == 4:
# Does not compile to TTIR, create unit test
test_name = f"{torch_name}_{test_num}.mlir"
test_num += 1
with open(
f"results/mlir_tests/stable_hlo/{test_name}", "w"
) as f:
f.write(op["stable_hlo_graph"])

result = subprocess.run(
[
"ttmlir-opt",
"--stablehlo-to-ttir-pipeline=enable-remove-dead-values=true",
f"results/mlir_tests/stable_hlo/{test_name}",
],
capture_output=True,
text=True,
)
if result.returncode != 0:
error = result.stderr.split("\n")[0]
trace_dump = result.stderr
print(error)
row_data = [
name,
input_shapes,
output_shapes,
num_ops,
status,
"",
raw_shlo,
op["ttir_graph"],
op["ttnn_graph"],
error,
trace_dump,
]
all_ops[op["key"]]["error"] = error
all_ops[op["key"]]["trace_dump"] = trace_dump
worksheet.write_row(row, 0, row_data)
name = ""
row += 1
row_data = ["", "", "", "", "", raw_shlo]
row_data = ["", "", "", "", "", "", raw_shlo]
worksheet.write_row(row, 0, row_data)
worksheet.set_row(row, None, None, {"hidden": True})
row += 1
Expand Down Expand Up @@ -171,6 +236,10 @@ def process_json_files():
"Status",
"Ops",
"Raw SHLO",
"Raw TTIR",
"Raw TTNNIR",
"Compile Error",
"Trace dump",
)
worksheet.write_row(row, 0, header, bold)
row += 1
Expand All @@ -191,6 +260,10 @@ def process_json_files():
"status": value["compilation_status"],
"stable_hlo_graph": value["stable_hlo_graph"],
"ops": value["stable_hlo_ops"],
"ttir_graph": value["ttir_graph"],
"ttnn_graph": value["ttnn_graph"],
"error": value["error"],
"trace_dump": value["trace_dump"],
}
)

Expand All @@ -203,11 +276,27 @@ def process_json_files():
status = op["status"]
raw_shlo = op["stable_hlo_graph"]
ops = op["ops"]
row_data = [name, input_shapes, output_shapes, num_ops, status]
ttir_graph = op["ttir_graph"]
ttnn_graph = op["ttnn_graph"]
error = op["error"]
trace_dump = op["trace_dump"]
row_data = [
name,
input_shapes,
output_shapes,
num_ops,
status,
"",
raw_shlo,
ttir_graph,
ttnn_graph,
error,
trace_dump,
]
name = ""
worksheet.write_row(row, 0, row_data)
row += 1
row_data = ["", "", "", "", "", raw_shlo]
row_data = ["", "", "", "", "", "", raw_shlo, ttir_graph, ttnn_graph]
worksheet.write_row(row, 0, row_data)
worksheet.set_row(row, None, None, {"hidden": True})
row += 1
Expand All @@ -217,7 +306,7 @@ def process_json_files():
worksheet.set_row(row, None, None, {"hidden": True})
row += 1

worksheet.autofit()
# worksheet.autofit()

ops = list(models_per_op.keys())
ops.sort()
Expand Down
4 changes: 2 additions & 2 deletions tests/torch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,10 @@ def __init__(self):

def forward(self, x):
y = x + x
z = y + x
z = y + y
z = torch.argmax(z)
return z

cc = CompilerConfig()
cc.compile_depth = tt_torch.tools.utils.CompileDepth.COMPILE_OP_BY_OP
cc.compile_depth = tt_torch.tools.utils.CompileDepth.EXECUTE_OP_BY_OP
verify_module(Basic(), [(256, 256)], compiler_config=cc, do_assert=False)
55 changes: 46 additions & 9 deletions tt_torch/csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@

namespace py = pybind11;

tt::runtime::Binary compile(std::string_view code) {
return tt::torch::Compile(code);
}

static tt::target::DataType torch_scalar_type_to_dt(torch::ScalarType st) {
switch (st) {
case torch::ScalarType::Byte:
Expand Down Expand Up @@ -88,8 +84,19 @@ std::vector<int64_t> as_vec_int64(std::vector<T> const &vec) {
}
return result;
}

std::vector<at::Tensor> run(const std::vector<at::Tensor> &inputs,
tt::runtime::Binary binary) {
py::bytes byte_stream) {

std::string data_str = byte_stream;
auto binary_ptr = std::shared_ptr<void>(
new char[data_str.size()],
[](void *ptr) { delete[] static_cast<char *>(ptr); } // Custom deleter
);
// Copy data into the allocated memory
std::memcpy(binary_ptr.get(), data_str.data(), data_str.size());
tt::runtime::Binary binary = tt::runtime::Binary(binary_ptr);

auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc();
int dev_0 = chip_ids[0];
auto device = tt::runtime::openDevice({dev_0});
Expand Down Expand Up @@ -124,14 +131,44 @@ std::vector<at::Tensor> run(const std::vector<at::Tensor> &inputs,
return outputs;
}

std::string compile_stable_hlo_to_ttir(std::string_view code) {
auto ret = tt::torch::compileStableHLOToTTIR(code);
return ret;
}

std::tuple<py::bytes, std::string>
compile_ttir_to_bytestream(std::string_view code) {
auto [binary, ttnn] = tt::torch::compileTTIRToTTNN(code);
auto size = ::flatbuffers::GetSizePrefixedBufferLength(
static_cast<const uint8_t *>(binary->get()));

std::string data_str(static_cast<const char *>(binary->get()), size);
delete binary;

return std::make_tuple(py::bytes(data_str), ttnn);
}

py::bytes compile_stablehlo_to_bytestream(std::string_view code) {
auto binary = tt::torch::Compile(code);
auto size = ::flatbuffers::GetSizePrefixedBufferLength(
static_cast<const uint8_t *>(binary->get()));

std::string data_str(static_cast<const char *>(binary->get()), size);
delete binary;
return py::bytes(data_str);
}

PYBIND11_MODULE(tt_mlir, m) {
m.doc() = "tt_mlir";
py::class_<tt::runtime::Binary>(m, "Binary")
.def("getProgramInputs", &tt::runtime::Binary::getProgramInputs)
.def("getProgramOutputs", &tt::runtime::Binary::getProgramOutputs)
.def("as_json", &tt::runtime::Binary::asJson);
m.def("compile", &compile,
"A function that compiles a stableHLO model to a flatbuffer");
.def("getProgramOutputs", &tt::runtime::Binary::getProgramOutputs);
m.def("compile", &compile_stablehlo_to_bytestream,
"A function that compiles stableHLO to a bytestream");
m.def("compile_ttir_to_bytestream", &compile_ttir_to_bytestream,
"A function that compiles TTIR to a bytestream");
m.def("compile_stable_hlo_to_ttir", &compile_stable_hlo_to_ttir,
"A function that compiles stableHLO to TTIR");
m.def("run", &run, "Push inputs and run binary");
m.def("get_current_system_desc", &tt::runtime::getCurrentSystemDesc,
"Get the current system descriptor");
Expand Down
Loading
Loading