diff --git a/pybuda/csrc/passes/lower_to_mlir.cpp b/pybuda/csrc/passes/lower_to_mlir.cpp index 8bf6a4266..9611c809d 100644 --- a/pybuda/csrc/passes/lower_to_mlir.cpp +++ b/pybuda/csrc/passes/lower_to_mlir.cpp @@ -174,6 +174,15 @@ class MLIRGenerator auto funcType = builder_.getType(mlir::TypeRange(argument_types), mlir::TypeRange(returns)); auto func = builder_.create(graphModule_.getLoc(), "main", funcType); + // Set the function argument names + for(size_t i = 0; i < argument_nodes.size(); i++) + { + graphlib::Node* argument_node = argument_nodes[i]; + llvm::SmallVector named_attributes; + named_attributes.push_back(builder_.getNamedAttr("ttir.name", builder_.getStringAttr(argument_node->name()))); + func.setArgAttrs(i, named_attributes); + } + // Start the body of the function by creating an entry block. mlir::Block *entryBlock = func.addEntryBlock();