Skip to content

Commit

Permalink
Fold the unit dimensions on scatter ops
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 committed Dec 27, 2024
1 parent f8e8dbe commit 612804e
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,68 @@ struct FoldAttentionWithProducerReshapeByExpansion final
linalg::ControlFusionFn controlFoldingReshapes;
};

/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand.
/// The `update` tensor is scanned from left to right, starting from the second
/// element. The number of unit dimensions are counted until reaching a non unit
/// dim.
struct FoldScatterUnitDims final : public OpRewritePattern<ScatterOp> {
FoldScatterUnitDims(MLIRContext *context, linalg::ControlDropUnitDims options,
PatternBenefit benefit = 1)
: OpRewritePattern<ScatterOp>(context, benefit),
options(std::move(options)) {}

LogicalResult matchAndRewrite(ScatterOp scatterOp,
PatternRewriter &rewriter) const override {
if (options.rankReductionStrategy !=
linalg::ControlDropUnitDims::RankReductionStrategy::
ReassociativeReshape) {
return rewriter.notifyMatchFailure(
scatterOp, "Only reassociative reshape strategy supported");
}
llvm::SmallVector<unsigned> canDrop = options.controlFn(scatterOp);
const ArrayRef<int64_t> updateShape = scatterOp.getUpdateType().getShape();

// Find the first `numDimsToDrop` unit dimensions in the update tensor,
// these are the ones that can be dropped.
int64_t numDimsToDrop =
llvm::find_if(scatterOp.getUpdateSliceShape(),
[](int64_t val) { return val != 1; }) -
updateShape.begin() - 1;

int64_t batchRank = scatterOp.getBatchRank();
llvm::erase_if(canDrop, [&](unsigned dimPos) {
return dimPos < batchRank || dimPos >= batchRank + numDimsToDrop;
});
if (canDrop.empty()) {
return failure();
}

SmallVector<int64_t> droppedUpdateShape;
droppedUpdateShape.reserve(updateShape.size() - canDrop.size());
for (auto [idx, dimLen] : llvm::enumerate(updateShape)) {
if (!llvm::is_contained(canDrop, idx)) {
droppedUpdateShape.push_back(dimLen);
}
}

auto reassoc =
getReassociationIndicesForCollapse(updateShape, droppedUpdateShape);
assert(reassoc.has_value() && "expected reassociation to be valid");
auto collapseOp = rewriter.create<tensor::CollapseShapeOp>(
scatterOp.getLoc(),
RankedTensorType::get(droppedUpdateShape,
scatterOp.getUpdateType().getElementType()),
scatterOp.getUpdates(), reassoc.value());

rewriter.modifyOpInPlace(scatterOp, [&]() {
scatterOp.setOperand(ScatterOp::kUpdatesOpNum, collapseOp.getResult());
});
return success();
}

linalg::ControlDropUnitDims options;
};

} // namespace

/// Return the `reassociation` indices to use to collapse the operand when the
Expand Down Expand Up @@ -708,4 +770,14 @@ void populateFoldReshapeOpsByExpansionPatterns(
patterns.getContext(), controlFoldingReshapes);
}

SmallVector<unsigned> defaultControlDropUnitDims(Operation *op) {
auto fusionOp = cast<LinalgFusionOpInterface>(op);
return llvm::to_vector(llvm::seq<unsigned>(0, fusionOp.getNumLoops()));
}

void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
patterns.add<FoldScatterUnitDims>(patterns.getContext(), options);
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ void populateBubbleTransposeFromLinalgExtOps(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFusionFn);

/// Default function to drop unit dims for for linalgext ops.
SmallVector<unsigned> defaultControlDropUnitDims(Operation *op);

/// Drop unit extent dims from linalg ext ops
void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options);

/// Helper struct to hold the results of collapsing an operation.
struct CollapseResult {
SmallVector<Value> results;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
#include "iree/compiler/DispatchCreation/Passes.h"
Expand Down Expand Up @@ -151,9 +153,14 @@ void FoldUnitExtentDimsPass::runOnOperation() {
if (!IREE::Flow::isNonNullAndOutsideDispatch(op)) {
return SmallVector<unsigned>{};
}
if (isa<IREE::LinalgExt::LinalgExtOp>(op)) {
return IREE::LinalgExt::defaultControlDropUnitDims(op);
}
return defaultFn(op);
};
linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns,
options);
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
if (failed(applyPatternsAndFoldGreedily(moduleOp,
std::move(foldUnitDimsPatterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,33 @@ module @fold_stream_parameter {
// CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32>
// CHECK: util.func public @fold_stream_parameter
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32>

// -----

util.func public @scatter0(%arg0: tensor<?x1x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x16x4x128xf16>
util.return %0 : tensor<?x2x16x4x128xf16>
}
// CHECK-LABEL: func public @scatter0
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-SAME: to tensor<?x2x16x4x128xf16>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[COLLAPSE]]

// -----

util.func public @scatter1(%arg0: tensor<?x1x1x16x4x128xf16>, %arg1: tensor<?x2xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x1x16x4x128xf16>, tensor<?x2xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x16x4x128xf16>
util.return %0 : tensor<?x2x16x4x128xf16>
}
// CHECK-LABEL: func public @scatter1
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-SAME: to tensor<?x16x4x128xf16>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[COLLAPSE]]

0 comments on commit 612804e

Please sign in to comment.