diff --git a/env/activate b/env/activate index 8e7a1171..b95434c2 100644 --- a/env/activate +++ b/env/activate @@ -32,7 +32,7 @@ else fi export TTTORCH_ENV_ACTIVATED=1 export TTMLIR_ENV_ACTIVATED=1 - export PATH=$TTMLIR_TOOLCHAIN_DIR/bin:$PATH + export PATH=$TT_TORCH_HOME/third_party/tt-mlir/src/tt-mlir-build/bin:$TTMLIR_TOOLCHAIN_DIR/bin:$PATH export TOKENIZERS_PARALLELISM=false if [ -n "$PROJECT_ROOT" ]; then export TT_METAL_HOME="$PROJECT_ROOT/third_party/tt-mlir/src/tt-mlir/third_party/tt-metal/src/tt-metal" diff --git a/results/parse_op_by_op_results.py b/results/parse_op_by_op_results.py index 8e4341cd..afba9e26 100644 --- a/results/parse_op_by_op_results.py +++ b/results/parse_op_by_op_results.py @@ -8,6 +8,8 @@ import xlsxwriter from mdutils.mdutils import MdUtils +import subprocess + # Script to parse the results of the unique ops json files and combine them into a spreadsheet # This script parses models compiled into stable hlo / TTIR op by op def find_json_files(directory="results"): @@ -92,6 +94,10 @@ def process_json_files(): "Status", "Ops", "Raw SHLO", + "Raw TTIR", + "Raw TTNNIR", + "Compile Error", + "Trace dump", ) worksheet.write_row(row, 0, header, bold) row += 1 @@ -112,6 +118,9 @@ def process_json_files(): "status": value["compilation_status"], "stable_hlo_graph": value["stable_hlo_graph"], "ops": value["stable_hlo_ops"], + "ttir_graph": value["ttir_graph"], + "ttnn_graph": value["ttnn_graph"], + "key": key, } ) ops_per_model[model_name] = list(torch_ops.keys()) @@ -124,6 +133,7 @@ def process_json_files(): for torch_name, torch_op in sorted(torch_ops.items()): stable_hlo_ops_per_torch_op[torch_name] = set() name = torch_name + test_num = 0 for op in torch_op: num_ops = op["num_ops"] input_shapes = extract_shape(op["input_shapes"]) @@ -131,11 +141,66 @@ def process_json_files(): status = op["status"] raw_shlo = op["stable_hlo_graph"] ops = op["ops"] - row_data = [name, input_shapes, output_shapes, num_ops, status] + error = "" + trace_dump = "" + if status == 5 or status == 4: + if status == 5: + # Does not compile to TTNNIR, create unit test + test_name = f"{torch_name}_{test_num}.mlir" + test_num += 1 + with open(f"results/mlir_tests/ttir/{test_name}", "w") as f: + f.write(op["ttir_graph"]) + + result = subprocess.run( + [ + "ttmlir-opt", + "--ttir-to-ttnn-backend-pipeline", + f"results/mlir_tests/ttir/{test_name}", + ], + capture_output=True, + text=True, + ) + elif status == 4: + # Does not compile to TTIR, create unit test + test_name = f"{torch_name}_{test_num}.mlir" + test_num += 1 + with open( + f"results/mlir_tests/stable_hlo/{test_name}", "w" + ) as f: + f.write(op["stable_hlo_graph"]) + + result = subprocess.run( + [ + "ttmlir-opt", + "--stablehlo-to-ttir-pipeline=enable-remove-dead-values=true", + f"results/mlir_tests/stable_hlo/{test_name}", + ], + capture_output=True, + text=True, + ) + if result.returncode != 0: + error = result.stderr.split("\n")[0] + trace_dump = result.stderr + print(error) + row_data = [ + name, + input_shapes, + output_shapes, + num_ops, + status, + "", + raw_shlo, + op["ttir_graph"], + op["ttnn_graph"], + error, + trace_dump, + ] + all_ops[op["key"]]["error"] = error + all_ops[op["key"]]["trace_dump"] = trace_dump worksheet.write_row(row, 0, row_data) name = "" row += 1 - row_data = ["", "", "", "", "", raw_shlo] + row_data = ["", "", "", "", "", "", raw_shlo] worksheet.write_row(row, 0, row_data) worksheet.set_row(row, None, None, {"hidden": True}) row += 1 @@ -171,6 +236,10 @@ def process_json_files(): "Status", "Ops", "Raw SHLO", + "Raw TTIR", + "Raw TTNNIR", + "Compile Error", + "Trace dump", ) worksheet.write_row(row, 0, header, bold) row += 1 @@ -191,6 +260,10 @@ def process_json_files(): "status": value["compilation_status"], "stable_hlo_graph": value["stable_hlo_graph"], "ops": value["stable_hlo_ops"], + "ttir_graph": value["ttir_graph"], + "ttnn_graph": value["ttnn_graph"], + "error": value["error"], + "trace_dump": value["trace_dump"], } ) @@ -203,11 +276,27 @@ def process_json_files(): status = op["status"] raw_shlo = op["stable_hlo_graph"] ops = op["ops"] - row_data = [name, input_shapes, output_shapes, num_ops, status] + ttir_graph = op["ttir_graph"] + ttnn_graph = op["ttnn_graph"] + error = op["error"] + trace_dump = op["trace_dump"] + row_data = [ + name, + input_shapes, + output_shapes, + num_ops, + status, + "", + raw_shlo, + ttir_graph, + ttnn_graph, + error, + trace_dump, + ] name = "" worksheet.write_row(row, 0, row_data) row += 1 - row_data = ["", "", "", "", "", raw_shlo] + row_data = ["", "", "", "", "", "", raw_shlo, ttir_graph, ttnn_graph] worksheet.write_row(row, 0, row_data) worksheet.set_row(row, None, None, {"hidden": True}) row += 1 @@ -217,7 +306,7 @@ def process_json_files(): worksheet.set_row(row, None, None, {"hidden": True}) row += 1 - worksheet.autofit() + # worksheet.autofit() ops = list(models_per_op.keys()) ops.sort() diff --git a/tests/torch/test_basic.py b/tests/torch/test_basic.py index 1bec508d..88542d18 100644 --- a/tests/torch/test_basic.py +++ b/tests/torch/test_basic.py @@ -346,10 +346,10 @@ def __init__(self): def forward(self, x): y = x + x - z = y + x + z = y + y z = torch.argmax(z) return z cc = CompilerConfig() - cc.compile_depth = tt_torch.tools.utils.CompileDepth.COMPILE_OP_BY_OP + cc.compile_depth = tt_torch.tools.utils.CompileDepth.EXECUTE_OP_BY_OP verify_module(Basic(), [(256, 256)], compiler_config=cc, do_assert=False) diff --git a/tt_torch/csrc/bindings.cpp b/tt_torch/csrc/bindings.cpp index 5f666002..6a742e72 100644 --- a/tt_torch/csrc/bindings.cpp +++ b/tt_torch/csrc/bindings.cpp @@ -8,10 +8,6 @@ namespace py = pybind11; -tt::runtime::Binary compile(std::string_view code) { - return tt::torch::Compile(code); -} - static tt::target::DataType torch_scalar_type_to_dt(torch::ScalarType st) { switch (st) { case torch::ScalarType::Byte: @@ -88,8 +84,19 @@ std::vector as_vec_int64(std::vector const &vec) { } return result; } + std::vector run(const std::vector &inputs, - tt::runtime::Binary binary) { + py::bytes byte_stream) { + + std::string data_str = byte_stream; + auto binary_ptr = std::shared_ptr( + new char[data_str.size()], + [](void *ptr) { delete[] static_cast(ptr); } // Custom deleter + ); + // Copy data into the allocated memory + std::memcpy(binary_ptr.get(), data_str.data(), data_str.size()); + tt::runtime::Binary binary = tt::runtime::Binary(binary_ptr); + auto [system_desc, chip_ids] = tt::runtime::getCurrentSystemDesc(); int dev_0 = chip_ids[0]; auto device = tt::runtime::openDevice({dev_0}); @@ -124,14 +131,44 @@ std::vector run(const std::vector &inputs, return outputs; } +std::string compile_stable_hlo_to_ttir(std::string_view code) { + auto ret = tt::torch::compileStableHLOToTTIR(code); + return ret; +} + +std::tuple +compile_ttir_to_bytestream(std::string_view code) { + auto [binary, ttnn] = tt::torch::compileTTIRToTTNN(code); + auto size = ::flatbuffers::GetSizePrefixedBufferLength( + static_cast(binary->get())); + + std::string data_str(static_cast(binary->get()), size); + delete binary; + + return std::make_tuple(py::bytes(data_str), ttnn); +} + +py::bytes compile_stablehlo_to_bytestream(std::string_view code) { + auto binary = tt::torch::Compile(code); + auto size = ::flatbuffers::GetSizePrefixedBufferLength( + static_cast(binary->get())); + + std::string data_str(static_cast(binary->get()), size); + delete binary; + return py::bytes(data_str); +} + PYBIND11_MODULE(tt_mlir, m) { m.doc() = "tt_mlir"; py::class_(m, "Binary") .def("getProgramInputs", &tt::runtime::Binary::getProgramInputs) - .def("getProgramOutputs", &tt::runtime::Binary::getProgramOutputs) - .def("as_json", &tt::runtime::Binary::asJson); - m.def("compile", &compile, - "A function that compiles a stableHLO model to a flatbuffer"); + .def("getProgramOutputs", &tt::runtime::Binary::getProgramOutputs); + m.def("compile", &compile_stablehlo_to_bytestream, + "A function that compiles stableHLO to a bytestream"); + m.def("compile_ttir_to_bytestream", &compile_ttir_to_bytestream, + "A function that compiles TTIR to a bytestream"); + m.def("compile_stable_hlo_to_ttir", &compile_stable_hlo_to_ttir, + "A function that compiles stableHLO to TTIR"); m.def("run", &run, "Push inputs and run binary"); m.def("get_current_system_desc", &tt::runtime::getCurrentSystemDesc, "Get the current system descriptor"); diff --git a/tt_torch/csrc/tt-mlir-interface.cpp b/tt_torch/csrc/tt-mlir-interface.cpp index c14fdc13..95fce054 100644 --- a/tt_torch/csrc/tt-mlir-interface.cpp +++ b/tt_torch/csrc/tt-mlir-interface.cpp @@ -46,7 +46,108 @@ namespace tt::torch { -tt::runtime::Binary Compile(std::string_view code) { +std::string compileStableHLOToTTIR(std::string_view code) { + mlir::MLIRContext context; + mlir::DialectRegistry registry; + + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + + mlir::tt::registerAllDialects(registry); + mlir::stablehlo::registerAllDialects(registry); + mlir::func::registerAllExtensions(registry); + mlir::tt::registerAllExtensions(registry); + + context.appendDialectRegistry(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceString( + llvm::StringRef(code.data(), code.size()), + // IR may be invalid because some fields may be using DenseElements + // instead of DenseArray. We rectify that below and verify after. + mlir::ParserConfig{&context, /*verifyAfterParse=*/true}); + + mlir::tt::ttir::registerPasses(); + mlir::tt::ttnn::registerPasses(); + + // Implicit nesting required to call the stablehlo.composite --> func.call + // conversion. + mlir::PassManager shlo_pm(mlir_module.get()->getName(), + mlir::PassManager::Nesting::Implicit); + mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options; + shlo_options.arithDialectConversionsEnabled = true; + shlo_options.removeDeadValuesEnabled = true; + shlo_options.legalizeCompositeToCallEnabled = true; + mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_pm, shlo_options); + // Run the pass manager. + if (mlir::failed(shlo_pm.run(mlir_module.get()))) { + throw std::runtime_error( + "Failed to run StableHLO to TTIR compiler pass pipeline."); + } + std::string buffer; + llvm::raw_string_ostream os(buffer); + mlir_module.get()->print(os); + os.flush(); + + return buffer; +} + +std::tuple *, std::string> +compileTTIRToTTNN(std::string_view code) { + + mlir::MLIRContext context; + mlir::DialectRegistry registry; + + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + + mlir::tt::registerAllDialects(registry); + mlir::stablehlo::registerAllDialects(registry); + mlir::func::registerAllExtensions(registry); + mlir::tt::registerAllExtensions(registry); + + context.appendDialectRegistry(registry); + + mlir::OwningOpRef mlir_module = + mlir::parseSourceString( + llvm::StringRef(code.data(), code.size()), + // IR may be invalid because some fields may be using DenseElements + // instead of DenseArray. We rectify that below and verify after. + mlir::ParserConfig{&context, /*verifyAfterParse=*/true}); + + mlir::tt::ttir::registerPasses(); + mlir::tt::ttnn::registerPasses(); + + mlir::PassManager pm(mlir_module.get()->getName()); + mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options; + mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm, options); + + // Run the pass manager. + if (mlir::failed(pm.run(mlir_module.get()))) { + throw std::runtime_error( + "Failed to run TTIR TO TTNN compiler pass pipeline."); + } + + std::shared_ptr *binary = new std::shared_ptr(); + *binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get()); + + if (binary == nullptr) { + throw std::runtime_error("Failed to generate flatbuffer binary."); + } + + std::string buffer; + llvm::raw_string_ostream os(buffer); + mlir_module->print(os); + os.flush(); + + return std::make_tuple(binary, buffer); +} + +std::shared_ptr *Compile(std::string_view code) { mlir::MLIRContext context; mlir::DialectRegistry registry; @@ -84,7 +185,8 @@ tt::runtime::Binary Compile(std::string_view code) { mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_pm, shlo_options); // Run the pass manager. if (mlir::failed(shlo_pm.run(mlir_module.get()))) { - throw std::runtime_error("Failed to run MLIR compiler pass pipeline."); + throw std::runtime_error( + "Failed to run StableHLO to TTIR compiler pass pipeline."); } mlir::PassManager pm(mlir_module.get()->getName()); @@ -93,17 +195,17 @@ tt::runtime::Binary Compile(std::string_view code) { // Run the pass manager. if (mlir::failed(pm.run(mlir_module.get()))) { - throw std::runtime_error("Failed to run MLIR compiler pass pipeline."); + throw std::runtime_error( + "Failed to run TTIR TO TTNN compiler pass pipeline."); } - std::shared_ptr binary_ptr = - mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get()); + std::shared_ptr *binary = new std::shared_ptr(); + *binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get()); - if (binary_ptr == nullptr) { + if (binary == nullptr) { throw std::runtime_error("Failed to generate flatbuffer binary."); } - tt::runtime::Binary binary(binary_ptr); return binary; } diff --git a/tt_torch/csrc/tt-mlir-interface.hpp b/tt_torch/csrc/tt-mlir-interface.hpp index 5614681b..239d01e5 100644 --- a/tt_torch/csrc/tt-mlir-interface.hpp +++ b/tt_torch/csrc/tt-mlir-interface.hpp @@ -9,5 +9,8 @@ #include "tt/runtime/runtime.h" namespace tt::torch { -tt::runtime::Binary Compile(std::string_view code); +std::shared_ptr *Compile(std::string_view code); +std::string compileStableHLOToTTIR(std::string_view code); +std::tuple *, std::string> +compileTTIRToTTNN(std::string_view code); } // namespace tt::torch diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 5fdbd5db..cca333b0 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -105,9 +105,23 @@ def compile_process(receiver, sender): obj = receiver.get() faulthandler.disable() asm = obj["asm"] - binary = tt_mlir.compile(asm) - result = {"binary": binary.as_json()} - sender.put({"binary": result}) + ttir = tt_mlir.compile_stable_hlo_to_ttir(asm) + sender.put({"ttir": ttir}) + time.sleep(0.1) + binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir) + sender.put({"binary": binary, "ttnn": ttnn}) + time.sleep(0.1) + sys.exit(0) + + +def execute_process(receiver, sender): + obj = receiver.get() + faulthandler.disable() + binary = obj["binary"] + inputs = obj["inputs"] + outputs = tt_mlir.run(inputs, binary) + sender.put({"outputs": outputs}) + time.sleep(0.1) sys.exit(0) @@ -173,13 +187,14 @@ def compile_op(self, node, *inputs, **kwargs): placeholders.append(inps) else: placeholders.append(inp) - placeholders = tuple(placeholders) if len(placeholders) != len(node.args): - raise ValueError( - f"Placeholders and args must be the same length: {len(placeholders)} != {len(node.args)}" - ) + # are any of the args duplicates? If so, we need to duplicate the placeholders + for idx, arg in enumerate(node.args): + if arg in node.args[idx + 1 :]: + placeholders.append(placeholders[idx]) + placeholders = tuple(placeholders) for placeholder, arg in zip(placeholders, node.args): if isinstance(placeholder, torch.fx.node.Node): placeholder.meta["tensor_meta"] = arg.meta["tensor_meta"] @@ -232,32 +247,57 @@ def compile_op(self, node, *inputs, **kwargs): process.start() sender.put(obj) start = time.time() + binary = None + while True: + try: + result = receiver.get_nowait() + if "ttir" in result: + op.compilation_status = OpCompilationStatus.CONVERTED_TO_TTIR + op.add_ttir_graph(result["ttir"]) + if "binary" in result: + binary = result["binary"] + op.binary = binary + op.add_ttnn_graph(result["ttnn"]) + op.compilation_status = OpCompilationStatus.CONVERTED_TO_TTNN + break + except mp.queues.Empty: + pass + if time.time() - start > self.compiler_config.single_op_timeout: + process.terminate() + break + if not process.is_alive(): + break + time.sleep(0.01) + process.join() + return binary, op + + def run_op(self, binary, *inputs): + sender = mp.Queue() + receiver = mp.Queue() + obj = {"binary": binary, "inputs": inputs} + + process = mp.Process(target=execute_process, args=(sender, receiver)) + process.start() + sender.put(obj) result = {} - result["binary"] = "" + start = time.time() while True: if not process.is_alive(): break try: result = receiver.get_nowait() - op.compilation_status = OpCompilationStatus.CONVERTED_TO_TTNN_IR break except mp.queues.Empty: pass if time.time() - start > self.compiler_config.single_op_timeout: process.terminate() + print("Timeout") break - time.sleep(0.01) + time.sleep(0.05) process.join() - return result["binary"], op - - def run_op(self, binary, *inputs): - pid = os.fork() - if pid == 0: - outputs = tt_mlir.run(inputs, binary) - if len(outputs) == 1: - outputs = outputs[0] - else: - pid, status = os.wait() + outputs = result["outputs"] + if len(outputs) == 1: + outputs = outputs[0] return outputs def run_gm_op_by_op(self, *inputs): diff --git a/tt_torch/tools/utils.py b/tt_torch/tools/utils.py index e5d08088..35d556b6 100644 --- a/tt_torch/tools/utils.py +++ b/tt_torch/tools/utils.py @@ -24,8 +24,9 @@ class OpCompilationStatus(IntEnum): CONVERTED_TO_TORCH_IR = 2 CONVERTED_TO_TORCH_BACKEND_IR = 3 CONVERTED_TO_STABLE_HLO = 4 - CONVERTED_TO_TTNN_IR = 5 - EXECUTED = 6 + CONVERTED_TO_TTIR = 5 + CONVERTED_TO_TTNN = 6 + EXECUTED = 7 class Op: @@ -37,6 +38,8 @@ def __init__(self, torch_name, input_shapes): self.stable_hlo_graph = "" self.stable_hlo_ops = [] + self.ttir_graph = "" + self.ttnn_graph = "" self.compilation_status = OpCompilationStatus.NOT_STARTED self.parsed_stable_hlo_ops = False @@ -56,6 +59,8 @@ def to_dict(self): "parsed_stable_hlo_ops": self.parsed_stable_hlo_ops, "stable_hlo_graph": self.stable_hlo_graph, "stable_hlo_ops": self.stable_hlo_ops, + "ttir_graph": self.ttir_graph, + "ttnn_graph": self.ttnn_graph, } def unique_key(self): @@ -76,6 +81,12 @@ def add_stable_hlo_graph(self, stable_hlo_graph: str): except: self.parsed_stable_hlo_ops = False + def add_ttir_graph(self, ttir_graph: str): + self.ttir_graph = ttir_graph + + def add_ttnn_graph(self, ttnn_graph: str): + self.ttnn_graph = ttnn_graph + class CompilerConfig: def __init__(self): diff --git a/tt_torch/tools/verify.py b/tt_torch/tools/verify.py index fc275143..0b551ef8 100644 --- a/tt_torch/tools/verify.py +++ b/tt_torch/tools/verify.py @@ -34,9 +34,8 @@ def _verify_torch_module( golden = mod(*inputs) atol = torch.max(torch.abs(golden - ret)).item() - assert ( - do_assert and atol - ) <= required_atol, f"ATOL too high: {atol} vs {required_atol}" + if do_assert: + assert atol <= required_atol, f"ATOL too high: {atol} vs {required_atol}" if np.prod(golden.shape) == 1: return @@ -46,7 +45,8 @@ def _verify_torch_module( np.ma.masked_invalid(torch.squeeze(golden).detach().numpy()).flatten(), ) ) - assert (do_assert and pcc) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}" + if do_assert: + assert pcc >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}" def _verify_onnx_module(