Skip to content

Commit

Permalink
Dump IR when TT_TORCH_ENABLE_IR_PRINTING is set
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksKnezevic committed Jan 3, 2025
1 parent 50eeae1 commit 2566b58
Showing 1 changed file with 12 additions and 59 deletions.
71 changes: 12 additions & 59 deletions tt_torch/csrc/tt-mlir-interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ std::string compileStableHLOToTTIR(std::string_view code) {
// conversion.
mlir::PassManager shlo_pm(mlir_module.get()->getName(),
mlir::PassManager::Nesting::Implicit);
const char *enable_printing = std::getenv("TT_TORCH_ENABLE_IR_PRINTING");
if (enable_printing && std::string(enable_printing) == "1") {
shlo_pm.getContext()->disableMultithreading();
shlo_pm.enableIRPrinting();
}
mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options;
shlo_options.arithDialectConversionsEnabled = true;
shlo_options.removeDeadValuesEnabled = true;
Expand Down Expand Up @@ -122,6 +127,11 @@ compileTTIRToTTNN(std::string_view code) {
mlir::tt::ttnn::registerPasses();

mlir::PassManager pm(mlir_module.get()->getName());
const char *enable_printing = std::getenv("TT_TORCH_ENABLE_IR_PRINTING");
if (enable_printing && std::string(enable_printing) == "1") {
pm.getContext()->disableMultithreading();
pm.enableIRPrinting();
}
mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options;
mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm, options);

Expand All @@ -147,65 +157,8 @@ compileTTIRToTTNN(std::string_view code) {
}

std::shared_ptr<void> *Compile(std::string_view code) {

mlir::MLIRContext context;
mlir::DialectRegistry registry;

registry.insert<mlir::arith::ArithDialect>();
registry.insert<mlir::func::FuncDialect>();
registry.insert<mlir::ml_program::MLProgramDialect>();
registry.insert<mlir::shape::ShapeDialect>();

mlir::tt::registerAllDialects(registry);
mlir::stablehlo::registerAllDialects(registry);
mlir::func::registerAllExtensions(registry);
mlir::tt::registerAllExtensions(registry);

context.appendDialectRegistry(registry);

mlir::OwningOpRef<mlir::ModuleOp> mlir_module =
mlir::parseSourceString<mlir::ModuleOp>(
llvm::StringRef(code.data(), code.size()),
// IR may be invalid because some fields may be using DenseElements
// instead of DenseArray. We rectify that below and verify after.
mlir::ParserConfig{&context, /*verifyAfterParse=*/true});

mlir::tt::ttir::registerPasses();
mlir::tt::ttnn::registerPasses();

// Implicit nesting required to call the stablehlo.composite --> func.call
// conversion.
mlir::PassManager shlo_pm(mlir_module.get()->getName(),
mlir::PassManager::Nesting::Implicit);
mlir::tt::ttir::StableHLOToTTIRPipelineOptions shlo_options;
shlo_options.arithDialectConversionsEnabled = true;
shlo_options.removeDeadValuesEnabled = true;
shlo_options.legalizeCompositeToCallEnabled = true;
mlir::tt::ttir::createStableHLOToTTIRPipeline(shlo_pm, shlo_options);
// Run the pass manager.
if (mlir::failed(shlo_pm.run(mlir_module.get()))) {
throw std::runtime_error(
"Failed to run StableHLO to TTIR compiler pass pipeline.");
}

mlir::PassManager pm(mlir_module.get()->getName());
mlir::tt::ttnn::TTIRToTTNNBackendPipelineOptions options;
mlir::tt::ttnn::createTTIRToTTNNBackendPipeline(pm, options);

// Run the pass manager.
if (mlir::failed(pm.run(mlir_module.get()))) {
throw std::runtime_error(
"Failed to run TTIR TO TTNN compiler pass pipeline.");
}

std::shared_ptr<void> *binary = new std::shared_ptr<void>();
*binary = mlir::tt::ttnn::ttnnToFlatbuffer(mlir_module.get());

if (binary == nullptr) {
throw std::runtime_error("Failed to generate flatbuffer binary.");
}

return binary;
std::string ttir = compileStableHLOToTTIR(code);
return std::get<0>(compileTTIRToTTNN(ttir));
}

} // namespace tt::torch

0 comments on commit 2566b58

Please sign in to comment.