Skip to content

Commit

Permalink
[GPU] Add a pass to convert accumulating GEMMs to GEMMs
Browse files Browse the repository at this point in the history
Signed-off-by: Nirvedh Meshram <[email protected]>
  • Loading branch information
nirvedhmeshram committed Jan 2, 2025
1 parent fc6c518 commit 78481bb
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 1 deletion.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ iree_compiler_cc_library(
"CleanupBufferAllocViewPass.cpp",
"ConcretizePadResultShape.cpp",
"ConfigTrackingCanonicalizer.cpp",
"ConvertAccGEMMToGEMMPass.cpp",
"ConvertBf16ArithToF32.cpp",
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ iree_cc_library(
"CleanupBufferAllocViewPass.cpp"
"ConcretizePadResultShape.cpp"
"ConfigTrackingCanonicalizer.cpp"
"ConvertAccGEMMToGEMMPass.cpp"
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
Expand Down
136 changes: 136 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/ConvertAccGEMMToGEMMPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

//===- ConvertAccGEMMtoGEMMpass.cpp ----------------------------------===//
//
// Converts Accumulating GEMM to GEMM + elementwise add.
//
//===----------------------------------------------------------------------===//

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_CONVERTACCGEMMTOGEMMPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

struct ConvertAccGEMMtoGEMM
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const override {
if (!linalg::isaContractionOpInterface(linalgOp) &&
!isa<linalg::ConvolutionOpInterface>(*linalgOp)) {
return failure();
}
if (!linalgOp.hasPureTensorSemantics())
return failure();

// Nothing to do if the output tensor operand is already a fill op.
SmallVector<OpOperand *> outputOperands;
if (!linalgOp.hasPureBufferSemantics()) {
outputOperands = llvm::to_vector(
llvm::map_range(linalgOp.getDpsInitsMutable(),
[](OpOperand &opOperand) { return &opOperand; }));
}
// Right now all the cases we see have one output. This can be relaxed once
// we see multiple output ops.
if (outputOperands.size() != 1)
return failure();
Value outputOperand = outputOperands.front()->get();

auto outsDefiningOp =
outputOperand.getDefiningOp<IREE::Flow::DispatchTensorLoadOp>();
if (!outsDefiningOp) {
// If not DispatchTensorLoadOp then do nothing.
return failure();
}
auto outputType = llvm::cast<RankedTensorType>(outputOperand.getType());
if (!outputType.getElementType().isIntOrFloat())
return failure();
auto elementType = outputType.getElementType();

Location loc = linalgOp.getLoc();

// Check if the output tensor access is a projected permutation
if (!linalgOp.getMatchingIndexingMap(outputOperands.front())
.isProjectedPermutation()) {
return rewriter.notifyMatchFailure(
linalgOp, "Output indexing map must be a projected permutation.");
}

int64_t outputRank = outputType.getRank();
SmallVector<utils::IteratorType> iterators(outputRank,
utils::IteratorType::parallel);
SmallVector<AffineMap> maps(3, rewriter.getMultiDimIdentityMap(outputRank));

// Create a zero tensor as the new output tensor operand to the Linalg
// contraction op.
SmallVector<OpFoldResult> mixedSizes =
tensor::getMixedSizes(rewriter, loc, outputOperand);
auto initOp =
rewriter.create<tensor::EmptyOp>(loc, mixedSizes, elementType);
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
Value fill =
rewriter.create<linalg::FillOp>(loc, zero, initOp.getResult()).result();

// Update the contraction op to use the new zero tensor as output operand.
rewriter.modifyOpInPlace(linalgOp,
[&]() { linalgOp.setDpsInitOperand(0, fill); });

// Create a generic op to add back the original output tensor operand.
rewriter.setInsertionPointAfter(linalgOp);
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, outputType, ValueRange{linalgOp->getResult(0), outputOperand},
fill, maps, iterators,
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
Value result;
if (llvm::isa<FloatType>(elementType)) {
result = b.create<arith::AddFOp>(nestedLoc, args[0], args[1]);
} else {
result = b.create<arith::AddIOp>(nestedLoc, args[0], args[1]);
}
b.create<linalg::YieldOp>(nestedLoc, result);
});
linalgOp->getResult(0).replaceAllUsesExcept(genericOp->getResult(0),
genericOp);
return success();
}
};

