diff --git a/pybuda/csrc/passes/lower_to_mlir.cpp b/pybuda/csrc/passes/lower_to_mlir.cpp index 5317ebeaa..10bb809ec 100644 --- a/pybuda/csrc/passes/lower_to_mlir.cpp +++ b/pybuda/csrc/passes/lower_to_mlir.cpp @@ -4,14 +4,22 @@ #include "lower_to_mlir.hpp" // Standard headers +#include #include #include -// PyBuda headers +// TTForge headers #include "graph_lib/graph.hpp" #include "graph_lib/node.hpp" #include "graph_lib/utils.hpp" #include "graph_lib/node_types.hpp" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" #include "utils/logger.hpp" // MLIR headers @@ -27,9 +35,7 @@ #pragma clang diagnostic pop // TTMLIR headers -#include "ttmlir/Dialect/TT/IR/TT.h" #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/IR/TTIR.h" #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" @@ -37,24 +43,29 @@ namespace { using namespace tt; /** - * @brief Implementation of TT-MLIR emission from the PyBuda graph. + * @brief Implementation of TT-MLIR emission from the TTForge graph. */ class MLIRGenerator { public: /// Construct a new MLIRGenerator object. - MLIRGenerator(mlir::MLIRContext &context) : builder_(&context) {} + MLIRGenerator(mlir::MLIRContext &context) : builder_(&context) + { + tt::log_info("MLIRGenerator"); + init_lowering_handler_map(); + } - /// Public API: Convert the PyBuda graph into an MLIR module operation for TTIR. + /// Public API: Convert the TTForge graph into an MLIR module operation for TTIR. mlir::ModuleOp emit_mlir(graphlib::Graph *graph) { - graphModule_ = mlir::ModuleOp::create(get_module_location(graph), "pybuda_graph"); + graphModule_ = mlir::ModuleOp::create(get_module_location(graph), "tt-forge-graph"); graphModule_->setAttr(mlir::tt::SystemDescAttr::name, mlir::tt::SystemDescAttr::getDefault(builder_.getContext())); - builder_.setInsertionPointToStart(&graphModule_.getBodyRegion().front()); - emit_mlir_function(graph); + mlir::OpPrintingFlags printFlags; + printFlags.enableDebugInfo(); + graphModule_.print(llvm::outs(), printFlags); /// Verify the module after we have finished constructing it, this will check /// the structural properties of the IR and invoke any specific verifiers we @@ -62,31 +73,33 @@ class MLIRGenerator if (failed(mlir::verify(graphModule_))) { graphModule_.emitError("module verification failed."); - return nullptr; + throw std::runtime_error("Generated MLIR module failed verification."); } - - mlir::OpPrintingFlags printFlags; - printFlags.enableDebugInfo(); - graphModule_.print(llvm::outs(), printFlags); - return graphModule_; } - private: - /// A "module" matches a PyBuda graph: containing a single function to exectue. + /// A "module" matches a TTForge graph: containing a single function to exectue. mlir::ModuleOp graphModule_; + /// The builder is a helper class to create IR. The builder /// is stateful, in particular it keeps an "insertion point": this is where /// the next operations will be introduced. mlir::OpBuilder builder_; - // The symbol table maintains a mapping between the names of pybuda nodes and their corresponding values in the current scope. - // Initially, the function arguments (model activations) are added to the symbol table. - // After evaluating each pybuda op node, the declare function adds a new entry to the symbol table for future reference. + + /// The symbol table maintains a mapping between the names of ttforge nodes and their corresponding values in the current scope. + /// Initially, the function arguments (model activations) are added to the symbol table. + /// After evaluating each ttforge op node, the declare function adds a new entry to the symbol table for future reference. std::map> symbolTable_; + /// Handler type for lowering ttforge operations to MLIR. + using HandlerType = mlir::Value (MLIRGenerator::*)(tt::graphlib::Graph *, tt::graphlib::OpNode *); + + /// Map of lowering handlers for ttforge operations to MLIR. + std::map lowering_handler_map; + /// Declares a variable in the current (only) scope. - /// The declaration corresponds to exactly one operation node in the PyBuda graph. + /// The declaration corresponds to exactly one operation node in the TTForge graph. void declare(graphlib::Node *node, mlir::Value value) { if (symbolTable_.find(node->name()) != symbolTable_.end()) { @@ -97,12 +110,11 @@ class MLIRGenerator } /// Emit a new function in MLIR. - /// A function represents a set of PyBuda operations that are executed to produce output results. - /// This function will generate the MLIR code for each PyBuda operation in the graph and emit the return operation for the function. + /// A function represents a set of TTForge operations that are executed to produce output results. + /// This function will generate the MLIR code for each TTForge operation in the graph and emit the return operation for the function. mlir::func::FuncOp emit_mlir_function(tt::graphlib::Graph *graph) { // Assemble the function arguments (inputs) llvm::SmallVector arguments; - for (auto *input : graph->nodes_by_type(tt::graphlib::kInput)) { arguments.push_back(get_node_type(input)); @@ -135,90 +147,148 @@ class MLIRGenerator // function. builder_.setInsertionPointToStart(entryBlock); - // Walk the graph in topological order and generate MLIR for each PyBuda operation + // Walk the graph in topological order and generate MLIR for each TTForge operation // node in the graph. For each new operation result, declare it in the symbol table. for (auto *node : graphlib::topological_sort(*graph)) { - // Skip if the node isn't PyBuda operation + // Skip if the node isn't TTForge operation if (node->node_type() != tt::graphlib::NodeType::kPyOp) { continue; } log_trace(LogMLIRGenerator, "Emitting MLIR for node {}", node->name()); - tt::graphlib::OpNode *op_node = dynamic_cast(node); - // Emit MLIR for the PyBuda operation node - mlir::Value opValue = emit_mlir_pybuda_operation(graph, op_node); - log_trace(LogMLIRGenerator, "Generated MLIR for node {} with value {}", node->name(), covnert_mlir_value_to_string(opValue)); + // Emit MLIR for the TTForge operation node + mlir::Value opValue = emit_mlir_tt_forge_operation(graph, op_node); + log_trace(LogMLIRGenerator, "Generated MLIR for node {} with value {}", + node->name(), covnert_mlir_value_to_string(opValue)); } - emit_mlir_return_op(graph); - return func; } - /// Emit an MLIR operation for a PyBuda node. - mlir::Value emit_mlir_pybuda_operation(tt::graphlib::Graph *graph, tt::graphlib::OpNode *op_node) + /// Emit an MLIR operation for a TTForge node. + mlir::Value emit_mlir_tt_forge_operation(tt::graphlib::Graph *graph, tt::graphlib::OpNode *op_node) { - mlir::Value opResult; - if (tt::graphlib::is_eltwise(op_node)) + auto handler = lowering_handler_map.find(op_node->op_name()); + // There is no known lowering handler for this operation. Report error. + if (handler == lowering_handler_map.end()) { - opResult = emit_mlir_pybuda_elementwise_op(graph, op_node); + log_error("Unsupported operation for lowering from TTForge to TTIR: {}", op_node->op_name()); + throw std::runtime_error("Unsupported operation for lowering from TTForge to TTIR: " + op_node->op_name()); } - // This is the first time we are visiting this PyBuda node during the traversal of the graph using topological sort. + // Call the handler to lower the TTForge op to MLIR + mlir::Value opResult = (this->*(handler->second))(graph, op_node); + + // This is the first time we are visiting this TTForge node during the traversal of the graph using topological sort. // Therefore, we need to declare the result of this operation so that we can refer to it later if needed. declare(op_node, opResult); - return opResult; } - /// Emit an MLIR operation for a PyBuda elementwise operation. - mlir::Value emit_mlir_pybuda_elementwise_op(tt::graphlib::Graph *graph, tt::graphlib::OpNode *op_node) + /// Emit an MLIR operation for a ttforge elementwise operation. + template + mlir::Value emit_mlir_ttforge_op(tt::graphlib::Graph *graph, tt::graphlib::OpNode *op_node) { // Evaluate operation return type - llvm::SmallVector return_type_vector; - return_type_vector.push_back(get_node_type(op_node)); - mlir::TypeRange return_types(return_type_vector); + llvm::SmallVector return_types = get_mlir_type_range(op_node); - // Creating input value range for the operation - // Since we are traversing the PyBuda graph using topological sort, - // all operands must be present in the symbol table. - // We iterate over the operands of the current node and retrieve their corresponding values from the symbol table. - llvm::SmallVector input_vector; - for (auto operand : graph->operands(op_node)) + // Evaluate operation operands: inputs and outputs per DPS + llvm::SmallVector operands = get_mlir_operands(graph, op_node); + + // Evaluate opeartion attributes + llvm::SmallVector attributes; + ::llvm::ArrayRef<::llvm::StringRef> operation_attributes = TTIROp::getAttributeNames(); + for(auto attribute_name: operation_attributes) { - input_vector.push_back(symbolTable_.at(operand->name()).first); + if(attribute_name.equals("operand_constraints")) + { + // Create operation constraint attributes + mlir::NamedAttribute operand_constraints_attribute = builder_.getNamedAttr( + "operand_constraints", + builder_.getArrayAttr(get_mlir_operand_constraint_attributes(graph, op_node))); + attributes.push_back(operand_constraints_attribute); + } + else if(attribute_name.equals(mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr())) + { + // Create operation segment sizes attributes + mlir::NamedAttribute operand_segment_sizes_attribute = builder_.getNamedAttr( + mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr(), + builder_.getDenseI32ArrayAttr({ + static_cast(graph->operands(op_node).size()), + static_cast(1) + })); + attributes.push_back(operand_segment_sizes_attribute); + } } - mlir::ValueRange inputs(input_vector); + // Workaround for now, need to figure out how to handle this properly + if(op_node->op_name() == "softmax") + { + log_info("Softmax"); + int32_t dimension = std::get(op_node->op_attrs()[0]); + mlir::NamedAttribute dimension_attribute = builder_.getNamedAttr( + "dimension", + builder_.getSI32IntegerAttr(dimension)); + attributes.push_back(dimension_attribute); + } - // Creating output value range for the operation by creating an empty tensor to hold the output value - llvm::SmallVector output_vector; - output_vector.push_back(emit_mlir_empty_tensor(graph, op_node)); - mlir::ValueRange outputs = mlir::ValueRange(output_vector); + auto op = builder_.create( + get_tt_forge_operation_location(graph, op_node), + mlir::TypeRange(return_types), + mlir::ValueRange(operands), + attributes); - // Create an array attribute with three elements, each representing an operand constraint of type "AnyDevice" - auto atributes = builder_.getArrayAttr(llvm::SmallVector( - 3, builder_.getAttr( - mlir::tt::OperandConstraint::AnyDevice))); + return op.getOperation()->getResult(0); + } + + // Get the TT-MLIR type for a TTForge operation. + llvm::SmallVector get_mlir_type_range(tt::graphlib::OpNode *op_node) + { + llvm::SmallVector return_type_vector; + return_type_vector.push_back(get_node_type(op_node)); + return return_type_vector; + } - if (op_node->op_name() == "add") + // All operands must be present in the symbol table, since we are + // traversing the TTForge graph using topological sort. We iterate over the + // operands of the current node and retrieve their corresponding values + // from the symbol table. + llvm::SmallVector get_mlir_operands( + tt::graphlib::Graph *graph, + tt::graphlib::OpNode *op_node) + { + llvm::SmallVector operands; + for (auto operand : graph->operands(op_node)) { - auto opResult = builder_.create(get_pybuda_operation_location(graph, op_node), return_types, inputs, outputs, atributes); - return opResult.getResult(0); + operands.push_back(symbolTable_.at(operand->name()).first); } - else if (op_node->op_name() == "multiply") + operands.push_back(emit_mlir_empty_tensor(graph, op_node)); + return operands; + } + + // Get the MLIR operand constraint attributes for a TTForge operation. + llvm::SmallVector get_mlir_operand_constraint_attributes( + tt::graphlib::Graph *graph, + tt::graphlib::OpNode *op_node) + { + llvm::SmallVector operand_constraints; + for ([[maybe_unused]] auto& operand: graph->operands(op_node)) { - auto opResult = builder_.create(get_pybuda_operation_location(graph, op_node), return_types, inputs, outputs, atributes); - return opResult.getResult(0); + mlir::Attribute operand_constraint_attribute = builder_.getAttr( + mlir::tt::OperandConstraint::AnyDevice); + operand_constraints.push_back(operand_constraint_attribute); } - else { - log_error("Unsupported operation for lowering from PyBuda to TTIR: {}", op_node->op_name()); - throw std::runtime_error("Unsupported operation for lowering from PyBuda to TTIR"); + for ([[maybe_unused]] auto& user: graph->data_users(op_node)) + { + mlir::Attribute operand_constraint_attribute = builder_.getAttr( + mlir::tt::OperandConstraint::AnyDevice); + operand_constraints.push_back(operand_constraint_attribute); } + return operand_constraints; } /// Emit an MLIR operation for an empty tensor. @@ -230,7 +300,10 @@ class MLIRGenerator shape_vec.push_back((int64_t)dim); } - return builder_.create(get_pybuda_operation_location(graph, node), shape_vec, get_float_type(node)); + return builder_.create( + get_tt_forge_operation_location(graph, node), + shape_vec, + get_float_type(node)); } /// Emit the return operation for the function. @@ -245,10 +318,12 @@ class MLIRGenerator returnValues.push_back(outputValue); } - builder_.create(builder_.getUnknownLoc(), mlir::ValueRange(returnValues)); + builder_.create( + builder_.getUnknownLoc(), + mlir::ValueRange(returnValues)); } - /// Get the MLIR float type type for a PyBuda node. + /// Get the MLIR float type type for a TTForge node. mlir::FloatType get_float_type(graphlib::Node *node) { switch (node->output_df()) @@ -262,12 +337,11 @@ class MLIRGenerator default: TT_ASSERT(false); } - // TODO add all supported types in switch return builder_.getF32Type(); } - /// Get the MLIR type for a PyBuda node. + /// Get the MLIR type for a TTForge node. mlir::Type get_node_type(graphlib::Node *node) { std::vector shape_vec; @@ -281,19 +355,22 @@ class MLIRGenerator /// Get the location for a module. mlir::Location get_module_location(tt::graphlib::Graph *graph) { - return mlir::FileLineColLoc::get(builder_.getContext(), graph->name(), graph->id(), 0); + return mlir::FileLineColLoc::get( + builder_.getContext(), graph->name(), graph->id(), 0); } /// Get the simple location for a node in a format "graph_name", (graph_id), (node_id) mlir::Location get_node_location(tt::graphlib::Graph *graph, tt::graphlib::Node *node) { - return mlir::FileLineColLoc::get(builder_.getContext(), graph->name(), graph->id(), node->id()); + return mlir::FileLineColLoc::get( + builder_.getContext(), graph->name(), graph->id(), node->id()); } - /// Get the location for a PyBuda operation. The location is a combination of the operation name and the node location. - mlir::Location get_pybuda_operation_location(tt::graphlib::Graph *graph, tt::graphlib::Node *node) + /// Get the location for a TTForge operation. The location is a combination of the operation name and the node location. + mlir::Location get_tt_forge_operation_location(tt::graphlib::Graph *graph, tt::graphlib::Node *node) { - return mlir::NameLoc::get(builder_.getStringAttr(node->name()), get_node_location(graph, node)); + return mlir::NameLoc::get( + builder_.getStringAttr(node->name()), get_node_location(graph, node)); } /// Convert an MLIR value to a string. @@ -301,18 +378,26 @@ class MLIRGenerator { std::string string_value; llvm::raw_string_ostream os(string_value); - os << value; - os.flush(); return string_value; } + + /// Initialize lowering handler map + void init_lowering_handler_map() + { + lowering_handler_map["add"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["multiply"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["subtract"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["relu"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["matmul"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op; + } }; } - namespace tt::passes { - /// Public API for generating MLIR from the PyBuda graph. + /// Public API for generating MLIR from the TTForge graph. mlir::OwningOpRef lower_to_mlir(graphlib::Graph * graph, mlir::MLIRContext& context) { return MLIRGenerator(context).emit_mlir(graph); diff --git a/pybuda/csrc/passes/lower_to_mlir.hpp b/pybuda/csrc/passes/lower_to_mlir.hpp index c6ec163a3..25f7fe2ad 100644 --- a/pybuda/csrc/passes/lower_to_mlir.hpp +++ b/pybuda/csrc/passes/lower_to_mlir.hpp @@ -15,7 +15,7 @@ namespace mlir { namespace tt::passes { - // Public API for generating MLIR from the PyBuda graph. + // Public API for generating MLIR from the TT-Forge graph. mlir::OwningOpRef lower_to_mlir(tt::graphlib::Graph * graph, mlir::MLIRContext& context); } // namespace tt:passes