diff --git a/tt_torch/csrc/tt-mlir-interface.cpp b/tt_torch/csrc/tt-mlir-interface.cpp index 44d01ae4..89d2fc5c 100644 --- a/tt_torch/csrc/tt-mlir-interface.cpp +++ b/tt_torch/csrc/tt-mlir-interface.cpp @@ -75,6 +75,11 @@ std::string compileStableHLOToTTIR(std::string_view code) { // conversion. mlir::PassManager shlo_pm(mlir_module.get()->getName(), mlir::PassManager::Nesting::Implicit); + const char *enable_printing = std::getenv("TT_TORCH_ENABLE_IR_PRINTING"); + if (enable_printing && std::string(enable_printing) == "1") { + shlo_pm.getContext()->disableMultithreading(); + shlo_pm.enableIRPrinting(); + } mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options; shlo_options.arithDialectConversionsEnabled = true; shlo_options.removeDeadValuesEnabled = true; @@ -122,6 +127,11 @@ compileTTIRToTTNN(std::string_view code) { mlir::tt::ttnn::registerPasses(); mlir::PassManager pm(mlir_module.get()->getName()); + const char *enable_printing = std::getenv("TT_TORCH_ENABLE_IR_PRINTING"); + if (enable_printing && std::string(enable_printing) == "1") { + pm.getContext()->disableMultithreading(); + pm.enableIRPrinting(); + } mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options; mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm, options); @@ -147,65 +157,8 @@ compileTTIRToTTNN(std::string_view code) { } std::shared_ptr *Compile(std::string_view code) { - - mlir::MLIRContext context; - mlir::DialectRegistry registry; - - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - - mlir::tt::registerAllDialects(registry); - mlir::stablehlo::registerAllDialects(registry); - mlir::func::registerAllExtensions(registry); - mlir::tt::registerAllExtensions(registry); - - context.appendDialectRegistry(registry); - - mlir::OwningOpRef mlir_module = - mlir::parseSourceString( - llvm::StringRef(code.data(), code.size()), - // IR may be invalid because some fields may be using DenseElements - // instead of DenseArray. We rectify that below and verify after. - mlir::ParserConfig{&context, /*verifyAfterParse=*/true}); - - mlir::tt::ttir::registerPasses(); - mlir::tt::ttnn::registerPasses(); - - // Implicit nesting required to call the stablehlo.composite --> func.call - // conversion. - mlir::PassManager shlo_pm(mlir_module.get()->getName(), - mlir::PassManager::Nesting::Implicit); - mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options; - shlo_options.arithDialectConversionsEnabled = true; - shlo_options.removeDeadValuesEnabled = true; - shlo_options.legalizeCompositeToCallEnabled = true; - mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_pm, shlo_options); - // Run the pass manager. - if (mlir::failed(shlo_pm.run(mlir_module.get()))) { - throw std::runtime_error( - "Failed to run StableHLO to TTIR compiler pass pipeline."); - } - - mlir::PassManager pm(mlir_module.get()->getName()); - mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options; - mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm, options); - - // Run the pass manager. - if (mlir::failed(pm.run(mlir_module.get()))) { - throw std::runtime_error( - "Failed to run TTIR TO TTNN compiler pass pipeline."); - } - - std::shared_ptr *binary = new std::shared_ptr(); - *binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get()); - - if (binary == nullptr) { - throw std::runtime_error("Failed to generate flatbuffer binary."); - } - - return binary; + std::string ttir = compileStableHLOToTTIR(code); + return std::get<0>(compileTTIRToTTNN(ttir)); } } // namespace tt::torch