Skip to content

Commit

Permalink
[DispatchCreation] Move the logic to transpose indexing maps into dis…
Browse files Browse the repository at this point in the history
…patch formation logic.

For cases like

```
%0 = linalg.matmul
%1 = linalg.generic {
    indexing_maps=[affine_map<(d0, d1) -> (d1, d0)>, ...], ...}
    ins(%0, ... : ...) outs(...) {...}

```

some preprocessing patterns convert these to

```
%0 = linalg.matmul
%1 = linalg.generic {
    indexing_maps=[affine_map<(d0, d1) -> (d0, d1)>, ...], ...}
    ins(%0, ... : ...) outs(...) {...}

```

to make these operations fusable. But these preprocessing are run too
early and doing a spooky-action-at-a-distance.

Instead just move this logic into dispatch formation itself.

Signed-off-by: MaheshRavishankar <[email protected]>
  • Loading branch information
MaheshRavishankar committed Jan 3, 2025
1 parent fc6c518 commit 6888a18
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -383,6 +384,70 @@ static bool areOpsFusable(Operation *producer, Operation *consumer,
return true;
}

/// The logic to decide fusability (using the `hasCompatibleOuterParallelLoops`)
/// currently works when the indexing map corresponding to result of the
/// producer and indexing map corresponding to operand in the result are not
/// transposed with respect to each other. To find more fusion opportunities for
/// consumer elementwise operation, the indexing maps in the consumer can be
/// made to "align" with the indexing map of the producer to enhance fusion.
static bool areOpsFusableAfterInterchangeOfConsumer(
OpOperand &fusableOperand,
const llvm::SmallBitVector &rootOuterParallelLoops) {
Operation *producer = fusableOperand.get().getDefiningOp();
if (!producer) {
return false;
}

Operation *consumer = fusableOperand.getOwner();
auto genericOp = dyn_cast<linalg::GenericOp>(consumer);
if (!genericOp) {
return false;
}
assert(genericOp.getNumDpsInputs() > 0 &&
"expected consumer to have at least one input");

if (!linalg::isElementwise(genericOp) || genericOp.getNumResults() != 1) {
return false;
}
AffineMap inputMap = genericOp.getMatchingIndexingMap(&fusableOperand);
if (!inputMap.isPermutation() || inputMap.isIdentity()) {
return false;
}
OpResult result = cast<OpResult>(genericOp.getResult(0));
if (!genericOp.getIndexingMapMatchingResult(result).isPermutation()) {
return false;
}

// For now this is restricting that all indexing maps corresponding to the
// input are same as the indexing map of the fused operand. THat is
// overly conservative. Really just need to check that the indexing map
// are permutations.
if (!llvm::all_of(
genericOp.getDpsInputOperands(), [&](OpOperand *inputOperand) {
AffineMap map = genericOp.getMatchingIndexingMap(inputOperand);
return map == inputMap;
})) {
return false;
}

// Make the input map identity.
auto perm =
llvm::map_to_vector(inputMap.getResults(), [](AffineExpr e) -> unsigned {
return cast<AffineDimExpr>(e).getPosition();
});
IRRewriter rewriter(consumer->getContext());
FailureOr<linalg::GenericOp> interchangedOp =
linalg::interchangeGenericOp(rewriter, genericOp, perm);
(void)interchangedOp;
assert(succeeded(interchangedOp) && "expected interchange to succeed");
assert(interchangedOp.value() == genericOp &&
"expected interchange to happen in place");
assert(
areOpsFusable(producer, interchangedOp.value(), rootOuterParallelLoops) &&
"expected the interchanged op to be fusable");
return true;
}

/// For the fusion of root op -> elementwise operation to be bufferized
/// in-place without use of extra memory, the result of the root operation
/// must be able to reuse the buffer for the result of the elementwise
Expand Down Expand Up @@ -531,7 +596,10 @@ isFusableWithConsumer(OpOperand &fusedOperand,
}

if (!areOpsFusable(producer, consumer, rootOuterParallelLoops)) {
return false;
if (!areOpsFusableAfterInterchangeOfConsumer(fusedOperand,
rootOuterParallelLoops)) {
return false;
}
}

// Check if the iteration spaces of the producer and consumer are same.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -922,3 +922,35 @@ util.func @custom_op_no_producer_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<
// CHECK-SAME: ins(%[[DISPATCH1]],
// CHECK: flow.return %[[CUSTOM_OP]]
// CHECK: util.return %[[DISPATCH2]]

// -----

util.func @fuse_transposed_op(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0: f32
%m = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%n = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%empty = tensor.empty(%m, %n) : tensor<?x?xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?x?xf32>) -> tensor<?x?xf32>
%matmul = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
%empty2 = tensor.empty(%n, %m) : tensor<?x?xf32>
%generic = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%matmul, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%empty2 : tensor<?x?xf32>) {
^bb0(%b0: f32, %b1 : f32, %b2 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
util.return %generic : tensor<?x?xf32>
}
// CHECK-LABEL: func public @fuse_transposed_op
// CHECK: %[[DISPATCH:.+]] = flow.dispatch.region
// CHECK: %[[MATMUL:.+]] = linalg.matmul
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[MATMUL]],
// CHECK: flow.return %[[GENERIC]]
// CHECK: return %[[DISPATCH]]

0 comments on commit 6888a18

Please sign in to comment.