Skip to content

Commit

Permalink
ukernel-lowering
Browse files Browse the repository at this point in the history
Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob committed Dec 19, 2024
1 parent ed9a028 commit e23f5a2
Show file tree
Hide file tree
Showing 14 changed files with 189 additions and 95 deletions.
8 changes: 4 additions & 4 deletions compiler/plugins/target/ROCM/builtins/ukernel/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,19 @@ argmax_bc_files = [
]

iree_amdgpu_bitcode_library(
name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4_gfx942",
name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_gfx942",
srcs = [
"common.h",
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c",
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c",
],
out = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc",
out = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc",
gpu_arch = "gfx942",
)

iree_c_embed_data(
name = "iree_uk_amdgpu_bitcode",
srcs = argmax_bc_files + [
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc",
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc",
],
c_file_output = "iree_uk_amdgpu_bitcode.c",
flatten = True,
Expand Down
8 changes: 4 additions & 4 deletions compiler/plugins/target/ROCM/builtins/ukernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,14 @@ iree_amdgpu_bitcode_library(

iree_amdgpu_bitcode_library(
NAME
iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4_gfx942
iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_gfx942
GPU_ARCH
gfx942
SRCS
"common.h"
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.c"
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.c"
OUT
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc"
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc"
)

iree_c_embed_data(
Expand All @@ -238,7 +238,7 @@ iree_c_embed_data(
"iree_uk_amdgpu_argmax_f32i64.gfx1100.bc"
"iree_uk_amdgpu_argmax_f32i64.gfx90a.bc"
"iree_uk_amdgpu_argmax_f32i64.gfx942.bc"
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc"
"iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc"
C_FILE_OUTPUT
"iree_uk_amdgpu_bitcode.c"
H_FILE_OUTPUT
Expand Down
1 change: 0 additions & 1 deletion compiler/plugins/target/ROCM/builtins/ukernel/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ typedef __UINT64_TYPE__ uint64_t;
// Vector typedefs
//===----------------------------------------------------------------------===//

typedef __attribute__((__vector_size__(8 * 2))) int64_t int64x2_t;
typedef __attribute__((__vector_size__(4 * 4))) int32_t int32x4_t;

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "compiler/plugins/target/ROCM/builtins/ukernel/common.h"

// Very naive kernel. TODO(bjacob):
// 1. Inlining: the `always_inline` attribute here is correctly preserved in
// the bitcode, but isn't having the intended effect of inlining calls to
// this function. Making that work is key as various function parameters
// (e.g. `unroll_m`) are meant to be constants.
// 2. Shared memory: can't allocate it within the microkernel (which is just a
// helper device function, not the actual amdgpu_kernel). Need to get it
// passed down here as a `T [[clang::address_space(3)]] *` parameter.
// 3. Better scheduling via either barrier intrinsics or inline assemby.
// 4. Subgroups1x4 being asymmetric is a historical accident... should be 2x2.
[[clang::always_inline]] void iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8(
const int8_t *a_buffer, int64_t a_offset, const int8_t *b_buffer,
int64_t b_offset, int32_t *c_buffer, int64_t c_offset, int32_t k_size,
int32_t unroll_m, int32_t subgroups_m, int32_t unroll_n,
int32_t subgroups_n, int32_t unroll_k) {
/*
TODO(bjacob): reenable this once inlining works.
// Load existing accumulators. This is a VLA, but should become fixed-size
// once this function is inlined and unroll_* factors become constants.
int32x4_t c[unroll_m][unroll_n];
*/
// Load existing accumulators.
if (unroll_m > 8 || unroll_n > 2) {
__builtin_trap();
}
int32x4_t c[8][2];
int32x4_t *c_global = (int32x4_t *)(c_buffer + c_offset);
for (int m = 0; m < unroll_m; ++m) {
for (int n = 0; n < unroll_n; ++n) {
c[m][n] = c_global[64 * (m * unroll_n + n)];
}
}

// Arithmetic loop.
const int64_t *a_global = (const int64_t *)(a_buffer + a_offset);
const int64_t *b_global = (const int64_t *)(b_buffer + b_offset);
for (int k_outer = 0; k_outer < k_size; ++k_outer) {
for (int m = 0; m < unroll_m; ++m) {
for (int n = 0; n < unroll_n; ++n) {
for (int k = 0; k < unroll_k; ++k) {
c[m][n] = __builtin_amdgcn_mfma_i32_16x16x32_i8(
a_global[64 * unroll_k * m + k], b_global[64 * unroll_k * n + k],
c[m][n], 0, 0, 0);
}
}
}
a_global += 64 * unroll_m * subgroups_m * unroll_k;
b_global += 64 * unroll_n * subgroups_n * unroll_k;
}

// Store accumulators.
for (int m = 0; m < unroll_m; ++m) {
for (int n = 0; n < unroll_n; ++n) {
c_global[64 * (m * unroll_n + n)] = c[m][n];
}
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x4x16x2x8xi8>,

// CHECK-LABEL: @multi_mma_mfma_i32_16x16x32_i8
// CHECK: iree_gpu.multi_mma
// CHECK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4.gfx942.bc"
// CHECK-SAME: #hal.executable.object<{path = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8.gfx942.bc"
// CHECK-NOT: promote_operands
// CHECK-SAME: reduction = [0, 0, 0]
// CHECK-SAME: #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8_unroll8x2x2_subgroups1x4"
// CHECK-SAME: #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
Expand All @@ -33,12 +34,8 @@ namespace {
static FailureOr<IREE::Codegen::UKernelOpInterface>
matchArgmaxDAGForUKernel(RewriterBase &rewriter, linalg::GenericOp op) {
Value input = op.getDpsInputOperand(0)->get();
auto inputType = cast<ShapedType>(input.getType());
Value index = op.getDpsInitOperand(1)->get();
auto indexType = cast<ShapedType>(index.getType());
std::string suffix;
llvm::raw_string_ostream(suffix)
<< inputType.getElementType() << indexType.getElementType();
auto loweringConfig = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
if (!loweringConfig) {
return rewriter.notifyMatchFailure(op, "no lowering_config on this op");
Expand Down Expand Up @@ -84,6 +81,50 @@ struct LowerArgmaxToUKernelPattern : OpRewritePattern<linalg::GenericOp> {
}
};

struct LowerMultiMmaToUKernelPattern : OpRewritePattern<IREE::GPU::MultiMmaOp> {
LowerMultiMmaToUKernelPattern(MLIRContext *context)
: OpRewritePattern<IREE::GPU::MultiMmaOp>(context) {}

LogicalResult matchAndRewrite(IREE::GPU::MultiMmaOp op,
PatternRewriter &rewriter) const override {
auto loweringConfig = getLoweringConfig<IREE::GPU::LoweringConfigAttr>(op);
if (!loweringConfig) {
return rewriter.notifyMatchFailure(op, "no lowering_config on this op");
}
IREE::GPU::UKernelConfigAttr ukernelAttr =
IREE::GPU::getUkernelSpec(loweringConfig);
if (!ukernelAttr) {
return rewriter.notifyMatchFailure(op, "no ukernel selected for this op");
}
auto mma = dyn_cast<IREE::GPU::DataTiledMMAAttr>(op.getKind());
if (!mma) {
return rewriter.notifyMatchFailure(op, "unhandled MMAInterfaceAttr");
}
auto castIndexToI32 = [&](Value val) {
return rewriter.create<arith::IndexCastOp>(op.getLoc(),
rewriter.getI32Type(), val);
};
auto constI32 = [&](int val) {
return rewriter.create<arith::ConstantIntOp>(op.getLoc(), val,
rewriter.getI32Type());
};
Value k = castIndexToI32(
rewriter.create<tensor::DimOp>(op.getLoc(), op.getLhs(), 1));
Value unrollM = constI32(mma.getUnrollM());
Value subgroupsM = constI32(mma.getSubgroupsM());
Value unrollN = constI32(mma.getUnrollN());
Value subgroupsN = constI32(mma.getSubgroupsN());
Value unrollK = constI32(mma.getUnrollK());
rewriter.replaceOpWithNewOp<IREE::Codegen::UKernelGenericOp>(
op, TypeRange{op.getAccType()}, ukernelAttr.getName(),
ValueRange{op.getLhs(), op.getRhs()}, op.getAcc(),
ValueRange{k, unrollM, subgroupsM, unrollN, subgroupsN, unrollK},
ukernelAttr.getDefAttrs(),
/*strided_outer_dims=*/rewriter.getIndexAttr(0));
return success();
}
};

struct GPULowerToUKernelsPass final
: impl::GPULowerToUKernelsPassBase<GPULowerToUKernelsPass> {
void runOnOperation() override {
Expand All @@ -101,7 +142,8 @@ struct GPULowerToUKernelsPass final
// evidence that it is difficult for codegen to consistently approach
// microkernels performance, and that consideration overrides the benefit of
// fusions for these ops.
patterns.insert<LowerArgmaxToUKernelPattern>(context);
patterns.add<LowerArgmaxToUKernelPattern, LowerMultiMmaToUKernelPattern>(
context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def GPULowerToUKernelsPass :
let dependentDialects = [
"::mlir::iree_compiler::IREE::Codegen::IREECodegenDialect",
"::mlir::iree_compiler::IREE::GPU::IREEGPUDialect",
"::mlir::arith::ArithDialect",
"::mlir::tensor::TensorDialect",
];
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-lower-to-ukernels,cse,canonicalize))" %s | FileCheck %s

#config = #iree_gpu.lowering_config<{ukernel = #iree_gpu.ukernel_config<name = "some_ukernel", def_attrs = {vm.import.module = "rocm"}>}>
func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
} {
func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> {
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant 0xFF800000 : f32
%0 = tensor.empty() : tensor<1xi64>
Expand Down Expand Up @@ -42,9 +40,7 @@ func.func @argmax_f32i64_with_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tenso

// -----

func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> attributes {
hal.executable.target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {ukernels = "all"}>
} {
func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> tensor<1xi64> {
%c0_i64 = arith.constant 0 : i64
%cst = arith.constant 0xFF800000 : f32
%0 = tensor.empty() : tensor<1xi64>
Expand All @@ -70,3 +66,27 @@ func.func @argmax_f32i64_without_selected_ukernel(%arg0 : tensor<1x?xf32>) -> te
//CHECK-LABEL: func @argmax_f32i64_without_selected_ukernel(
// CHECK-NOT: iree_codegen.ukernel.generic
// CHECK: linalg.generic

// -----

func.func @multi_mma_mfma_i32_16x16x32_i8(%a : tensor<1x2x8x1x1x2x8xi8>, %b : tensor<1x2x1x2x1x1x2x8xi8>, %c : tensor<1x1x1x8x2x1x1x4xi32>) -> tensor<1x1x1x8x2x1x1x4xi32> {
%d = iree_gpu.multi_mma %a, %b, %c {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = [#iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<parallel>, #iree_gpu.iterator_type<reduction>],
kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_I32_16x16x32_I8, unroll_m = 8, unroll_n = 2, subgroups_n = 4, unroll_k = 2>,
lowering_config = #iree_gpu.lowering_config<{
reduction = [0, 0, 0],
ukernel = #iree_gpu.ukernel_config<name = "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8", def_attrs = {vm.import.module = "rocm"}>,
workgroup = [1, 1, 0]}>
} : tensor<1x2x8x1x1x2x8xi8>, tensor<1x2x1x2x1x1x2x8xi8> into tensor<1x1x1x8x2x1x1x4xi32>
return %d : tensor<1x1x1x8x2x1x1x4xi32>
}

// CHECK-LABEL: func @multi_mma_mfma_i32_16x16x32_i8(
// CHECK-DAG: %c2_i32 = arith.constant 2 : i32
// CHECK-DAG: %c8_i32 = arith.constant 8 : i32
// CHECK-DAG: %c1_i32 = arith.constant 1 : i32
// CHECK-DAG: %c4_i32 = arith.constant 4 : i32
// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic
// CHECK-SAME: "iree_uk_amdgpu_multi_mma_mfma_i32_16x16x32_i8"
// CHECK-SAME: (%c2_i32, %c8_i32, %c1_i32, %c2_i32, %c4_i32, %c2_i32 : i32, i32, i32, i32, i32, i32)
Original file line number Diff line number Diff line change
Expand Up @@ -620,16 +620,11 @@ distributeMultiMmaOp(RewriterBase &rewriter, IREE::GPU::MultiMmaOp mmaOp,
accStrides);

// Step 3. Create the new multi_mma op.
auto newKind = mmaOp.getKind();
if (auto dataTiledMma = dyn_cast<DataTiledMMAAttr>(newKind)) {
newKind = DataTiledMMAAttr::get(
context, dataTiledMma.getIntrinsic(), dataTiledMma.getUnrollM(),
/*subgroups_m=*/1, dataTiledMma.getUnrollN(),
/*subgroups_n=*/1, dataTiledMma.getUnrollK());
}
auto newMmaOp = rewriter.create<IREE::GPU::MultiMmaOp>(
loc, lhsSlice, rhsSlice, accSlice, mmaOp.getIndexingMaps(),
mmaOp.getIteratorTypes(), newKind);
mmaOp.getIteratorTypes(), mmaOp.getKind());

newMmaOp->setDiscardableAttrs(mmaOp->getDiscardableAttrDictionary());

// Step 4. Insert the result of the multi_mma using the same offsets/sizes as
// the accumulator slice.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups(%lhs: tensor<
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]]
// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, unroll_k = 4>}
// CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout<intrinsic = MFMA_F32_16x16x4_F32, subgroups_m = 2, subgroups_n = 2, unroll_k = 4>}
// CHECK-SAME: : tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x1x1x4xf32> into tensor<1x1x1x1x1x1x4xf32>
// CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]]
// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,9 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
}
funcPassManager.addPass(IREE::GPU::createDistributeMmaToLanesPass());

// Step 4.5. Things that need to happen right after distribution to threads.
funcPassManager.addPass(createGPULowerToUKernelsPass());

// Normalize loop bounds for later lowerings.
funcPassManager.addPass(iree_compiler::createNormalizeLoopBoundsPass(
NormalizeLoopBoundsPassOptions{/*normalizeFor=*/false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,8 @@ getUKernelNameAndSuffixForMultiMma(IREE::GPU::MultiMmaOp op) {
if (!mma) {
return {}; // Only handling DataTiledMMAAttr for now.
}
std::string suffix{
stringifyMMAIntrinsic(mma.getIntrinsic().getValue()).lower()};
if (mma.getUnrollM() != 1 || mma.getUnrollN() != 1 || mma.getUnrollK() != 1) {
suffix += llvm::formatv("_unroll{}x{}x{}", mma.getUnrollM(),
mma.getUnrollN(), mma.getUnrollK());
}
if (mma.getSubgroupsM() != 1 || mma.getSubgroupsN() != 1) {
suffix += llvm::formatv("_subgroups{}x{}", mma.getSubgroupsM(),
mma.getSubgroupsN());
}
return {"multi_mma", suffix};
return {"multi_mma",
stringifyMMAIntrinsic(mma.getIntrinsic().getValue()).lower()};
}

// Returns ukernel name and suffix for any op. Empty name = no ukernel.
Expand Down
Loading

0 comments on commit e23f5a2

Please sign in to comment.