From a2dc626be7df141ddf72331b8cec7c4108f26f39 Mon Sep 17 00:00:00 2001 From: Usman Aziz Date: Thu, 23 Jan 2025 11:29:09 -0500 Subject: [PATCH] =?UTF-8?q?Tighten=20implicit=20broadcast=20constraints=20?= =?UTF-8?q?for=20multiplyOp=20and=20maximumOp=20f=E2=80=A6=20(#1919)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …or lack of support. Refactor repeat tests. --- include/ttmlir/Dialect/TTIR/IR/TTIROps.td | 4 +- .../Dialect/TTNN/implicit_broadcast.mlir | 8 +-- test/ttmlir/Dialect/TTNN/simple_repeat.mlir | 58 +++++++++---------- .../Silicon/TTNN/implicit_broadcast.mlir | 6 +- 4 files changed, 38 insertions(+), 38 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td index 3ca46eda4f..5a16874061 100644 --- a/include/ttmlir/Dialect/TTIR/IR/TTIROps.td +++ b/include/ttmlir/Dialect/TTIR/IR/TTIROps.td @@ -1546,7 +1546,7 @@ def TTIR_AddOp : TTIR_GenericElementwiseBinaryOp<"add", [TTIR_FullyBroadcastable }]; } -def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply", [TTIR_FullyBroadcastable]> { +def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply", [TTIR_PartiallyBroadcastable]> { let summary = "Eltwise multiply."; let description = [{ Eltwise multiply operation. @@ -1560,7 +1560,7 @@ def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div", [TTIR_PartiallyBroadcast }]; } -def TTIR_MaximumOp : TTIR_GenericElementwiseBinaryOp<"maximum", [TTIR_PartiallyBroadcastable]> { +def TTIR_MaximumOp : TTIR_GenericElementwiseBinaryOp<"maximum"> { let summary = "Eltwise maximum."; let description = [{ Calculates maximum of input tensors' values element-wise and stores result in output tensor. diff --git a/test/ttmlir/Dialect/TTNN/implicit_broadcast.mlir b/test/ttmlir/Dialect/TTNN/implicit_broadcast.mlir index 37aceb3963..4f99fdd62a 100644 --- a/test/ttmlir/Dialect/TTNN/implicit_broadcast.mlir +++ b/test/ttmlir/Dialect/TTNN/implicit_broadcast.mlir @@ -2,11 +2,11 @@ module { func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> { // CHECK-NOT: ttnn.repeat - // CHECK: %{{[0-9]+}} = "ttnn.multiply" + // CHECK: %{{[0-9]+}} = "ttnn.add" %0 = tensor.empty() : tensor<1x16x32xf32> %1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> %2 = tensor.empty() : tensor<1x16x32xf32> - %3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> + %3 = "ttir.add"(%arg0, %1, %2) <{operandSegmentSizes = array}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> return %3 : tensor<1x16x32xf32> } } @@ -51,12 +51,12 @@ module { func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32> module { func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> { // CHECK-NOT: ttnn.repeat - // CHECK: [[VAL0:%[0-9]+]] = "ttnn.multiply" + // CHECK: [[VAL0:%[0-9]+]] = "ttnn.add" // CHECK: %{{[0-9]+}} = "ttnn.add"(%{{[0-9]+}}, [[VAL0]], %{{[0-9]+}}) %0 = tensor.empty() : tensor<1x16x32xf32> %1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> %2 = tensor.empty() : tensor<1x16x32xf32> - %3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> + %3 = "ttir.add"(%arg0, %1, %2) <{operandSegmentSizes = array}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> %4 = tensor.empty() : tensor<1x16x32xf32> %5 = "ttir.add"(%1, %3, %4) <{operandSegmentSizes = array}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> return %5 : tensor<1x16x32xf32> diff --git a/test/ttmlir/Dialect/TTNN/simple_repeat.mlir b/test/ttmlir/Dialect/TTNN/simple_repeat.mlir index 00fddfb786..b7e5278168 100644 --- a/test/ttmlir/Dialect/TTNN/simple_repeat.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_repeat.mlir @@ -27,36 +27,36 @@ module { } module { - func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> { - // CHECK: %{{[0-9]+}} = "ttnn.reshape" - // CHECK: %{{[0-9]+}} = "ttnn.repeat" - // CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32] - %0 = tensor.empty() : tensor<1x23x40x128xf32> - %1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> - %2 = tensor.empty() : tensor<1x1x1x128xf32> - %3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> - %4 = tensor.empty() : tensor<1x23x40x128xf32> - %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> - %6 = tensor.empty() : tensor<1x23x40x128xf32> - %7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> - return %7 : tensor<1x23x40x128xf32> - } + func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> { + // CHECK: %{{[0-9]+}} = "ttnn.reshape" + // CHECK: %{{[0-9]+}} = "ttnn.repeat" + // CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32] + %0 = tensor.empty() : tensor<1x23x40x128xf32> + %1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> + %2 = tensor.empty() : tensor<1x1x1x128xf32> + %3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32> + %4 = tensor.empty() : tensor<1x23x40x128xf32> + %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> + %6 = tensor.empty() : tensor<1x23x40x128xf32> + %7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32> + return %7 : tensor<1x23x40x128xf32> } +} module { - func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> { - // CHECK: %{{[0-9]+}} = "ttnn.repeat" - // CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32] - %0 = tensor.empty() : tensor<1x6x2xf32> - %1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32> - %2 = tensor.empty() : tensor<1x6x1x2xf32> - %3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32> - %4 = tensor.empty() : tensor<400x6x1x2xf32> - %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32> - %6 = tensor.empty() : tensor<2400x1x2xf32> - %7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32> - %8 = tensor.empty() : tensor<2400x2xf32> - %9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32> - return %9 : tensor<2400x2xf32> - } + func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> { + // CHECK: %{{[0-9]+}} = "ttnn.repeat" + // CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32] + %0 = tensor.empty() : tensor<1x6x2xf32> + %1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32> + %2 = tensor.empty() : tensor<1x6x1x2xf32> + %3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32> + %4 = tensor.empty() : tensor<400x6x1x2xf32> + %5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32> + %6 = tensor.empty() : tensor<2400x1x2xf32> + %7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32> + %8 = tensor.empty() : tensor<2400x2xf32> + %9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32> + return %9 : tensor<2400x2xf32> } +} diff --git a/test/ttmlir/Silicon/TTNN/implicit_broadcast.mlir b/test/ttmlir/Silicon/TTNN/implicit_broadcast.mlir index bd6d811c96..9a7658583b 100644 --- a/test/ttmlir/Silicon/TTNN/implicit_broadcast.mlir +++ b/test/ttmlir/Silicon/TTNN/implicit_broadcast.mlir @@ -4,17 +4,17 @@ module { func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> { // CHECK-NOT: ttnn.repeat - // CHECK: %{{[0-9]+}} = "ttnn.multiply" + // CHECK: %{{[0-9]+}} = "ttnn.add" %0 = tensor.empty() : tensor<1x16x32xf32> %1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> %2 = tensor.empty() : tensor<1x16x32xf32> - %3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> + %3 = "ttir.add"(%arg0, %1, %2) <{operandSegmentSizes = array}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32> return %3 : tensor<1x16x32xf32> } } module { -func.func @main(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<784x128xf32> { + func.func @main(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<784x128xf32> { // CHECK: %{{[0-9]+}} = "ttnn.reshape" // CHECK-NOT: "ttnn.repeat" // CHECK: %{{[0-9]+}} = "ttnn.reshape"