Skip to content

Commit

Permalink
Tighten implicit broadcast constraints for multiplyOp and maximumOp f… (
Browse files Browse the repository at this point in the history
#1919)

…or lack of support. Refactor repeat tests.
  • Loading branch information
uazizTT authored Jan 23, 2025
1 parent 652fa47 commit a2dc626
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 38 deletions.
4 changes: 2 additions & 2 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1546,7 +1546,7 @@ def TTIR_AddOp : TTIR_GenericElementwiseBinaryOp<"add", [TTIR_FullyBroadcastable
}];
}

def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply", [TTIR_FullyBroadcastable]> {
def TTIR_MultiplyOp : TTIR_GenericElementwiseBinaryOp<"multiply", [TTIR_PartiallyBroadcastable]> {
let summary = "Eltwise multiply.";
let description = [{
Eltwise multiply operation.
Expand All @@ -1560,7 +1560,7 @@ def TTIR_DivOp : TTIR_GenericElementwiseBinaryOp<"div", [TTIR_PartiallyBroadcast
}];
}

def TTIR_MaximumOp : TTIR_GenericElementwiseBinaryOp<"maximum", [TTIR_PartiallyBroadcastable]> {
def TTIR_MaximumOp : TTIR_GenericElementwiseBinaryOp<"maximum"> {
let summary = "Eltwise maximum.";
let description = [{
Calculates maximum of input tensors' values element-wise and stores result in output tensor.
Expand Down
8 changes: 4 additions & 4 deletions test/ttmlir/Dialect/TTNN/implicit_broadcast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
module {
func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> {
// CHECK-NOT: ttnn.repeat
// CHECK: %{{[0-9]+}} = "ttnn.multiply"
// CHECK: %{{[0-9]+}} = "ttnn.add"
%0 = tensor.empty() : tensor<1x16x32xf32>
%1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array<i32: 1, 16, 1>}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%2 = tensor.empty() : tensor<1x16x32xf32>
%3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%3 = "ttir.add"(%arg0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
return %3 : tensor<1x16x32xf32>
}
}
Expand Down Expand Up @@ -51,12 +51,12 @@ module { func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>
module {
func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> {
// CHECK-NOT: ttnn.repeat
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.multiply"
// CHECK: [[VAL0:%[0-9]+]] = "ttnn.add"
// CHECK: %{{[0-9]+}} = "ttnn.add"(%{{[0-9]+}}, [[VAL0]], %{{[0-9]+}})
%0 = tensor.empty() : tensor<1x16x32xf32>
%1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array<i32: 1, 16, 1>}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%2 = tensor.empty() : tensor<1x16x32xf32>
%3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%3 = "ttir.add"(%arg0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%4 = tensor.empty() : tensor<1x16x32xf32>
%5 = "ttir.add"(%1, %3, %4) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
return %5 : tensor<1x16x32xf32>
Expand Down
58 changes: 29 additions & 29 deletions test/ttmlir/Dialect/TTNN/simple_repeat.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,36 @@ module {
}

module {
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
%3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
%4 = tensor.empty() : tensor<1x23x40x128xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 1, 23, 40, 1>}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%6 = tensor.empty() : tensor<1x23x40x128xf32>
%7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
return %7 : tensor<1x23x40x128xf32>
}
func.func @main(%arg0: tensor<1x23x40x1xf32>, %arg1: tensor<128xf32>) -> tensor<1x23x40x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [1 : i32, 23 : i32, 40 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x23x40x128xf32>
%1 = "ttir.broadcast"(%arg0, %0) <{broadcast_dimensions = array<i32 : 1, 1, 1, 128>}> : (tensor<1x23x40x1xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%2 = tensor.empty() : tensor<1x1x1x128xf32>
%3 = "ttir.reshape"(%arg1, %2) <{shape = [1 : i32, 1 : i32, 1 : i32, 128 : i32]}> : (tensor<128xf32>, tensor<1x1x1x128xf32>) -> tensor<1x1x1x128xf32>
%4 = tensor.empty() : tensor<1x23x40x128xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32 : 1, 23, 40, 1>}> : (tensor<1x1x1x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
%6 = tensor.empty() : tensor<1x23x40x128xf32>
%7 = "ttir.div"(%1, %5, %6) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>, tensor<1x23x40x128xf32>) -> tensor<1x23x40x128xf32>
return %7 : tensor<1x23x40x128xf32>
}
}

module {
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
%3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
%4 = tensor.empty() : tensor<400x6x1x2xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32: 400, 1, 1, 1>}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
%6 = tensor.empty() : tensor<2400x1x2xf32>
%7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
%8 = tensor.empty() : tensor<2400x2xf32>
%9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
return %9 : tensor<2400x2xf32>
}
func.func @main(%arg0: tensor<6x2xf32>) -> tensor<2400x2xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.repeat"
// CHECK-SAME: shape = [400 : i32, 1 : i32, 1 : i32, 1 : i32]
%0 = tensor.empty() : tensor<1x6x2xf32>
%1 = "ttir.reshape"(%arg0, %0) <{shape = [1 : i32, 6 : i32, 2 : i32]}> : (tensor<6x2xf32>, tensor<1x6x2xf32>) -> tensor<1x6x2xf32>
%2 = tensor.empty() : tensor<1x6x1x2xf32>
%3 = "ttir.reshape"(%1, %2) <{shape = [1 : i32, 6 : i32, 1 : i32, 2 : i32]}> : (tensor<1x6x2xf32>, tensor<1x6x1x2xf32>) -> tensor<1x6x1x2xf32>
%4 = tensor.empty() : tensor<400x6x1x2xf32>
%5 = "ttir.broadcast"(%3, %4) <{broadcast_dimensions = array<i32: 400, 1, 1, 1>}> : (tensor<1x6x1x2xf32>, tensor<400x6x1x2xf32>) -> tensor<400x6x1x2xf32>
%6 = tensor.empty() : tensor<2400x1x2xf32>
%7 = "ttir.reshape"(%5, %6) <{shape = [2400 : i32, 1 : i32, 2 : i32]}> : (tensor<400x6x1x2xf32>, tensor<2400x1x2xf32>) -> tensor<2400x1x2xf32>
%8 = tensor.empty() : tensor<2400x2xf32>
%9 = "ttir.reshape"(%7, %8) <{shape = [2400 : i32, 2 : i32]}> : (tensor<2400x1x2xf32>, tensor<2400x2xf32>) -> tensor<2400x2xf32>
return %9 : tensor<2400x2xf32>
}
}
6 changes: 3 additions & 3 deletions test/ttmlir/Silicon/TTNN/implicit_broadcast.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
module {
func.func @main(%arg0: tensor<1x16x32xf32>, %arg1: tensor<1x1x32xf32>) -> tensor<1x16x32xf32> {
// CHECK-NOT: ttnn.repeat
// CHECK: %{{[0-9]+}} = "ttnn.multiply"
// CHECK: %{{[0-9]+}} = "ttnn.add"
%0 = tensor.empty() : tensor<1x16x32xf32>
%1 = "ttir.broadcast"(%arg1, %0) <{broadcast_dimensions = array<i32: 1, 16, 1>}> : (tensor<1x1x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%2 = tensor.empty() : tensor<1x16x32xf32>
%3 = "ttir.multiply"(%arg0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
%3 = "ttir.add"(%arg0, %1, %2) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<1x16x32xf32>, tensor<1x16x32xf32>, tensor<1x16x32xf32>) -> tensor<1x16x32xf32>
return %3 : tensor<1x16x32xf32>
}
}

module {
func.func @main(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<784x128xf32> {
func.func @main(%arg0: tensor<128xf32>, %arg1: tensor<128xf32>) -> tensor<784x128xf32> {
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
// CHECK-NOT: "ttnn.repeat"
// CHECK: %{{[0-9]+}} = "ttnn.reshape"
Expand Down

0 comments on commit a2dc626

Please sign in to comment.