struct ConvertAccGEMMToGEMMPass
: public impl::ConvertAccGEMMToGEMMPassBase<ConvertAccGEMMToGEMMPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, linalg::LinalgDialect,
tensor::TensorDialect>();
}

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.add<ConvertAccGEMMtoGEMM>(&getContext());
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
}
}
};

} // namespace
} // namespace mlir::iree_compiler
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def ConcretizePadResultShapePass :
"implements OffsetSizeAndStrideOpInterface.";
}

def ConvertAccGEMMToGEMMPass :
Pass<"iree-convert-accgemm-to-gemm", ""> {
let summary = "Convert accumulating GEMMs to GEMMs post dispatch creation.";
}

def ConvertBf16ArithToF32Pass : Pass<"iree-convert-bf16-arith-to-f32", ""> {
let summary = "Convert bf16 arithmetic operations to f32";
}
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_lit_test_suite(
"bubble_up_ordinal_ops.mlir",
"bufferize_copy_only_dispatches.mlir",
"canonicalize_interface_load_store.mlir",
"convert_accgemm_to_gemm.mlir",
"convert_bf16_to_uint16_buffers.mlir",
"convert_bf16_arith_to_f32.mlir",
"convert_to_destination_passing_style.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_lit_test_suite(
"bubble_up_ordinal_ops.mlir"
"bufferize_copy_only_dispatches.mlir"
"canonicalize_interface_load_store.mlir"
"convert_accgemm_to_gemm.mlir"
"convert_bf16_arith_to_f32.mlir"
"convert_bf16_to_uint16_buffers.mlir"
"convert_to_destination_passing_style.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: iree-opt --split-input-file --iree-convert-accgemm-to-gemm %s | FileCheck %s

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
]>

func.func @accumulate_gemm() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<512x128xi8>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<512x128xi8>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [512, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<512x128xi8>> -> tensor<512x128xi8>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [512, 128], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<512x128xi8>> -> tensor<512x128xi8>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : !flow.dispatch.tensor<readwrite:tensor<512x512xi32>> -> tensor<512x512xi32>
%6 = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%3, %4 : tensor<512x128xi8>, tensor<512x128xi8>) outs(%5 : tensor<512x512xi32>) {
^bb0(%in: i8, %in_0: i8, %out: i32):
%7 = arith.extsi %in : i8 to i32
%8 = arith.extsi %in_0 : i8 to i32
%9 = arith.muli %7, %8 : i32
%10 = arith.addi %out, %9 : i32
linalg.yield %10 : i32
} -> tensor<512x512xi32>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [512, 512], strides = [1, 1] : tensor<512x512xi32> -> !flow.dispatch.tensor<readwrite:tensor<512x512xi32>>
return
}

// CHECK-LABEL: func.func @accumulate_gemm()
// CHECK: %[[GEMM:.+]] = linalg.generic
// CHECK: %[[ADD:.+]] = linalg.generic {{.+}} ins(%[[GEMM]]
4 changes: 3 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,9 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
if (pipelineOptions.useIgemmConvolution) {
funcPassManager.addPass(createConvolutionToIGEMMPass());
}

// TODO (nirvedhmeshram) : Can remove this pass after
// https://github.com/iree-org/iree/issues/19546 is fixed.
funcPassManager.addPass(createConvertAccGEMMToGEMMPass());
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true,
/*convertToDpsOptions=*/std::nullopt);

Expand Down

0 comments on commit 78481bb

Please sign in to comment.