-
Notifications
You must be signed in to change notification settings - Fork 645
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GPU] Add a pass to convert accumulating GEMMs to GEMMs
Signed-off-by: Nirvedh Meshram <[email protected]>
- Loading branch information
1 parent
fc6c518
commit 78481bb
Showing
8 changed files
with
184 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
136 changes: 136 additions & 0 deletions
136
compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
// Copyright 2025 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
//===- ConvertAccGEMMtoGEMMpass.cpp ----------------------------------===// | ||
// | ||
// Converts Accumulating GEMM to GEMM + elementwise add. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" | ||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/Tensor/Utils/Utils.h" | ||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" | ||
#include "mlir/IR/AffineMap.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
#define GEN_PASS_DEF_CONVERTACCGEMMTOGEMMPASS | ||
#include "iree/compiler/Codegen/Common/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
struct ConvertAccGEMMtoGEMM | ||
: public OpInterfaceRewritePattern<linalg::LinalgOp> { | ||
using OpInterfaceRewritePattern::OpInterfaceRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, | ||
PatternRewriter &rewriter) const override { | ||
if (!linalg::isaContractionOpInterface(linalgOp) && | ||
!isa<linalg::ConvolutionOpInterface>(*linalgOp)) { | ||
return failure(); | ||
} | ||
if (!linalgOp.hasPureTensorSemantics()) | ||
return failure(); | ||
|
||
// Nothing to do if the output tensor operand is already a fill op. | ||
SmallVector<OpOperand *> outputOperands; | ||
if (!linalgOp.hasPureBufferSemantics()) { | ||
outputOperands = llvm::to_vector( | ||
llvm::map_range(linalgOp.getDpsInitsMutable(), | ||
[](OpOperand &opOperand) { return &opOperand; })); | ||
} | ||
// Right now all the cases we see have one output. This can be relaxed once | ||
// we see multiple output ops. | ||
if (outputOperands.size() != 1) | ||
return failure(); | ||
Value outputOperand = outputOperands.front()->get(); | ||
|
||
auto outsDefiningOp = | ||
outputOperand.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>(); | ||
if (!outsDefiningOp) { | ||
// If not DispatchTensorLoadOp then do nothing. | ||
return failure(); | ||
} | ||
auto outputType = llvm::cast<RankedTensorType>(outputOperand.getType()); | ||
if (!outputType.getElementType().isIntOrFloat()) | ||
return failure(); | ||
auto elementType = outputType.getElementType(); | ||
|
||
Location loc = linalgOp.getLoc(); | ||
|
||
// Check if the output tensor access is a projected permutation | ||
if (!linalgOp.getMatchingIndexingMap(outputOperands.front()) | ||
.isProjectedPermutation()) { | ||
return rewriter.notifyMatchFailure( | ||
linalgOp, "Output indexing map must be a projected permutation."); | ||
} | ||
|
||
int64_t outputRank = outputType.getRank(); | ||
SmallVector<utils::IteratorType> iterators(outputRank, | ||
utils::IteratorType::parallel); | ||
SmallVector<AffineMap> maps(3, rewriter.getMultiDimIdentityMap(outputRank)); | ||
|
||
// Create a zero tensor as the new output tensor operand to the Linalg | ||
// contraction op. | ||
SmallVector<OpFoldResult> mixedSizes = | ||
tensor::getMixedSizes(rewriter, loc, outputOperand); | ||
auto initOp = | ||
rewriter.create<tensor::EmptyOp>(loc, mixedSizes, elementType); | ||
Value zero = rewriter.create<arith::ConstantOp>( | ||
loc, rewriter.getZeroAttr(elementType)); | ||
Value fill = | ||
rewriter.create<linalg::FillOp>(loc, zero, initOp.getResult()).result(); | ||
|
||
// Update the contraction op to use the new zero tensor as output operand. | ||
rewriter.modifyOpInPlace(linalgOp, | ||
[&]() { linalgOp.setDpsInitOperand(0, fill); }); | ||
|
||
// Create a generic op to add back the original output tensor operand. | ||
rewriter.setInsertionPointAfter(linalgOp); | ||
auto genericOp = rewriter.create<linalg::GenericOp>( | ||
loc, outputType, ValueRange{linalgOp->getResult(0), outputOperand}, | ||
fill, maps, iterators, | ||
[&](OpBuilder &b, Location nestedLoc, ValueRange args) { | ||
Value result; | ||
if (llvm::isa<FloatType>(elementType)) { | ||
result = b.create<arith::AddFOp>(nestedLoc, args[0], args[1]); | ||
} else { | ||
result = b.create<arith::AddIOp>(nestedLoc, args[0], args[1]); | ||
} | ||
b.create<linalg::YieldOp>(nestedLoc, result); | ||
}); | ||
linalgOp->getResult(0).replaceAllUsesExcept(genericOp->getResult(0), | ||
genericOp); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct ConvertAccGEMMToGEMMPass | ||
: public impl::ConvertAccGEMMToGEMMPassBase<ConvertAccGEMMToGEMMPass> { | ||
void getDependentDialects(DialectRegistry ®istry) const override { | ||
registry.insert<arith::ArithDialect, linalg::LinalgDialect, | ||
tensor::TensorDialect>(); | ||
} | ||
|
||
void runOnOperation() override { | ||
RewritePatternSet patterns(&getContext()); | ||
patterns.add<ConvertAccGEMMtoGEMM>(&getContext()); | ||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { | ||
return signalPassFailure(); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
compiler/src/iree/compiler/Codegen/Common/test/convert_accgemm_to_gemm.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// RUN: iree-opt --split-input-file --iree-convert-accgemm-to-gemm %s | FileCheck %s | ||
|
||
#pipeline_layout = #hal.pipeline.layout<bindings = [ | ||
#hal.pipeline.binding<storage_buffer>, | ||
#hal.pipeline.binding<storage_buffer>, | ||
#hal.pipeline.binding<storage_buffer> | ||
]> | ||
|
||
func.func @accumulate_gemm() { | ||
%c0 = arith.constant 0 : index | ||
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<512x128xi8>> | ||
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<512x128xi8>> | ||
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>> | ||
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [512, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<512x128xi8>> -> tensor<512x128xi8> | ||
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<512x128xi8>> -> tensor<512x128xi8> | ||
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>> -> tensor<512x512xi32> | ||
%6 = linalg.generic { | ||
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, | ||
affine_map<(d0, d1, d2) -> (d1, d2)>, | ||
affine_map<(d0, d1, d2) -> (d0, d1)>], | ||
iterator_types = ["parallel", "parallel", "reduction"]} | ||
ins(%3, %4 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%5 : tensor<512x512xi32>) { | ||
^bb0(%in: i8, %in_0: i8, %out: i32): | ||
%7 = arith.extsi %in : i8 to i32 | ||
%8 = arith.extsi %in_0 : i8 to i32 | ||
%9 = arith.muli %7, %8 : i32 | ||
%10 = arith.addi %out, %9 : i32 | ||
linalg.yield %10 : i32 | ||
} -> tensor<512x512xi32> | ||
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : tensor<512x512xi32> -> !flow.dispatch.tensor<readwrite:tensor<512x512xi32>> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @accumulate_gemm() | ||
// CHECK: %[[GEMM:.+]] = linalg.generic | ||
// CHECK: %[[ADD:.+]] = linalg.generic {{.+}} ins(%[[GEMM]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters