Skip to content

Commit

Permalink
Remove action-at-a-distance patterns.
Browse files Browse the repository at this point in the history
Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Jan 3, 2025
1 parent 6888a18 commit ea95ca5
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,44 +38,6 @@ namespace mlir::iree_compiler::DispatchCreation {

namespace {

//===----------------------------------------------------------------------===//
// ElementwiseOpInterchangePattern
//===----------------------------------------------------------------------===//

// If possible, interchange indexing maps to make input maps all identity.
struct ElementwiseOpInterchangePattern final
: public OpRewritePattern<linalg::GenericOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(linalg::GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1 ||
genericOp.getNumDpsInputs() == 0)
return failure();

// All input maps must be equal and non-identity. All maps, including
// output, must be be permutations. Permutation maps are checked by
// isElementwise but may be removed.
AffineMap inputMap = genericOp.getIndexingMapsArray().front();
auto *initOperand = genericOp.getDpsInitOperand(0);
if (inputMap.isIdentity() || !inputMap.isPermutation() ||
!genericOp.getMatchingIndexingMap(initOperand).isPermutation()) {
return failure();
}
for (auto *operand : genericOp.getDpsInputOperands()) {
if (genericOp.getMatchingIndexingMap(operand) != inputMap) {
return failure();
}
}

// Make all inputs identity.
ArrayRef<AffineExpr> exprs = inputMap.getResults();
auto perm = llvm::map_to_vector(exprs, [](AffineExpr e) -> unsigned {
return cast<AffineDimExpr>(e).getPosition();
});
return linalg::interchangeGenericOp(rewriter, genericOp, perm);
}
};

//===----------------------------------------------------------------------===//
// FoldSuccessiveTensorInsertSliceOps
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -153,8 +115,7 @@ struct FusionPreprocessingPass final
: public impl::FusionPreprocessingPassBase<FusionPreprocessingPass> {
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<ElementwiseOpInterchangePattern,
FoldSuccessiveTensorInsertSliceOps>(&getContext());
patterns.add<FoldSuccessiveTensorInsertSliceOps>(&getContext());

// Fold away `tensor.dim` operations that can be resolved in terms of its
// operand shapes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,77 +30,3 @@ util.func public @fold_insert_slices(%source : tensor<?x?xf32>,
// CHECK: %[[RETURN:.+]] = tensor.insert_slice %[[SOURCE]] into %[[FILL]]
// CHECK-SAME: [%[[NEW_OFFSET0]], %[[NEW_OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]]
// CHECK: util.return %[[RETURN]]

// -----

#ident = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#perm = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>
util.func @single_input_interchange(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x320x128x128xf16> {
%0 = tensor.empty() : tensor<2x320x128x128xf16>
%1 = linalg.generic {indexing_maps = [#perm, #ident], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x128x128x320xf32>) outs(%0 : tensor<2x320x128x128xf16>) {
^bb0(%in: f32, %out: f16):
%2 = arith.truncf %in : f32 to f16
linalg.yield %2 : f16
} -> tensor<2x320x128x128xf16>
util.return %1 : tensor<2x320x128x128xf16>
}

// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[$PERM_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>
// CHECK-LABEL: util.func public @single_input_interchange
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x128x128x320xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x320x128x128xf16>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$PERM_MAP]]]
// CHECK-SAME: ins(%[[ARG0]] : tensor<2x128x128x320xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x320x128x128xf16>)

// -----

#ident = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#perm = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>
util.func @multi_input_interchange(%arg0: tensor<2x128x128x320xf32>) -> tensor<2x320x128x128xf16> {
%0 = tensor.empty() : tensor<2x320x128x128xf16>
%1 = linalg.generic {indexing_maps = [#perm, #perm, #ident], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %arg0 : tensor<2x128x128x320xf32>, tensor<2x128x128x320xf32>) outs(%0 : tensor<2x320x128x128xf16>) {
^bb0(%in: f32, %in_1: f32, %out: f16):
%2 = arith.addf %in, %in_1 : f32
%3 = arith.truncf %2 : f32 to f16
linalg.yield %3 : f16
} -> tensor<2x320x128x128xf16>
util.return %1 : tensor<2x320x128x128xf16>
}

// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[$PERM_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d1)>
// CHECK-LABEL: util.func public @multi_input_interchange
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x128x128x320xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<2x320x128x128xf16>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$IDENT_MAP]], #[[$PERM_MAP]]]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]] : tensor<2x128x128x320xf32>, tensor<2x128x128x320xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x320x128x128xf16>)

// -----

#ident = affine_map<(d0, d1) -> (d0, d1)>
#perm0 = affine_map<(d0, d1) -> (d1, d0)>
util.func @multi_input_no_interchange(%arg0: tensor<10x10xf32>) -> tensor<10x10xf16> {
%0 = tensor.empty() : tensor<10x10xf16>
%1 = linalg.generic {indexing_maps = [#ident, #perm0, #perm0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg0 : tensor<10x10xf32>, tensor<10x10xf32>) outs(%0 : tensor<10x10xf16>) {
^bb0(%in: f32, %in_1: f32, %out: f16):
%2 = arith.addf %in, %in_1 : f32
%3 = arith.truncf %2 : f32 to f16
linalg.yield %3 : f16
} -> tensor<10x10xf16>
util.return %1 : tensor<10x10xf16>
}

// CHECK-DAG: #[[$IDENT_MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[$PERM_MAP0:.+]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-LABEL: util.func public @multi_input_no_interchange
// CHECK-SAME: %[[ARG0:.*]]: tensor<10x10xf32>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<10x10xf16>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$IDENT_MAP]], #[[$PERM_MAP0]], #[[$PERM_MAP0]]]
// CHECK-SAME: ins(%[[ARG0]], %[[ARG0]] : tensor<10x10xf32>, tensor<10x10xf32>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<10x10xf16>)

0 comments on commit ea95ca5

Please sign in to comment.