Skip to content

Commit

Permalink
[lower_to_mlir] Use named_attrs from tt-forge ops instead of hardcodi…
Browse files Browse the repository at this point in the history
…ng attributes for each op, while lowering to mlir.
  • Loading branch information
dgolubovicTT committed Aug 7, 2024
1 parent eda3539 commit 5a9d443
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
35 changes: 25 additions & 10 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,23 @@ using namespace tt;
/**
* @brief Implementation of TT-MLIR emission from the TTForge graph.
*/
mlir::Attribute convert_to_mlir_attribute(const tt::BudaOpAttr& value, mlir::OpBuilder& builder) {
return std::visit([&builder](auto&& arg) -> mlir::Attribute {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::string>) {
return builder.getStringAttr(arg);
} else if constexpr (std::is_same_v<T, bool>) {
return builder.getBoolAttr(arg);
} else if constexpr (std::is_same_v<T, int>) {
return builder.getI32IntegerAttr(arg);
} else if constexpr (std::is_same_v<T, float>) {
return builder.getF32FloatAttr(arg);
} else {
// If type not handled, throw an exception or handle it appropriately
throw std::runtime_error("Unhandled attribute type");
}
}, value);
}
class MLIRGenerator
{
public:
Expand Down Expand Up @@ -204,15 +221,15 @@ class MLIRGenerator
::llvm::ArrayRef<::llvm::StringRef> operation_attributes = TTIROp::getAttributeNames();
for(auto attribute_name: operation_attributes)
{
if(attribute_name.equals("operand_constraints"))
if(attribute_name == "operand_constraints")
{
// Create operation constraint attributes
mlir::NamedAttribute operand_constraints_attribute = builder_.getNamedAttr(
"operand_constraints",
builder_.getArrayAttr(get_mlir_operand_constraint_attributes(graph, op_node)));
attributes.push_back(operand_constraints_attribute);
}
else if(attribute_name.equals(mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr()))
else if(attribute_name == mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr())
{
// Create operation segment sizes attributes
mlir::NamedAttribute operand_segment_sizes_attribute = builder_.getNamedAttr(
Expand All @@ -225,15 +242,13 @@ class MLIRGenerator
}
}

// Workaround for now, need to figure out how to handle this properly
if(op_node->op_name() == "softmax")
for(const auto & attribute: op_node->op_type().named_attrs)
{
log_info("Softmax");
int32_t dimension = std::get<int>(op_node->op_attrs()[0]);
mlir::NamedAttribute dimension_attribute = builder_.getNamedAttr(
"dimension",
builder_.getSI32IntegerAttr(dimension));
attributes.push_back(dimension_attribute);
// convert atribute to mlir atribute
auto mlir_atribute = convert_to_mlir_attribute(attribute.second, builder_);
mlir::NamedAttribute named_attribute = builder_.getNamedAttr(
attribute.first, mlir_atribute);
attributes.push_back(named_attribute);
}

auto op = builder_.create<TTIROp>(
Expand Down
4 changes: 2 additions & 2 deletions pybuda/pybuda/op/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def Softmax(
Tensor
Buda tensor
"""
return op("softmax", name, operandA, attrs=(dim, stable)).get_tensor()
return op("softmax", name, operandA, attrs=(dim, stable), dimension=dim, stable=stable).get_tensor()


def LogSoftmax(
Expand Down Expand Up @@ -82,7 +82,7 @@ def LogSoftmax(
Tensor
Buda tensor
"""
return op("log_softmax", name, operandA, attrs=(dim, stable)).get_tensor()
return op("log_softmax", name, operandA, attrs=(dim, stable), dimension=dim, stable=stable).get_tensor()

def Layernorm(
name: str,
Expand Down

0 comments on commit 5a9d443

Please sign in to comment.