From 8d96ba8870e220866ea5e796ab795cc8a5a2abd6 Mon Sep 17 00:00:00 2001 From: Nick Smith <127986401+nsmithtt@users.noreply.github.com> Date: Wed, 28 Aug 2024 03:16:12 -0700 Subject: [PATCH] Create tensor grids with same rank as device grid bug fix #500 (#515) Previously we were just creating tensor layout grids with rank 1 or 2 depending on the tensor's rank. This however is incorrect, the tensor grid must be of the same rank as the device grid. The fix is to use the device grid's rank in the layout type converter to ensure that by default a tensor layout gets a grid of equivalent rank. --- include/ttmlir/Dialect/TT/IR/TTOpsTypes.td | 5 +++++ lib/Dialect/TT/IR/TTOpsTypes.cpp | 9 ++++---- lib/Dialect/TTIR/Transforms/Passes.cpp | 21 ++++++++++++------- test/ttmlir/Dialect/TTIR/test_layout.mlir | 2 +- .../TTNN/eltwise/operand_broadcasts.mlir | 2 +- .../TTNN/eltwise/unary/relu/simple_relu.mlir | 2 +- .../TTNN/embedding/embedding_1d_tensor.mlir | 2 +- .../TTNN/embedding/simple_embedding.mlir | 2 +- test/ttmlir/Dialect/TTNN/simple_concat.mlir | 2 +- test/ttmlir/Dialect/TTNN/simple_div.mlir | 2 +- test/ttmlir/Dialect/TTNN/simple_ge.mlir | 2 +- test/ttmlir/Dialect/TTNN/simple_matmul.mlir | 2 +- test/ttmlir/Dialect/TTNN/simple_mean.mlir | 4 ++-- test/ttmlir/Dialect/TTNN/simple_multiply.mlir | 2 +- test/ttmlir/Dialect/TTNN/simple_subtract.mlir | 2 +- test/ttmlir/Dialect/TTNN/simple_sum.mlir | 2 +- .../Dialect/TTNN/softmax/simple_softmax.mlir | 2 +- .../TTNN/transpose/simple_transpose.mlir | 2 +- .../simple_transpose_8x16_reverse_dims.mlir | 2 +- .../TTNN/transpose/simple_transpose_8x8.mlir | 2 +- .../simple_transpose_negative_dims.mlir | 2 +- test/ttmlir/Translate/TTNN/1d_tensor.mlir | 8 +++++++ 22 files changed, 51 insertions(+), 30 deletions(-) create mode 100644 test/ttmlir/Translate/TTNN/1d_tensor.mlir diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index 90008871c..6f55d29f6 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -44,6 +44,10 @@ def TT_GridAttr : TT_Attr<"Grid", "grid"> { static GridAttr get(::mlir::MLIRContext *context, ArrayRef shape) { return GridAttr::get(context, shape, AffineMap::get(context)); } + + static GridAttr get(::mlir::MLIRContext *context, std::int64_t rank) { + return GridAttr::get(context, SmallVector(rank, 1)); + } }]; } @@ -259,6 +263,7 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> { static LayoutAttr get(::mlir::MLIRContext *context, RankedTensorType ty, MemorySpace memorySpace, + GridAttr grid, Type elementType); LayoutAttr withGrid(::mlir::MLIRContext *context, ArrayRef tensorShape, GridAttr grid, ArrayRef> collapseIntervals = {{0, -1}}); LayoutAttr withGrid(::mlir::MLIRContext *context, diff --git a/lib/Dialect/TT/IR/TTOpsTypes.cpp b/lib/Dialect/TT/IR/TTOpsTypes.cpp index 6a154acd7..d760b012e 100644 --- a/lib/Dialect/TT/IR/TTOpsTypes.cpp +++ b/lib/Dialect/TT/IR/TTOpsTypes.cpp @@ -451,8 +451,7 @@ LayoutAttr LayoutAttr::get( ArrayRef> collapseIntervals, OOBVal oobVal) { if (not grid) { - grid = tensorShape.size() == 1 ? GridAttr::get(context, {1}) - : GridAttr::get(context, {1, 1}); + grid = GridAttr::get(context, tensorShape.size()); } auto linear = collapsedLinearAffineMap(context, tensorShape, grid.getShape(), @@ -474,9 +473,11 @@ LayoutAttr LayoutAttr::get( } LayoutAttr LayoutAttr::get(::mlir::MLIRContext *context, RankedTensorType ty, - MemorySpace memorySpace, Type elementType) { + MemorySpace memorySpace, GridAttr grid, + Type elementType) { assert(ty); - return get(context, ty.getShape(), elementType, memorySpace, {}, {{0, -1}}, + assert(grid); + return get(context, ty.getShape(), elementType, memorySpace, grid, {{0, -1}}, OOBVal::Undef); } diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index 450dda983..3b6f43753 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -433,15 +433,20 @@ inline MemorySpace uppermostMemorySpace(OperandConstraint operandConstraint) { class TTIRLayoutTensorTypeConverter : public TypeConverter { public: - TTIRLayoutTensorTypeConverter(MLIRContext *ctx, MemorySpace initMemorySpace) { + TTIRLayoutTensorTypeConverter(MLIRContext *ctx, MemorySpace initMemorySpace, + GridAttr deviceGrid) { addConversion([](Type type) { return type; }); - addConversion([ctx, initMemorySpace](RankedTensorType type) -> Type { + addConversion([ctx, initMemorySpace, + deviceGrid](RankedTensorType type) -> Type { auto layout = type.getEncoding(); if (layout) { return type; } + std::int64_t deviceGridRank = deviceGrid.getShape().size(); + // Default to single core grid + auto tensorGrid = GridAttr::get(ctx, deviceGridRank); // Default to initMemorySpace, the optimizer might decide otherwise - auto newLayout = LayoutAttr::get(ctx, type, initMemorySpace); + auto newLayout = LayoutAttr::get(ctx, type, initMemorySpace, tensorGrid); return RankedTensorType::get(type.getShape(), type.getElementType(), newLayout); }); @@ -526,8 +531,8 @@ static std::optional createToLayoutOp(PatternRewriter &rewriter, return std::nullopt; } - auto desiredLayout = - rewriter.getAttr(ty, desiredMemorySpace, desiredElementType); + auto desiredLayout = rewriter.getAttr( + ty, desiredMemorySpace, currLayout.getGrid(), desiredElementType); auto output = rewriter.create( loc, ty.getShape(), ty.getElementType(), desiredLayout); @@ -627,8 +632,10 @@ class TTIRLayout : public impl::TTIRLayoutBase { void runOnOperation() final { { - TTIRLayoutTensorTypeConverter typeConverter(&getContext(), - initMemorySpace); + auto device = getCurrentScopeDevice(getOperation()); + assert(device && "Device not found"); + TTIRLayoutTensorTypeConverter typeConverter( + &getContext(), initMemorySpace, device.getGrid()); RewritePatternSet patterns(&getContext()); patterns.add(typeConverter, &getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); diff --git a/test/ttmlir/Dialect/TTIR/test_layout.mlir b/test/ttmlir/Dialect/TTIR/test_layout.mlir index 5a5f11426..c945150ee 100644 --- a/test/ttmlir/Dialect/TTIR/test_layout.mlir +++ b/test/ttmlir/Dialect/TTIR/test_layout.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-layout %s | FileCheck %s +// RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<8x64x128xf32>, %arg1: tensor<8x64x128xf32>) -> tensor<8x64x128xf32> { diff --git a/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts.mlir b/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts.mlir index e031cd3c2..772cabf84 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/operand_broadcasts.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @bcast_one_dim(%arg0: tensor<2x64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<2x64x128xf32> { diff --git a/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir b/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir index 1b0f2f483..54a1f1ca7 100644 --- a/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir +++ b/test/ttmlir/Dialect/TTNN/eltwise/unary/relu/simple_relu.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>) -> tensor<64x128xf32> { diff --git a/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir b/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir index e05931c85..0990b44fd 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/embedding_1d_tensor.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> { diff --git a/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir b/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir index 4d6315331..9f9428520 100644 --- a/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir +++ b/test/ttmlir/Dialect/TTNN/embedding/simple_embedding.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x32x128xf32> { diff --git a/test/ttmlir/Dialect/TTNN/simple_concat.mlir b/test/ttmlir/Dialect/TTNN/simple_concat.mlir index d9adfc286..14140b368 100644 --- a/test/ttmlir/Dialect/TTNN/simple_concat.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_concat.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xf32>, %arg1: tensor<32x64xf32>) -> tensor<32x96xf32> { diff --git a/test/ttmlir/Dialect/TTNN/simple_div.mlir b/test/ttmlir/Dialect/TTNN/simple_div.mlir index d879f8e29..b35570494 100644 --- a/test/ttmlir/Dialect/TTNN/simple_div.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_div.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { diff --git a/test/ttmlir/Dialect/TTNN/simple_ge.mlir b/test/ttmlir/Dialect/TTNN/simple_ge.mlir index 427eb1747..b40a0028b 100644 --- a/test/ttmlir/Dialect/TTNN/simple_ge.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_ge.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { diff --git a/test/ttmlir/Dialect/TTNN/simple_matmul.mlir b/test/ttmlir/Dialect/TTNN/simple_matmul.mlir index c099bf97b..992b0c21d 100644 --- a/test/ttmlir/Dialect/TTNN/simple_matmul.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_matmul.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s #any_device_tile = #tt.operand_constraint // CHECK: #[[TILED_LAYOUT:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>> module attributes {} { diff --git a/test/ttmlir/Dialect/TTNN/simple_mean.mlir b/test/ttmlir/Dialect/TTNN/simple_mean.mlir index 7d1dead21..a41c63e63 100644 --- a/test/ttmlir/Dialect/TTNN/simple_mean.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_mean.mlir @@ -1,6 +1,6 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint -module attributes {} { +module { func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> { // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] diff --git a/test/ttmlir/Dialect/TTNN/simple_multiply.mlir b/test/ttmlir/Dialect/TTNN/simple_multiply.mlir index fbf5b5bbb..a4205de46 100644 --- a/test/ttmlir/Dialect/TTNN/simple_multiply.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_multiply.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { diff --git a/test/ttmlir/Dialect/TTNN/simple_subtract.mlir b/test/ttmlir/Dialect/TTNN/simple_subtract.mlir index c48f7603e..1c06449ee 100644 --- a/test/ttmlir/Dialect/TTNN/simple_subtract.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_subtract.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { diff --git a/test/ttmlir/Dialect/TTNN/simple_sum.mlir b/test/ttmlir/Dialect/TTNN/simple_sum.mlir index 43d515f71..ab3f7cbe6 100644 --- a/test/ttmlir/Dialect/TTNN/simple_sum.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_sum.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x32xbf16> { diff --git a/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir b/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir index 96ce29093..a66681245 100644 --- a/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir +++ b/test/ttmlir/Dialect/TTNN/softmax/simple_softmax.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<512x1024xbf16>) -> tensor<512x1024xbf16> { diff --git a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir index 6c6657a0a..fbf377df1 100644 --- a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir +++ b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>) -> tensor<128x64xbf16> { diff --git a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x16_reverse_dims.mlir b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x16_reverse_dims.mlir index b7f643319..70640d041 100644 --- a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x16_reverse_dims.mlir +++ b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x16_reverse_dims.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x16xbf16>) -> tensor<16x64xbf16> { diff --git a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x8.mlir b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x8.mlir index 50debd891..b9cedf226 100644 --- a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x8.mlir +++ b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_8x8.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xbf16>) -> tensor<32x32xbf16> { diff --git a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_negative_dims.mlir b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_negative_dims.mlir index 996fd1989..035475bc4 100644 --- a/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_negative_dims.mlir +++ b/test/ttmlir/Dialect/TTNN/transpose/simple_transpose_negative_dims.mlir @@ -1,4 +1,4 @@ -// RUN: ttmlir-opt --ttir-load-system-desc --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | FileCheck %s #any_device_tile = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<32x32xbf16>) -> tensor<32x32xbf16> { diff --git a/test/ttmlir/Translate/TTNN/1d_tensor.mlir b/test/ttmlir/Translate/TTNN/1d_tensor.mlir new file mode 100644 index 000000000..695812737 --- /dev/null +++ b/test/ttmlir/Translate/TTNN/1d_tensor.mlir @@ -0,0 +1,8 @@ +// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline %s | ttmlir-translate --ttnn-to-flatbuffer +#any_device = #tt.operand_constraint + +func.func @embedding_1d_tensor(%arg0: tensor<32xf32>, %arg1: tensor<512x128xf32>) -> tensor<32x128xf32> { + %0 = tensor.empty() : tensor<32x128xf32> + %1 = "ttir.embedding"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<32xf32>, tensor<512x128xf32>, tensor<32x128xf32>) -> tensor<32x128xf32> + return %1 : tensor<32x128xf32> +}