diff --git a/pybuda/csrc/passes/lower_to_mlir.cpp b/pybuda/csrc/passes/lower_to_mlir.cpp index 88db81e44..a9b0109d4 100644 --- a/pybuda/csrc/passes/lower_to_mlir.cpp +++ b/pybuda/csrc/passes/lower_to_mlir.cpp @@ -134,6 +134,12 @@ class MLIRGenerator return builder_.getSI32IntegerAttr(arg); } else if constexpr (std::is_same_v) { return builder_.getF32FloatAttr(arg); + } else if constexpr (std::is_same_v>) { + llvm::SmallVector attributes; + for (auto& element : arg) { + attributes.push_back(builder_.getI32IntegerAttr(element)); + } + return builder_.getArrayAttr(attributes); } else { // If type not handled, throw an exception or handle it appropriately throw std::runtime_error("Unhandled attribute type"); @@ -445,6 +451,8 @@ class MLIRGenerator lowering_handler_map["relu"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["matmul"] = &MLIRGenerator::emit_mlir_ttforge_op; lowering_handler_map["softmax"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["reduce_sum"] = &MLIRGenerator::emit_mlir_ttforge_op; + lowering_handler_map["reduce_avg"] = &MLIRGenerator::emit_mlir_ttforge_op; } }; } diff --git a/pybuda/pybuda/op/eval/common.py b/pybuda/pybuda/op/eval/common.py index 501d92771..b04740c60 100644 --- a/pybuda/pybuda/op/eval/common.py +++ b/pybuda/pybuda/op/eval/common.py @@ -191,6 +191,27 @@ def calculate_pcc(a, b): return pcc +# Calculates pcc between golden and calculated tensors. If calculated pcc is >= than pcc threshold, returns True +def compare_with_golden_pcc(golden: Union[torch.Tensor, tf.Tensor, tf.Variable], calculated: torch.Tensor, pcc, rtol=None, atol=None): + pcc_value = 0 + if not (pcc is None or golden.flatten().size() == (1,)): # PCC for single values doesn't work + pcc_value = calculate_pcc(golden, calculated) + if pcc_value >= pcc : + logger.trace("PCC is correct") + logger.trace("Golden: (shape = {}", golden.shape) + logger.trace(golden) + logger.trace("Calculated: (shape = {}", calculated.shape) + logger.trace(calculated) + return True + else: + logger.error("Tensor mismatch") + return False + else: + # For scalar values, we can't calculate PCC, but we can compare golden and calculated values using relative and absolute tolerances + golden = golden.flatten()[0] + calculated = calculated.flatten()[0] + return torch.allclose(golden, calculated, atol=atol, rtol=rtol) + def compare_tensor_to_golden(name: str, golden: Union[torch.Tensor, tf.Tensor, tf.Variable], calculated: torch.Tensor, is_buda=False, rtol=None, atol=None, pcc=None, warning_only=False, relative_atol = None, verify_cfg = None): # Convert golden to pytorch tensor for comparisons if isinstance(golden, (tf.Tensor, tf.Variable)): diff --git a/pybuda/pybuda/op/reduce.py b/pybuda/pybuda/op/reduce.py index 5b582c29c..adc48ce26 100644 --- a/pybuda/pybuda/op/reduce.py +++ b/pybuda/pybuda/op/reduce.py @@ -7,7 +7,8 @@ def ReduceSum( name: str, operandA: Tensor, - dim: int) -> Tensor: + dim: int, + keep_dim: bool = True) -> Tensor: """ Reduce by summing along the given dimension @@ -32,12 +33,13 @@ def ReduceSum( # if dim < 0: # dim += 4 - return op("reduce_sum", name, operandA, attrs=(dim,)).get_tensor() + return op("reduce_sum", name, operandA, attrs=(dim,), dim_arg=[dim], keep_dim= keep_dim).get_tensor() def ReduceAvg( name: str, operandA: Tensor, - dim: int) -> Tensor: + dim: int, + keep_dim: bool = True) -> Tensor: """ Reduce by averaging along the given dimension @@ -62,7 +64,7 @@ def ReduceAvg( # if dim < 0: # dim += 4 - return op("reduce_avg", name, operandA, attrs=(dim,)).get_tensor() + return op("reduce_avg", name, operandA, attrs=(dim,), dim_arg=[dim], keep_dim= keep_dim).get_tensor() def GroupedReduceAvg( name: str, diff --git a/pybuda/test/mlir/mnist/test_inference.py b/pybuda/test/mlir/mnist/test_inference.py index 66af4d36f..e0b99560f 100644 --- a/pybuda/test/mlir/mnist/test_inference.py +++ b/pybuda/test/mlir/mnist/test_inference.py @@ -5,6 +5,7 @@ import torch from .utils import * import pybuda +from pybuda.op.eval.common import compare_with_golden_pcc def test_mnist_inference(): inputs = [torch.rand(1, 784)] @@ -16,4 +17,4 @@ def test_mnist_inference(): co_out = compiled_model(*[i.to("tt") for i in inputs]) co_out = [co.to("cpu") for co in co_out] - assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)] + assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] diff --git a/pybuda/test/mlir/test_ops.py b/pybuda/test/mlir/test_ops.py index d3c2ddc00..bde2a3bc8 100644 --- a/pybuda/test/mlir/test_ops.py +++ b/pybuda/test/mlir/test_ops.py @@ -5,10 +5,12 @@ import os import pytest +import pytest import torch from torch import nn import pybuda +from pybuda.op.eval.common import compare_with_golden_pcc def test_add(): class Add(nn.Module): @@ -27,7 +29,7 @@ def forward(self, a, b): co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] - assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)] + assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] def test_subtract(): @@ -47,7 +49,7 @@ def forward(self, a, b): co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] - assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)] + assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] def test_multiply(): @@ -67,7 +69,7 @@ def forward(self, a, b): co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] - assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)] + assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] def test_relu(): @@ -88,7 +90,7 @@ def forward(self, a): co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] - assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)] + assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] @pytest.mark.skip(reason="This is not ready yet") def test_linear(): @@ -109,7 +111,7 @@ def forward(self, a): co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] - assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)] + assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] def test_softmax(): @@ -130,4 +132,48 @@ def forward(self, a): co_out = compiled_model(*inputs) co_out = [co.to("cpu") for co in co_out] - assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)] + assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] + +@pytest.mark.parametrize("input_shape", [(1,32,32), (1,64,64), (1,128,128,128)], ids=["32","64","128"]) +@pytest.mark.parametrize("dim", [-1,-2], ids=["-1","-2"]) +def test_reduce_sum(input_shape, dim): + class ReduceSum(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + # reduce is supported on tt-metal only with keepdim=True + return torch.sum(a, dim=dim, keepdim=True) + + inputs = [torch.rand(input_shape)] + + framework_model = ReduceSum() + fw_out = framework_model(*inputs) + + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99) + +@pytest.mark.parametrize("input_shape", [(1,32,32), (1,64,64), (1,128,128,128)], ids=["32","64","128"]) +@pytest.mark.parametrize("dim", [-1,-2], ids=["-1","-2"]) +def test_reduce_mean(input_shape, dim): + class ReduceMean(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a): + # reduce is supported on tt-metal only with keepdim=True + return torch.mean(a, dim=1, keepdim=True) + + inputs = [torch.rand(1, 32, 32)] + + framework_model = ReduceMean() + fw_out = framework_model(*inputs) + + compiled_model = pybuda.compile(framework_model, sample_inputs=inputs) + co_out = compiled_model(*inputs) + + co_out = [co.to("cpu") for co in co_out] + assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0], pcc=0.99) \ No newline at end of file