From 612804e7934304a0f341fd81e83fd1ec6749af3f Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Thu, 26 Dec 2024 07:43:31 -0800 Subject: [PATCH] Fold the unit dimensions on scatter ops Signed-off-by: Ian Wood --- .../LinalgExt/Transforms/ReshapeFusion.cpp | 72 +++++++++++++++++++ .../Dialect/LinalgExt/Transforms/Transforms.h | 7 ++ .../DispatchCreation/FoldUnitExtentDims.cpp | 7 ++ .../DispatchCreation/test/fold_unit_dims.mlir | 30 ++++++++ 4 files changed, 116 insertions(+) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp index e87efd9f2099..83d1fa029a52 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/ReshapeFusion.cpp @@ -486,6 +486,68 @@ struct FoldAttentionWithProducerReshapeByExpansion final linalg::ControlFusionFn controlFoldingReshapes; }; +/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand. +/// The `update` tensor is scanned from left to right, starting from the second +/// element. The number of unit dimensions are counted until reaching a non unit +/// dim. +struct FoldScatterUnitDims final : public OpRewritePattern { + FoldScatterUnitDims(MLIRContext *context, linalg::ControlDropUnitDims options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(std::move(options)) {} + + LogicalResult matchAndRewrite(ScatterOp scatterOp, + PatternRewriter &rewriter) const override { + if (options.rankReductionStrategy != + linalg::ControlDropUnitDims::RankReductionStrategy:: + ReassociativeReshape) { + return rewriter.notifyMatchFailure( + scatterOp, "Only reassociative reshape strategy supported"); + } + llvm::SmallVector canDrop = options.controlFn(scatterOp); + const ArrayRef updateShape = scatterOp.getUpdateType().getShape(); + + // Find the first `numDimsToDrop` unit dimensions in the update tensor, + // these are the ones that can be dropped. + int64_t numDimsToDrop = + llvm::find_if(scatterOp.getUpdateSliceShape(), + [](int64_t val) { return val != 1; }) - + updateShape.begin() - 1; + + int64_t batchRank = scatterOp.getBatchRank(); + llvm::erase_if(canDrop, [&](unsigned dimPos) { + return dimPos < batchRank || dimPos >= batchRank + numDimsToDrop; + }); + if (canDrop.empty()) { + return failure(); + } + + SmallVector droppedUpdateShape; + droppedUpdateShape.reserve(updateShape.size() - canDrop.size()); + for (auto [idx, dimLen] : llvm::enumerate(updateShape)) { + if (!llvm::is_contained(canDrop, idx)) { + droppedUpdateShape.push_back(dimLen); + } + } + + auto reassoc = + getReassociationIndicesForCollapse(updateShape, droppedUpdateShape); + assert(reassoc.has_value() && "expected reassociation to be valid"); + auto collapseOp = rewriter.create( + scatterOp.getLoc(), + RankedTensorType::get(droppedUpdateShape, + scatterOp.getUpdateType().getElementType()), + scatterOp.getUpdates(), reassoc.value()); + + rewriter.modifyOpInPlace(scatterOp, [&]() { + scatterOp.setOperand(ScatterOp::kUpdatesOpNum, collapseOp.getResult()); + }); + return success(); + } + + linalg::ControlDropUnitDims options; +}; + } // namespace /// Return the `reassociation` indices to use to collapse the operand when the @@ -708,4 +770,14 @@ void populateFoldReshapeOpsByExpansionPatterns( patterns.getContext(), controlFoldingReshapes); } +SmallVector defaultControlDropUnitDims(Operation *op) { + auto fusionOp = cast(op); + return llvm::to_vector(llvm::seq(0, fusionOp.getNumLoops())); +} + +void populateFoldUnitExtentDimsPatterns( + RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) { + patterns.add(patterns.getContext(), options); +} + } // namespace mlir::iree_compiler::IREE::LinalgExt diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h index 8bf84cab2574..8da0225e27ef 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h @@ -25,6 +25,13 @@ void populateBubbleTransposeFromLinalgExtOps( RewritePatternSet &patterns, const linalg::ControlFusionFn &controlFusionFn); +/// Default function to drop unit dims for for linalgext ops. +SmallVector defaultControlDropUnitDims(Operation *op); + +/// Drop unit extent dims from linalg ext ops +void populateFoldUnitExtentDimsPatterns( + RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options); + /// Helper struct to hold the results of collapsing an operation. struct CollapseResult { SmallVector results; diff --git a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp index f802f0b9742b..40fabc56bcf7 100644 --- a/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp +++ b/compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h" +#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h" #include "iree/compiler/Dialect/Stream/IR/StreamTypes.h" #include "iree/compiler/Dialect/Util/Analysis/Explorer.h" #include "iree/compiler/DispatchCreation/Passes.h" @@ -151,9 +153,14 @@ void FoldUnitExtentDimsPass::runOnOperation() { if (!IREE::Flow::isNonNullAndOutsideDispatch(op)) { return SmallVector{}; } + if (isa(op)) { + return IREE::LinalgExt::defaultControlDropUnitDims(op); + } return defaultFn(op); }; linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options); + IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, + options); linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns); if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(foldUnitDimsPatterns)))) { diff --git a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir index 249a8b1cba4b..3afa827e4911 100644 --- a/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir +++ b/compiler/src/iree/compiler/DispatchCreation/test/fold_unit_dims.mlir @@ -106,3 +106,33 @@ module @fold_stream_parameter { // CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32> // CHECK: util.func public @fold_stream_parameter // CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32> + +// ----- + +util.func public @scatter0(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor + util.return %0 : tensor +} +// CHECK-LABEL: func public @scatter0 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape +// CHECK-SAME: to tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[COLLAPSE]] + +// ----- + +util.func public @scatter1(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) { + ^bb0(%arg3: f16, %arg4: f16): + iree_linalg_ext.yield %arg3 : f16 + } -> tensor + util.return %0 : tensor +} +// CHECK-LABEL: func public @scatter1 +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape +// CHECK-SAME: to tensor +// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: ins(%[[COLLAPSE]]