Skip to content

Commit

Permalink
[lower_to_mlir] lower reduce op from tt-forge to mlir
Browse files Browse the repository at this point in the history
  • Loading branch information
dgolubovicTT committed Aug 12, 2024
1 parent 6d34c6b commit 65ace5a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 4 deletions.
8 changes: 8 additions & 0 deletions pybuda/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ class MLIRGenerator
return builder_.getSI32IntegerAttr(arg);
} else if constexpr (std::is_same_v<T, float>) {
return builder_.getF32FloatAttr(arg);
} else if constexpr (std::is_same_v<T, std::vector<int>>) {
llvm::SmallVector<mlir::Attribute> attributes;
for (auto& element : arg) {
attributes.push_back(builder_.getI32IntegerAttr(element));
}
return builder_.getArrayAttr(attributes);
} else {
// If type not handled, throw an exception or handle it appropriately
throw std::runtime_error("Unhandled attribute type");
Expand Down Expand Up @@ -433,6 +439,8 @@ class MLIRGenerator
lowering_handler_map["relu"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ReluOp>;
lowering_handler_map["matmul"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MatmulOp>;
lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SoftmaxOp>;
lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::SumOp>;
lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::MeanOp>;
}
};
}
Expand Down
16 changes: 16 additions & 0 deletions pybuda/pybuda/op/eval/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,22 @@ def calculate_pcc(a, b):

return pcc

# Calculates pcc between golden and calculated tensors. If calculated pcc is >= than pcc threshold, returns True
def compare_with_golden_pcc(golden: Union[torch.Tensor, tf.Tensor, tf.Variable], calculated: torch.Tensor, pcc):
pcc_value = 0
if not (pcc is None or golden.flatten().size() == (1,)): # PCC for single values doesn't work
pcc_value = calculate_pcc(golden, calculated)
if pcc_value >= pcc :
logger.trace("PCC is correct")
logger.trace("Golden: (shape = {}", golden.shape)
logger.trace(golden)
logger.trace("Calculated: (shape = {}", calculated.shape)
logger.trace(calculated)
return True
else:
logger.error("Tensor mismatch")
return False

def compare_tensor_to_golden(name: str, golden: Union[torch.Tensor, tf.Tensor, tf.Variable], calculated: torch.Tensor, is_buda=False, rtol=None, atol=None, pcc=None, warning_only=False, relative_atol = None, verify_cfg = None):
# Convert golden to pytorch tensor for comparisons
if isinstance(golden, (tf.Tensor, tf.Variable)):
Expand Down
10 changes: 6 additions & 4 deletions pybuda/pybuda/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
def ReduceSum(
name: str,
operandA: Tensor,
dim: int) -> Tensor:
dim: int,
keep_dim: bool = True) -> Tensor:
"""
Reduce by summing along the given dimension
Expand All @@ -32,12 +33,13 @@ def ReduceSum(
# if dim < 0:
# dim += 4

return op("reduce_sum", name, operandA, attrs=(dim,)).get_tensor()
return op("reduce_sum", name, operandA, attrs=(dim,), dim_arg=[dim], keep_dim= keep_dim).get_tensor()

def ReduceAvg(
name: str,
operandA: Tensor,
dim: int) -> Tensor:
dim: int,
keep_dim: bool = True) -> Tensor:
"""
Reduce by averaging along the given dimension
Expand All @@ -62,7 +64,7 @@ def ReduceAvg(
# if dim < 0:
# dim += 4

return op("reduce_avg", name, operandA, attrs=(dim,)).get_tensor()
return op("reduce_avg", name, operandA, attrs=(dim,), dim_arg=[dim], keep_dim= keep_dim).get_tensor()

def GroupedReduceAvg(
name: str,
Expand Down
42 changes: 42 additions & 0 deletions pybuda/test/mlir/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torch import nn

import pybuda
from pybuda.op.eval.common import compare_with_golden_pcc

def test_add():
class Add(nn.Module):
Expand Down Expand Up @@ -130,3 +131,44 @@ def forward(self, a):

co_out = [co.to("cpu") for co in co_out]
assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)]

def test_reduce_sum():
class ReduceSum(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a):
# reduce is supported on tt-metal only with keepdim=True
return torch.sum(a, dim=1, keepdim=True)

inputs = [torch.rand(1, 32, 32)]

framework_model = ReduceSum()
fw_out = framework_model(*inputs)

compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)


def test_reduce_mean():
class ReduceMean(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a):
# reduce is supported on tt-metal only with keepdim=True
return torch.mean(a, dim=1, keepdim=True)

inputs = [torch.rand(1, 32, 32)]

framework_model = ReduceMean()
fw_out = framework_model(*inputs)

compiled_model = pybuda.compile(framework_model, sample_inputs=inputs)
co_out = compiled_model(*inputs)

co_out = [co.to("cpu") for co in co_out]
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99)

0 comments on commit 65ace5a

Please sign in to comment.