diff --git a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp index 79ae8d3b2ba8..7ce4bddd5731 100644 --- a/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/BubbleUpExpandShapes.cpp @@ -38,6 +38,81 @@ struct BubbleUpExpandShapesPass final void runOnOperation() override; }; +/// Bubbles a `tensor.expand_shape` op through a `tensor.extract_slice` op. This +/// pattern only gets applied when the `extract_slice` doesn't modify dimensions +/// that are expanded by the `expand_shape` and when the `extract_slice` is +/// completely static. +/// TODO: move this upstream with other tensor bubbling patterns. +struct BubbleExpandThroughExtract final + : public OpRewritePattern { + + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + auto extractOp = expandOp.getSrc().getDefiningOp(); + if (!extractOp) { + return failure(); + } + + auto srcType = extractOp.getSourceType(); + auto extractedType = extractOp.getType(); + auto expandedType = expandOp.getType(); + + if (srcType.getRank() != extractedType.getRank()) { + return rewriter.notifyMatchFailure( + extractOp, "Rank reducing extract_slice not supported"); + } + + if (!srcType.hasStaticShape() || !extractedType.hasStaticShape() || + !expandedType.hasStaticShape()) { + return failure(); + } + + auto reassoc = expandOp.getReassociationIndices(); + for (auto i : llvm::seq(0, extractedType.getRank())) { + if (reassoc[i].size() == 1) { + continue; + } + + if (srcType.getShape()[i] != extractedType.getShape()[i]) { + return rewriter.notifyMatchFailure( + extractOp, "Extract modifies the expanded dimension"); + } + } + + SmallVector newExpandShape; + SmallVector offsets; + SmallVector sizes; + SmallVector strides; + for (auto [inDim, outDims] : llvm::enumerate(reassoc)) { + if (outDims.size() == 1) { + newExpandShape.push_back(srcType.getShape()[inDim]); + offsets.push_back(extractOp.getStaticOffsets()[inDim]); + sizes.push_back(extractOp.getStaticSizes()[inDim]); + strides.push_back(extractOp.getStaticStrides()[inDim]); + } else { + for (auto outDim : outDims) { + newExpandShape.push_back(expandedType.getShape()[outDim]); + offsets.push_back(0); + sizes.push_back(expandedType.getShape()[outDim]); + strides.push_back(1); + } + } + } + + Type newExpandType = + RankedTensorType::get(newExpandShape, expandedType.getElementType()); + auto newExpand = rewriter.create( + expandOp.getLoc(), newExpandType, extractOp.getSource(), reassoc); + + rewriter.replaceOpWithNewOp( + expandOp, expandedType, newExpand, ValueRange{}, ValueRange{}, + ValueRange{}, offsets, sizes, strides); + return success(); + } +}; + } // namespace void BubbleUpExpandShapesPass::runOnOperation() { @@ -87,6 +162,7 @@ void BubbleUpExpandShapesPass::runOnOperation() { // Add patterns to do some additional cleanup (on top of canonicalizations // that can be done later) of reshape ops. tensor::populateFoldTensorEmptyPatterns(bubbleExpandShapePatterns); + bubbleExpandShapePatterns.insert(context); GreedyRewriteConfig rewriteConfig; rewriteConfig.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel index c132debab4f1..f45b4b75c30d 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel +++ b/compiler/src/iree/compiler/DispatchCreation/test/BUILD.bazel @@ -27,6 +27,7 @@ iree_lit_test_suite( "form_dispatch_regions.mlir", "dispatch_linalg_on_tensors.mlir", "convert_region_to_workgroups.mlir", + "bubble_up_expand_shapes.mlir", "bubble_up_extract_slice.mlir", "form_dispatch_workgroups.mlir", "dispatch_linalg_ext_fusion.mlir", diff --git a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt index 582e9ae937cc..d13f858de549 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/DispatchCreation/test/CMakeLists.txt @@ -15,6 +15,7 @@ iree_lit_test_suite( lit SRCS "attention_fuse_by_expansion.mlir" + "bubble_up_expand_shapes.mlir" "bubble_up_extract_slice.mlir" "clone_producers_into_dispatch_regions.mlir" "collapse_dimensions.mlir" diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir new file mode 100644 index 000000000000..b014d59f881c --- /dev/null +++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_expand_shapes.mlir @@ -0,0 +1,23 @@ +// RUN: iree-opt --split-input-file --mlir-print-local-scope --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-bubble-up-expand-shapes))" %s | FileCheck %s + +util.func public @bubbble_expand_through_extract(%arg0 : tensor<2x4096x5120xf16>) -> (tensor<2x64x64x2560xf16>) { + %extracted_slice_237 = tensor.extract_slice %arg0[0, 0, 0] [2, 4096, 2560] [1, 1, 1] : tensor<2x4096x5120xf16> to tensor<2x4096x2560xf16> + %expanded_239 = tensor.expand_shape %extracted_slice_237 [[0], [1, 2], [3]] output_shape [2, 64, 64, 2560] : tensor<2x4096x2560xf16> into tensor<2x64x64x2560xf16> + util.return %expanded_239 : tensor<2x64x64x2560xf16> +} + +// CHECK-LABEL: @bubbble_expand_through_extract +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[EXPAND]] + +// ----- + +util.func public @unsupported_bubbble_expand_through_extract(%arg0 : tensor<2x4096x5120xf16>) -> (tensor<2x32x64x2560xf16>) { + %extracted_slice_237 = tensor.extract_slice %arg0[0, 0, 0] [2, 2048, 2560] [1, 1, 1] : tensor<2x4096x5120xf16> to tensor<2x2048x2560xf16> + %expanded_239 = tensor.expand_shape %extracted_slice_237 [[0], [1, 2], [3]] output_shape [2, 32, 64, 2560] : tensor<2x2048x2560xf16> into tensor<2x32x64x2560xf16> + util.return %expanded_239 : tensor<2x32x64x2560xf16> +} + +// CHECK-LABEL: @unsupported_bubbble_expand_through_extract +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[EXTRACT]] diff --git a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir index 691cbfab0a19..b582b5628fa3 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/bubble_up_extract_slice.mlir @@ -95,6 +95,8 @@ util.func public @bubble_up_extract_with_use(%arg0 : tensor<1024x7x7x2xi8>) -> ( // CHECK-SAME: ins(%[[EXTRACT0]] : tensor<1024x7x7xi8>) // CHECK: util.return %[[GENERIC1]], %[[GENERIC0]] +// ----- + util.func public @bubble_up_extract_fill_multi_use() -> tensor<2x320x130x130xf8E4M3FNUZ> { %cst_1 = arith.constant 1.000000e+00 : f8E4M3FNUZ %cst_2 = arith.constant 2.000000e+00 : f8E4M3FNUZ