From 9c9d5f756b3e0d7d3943dfc210de29dfc62c8319 Mon Sep 17 00:00:00 2001 From: Guo-Peilin Date: Tue, 31 Oct 2023 11:40:59 +0800 Subject: [PATCH] add more ops for MLIR-based end-to-end GPU Tensor Core GEMM codegen (#1260) --- .../data/matmul_nn_s_f16_gpu_schedule_1.mlir | 78 + .../disc-transform/default_schedule_matmul.cc | 1 - .../mlir/disc/tools/disc-transform/BUILD | 2 + .../TransformOps/GPUPipeline.cc | 785 ++++++++++ .../TransformOps/OptimizeShareMemory.cc | 268 ++++ .../TransformOps/TransformOpsExt.cc | 1337 +++++++++++++++++ .../TransformOps/TransformOpsExt.h | 9 + .../TransformOps/TransformOpsExt.td | 333 +++- .../disc_lower_gpu_ops_to_nvvm_ops.cc | 44 + .../transforms/disc_transform_schedule.cc | 13 +- .../transforms/revise_kernel_outlining.cc | 25 +- 11 files changed, 2882 insertions(+), 13 deletions(-) create mode 100644 tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule_1.mlir create mode 100644 tao_compiler/mlir/disc/tools/disc-transform/TransformOps/GPUPipeline.cc create mode 100644 tao_compiler/mlir/disc/tools/disc-transform/TransformOps/OptimizeShareMemory.cc diff --git a/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule_1.mlir b/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule_1.mlir new file mode 100644 index 00000000000..443918164f6 --- /dev/null +++ b/tao_compiler/mlir/disc/tests/disc-transform/data/matmul_nn_s_f16_gpu_schedule_1.mlir @@ -0,0 +1,78 @@ +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match attributes {disc.transform.name = "dot_general"} in %arg0 : (!transform.any_op) -> !transform.any_op + %1:2 = split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %forall_op, %tiled_op = transform.structured.tile_to_forall_op %1#1 num_threads [] tile_sizes [128, 128](mapping = [#gpu.block, #gpu.block]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %fused_op, %new_containing_op = transform.structured.fuse_into_containing_op %1#0 into %forall_op : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + %padding_mn = transform.disc.padding_mn %tiled_op padding_values [0.0:f16, 0.0:f16, 0.0:f16] tile_sizes [128, 128] : (!transform.any_op) -> (!transform.any_op) + %for_op, %splitted_op = transform.disc.split_reduction_serial %padding_mn by tile_sizes = [32] loop_type = "cta-k-loop" : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %padding_k = transform.disc.padding_k %for_op padding_values [0.0:f16, 0.0:f16] tile_sizes [32] : (!transform.any_op) -> (!transform.any_op) + transform.disc.apply_dce %arg0 : !transform.any_op + transform.disc.apply_cse %arg0 : !transform.any_op + %promoted_dot, %lhs_alloc, %rhs_alloc = transform.disc.promote_dot_operands %padding_k [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %forall_op_0, %tiled_op_1 = transform.structured.tile_to_forall_op %promoted_dot num_threads [] tile_sizes [64, 64](mapping = [#gpu.warp, #gpu.warp]) : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %for_op_2, %splitted_op_3 = transform.disc.split_reduction_serial %tiled_op_1 by tile_sizes = [32] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %tiled_linalg_op, %loops:3 = transform.structured.tile %for_op_2[16, 8, 16] {interchange = [0, 1, 2]} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.disc.apply_licm %arg0 : !transform.any_op + transform.disc.apply_dce %arg0 : !transform.any_op + transform.disc.apply_cse %arg0 : !transform.any_op + %2 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %3 = transform.disc.apply_patterns %2 {canonicalization} : (!transform.any_op) -> !transform.any_op + %4 = transform.structured.vectorize %3 {vectorize_padding} : (!transform.any_op) -> !transform.any_op + %func1 = transform.structured.match ops{["func.func"]} in %4 : (!transform.any_op) -> !transform.any_op + transform.disc.swap_alloc_tensor %func1 : (!transform.any_op) -> () + %5 = transform.disc.bufferize {target_gpu} %arg0 : (!transform.any_op) -> !transform.any_op + %6 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.erase_dealloc %6 : (!transform.any_op) -> () + transform.disc.apply_dce %5 : !transform.any_op + transform.disc.apply_cse %5 : !transform.any_op + %8 = transform.structured.match ops{["scf.forall"]} attributes {mapping = [#gpu.block, #gpu.block]} in %5 : (!transform.any_op) -> !transform.any_op + %9 = transform.disc.forall_to_gpu_ctas %8 : (!transform.any_op) -> !transform.any_op + %10 = transform.structured.match ops{["scf.forall"]} attributes {mapping = [#gpu.warp, #gpu.warp]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.forall_to_gpu_warps %10 : (!transform.any_op) -> () + transform.disc.apply_dce %5 : !transform.any_op + transform.disc.apply_cse %5 : !transform.any_op + %12 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.vector.vector_to_mma_conversion %12 : (!transform.any_op) -> () + transform.disc.apply_dce %5 : !transform.any_op + transform.disc.apply_cse %5 : !transform.any_op + // 1. use register to cache the result of ldmatrix + // 2. use register to cache the result of mma's accumulation result + // 3. store the final result from reg to smem and to gmem + // 4. use padding for output smem matrix to avoid bank conflict` + %mma = transform.structured.match ops{["nvgpu.mma.sync"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.move_data_to_register %mma by block_mn_shape = [128, 128] smem_padding = 8 : (!transform.any_op) -> () + transform.disc.apply_licm %5 : !transform.any_op + transform.disc.apply_dce %5 : !transform.any_op + transform.disc.apply_cse %5 : !transform.any_op + // use cp.asys to load matrix A and B from gmem to smem + %transfer_write = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.expand_transfer_rw_to_memref_copy %transfer_write : (!transform.any_op) -> () + // swizzle the access of input matrix, + // including from gmem to smem by cp.async and from smem to reg by ldmatrix + %swizzle = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.swizzle_smem %swizzle : (!transform.any_op) -> () + // multi buffering for software pipeline + %multi_buffering = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.multi_buffering %multi_buffering by multi_buffering_factor = 2 : (!transform.any_op) -> () + // reuse smem for input and output matrix + %pack_smem = transform.structured.match ops{["scf.parallel"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.pack_smem %pack_smem : (!transform.any_op) -> () + // manually lowering nvgpu's DeviceAsyncCreateGroupOp and DeviceAsyncWaitOp to NVVM's correspondingly, + // so that DeviceAsyncToken no longer cta-k-loop's loop carried variable, + // which is easier for further software pipeline + %14 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.convert_nvgpu_async_cp_to_nvvm_async_cp %14 : (!transform.any_op) -> () + // software pipeline + %pipeline = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.gpu_software_pipeline %pipeline by depth = 2: (!transform.any_op) -> () + transform.disc.apply_licm %5 : !transform.any_op + transform.disc.apply_dce %5 : !transform.any_op + transform.disc.apply_cse %5 : !transform.any_op + %13 = transform.structured.match ops{["func.func"]} in %5 : (!transform.any_op) -> !transform.any_op + transform.disc.inline_and_convert_gpu_ids %13 : (!transform.any_op) -> () + transform.disc.apply_licm %5 : !transform.any_op + transform.disc.apply_dce %5 : !transform.any_op + transform.disc.apply_cse %5 : !transform.any_op + %canonicalization1 = transform.disc.apply_patterns %5 {canonicalization} : (!transform.any_op) -> !transform.any_op +} diff --git a/tao_compiler/mlir/disc/tests/disc-transform/default_schedule_matmul.cc b/tao_compiler/mlir/disc/tests/disc-transform/default_schedule_matmul.cc index 2f6f613abf2..247199b5ad4 100644 --- a/tao_compiler/mlir/disc/tests/disc-transform/default_schedule_matmul.cc +++ b/tao_compiler/mlir/disc/tests/disc-transform/default_schedule_matmul.cc @@ -287,5 +287,4 @@ TEST(Matmul, F16_256x256x128_Using_Default_Schedule) { /*expected_output_vals*/ {}, /*profiling*/ true)); } - } // namespace mlir_test diff --git a/tao_compiler/mlir/disc/tools/disc-transform/BUILD b/tao_compiler/mlir/disc/tools/disc-transform/BUILD index aa73e1888f0..d63058f44e6 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/BUILD +++ b/tao_compiler/mlir/disc/tools/disc-transform/BUILD @@ -217,6 +217,8 @@ cc_library( "TransformOps/TransformOpsExt.cc", "TransformOps/TransformOpsExt.cc.inc", "TransformOps/TransformOpsExt.h.inc", + "TransformOps/GPUPipeline.cc", + "TransformOps/OptimizeShareMemory.cc", ], hdrs = [ "TransformOps/TransformOpsExt.h", diff --git a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/GPUPipeline.cc b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/GPUPipeline.cc new file mode 100644 index 00000000000..33d2bccc352 --- /dev/null +++ b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/GPUPipeline.cc @@ -0,0 +1,785 @@ +// Copyright 2023 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/SCF/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/MathExtras.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.h" +#include "mlir/disc/tools/disc-transform/utils.h" + +//====---------------------------------------------------------------------===// +// Pass to pipeline copy to shared memory for matmul op. +//====---------------------------------------------------------------------===// + +namespace mlir { +namespace disc_ral { +namespace transform_dialect { + +using namespace mlir; +using namespace mlir::scf; + +namespace { +/// Helper to keep internal information during pipelining transformation. +struct LoopPipelinerInternal { + /// Coarse liverange information for ops used across stages. + struct LiverangeInfo { + unsigned lastUseStage = 0; + unsigned defStage = 0; + }; + + protected: + ForOp forOp; + unsigned maxStage = 0; + DenseMap stages; + std::vector opOrder; + Value upperBound; + int64_t ub; + int64_t lb; + int64_t step; + PipeliningOption::AnnotationlFnType annotateFn = nullptr; + bool peelEpilogue; + PipeliningOption::PredicateOpFn predicateFn = nullptr; + + // When peeling the kernel we generate several version of each value for + // different stage of the prologue. This map tracks the mapping between + // original Values in the loop and the different versions + // peeled from the loop. + DenseMap> valueMapping; + + /// Assign a value to `valueMapping`, this means `val` represents the version + /// `idx` of `key` in the epilogue. + void setValueMapping(Value key, Value el, int64_t idx); + + public: + /// Initalize the information for the given `op`, return true if it + /// satisfies the pre-condition to apply pipelining. + bool initializeLoopInfo(ForOp op, const PipeliningOption& options); + /// Emits the prologue, this creates `maxStage - 1` part which will contain + /// operations from stages [0; i], where i is the part index. + void emitPrologue(RewriterBase& rewriter); + /// Gather liverange information for Values that are used in a different stage + /// than its definition. + llvm::MapVector analyzeCrossStageValues(); + scf::ForOp createKernelLoop( + const llvm::MapVector& crossStageValues, + RewriterBase& rewriter, + llvm::DenseMap, unsigned>& loopArgMap); + /// Emits the pipelined kernel. This clones loop operations following user + /// order and remaps operands defined in a different stage as their use. + LogicalResult createKernel( + scf::ForOp newForOp, + const llvm::MapVector& crossStageValues, + const llvm::DenseMap, unsigned>& loopArgMap, + RewriterBase& rewriter); + /// Emits the epilogue, this creates `maxStage - 1` part which will contain + /// operations from stages [i; maxStage], where i is the part index. + llvm::SmallVector emitEpilogue(RewriterBase& rewriter); +}; + +bool LoopPipelinerInternal::initializeLoopInfo( + ForOp op, const PipeliningOption& options) { + forOp = op; + upperBound = forOp.getUpperBound(); + auto upperBoundCst = upperBound.getDefiningOp(); + if (upperBoundCst) { + ub = upperBoundCst.value(); + } + // only require lowerBound and step to be constant + auto lowerBoundCst = + forOp.getLowerBound().getDefiningOp(); + if (!lowerBoundCst) return false; + lb = lowerBoundCst.value(); + auto stepCst = forOp.getStep().getDefiningOp(); + if (!stepCst) return false; + step = stepCst.value(); + peelEpilogue = options.peelEpilogue; + predicateFn = options.predicateFn; + if (!peelEpilogue && predicateFn == nullptr) return false; + std::vector> schedule; + options.getScheduleFn(forOp, schedule); + if (schedule.empty()) return false; + + // Note: user need to assure that loop's iteration + // must greater than maxStage + opOrder.reserve(schedule.size()); + for (auto& opSchedule : schedule) { + maxStage = std::max(maxStage, opSchedule.second); + stages[opSchedule.first] = opSchedule.second; + opOrder.push_back(opSchedule.first); + } + + // All operations need to have a stage. + for (Operation& op : forOp.getBody()->without_terminator()) { + if (stages.find(&op) == stages.end()) { + op.emitOpError("not assigned a pipeline stage"); + return false; + } + } + + // Currently, we do not support assigning stages to ops in nested regions. The + // block of all operations assigned a stage should be the single `scf.for` + // body block. + for (const auto& [op, stageNum] : stages) { + (void)stageNum; + if (op == forOp.getBody()->getTerminator()) { + op->emitError("terminator should not be assigned a stage"); + return false; + } + if (op->getBlock() != forOp.getBody()) { + op->emitOpError( + "the owning Block of all operations assigned a stage " + "should be the loop body block"); + return false; + } + } + + // Only support loop carried dependency with a distance of 1. This means the + // source of all the scf.yield operands needs to be defined by operations in + // the loop. + if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(), + [this](Value operand) { + Operation* def = operand.getDefiningOp(); + return !def || stages.find(def) == stages.end(); + })) + return false; + annotateFn = options.annotateFn; + return true; +} + +/// Clone `op` and call `callback` on the cloned op's oeprands as well as any +/// operands of nested ops that: +/// 1) aren't defined within the new op or +/// 2) are block arguments. +static Operation* cloneAndUpdateOperands( + RewriterBase& rewriter, Operation* op, + function_ref callback) { + Operation* clone = rewriter.clone(*op); + for (OpOperand& operand : clone->getOpOperands()) callback(&operand); + clone->walk([&](Operation* nested) { + for (OpOperand& operand : nested->getOpOperands()) { + Operation* def = operand.get().getDefiningOp(); + if ((def && !clone->isAncestor(def)) || + operand.get().isa()) + callback(&operand); + } + }); + return clone; +} + +void LoopPipelinerInternal::emitPrologue(RewriterBase& rewriter) { + // Initialize the iteration argument to the loop initiale values. + for (BlockArgument& arg : forOp.getRegionIterArgs()) { + OpOperand& operand = forOp.getOpOperandForRegionIterArg(arg); + setValueMapping(arg, operand.get(), 0); + } + auto yield = cast(forOp.getBody()->getTerminator()); + for (int64_t i = 0; i < maxStage; i++) { + // special handling for induction variable as the increment is implicit. + Value iv = + rewriter.create(forOp.getLoc(), lb + i * step); + setValueMapping(forOp.getInductionVar(), iv, i); + for (Operation* op : opOrder) { + if (stages[op] > i) continue; + Operation* newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand* newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[i - stages[op]]; + newOperand->set(replacement); + } + }); + if (annotateFn) + annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + i - stages[op]); + // If the value is a loop carried dependency update the loop argument + // mapping. + for (OpOperand& operand : yield->getOpOperands()) { + if (operand.get() != op->getResult(destId)) continue; + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), i - stages[op] + 1); + } + } + } + } +} + +llvm::MapVector +LoopPipelinerInternal::analyzeCrossStageValues() { + llvm::MapVector crossStageValues; + for (Operation* op : opOrder) { + unsigned stage = stages[op]; + + auto analyzeOperand = [&](OpOperand& operand) { + Operation* def = operand.get().getDefiningOp(); + if (!def) return; + auto defStage = stages.find(def); + if (defStage == stages.end() || defStage->second == stage) return; + assert(stage > defStage->second); + LiverangeInfo& info = crossStageValues[operand.get()]; + info.defStage = defStage->second; + info.lastUseStage = std::max(info.lastUseStage, stage); + }; + + for (OpOperand& operand : op->getOpOperands()) analyzeOperand(operand); + visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) { + analyzeOperand(*operand); + }); + } + return crossStageValues; +} + +scf::ForOp LoopPipelinerInternal::createKernelLoop( + const llvm::MapVector& + crossStageValues, + RewriterBase& rewriter, + llvm::DenseMap, unsigned>& loopArgMap) { + // Creates the list of initial values associated to values used across + // stages. The initial values come from the prologue created above. + // Keep track of the kernel argument associated to each version of the + // values passed to the kernel. + llvm::SmallVector newLoopArg; + // For existing loop argument initialize them with the right version from the + // prologue. + for (const auto& retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation* def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance 1"); + unsigned defStage = stages[def]; + Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]] + [maxStage - defStage]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + } + for (auto escape : crossStageValues) { + LiverangeInfo& info = escape.second; + Value value = escape.first; + for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage; + stageIdx++) { + Value valueVersion = + valueMapping[value][maxStage - info.lastUseStage + stageIdx]; + assert(valueVersion); + newLoopArg.push_back(valueVersion); + loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage - + stageIdx)] = newLoopArg.size() - 1; + } + } + + // Create the new kernel loop. When we peel the epilgue we need to peel + // `numStages - 1` iterations. Then we adjust the upper bound to remove those + // iterations. + Value newUb = forOp.getUpperBound(); + if (peelEpilogue) { + if (ub == 0) { + newUb = rewriter.create(forOp.getLoc(), + ub - maxStage * step); + } else { + newUb = rewriter.create( + forOp.getLoc(), upperBound, + rewriter.create(forOp.getLoc(), + maxStage * step)); + } + } + auto newForOp = + rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, + forOp.getStep(), newLoopArg); + // When there are no iter args, the loop body terminator will be created. + // Since we always create it below, remove the terminator if it was created. + if (!newForOp.getBody()->empty()) + rewriter.eraseOp(newForOp.getBody()->getTerminator()); + return newForOp; +} + +/// Replace any use of `target` with `replacement` in `op`'s operands or within +/// `op`'s nested regions. +static void replaceInOp(Operation* op, Value target, Value replacement) { + for (auto& use : llvm::make_early_inc_range(target.getUses())) { + if (op->isAncestor(use.getOwner())) use.set(replacement); + } +} + +/// Given a cloned op in the new kernel body, updates induction variable uses. +/// We replace it with a version incremented based on the stage where it is +/// used. +static void updateInductionVariableUses(RewriterBase& rewriter, Location loc, + Operation* newOp, Value newForIv, + unsigned maxStage, unsigned useStage, + unsigned step) { + rewriter.setInsertionPoint(newOp); + Value offset = rewriter.create( + loc, (maxStage - useStage) * step); + Value iv = rewriter.create(loc, newForIv, offset); + replaceInOp(newOp, newForIv, iv); + rewriter.setInsertionPointAfter(newOp); +} + +/// If the value is a loop carried value coming from stage N + 1 remap, it will +/// become a direct use. +static void updateIterArgUses(RewriterBase& rewriter, IRMapping& bvm, + Operation* newOp, ForOp oldForOp, ForOp newForOp, + unsigned useStage, + const DenseMap& stages) { + for (unsigned i = 0; i < oldForOp.getNumRegionIterArgs(); i++) { + Value yieldedVal = oldForOp.getBody()->getTerminator()->getOperand(i); + Operation* dep = yieldedVal.getDefiningOp(); + if (!dep) continue; + auto stageDep = stages.find(dep); + if (stageDep == stages.end() || stageDep->second == useStage) continue; + if (stageDep->second != useStage + 1) continue; + Value replacement = bvm.lookup(yieldedVal); + replaceInOp(newOp, newForOp.getRegionIterArg(i), replacement); + } +} + +/// For operands defined in a previous stage we need to remap it to use the +/// correct region argument. We look for the right version of the Value based +/// on the stage where it is used. +static void updateCrossStageUses( + RewriterBase& rewriter, Operation* newOp, IRMapping& bvm, ForOp newForOp, + unsigned useStage, const DenseMap& stages, + const llvm::DenseMap, unsigned>& loopArgMap) { + // Because we automatically cloned the sub-regions, there's no simple way + // to walk the nested regions in pairs of (oldOps, newOps), so we just + // traverse the set of remapped loop arguments, filter which ones are + // relevant, and replace any uses. + for (auto [remapPair, newIterIdx] : loopArgMap) { + auto [crossArgValue, stageIdx] = remapPair; + Operation* def = crossArgValue.getDefiningOp(); + assert(def); + unsigned stageDef = stages.lookup(def); + if (useStage <= stageDef || useStage - stageDef != stageIdx) continue; + + // Use "lookupOrDefault" for the target value because some operations + // are remapped, while in other cases the original will be present. + Value target = bvm.lookupOrDefault(crossArgValue); + Value replacement = newForOp.getRegionIterArg(newIterIdx); + + // Replace uses in the new op's operands and any nested uses. + replaceInOp(newOp, target, replacement); + } +} + +LogicalResult LoopPipelinerInternal::createKernel( + scf::ForOp newForOp, + const llvm::MapVector& + crossStageValues, + const llvm::DenseMap, unsigned>& loopArgMap, + RewriterBase& rewriter) { + valueMapping.clear(); + + // Create the kernel, we clone instruction based on the order given by + // user and remap operands coming from a previous stages. + rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin()); + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (const auto& arg : llvm::enumerate(forOp.getRegionIterArgs())) { + mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); + } + SmallVector predicates(maxStage + 1, nullptr); + if (!peelEpilogue) { + // Create a predicate for each stage except the last stage. + for (unsigned i = 0; i < maxStage; i++) { + Value c; + if (ub == 0) { + c = rewriter.create(newForOp.getLoc(), + ub - (maxStage - i) * step); + } else { + c = rewriter.create( + newForOp.getLoc(), upperBound, + rewriter.create(newForOp.getLoc(), + (maxStage - i) * step)); + } + Value pred = rewriter.create( + newForOp.getLoc(), arith::CmpIPredicate::slt, + newForOp.getInductionVar(), c); + predicates[i] = pred; + } + } + for (Operation* op : opOrder) { + int64_t useStage = stages[op]; + auto* newOp = rewriter.clone(*op, mapping); + // Within the kernel body, update uses of the induction variable, uses of + // the original iter args, and uses of cross stage values. + updateInductionVariableUses(rewriter, forOp.getLoc(), newOp, + newForOp.getInductionVar(), maxStage, + stages[op], step); + updateIterArgUses(rewriter, mapping, newOp, forOp, newForOp, useStage, + stages); + updateCrossStageUses(rewriter, newOp, mapping, newForOp, useStage, stages, + loopArgMap); + + if (predicates[useStage]) { + newOp = predicateFn(rewriter, newOp, predicates[useStage]); + if (!newOp) return failure(); + // Remap the results to the new predicated one. + for (auto values : llvm::zip(op->getResults(), newOp->getResults())) + mapping.map(std::get<0>(values), std::get<1>(values)); + } + rewriter.setInsertionPointAfter(newOp); + if (annotateFn) + annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0); + } + + // Collect the Values that need to be returned by the forOp. For each + // value we need to have `LastUseStage - DefStage` number of versions + // returned. + // We create a mapping between original values and the associated loop + // returned values that will be needed by the epilogue. + llvm::SmallVector yieldOperands; + for (Value retVal : forOp.getBody()->getTerminator()->getOperands()) { + yieldOperands.push_back(mapping.lookupOrDefault(retVal)); + } + for (auto& it : crossStageValues) { + int64_t version = maxStage - it.second.lastUseStage + 1; + unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage; + // add the original version to yield ops. + // If there is a live range spanning across more than 2 stages we need to + // add extra arg. + for (unsigned i = 1; i < numVersionReturned; i++) { + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back( + newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + + newForOp.getNumInductionVars()]); + } + setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), + version++); + yieldOperands.push_back(mapping.lookupOrDefault(it.first)); + } + // Map the yield operand to the forOp returned value. + for (const auto& retVal : + llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) { + Operation* def = retVal.value().getDefiningOp(); + assert(def && "Only support loop carried dependencies of distance 1"); + unsigned defStage = stages[def]; + setValueMapping(forOp.getRegionIterArgs()[retVal.index()], + newForOp->getResult(retVal.index()), + maxStage - defStage + 1); + } + rewriter.create(forOp.getLoc(), yieldOperands); + return success(); +} + +llvm::SmallVector LoopPipelinerInternal::emitEpilogue( + RewriterBase& rewriter) { + llvm::SmallVector returnValues(forOp->getNumResults()); + // Emit different versions of the induction variable. They will be + // removed by dead code if not used. + for (int64_t i = 0; i < maxStage; i++) { + Value newlastIter; + if (ub == 0) { + newlastIter = rewriter.create( + forOp.getLoc(), lb + step * ((((ub - 1) - lb) / step) - i)); + } else { + AffineExpr d0; + auto ctx = rewriter.getContext(); + bindDims(ctx, d0); + auto map = AffineMap::get( + 1, 0, {lb + step * ((((d0 - 1) - lb).floorDiv(step)) - i)}, ctx); + newlastIter = rewriter.create(forOp.getLoc(), map, + upperBound); + } + setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i); + } + // Emit `maxStage - 1` epilogue part that includes operations from stages + // [i; maxStage]. + for (int64_t i = 1; i <= maxStage; i++) { + for (Operation* op : opOrder) { + if (stages[op] < i) continue; + Operation* newOp = + cloneAndUpdateOperands(rewriter, op, [&](OpOperand* newOperand) { + auto it = valueMapping.find(newOperand->get()); + if (it != valueMapping.end()) { + Value replacement = it->second[maxStage - stages[op] + i]; + newOperand->set(replacement); + } + }); + if (annotateFn) + annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1); + for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) { + setValueMapping(op->getResult(destId), newOp->getResult(destId), + maxStage - stages[op] + i); + // If the value is a loop carried dependency update the loop argument + // mapping and keep track of the last version to replace the original + // forOp uses. + for (OpOperand& operand : + forOp.getBody()->getTerminator()->getOpOperands()) { + if (operand.get() != op->getResult(destId)) continue; + unsigned version = maxStage - stages[op] + i + 1; + // If the version is greater than maxStage it means it maps to the + // original forOp returned value. + if (version > maxStage) { + returnValues[operand.getOperandNumber()] = newOp->getResult(destId); + continue; + } + setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], + newOp->getResult(destId), version); + } + } + } + } + return returnValues; +} + +void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) { + auto it = valueMapping.find(key); + // If the value is not in the map yet add a vector big enough to store all + // versions. + if (it == valueMapping.end()) + it = + valueMapping + .insert(std::make_pair(key, llvm::SmallVector(maxStage + 1))) + .first; + it->second[idx] = el; +} + +} // namespace + +/// Populate `ops` with the set of operations that belong to the stage 0 of the +/// pipelined version of the given loop when pipelining copies to shared memory. +/// Specifically, this collects: +/// +/// 1. all loads from global memory, both sync and async; +/// 2. the barriers for async loads. +/// +/// In particular, barriers are omitted if they do not dominate at least one +/// async load for which there is not yet a barrier. +static LogicalResult collectStage0PipeliningOps( + scf::ForOp forOp, llvm::SmallPtrSet& ops) { + llvm::SmallPtrSet barriers; + for (Operation& op : *forOp.getBody()) { + if (isa(op)) { + barriers.insert(&op); + continue; + } + + if (isa(op)) { + ops.insert(&op); + ops.insert(std::make_move_iterator(barriers.begin()), + std::make_move_iterator(barriers.end())); + assert(barriers.empty() && + "expected to have moved the barriers into another set"); + continue; + } + } + + return success(); +} + +/// Hook for the loop pipeliner that sets the "num groups in flight" attribute +/// of async wait operations corresponding to pipelined shared memory copies. +// TODO: this currently assumes that there are no groups that could be in flight +// in the existing code. +static void setAsyncWaitGroupsInFlight( + OpBuilder& builder, Operation* op, + scf::PipeliningOption::PipelinerPart part, unsigned iteration, + unsigned depth) { + // Based on the order of copies within the loop we need to set the number + // of copies in flight, unless it is already set. + auto waitOp = dyn_cast(op); + if (!waitOp || waitOp.getN()) return; + + int numGroupInFlight = 0; + if (part == scf::PipeliningOption::PipelinerPart::Kernel || + part == scf::PipeliningOption::PipelinerPart::Prologue) { + numGroupInFlight = depth - 1; + } else { + // By construction there should be no wait op in the prologue as all the + // wait should be in the last stage. + assert(part == scf::PipeliningOption::PipelinerPart::Epilogue); + // Based on the schedule we pick we know how many groups are in flight for + // each iteration of the epilogue. + numGroupInFlight = depth - 1 - iteration; + } + waitOp.setN(numGroupInFlight); +} + +/// Hook for the loop pipeliner that populates `ops` with the stage information +/// as follows: +/// +/// - operations in `stage0Ops` (typically loads from global memory and +/// related barriers) are at stage 0; +/// - operations in the backward slice of any stage0Ops are all at stage 0; +/// - other operations are at stage `depth`; +/// - the internal order of the pipelined loop has ops at stage `depth` first, +/// then those at stage 0, with relative order within each group preserved. +/// +static void getPipelineStages( + scf::ForOp forOp, + std::vector>& opsWithPipelineStages, + unsigned depth, llvm::SmallPtrSetImpl& stage0Ops) { + SetVector dependencies; + BackwardSliceOptions options([&](Operation* visited) { + return visited->getBlock() == forOp.getBody(); + }); + options.inclusive = true; + for (Operation& op : forOp.getBody()->getOperations()) { + if (stage0Ops.contains(&op)) getBackwardSlice(&op, &dependencies, options); + } + + for (Operation& op : forOp.getBody()->getOperations()) { + if (!dependencies.contains(&op) && !isa(op)) + opsWithPipelineStages.emplace_back(&op, depth); + } + for (Operation& op : forOp.getBody()->getOperations()) { + if (dependencies.contains(&op)) opsWithPipelineStages.emplace_back(&op, 0); + } +} + +/// Hook for the loop pipeliner. Replaces op with a predicated version and +/// returns the resulting operation. Returns the original op if the predication +/// isn't necessary for the given op. Returns null if predication is needed but +/// not supported. +static Operation* replaceOpWithPredicatedOp(RewriterBase& rewriter, + Operation* op, Value predicate) { + // Some operations may be fine to execute "speculatively" more times than the + // original number of iterations, in particular side-effect free operations + // and barriers, even if they cannot be predicated. + if (isMemoryEffectFree(op) || + isa(op)) { + return op; + } + + // Otherwise, only async copies can currently be predicated. + auto asyncCopyOp = dyn_cast(op); + if (!asyncCopyOp) return nullptr; + + // Create srcElement Value based on `predicate`. The next lines generate + // the following code: + // + // srcElement = (pred) ? prevSrcElements : 0; + // + Location loc = asyncCopyOp->getLoc(); + Value dstElements = + rewriter.create(loc, asyncCopyOp.getDstElementsAttr()); + Value originalSrcElement = + asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements; + Value c0Index = rewriter.create(loc, 0); + auto srcElements = rewriter.create( + loc, predicate, originalSrcElement, c0Index); + auto asyncCopyZeroFillOp = rewriter.create( + loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), + asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), + asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, + UnitAttr()); + rewriter.replaceOp(asyncCopyOp, ValueRange{asyncCopyZeroFillOp}); + return asyncCopyZeroFillOp; +} + +/// Applies loop pipelining with the given depth to the given loop so that +/// copies into the shared memory are pipelined. Doesn't affect other loops. +/// Returns a pair containing the error state and the pipelined op, the latter +/// being null in case of any failure. The error state contains a definite error +/// if the IR has been modified and a silenceable error otherwise. +std::tuple applyPipelining( + scf::ForOp forOp, int64_t depth, bool epiloguePeeling) { + llvm::SmallPtrSet stage0Ops; + if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) { + return std::make_tuple( + emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"), + scf::ForOp()); + } + if (stage0Ops.empty()) { + return std::make_tuple( + emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp()); + } + + IRRewriter rewriter(forOp->getContext()); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(forOp); + scf::PipeliningOption options; + unsigned maxDepth = depth; + auto setAnnotation = [&](Operation* op, + scf::PipeliningOption::PipelinerPart part, + unsigned iteration) { + return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth); + }; + options.getScheduleFn = + [&](scf::ForOp schedulingFor, + std::vector>& ops) { + if (schedulingFor != forOp) return; + return getPipelineStages(forOp, ops, maxDepth, stage0Ops); + }; + options.annotateFn = setAnnotation; + if (!epiloguePeeling) { + options.peelEpilogue = false; + options.predicateFn = replaceOpWithPredicatedOp; + } + + LoopPipelinerInternal pipeliner; + if (!pipeliner.initializeLoopInfo(forOp, options)) { + return std::make_tuple( + emitSilenceableFailure(forOp, "failed to initialize loop info"), + scf::ForOp()); + } + + // 1. Emit prologue. + pipeliner.emitPrologue(rewriter); + + // 2. Track values used across stages. When a value cross stages it will + // need to be passed as loop iteration arguments. + // We first collect the values that are used in a different stage than where + // they are defined. + llvm::MapVector + crossStageValues = pipeliner.analyzeCrossStageValues(); + + // Mapping between original loop values used cross stage and the block + // arguments associated after pipelining. A Value may map to several + // arguments if its liverange spans across more than 2 stages. + llvm::DenseMap, unsigned> loopArgMap; + // 3. Create the new kernel loop and return the block arguments mapping. + ForOp newForOp = + pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap); + // Create the kernel block, order ops based on user choice and remap + // operands. + if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap, + rewriter))) + return std::make_tuple( + emitSilenceableFailure(forOp, "pipeliner failed to create kernel"), + scf::ForOp()); + llvm::SmallVector returnValues = + newForOp.getResults().take_front(forOp->getNumResults()); + if (options.peelEpilogue) { + // 4. Emit the epilogue after the new forOp. + rewriter.setInsertionPointAfter(newForOp); + returnValues = pipeliner.emitEpilogue(rewriter); + } + // 5. Erase the original loop and replace the uses with the epilogue output. + if (forOp->getNumResults() > 0) + rewriter.replaceOp(forOp, returnValues); + else + rewriter.eraseOp(forOp); + return std::make_tuple(DiagnosedSilenceableFailure::success(), newForOp); +} + +} // namespace transform_dialect +} // namespace disc_ral +} // namespace mlir diff --git a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/OptimizeShareMemory.cc b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/OptimizeShareMemory.cc new file mode 100644 index 00000000000..708d06cd497 --- /dev/null +++ b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/OptimizeShareMemory.cc @@ -0,0 +1,268 @@ +// Copyright 2023 The BladeDISC Authors. All rights reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transforms to optimize accesses to shared memory. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/NVGPU/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.h" +#include "mlir/disc/tools/disc-transform/utils.h" + +namespace mlir { +namespace disc_ral { +namespace transform_dialect { + +using namespace mlir; +using namespace mlir::nvgpu; + +/// The size of a shared memory line according to NV documentation. +constexpr int64_t kSharedMemoryLineSizeBytes = 128; +/// We optimize for 128bit accesses, but this can be made an argument in the +/// future. +constexpr int64_t kDefaultVectorSizeBits = 128; + +static Operation::operand_range getIndices(Operation* op) { + if (auto ldmatrixOp = dyn_cast(op)) + return ldmatrixOp.getIndices(); + if (auto copyOp = dyn_cast(op)) + return copyOp.getDstIndices(); + if (auto loadOp = dyn_cast(op)) return loadOp.getIndices(); + if (auto storeOp = dyn_cast(op)) return storeOp.getIndices(); + if (auto vectorReadOp = dyn_cast(op)) + return vectorReadOp.getIndices(); + if (auto vectorStoreOp = dyn_cast(op)) + return vectorStoreOp.getIndices(); + if (auto transferReadOp = dyn_cast(op)) + return transferReadOp.getIndices(); + if (auto transferWriteOp = dyn_cast(op)) + return transferWriteOp.getIndices(); + llvm_unreachable("unsupported op type"); +} + +static void setIndices(Operation* op, ArrayRef indices) { + if (auto ldmatrixOp = dyn_cast(op)) + return ldmatrixOp.getIndicesMutable().assign(indices); + if (auto copyOp = dyn_cast(op)) + return copyOp.getDstIndicesMutable().assign(indices); + if (auto loadOp = dyn_cast(op)) + return loadOp.getIndicesMutable().assign(indices); + if (auto storeOp = dyn_cast(op)) + return storeOp.getIndicesMutable().assign(indices); + if (auto vectorReadOp = dyn_cast(op)) + return vectorReadOp.getIndicesMutable().assign(indices); + if (auto vectorStoreOp = dyn_cast(op)) + return vectorStoreOp.getIndicesMutable().assign(indices); + if (auto transferReadOp = dyn_cast(op)) + return transferReadOp.getIndicesMutable().assign(indices); + if (auto transferWriteOp = dyn_cast(op)) + return transferWriteOp.getIndicesMutable().assign(indices); + llvm_unreachable("unsupported op type"); +} + +/// Uses `srcIndexValue` to permute `tgtIndexValue` via +/// `result = xor(floordiv(srcIdxVal,permuteEveryN), +/// floordiv(tgtIdxVal,vectorSize))) +/// + tgtIdxVal % vectorSize` +/// This is done using an optimized sequence of `arith` operations. +static Value permuteVectorOffset(OpBuilder& b, Location loc, + ArrayRef indices, MemRefType memrefTy, + int64_t srcDim, int64_t tgtDim) { + // Adjust the src index to change how often the permutation changes + // if necessary. + Value src = indices[srcDim]; + + // We only want to permute every N iterations of the target dim where N is + // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)). + const int64_t permuteEveryN = std::max( + 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * + memrefTy.getElementTypeBitWidth()) / + 8)); + + // clang-format off + // Index bit representation (b0 = least significant bit) for dim(1) + // of a `memref` is as follows: + // N := log2(128/elementSizeBits) + // M := log2(dimSize(1)) + // then + // bits[0:N] = sub-vector element offset + // bits[N:M] = vector index + // clang-format on + int64_t n = + llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth()); + int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim)); + + // Capture bits[0:(M-N)] of src by first creating a (M-N) mask. + int64_t mask = (1LL << (m - n)) - 1; + if (permuteEveryN > 1) mask = mask << llvm::Log2_64(permuteEveryN); + Value srcBits = b.create(loc, mask); + srcBits = b.create(loc, src, srcBits); + + // Use the src bits to permute the target bits b[N:M] containing the + // vector offset. + if (permuteEveryN > 1) { + int64_t shlBits = n - llvm::Log2_64(permuteEveryN); + if (shlBits > 0) { + Value finalShiftVal = b.create(loc, shlBits); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } else if (shlBits < 0) { + Value finalShiftVal = b.create(loc, -1 * shlBits); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } + } else { + Value finalShiftVal = b.create(loc, n); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } + + Value permutedVectorIdx = + b.create(loc, indices[tgtDim], srcBits); + return permutedVectorIdx; +} + +static void transformIndices(OpBuilder& builder, Location loc, + SmallVector& indices, + MemRefType memrefTy, int64_t srcDim, + int64_t tgtDim) { + indices[tgtDim] = + permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim); +} + +/// Return all operations within `parentOp` that read from or write to +/// `shmMemRef`. +static LogicalResult getShmReadAndWriteOps( + Operation* parentOp, Value shmMemRef, SmallVector& readOps, + SmallVector& writeOps) { + parentOp->walk([&](Operation* op) { + MemoryEffectOpInterface iface = dyn_cast(op); + if (!iface) return; + std::optional effect = + iface.getEffectOnValue(shmMemRef); + if (effect) { + readOps.push_back(op); + return; + } + effect = iface.getEffectOnValue(shmMemRef); + if (effect) writeOps.push_back(op); + }); + + // Restrict to a supported set of ops. We also require at least 2D access, + // although this could be relaxed. + if (llvm::any_of(readOps, [](Operation* op) { + return !isa(op) || + getIndices(op).size() < 2; + })) + return failure(); + if (llvm::any_of(writeOps, [](Operation* op) { + return !isa( + op) || + getIndices(op).size() < 2; + })) + return failure(); + + return success(); +} + +LogicalResult optimizeSharedMemoryReadsAndWrites(Operation* parentOp, + Value memrefValue) { + auto memRefType = dyn_cast(memrefValue.getType()); + if (!memRefType || !mlir::disc_ral::hasSharedMemoryAddressSpace(memRefType)) + return failure(); + + // Note: Upstream will check subview ops, we don't + // clang-format off + // Abort if the given value has any sub-views; we do not do any alias + // analysis. + // bool hasSubView = false; + // parentOp->walk([&](memref::SubViewOp subView) { + // subView.dump(); + // hasSubView = true; + // }); + // if (hasSubView) + // return failure(); + // clang-format on + + // Check if this is necessary given the assumption of 128b accesses: + // If dim[rank-1] is small enough to fit 8 rows in a 128B line. + const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); + const int64_t rowsPerLine = + (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / + rowSize; + const int64_t threadGroupSize = + 1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8)); + if (rowsPerLine >= threadGroupSize) return failure(); + + // Get sets of operations within the function that read/write to shared + // memory. + SmallVector shmReadOps; + SmallVector shmWriteOps; + if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps, + shmWriteOps))) + return failure(); + + if (shmReadOps.empty() || shmWriteOps.empty()) return failure(); + + OpBuilder builder(parentOp->getContext()); + + int64_t tgtDim = memRefType.getRank() - 1; + int64_t srcDim = memRefType.getRank() - 2; + + // Transform indices for the ops writing to shared memory. + while (!shmWriteOps.empty()) { + Operation* shmWriteOp = shmWriteOps.back(); + shmWriteOps.pop_back(); + builder.setInsertionPoint(shmWriteOp); + + auto indices = getIndices(shmWriteOp); + SmallVector transformedIndices(indices.begin(), indices.end()); + transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, + memRefType, srcDim, tgtDim); + setIndices(shmWriteOp, transformedIndices); + } + + // Transform indices for the ops reading from shared memory. + while (!shmReadOps.empty()) { + Operation* shmReadOp = shmReadOps.back(); + shmReadOps.pop_back(); + builder.setInsertionPoint(shmReadOp); + + auto indices = getIndices(shmReadOp); + SmallVector transformedIndices(indices.begin(), indices.end()); + transformIndices(builder, shmReadOp->getLoc(), transformedIndices, + memRefType, srcDim, tgtDim); + setIndices(shmReadOp, transformedIndices); + } + + return success(); +} + +} // namespace transform_dialect +} // namespace disc_ral +} // namespace mlir \ No newline at end of file diff --git a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc index 2f86b5c8be4..88fe7233bf5 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc +++ b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.cc @@ -29,6 +29,7 @@ #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -43,6 +44,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" @@ -52,6 +54,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "mlir/Transforms/Passes.h" +#include "mlir/disc/IR/disc_shape_ops.h" #include "mlir/disc/tools/disc-transform/LinalgExt/LinalgExtOps.h" #include "mlir/disc/tools/disc-transform/utils.h" #include "mlir/disc/transforms/codegen_utils.h" @@ -3562,11 +3565,37 @@ DiagnosedSilenceableFailure DISCSplitReductionSerialOp::applyToOne( Value dimM = b.createOrFold(loc, lhs, zero); Value dimN = b.createOrFold(loc, rhs, one); Value dimK = b.createOrFold(loc, lhs, one); + // TODO: use a better way to get the dim K + if (lhs.getType().cast().isDynamicDim(1)) { + tensor::PadOp pad = lhs.getDefiningOp(); + if (pad) { + tensor::ExtractSliceOp slice = + pad.getSource().getDefiningOp(); + if (slice) { + dimK = b.create(loc, slice.getSource(), one); + } + } + } scf::ForOp forOp = b.create(loc, zero, dimK, step, ValueRange{output}); + if (getLoopType().has_value()) + if (getLoopType().value().equals("cta-k-loop")) + forOp->setAttr("loop-type", StringAttr::get(ctx, "cta-k-loop")); b.setInsertionPoint(forOp.getBody(), forOp.getBody()->begin()); Value iv = forOp.getInductionVar(); + auto lhsTy = lhs.getType().cast(); + bool isKDimDynamic = lhsTy.isDynamicDim(1); + if (isKDimDynamic || + (!isKDimDynamic && lhsTy.getDimSize(1) % staticTileSize != 0)) { + AffineExpr d0, s0; + bindDims(ctx, d0); + bindSymbols(ctx, s0); + AffineMap minMap = AffineMap::get( + 1, 1, {d0 * (-1) + s0, b.getAffineConstantExpr(staticTileSize)}, ctx); + step = b.create(loc, b.getIndexType(), minMap, + ValueRange{iv, dimK}); + } SmallVector lhsOffsets{zero, iv}; SmallVector lhsDimUppers{dimM, step}; SmallVector lhsStrides{one, one}; @@ -4029,6 +4058,1314 @@ void transform_dialect::ApplyLoopIndependentCodeMotionOp::getEffects( transform::modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// DISCPaddingMNOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform_dialect::DISCPaddingMN::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + linalg::MatmulOp matmul = dyn_cast(target); + if (!matmul) { + return mlir::emitDefiniteFailure(target, + "apples only to linalg.matmul op."); + } + + const ArrayRef tileSizes = getTileSizes(); + if (tileSizes.size() != 2) { + return mlir::emitDefiniteFailure(target, "expect only tile M and N"); + } + int64_t MTileSize = tileSizes[0]; + int64_t NTileSize = tileSizes[1]; + + MLIRContext* ctx = target->getContext(); + IRRewriter rewriter(ctx); + Location loc = matmul.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + b.setInsertionPoint(target); + + Value lhs, rhs, res; + lhs = matmul.getOperand(0); + rhs = matmul.getOperand(1); + res = matmul.getOperand(2); + auto lhsTy = lhs.getType().cast(); + auto rhsTy = rhs.getType().cast(); + auto resTy = res.getType().cast(); + + bool shouldPaddingM = true, shouldPaddingN = true; + if (!lhsTy.isDynamicDim(0)) { + // TODO: support small static shape + if (lhsTy.getDimSize(0) < MTileSize) { + return mlir::emitDefiniteFailure( + target, "expected dimension M greater than MTileSize"); + } + if (lhsTy.getDimSize(0) % MTileSize == 0) shouldPaddingM = false; + } + if (!rhsTy.isDynamicDim(1)) { + if (rhsTy.getDimSize(1) < NTileSize) { + return mlir::emitDefiniteFailure( + target, "expected dimension N greater than NTileSize"); + } + if (rhsTy.getDimSize(1) % NTileSize == 0) shouldPaddingN = false; + } + if (!shouldPaddingM && !shouldPaddingN) { + results.push_back(matmul); + return DiagnosedSilenceableFailure::success(); + } + + // Padded shape + SmallVector lhsPadShape = {MTileSize, lhsTy.isDynamicDim(1) + ? ShapedType::kDynamic + : lhsTy.getDimSize(1)}; + SmallVector rhsPadShape = { + rhsTy.isDynamicDim(0) ? ShapedType::kDynamic : rhsTy.getDimSize(0), + NTileSize}; + SmallVector resPadShape = {MTileSize, NTileSize}; + + // Padding low and high + auto zeroIndex = b.createOrFold(0); + SmallVector paddingLow; + paddingLow.resize(lhsTy.getRank(), zeroIndex); + AffineExpr d0; + bindDims(rewriter.getContext(), d0); + auto mDimSubMap = AffineMap::get(1, 0, {MTileSize - d0}, ctx); + auto nDimSubMap = AffineMap::get(1, 0, {NTileSize - d0}, ctx); + Value paddingMHigh = zeroIndex, paddingNHigh = zeroIndex; + Value mDim = b.create(lhs, 0); + Value nDim = b.create(rhs, 1); + if (shouldPaddingM) { + paddingMHigh = b.create(mDimSubMap, mDim); + } + if (shouldPaddingN) { + paddingNHigh = b.create(nDimSubMap, nDim); + } + SmallVector lhsPadingHigh = {paddingMHigh, zeroIndex}; + SmallVector rhsPadingHigh = {zeroIndex, paddingNHigh}; + SmallVector resPadingHigh = {paddingMHigh, paddingNHigh}; + + // Convert the padding values to attributes. + SmallVector paddingValues; + if (getPaddingValues().size() != 3) { + emitOpError("expects padding A, B and C"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + for (auto const& it : + llvm::zip(getPaddingValues(), matmul.getOperandTypes())) { + auto attr = dyn_cast(std::get<0>(it)); + if (!attr) { + emitOpError("expects padding values to be typed attributes"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + auto elementType = getElementTypeOrSelf(std::get<1>(it)); + if (attr.getType() != elementType) { + auto diag = this->emitOpError("expects a padding value of type ") + << elementType << ", got " << attr; + diag.attachNote(matmul.getLoc()) << "when applied to this op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + paddingValues.push_back(attr); + } + Value lhsPaddingValue = b.create(paddingValues[0]); + Value rhsPaddingValue = b.create(paddingValues[1]); + Value resPaddingValue = b.create(paddingValues[2]); + + if (shouldPaddingM) { + lhs = b.create(lhsTy.clone(lhsPadShape), lhs, paddingLow, + lhsPadingHigh, lhsPaddingValue); + } + if (shouldPaddingN) { + rhs = b.create(rhsTy.clone(rhsPadShape), rhs, paddingLow, + rhsPadingHigh, rhsPaddingValue); + } + res = b.create(resTy.clone(resPadShape), res, paddingLow, + resPadingHigh, resPaddingValue, true); + SmallVector matmulInputVals = {lhs, rhs}; + auto newMatmul = b.create( + matmulInputVals, ArrayRef(res), target->getAttrs()); + results.push_back(newMatmul); + + // Extract padded matmul + SmallVector offsets = {zeroIndex, zeroIndex}; + Value mSize = shouldPaddingM + ? mDim + : b.create(lhsTy.getDimSize(0)); + Value nSize = shouldPaddingN + ? nDim + : b.create(rhsTy.getDimSize(1)); + SmallVector sizes = {mSize, nSize}; + auto oneIndex = b.createOrFold(1); + SmallVector strides = {oneIndex, oneIndex}; + auto extractSliceMatmul = b.create( + resTy, newMatmul.getResult(0), offsets, sizes, strides); + matmul.getResult(0).replaceAllUsesWith(extractSliceMatmul.getResult()); + + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// DISCPaddingKOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform_dialect::DISCPaddingK::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + linalg::MatmulOp matmul = dyn_cast(target); + if (!matmul) { + return mlir::emitDefiniteFailure(target, + "apples only to linalg.matmul op."); + } + + const ArrayRef tileSizes = getTileSizes(); + if (tileSizes.size() != 1) { + return mlir::emitDefiniteFailure(target, "expect only tile K"); + } + + int64_t KTileSize = tileSizes[0]; + + MLIRContext* ctx = target->getContext(); + IRRewriter rewriter(ctx); + Location loc = matmul.getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + b.setInsertionPoint(target); + + Value lhs, rhs, res; + lhs = matmul.getOperand(0); + rhs = matmul.getOperand(1); + res = matmul.getOperand(2); + auto lhsTy = lhs.getType().cast(); + auto rhsTy = rhs.getType().cast(); + + bool shouldPaddingK = true; + if (!lhsTy.isDynamicDim(1)) { + if (lhsTy.getDimSize(1) < KTileSize) { + return mlir::emitDefiniteFailure( + target, "expect dimemsion K greater than KTileSize"); + } + if (lhsTy.getDimSize(1) % KTileSize == 0) shouldPaddingK = false; + } + + // Padded shape + SmallVector lhsPadShape = { + lhsTy.isDynamicDim(0) ? ShapedType::kDynamic : lhsTy.getDimSize(0), + KTileSize}; + SmallVector rhsPadShape = {KTileSize, rhsTy.isDynamicDim(1) + ? ShapedType::kDynamic + : rhsTy.getDimSize(1)}; + + // Padding low and high + auto zeroIndex = b.createOrFold(0); + SmallVector padingLow; + padingLow.resize(lhsTy.getRank(), zeroIndex); + AffineExpr d0; + bindDims(rewriter.getContext(), d0); + auto kDimSubMap = AffineMap::get(1, 0, {KTileSize - d0}, ctx); + Value kDim = b.create(lhs, 1); + // Add an unecessary zero padding for later + // FoldOrthogonalPaddings canonicalization + Value paddingKHigh = shouldPaddingK + ? b.create(kDimSubMap, kDim) + : zeroIndex; + SmallVector lhsPadingHigh = {zeroIndex, paddingKHigh}; + SmallVector rhsPadingHigh = {paddingKHigh, zeroIndex}; + + // Convert the padding values to attributes. + SmallVector paddingValues; + if (getPaddingValues().size() != 2) { + emitOpError("expects only padding A, B"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + for (auto const& it : + llvm::zip(getPaddingValues(), matmul.getOperandTypes())) { + auto attr = dyn_cast(std::get<0>(it)); + if (!attr) { + emitOpError("expects padding values to be typed attributes"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + Type elementType = getElementTypeOrSelf(std::get<1>(it)); + if (attr.getType() != elementType) { + auto diag = this->emitOpError("expects a padding value of type ") + << elementType << ", got " << attr; + diag.attachNote(matmul.getLoc()) << "when applied to this op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + paddingValues.push_back(attr); + } + + Value lhsPaddingValue = b.create(paddingValues[0]); + Value rhsPaddingValue = b.create(paddingValues[1]); + Value lhsPadOp = + b.create(lhsTy.clone(lhsPadShape), lhs, padingLow, + lhsPadingHigh, lhsPaddingValue, true); + Value rhsPadOp = + b.create(rhsTy.clone(rhsPadShape), rhs, padingLow, + rhsPadingHigh, rhsPaddingValue, true); + + SmallVector matmulInputVals = {lhsPadOp, rhsPadOp}; + auto newMatmul = b.create( + matmulInputVals, ArrayRef(res), target->getAttrs()); + + matmul.getResult(0).replaceAllUsesWith(newMatmul.getResult(0)); + results.push_back(newMatmul); + + RewritePatternSet pattern(ctx); + func::FuncOp funcOp = newMatmul->getParentOfType(); + tensor::PadOp::getCanonicalizationPatterns(pattern, ctx); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(pattern)))) { + emitOpError("failed to run padop canonicalization patterns"); + return DiagnosedSilenceableFailure::definiteFailure(); + } + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// DISCSwapAllocTensorOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform_dialect::DISCSwapAllocTensor::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + MLIRContext* ctx = target->getContext(); + SimplePatternRewriter rewriter(ctx); + Location loc = target->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + + SmallVector allocs; + // Collect all the candidate alloc operations. + target->walk([&](bufferization::AllocTensorOp allocOp) { + vector::TransferWriteOp xWrite = + allocOp.getCopy().getDefiningOp(); + if (xWrite) { + tensor::EmptyOp emptyOp = + cast(xWrite.getOperand(1).getDefiningOp()); + if (emptyOp) { + allocs.push_back(allocOp); + } + } + }); + + for (auto allocOp : allocs) { + vector::TransferWriteOp xWrite = + allocOp.getCopy().getDefiningOp(); + tensor::EmptyOp emptyOp = + cast(xWrite.getOperand(1).getDefiningOp()); + // alloc A and B before cta-k-loop + b.setInsertionPoint(emptyOp->getParentOp()); + std::optional memorySpace = allocOp.getMemorySpace(); + Value newAllocOp = b.create( + allocOp.getType(), allocOp.getDynamicSizes(), + /*copy=*/Value(), + memorySpace ? cast(*memorySpace) : IntegerAttr()); + + b.setInsertionPoint(xWrite); + auto newXWrite = b.create( + xWrite.getVector(), newAllocOp, xWrite.getIndices(), + xWrite.getPermutationMapAttr(), xWrite.getMask(), + xWrite.getInBoundsAttr()); + allocOp.getResult().replaceAllUsesWith(newXWrite.getResult()); + + allocOp.erase(); + xWrite.erase(); + } + + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// DISCExpandTransferRWToMemrefCopyOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform_dialect::DISCExpandTransferRWToMemrefCopy::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + MLIRContext* ctx = target->getContext(); + IRRewriter rewriter(ctx); + Location loc = target->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + SmallVector writes; + SmallVector copies; + scf::ForOp ctaKLoop; + scf::ParallelOp parallelOp; + target->walk([&](scf::ForOp forOp) { + auto loopType = forOp->getAttrOfType("loop-type"); + if (loopType && loopType.getValue().equals("cta-k-loop") && + isa(forOp->getParentOp())) { + ctaKLoop = forOp; + parallelOp = cast(forOp->getParentOp()); + WalkResult::interrupt(); + } + }); + if (!ctaKLoop) { + return mlir::emitDefiniteFailure(target, "cannot find ctaKLoop"); + } + // %subview_8 = memref.subview %arg0[%4, %arg5] [%2, %8] [1, 1] + // %9 = vector.transfer_read %subview_8[%c0, %c0], %cst_0 + // vector.transfer_write %9, %subview_7[%c0, %c0] + // TODO: use another way to find out the R/W of ShareMemory + ctaKLoop->walk([&](vector::TransferWriteOp write) { + auto loop = dyn_cast_or_null(write->getParentOp()); + if (!loop || loop != ctaKLoop) { + WalkResult::skip(); + } + auto read = write.getVector().getDefiningOp(); + auto dst = write.getSource(); + if (!read || !dst) WalkResult::skip(); + if (!write.getPermutationMap().isMinorIdentity()) WalkResult::skip(); + if (!vector::isLastMemrefDimUnitStride( + dyn_cast(write.getShapedType()))) + WalkResult::skip(); + if (!hasSharedMemoryAddressSpace( + llvm::cast(write.getShapedType()))) { + WalkResult::skip(); + } + if (write.hasOutOfBoundsDim() || write.getMask()) WalkResult::skip(); + if (write.getMask() || read.getMask()) WalkResult::skip(); + if (!read.getPermutationMap().isMinorIdentity()) WalkResult::skip(); + if (!vector::isLastMemrefDimUnitStride( + dyn_cast(read.getShapedType()))) { + WalkResult::skip(); + } + writes.push_back(write); + }); + if (writes.empty()) { + return DiagnosedSilenceableFailure::success(); + } + + Value ctaKLoopIV = ctaKLoop.getInductionVar(); + Value ctaKLoopStep = + ctaKLoop.getStep().getDefiningOp(); + Value ctaKLoopUB = ctaKLoop.getUpperBound(); + + const size_t kThreadCopyBytes = 16; + const size_t kThreadsPerBlock = 128; + const size_t kThreadsPerWarp = 32; + b.setInsertionPointToStart(¶llelOp.getRegion().front()); + Value threadId = parallelOp.getInductionVars()[1]; + Value warpSize = b.create(kThreadsPerWarp); + Value warpId = b.create(threadId, warpSize); + Value laneId = b.create(threadId, warpSize); + SmallVector tokens; + auto d0 = b.getAffineDimExpr(0); + auto s0 = b.getAffineSymbolExpr(0); + for (auto write : writes) { + auto read = write.getVector().getDefiningOp(); + auto padding = read.getPadding(); + auto dst = write.getSource(); + auto dstMemref = cast(dst.getType()); + auto src = read.getSource(); + ArrayRef shape = + dyn_cast(write.getShapedType()).getShape(); + int64_t total_chunk_lines = shape[0]; + int64_t chunk_size_in_bytes = shape[1] * 2; + int64_t chunk_copy_lines_per_waro = + kThreadsPerWarp * kThreadCopyBytes / chunk_size_in_bytes; + int64_t chunk_copy_lines_per_block = + kThreadsPerBlock * kThreadCopyBytes / chunk_size_in_bytes; + int64_t iterations = shape[0] / chunk_copy_lines_per_block; + int64_t chunk_copy_line_lanes = chunk_size_in_bytes / kThreadCopyBytes; + b.setInsertionPointAfter(write); + Value zero = b.create(0); + Value one = b.create(1); + Value numElements = b.create(kThreadCopyBytes / 2); + for (int64_t i = 0; i < iterations; ++i) { + auto offsetXMap = AffineMap::get( + 1, 0, + {i * chunk_copy_lines_per_block + d0.floorDiv(chunk_copy_line_lanes)}, + ctx); + auto offsetYMap = AffineMap::get( + 1, 0, {(d0 % chunk_copy_line_lanes) * (kThreadCopyBytes / 2)}, ctx); + Value offsetX = + b.create(offsetXMap, ValueRange{threadId}); + Value offsetY = + b.create(offsetYMap, ValueRange{threadId}); + // clang-format off + // if (offsetX >= src.dim0 || offsetY >= src.dim1) + // srcElements = 0; + // else + // if (offsetY + numElements > src.dim1) + // srcElements = src.dim1 - offsetY; + // else + // srcElements = numElements + // clang-format on + Value dim0 = b.create(src, zero); + Value dim1 = b.create(src, one); + Value dim0GeX = + b.create(arith::CmpIPredicate::uge, offsetX, dim0); + Value dim1GeY = + b.create(arith::CmpIPredicate::uge, offsetY, dim1); + Value cond0 = b.create(dim0GeX, dim1GeY); + + Value diffY = b.create(dim1, offsetY); + Value cond1 = b.create( + arith::CmpIPredicate::ugt, + b.create(offsetY, numElements), dim1); + Value srcElements = b.create( + cond0, zero, b.create(cond1, diffY, numElements)); + auto token = b.create( + nvgpu::DeviceAsyncTokenType::get(ctx), dst, + /*dstIndices*/ ValueRange{offsetX, offsetY}, src, + /*srcIndices*/ ValueRange{offsetX, offsetY}, + /*dstElements*/ b.getIndexAttr(kThreadCopyBytes / 2), + /*srcElements*/ srcElements, + /*bypassL1*/ b.getUnitAttr()); + tokens.push_back(token); + } + } + auto tokenGroup = b.create( + nvgpu::DeviceAsyncTokenType::get(ctx), tokens); + // we will manually lower nvgpu::DeviceAsyncWaitOp to + // NVVM::CpAsyncCommitGroupOp, so it's ok to use the + // `nullptr` directly. + b.create(tokenGroup, nullptr); + b.create(); + + for (auto write : writes) { + write.erase(); + } + return DiagnosedSilenceableFailure::success(); +} + +/// Replace the uses of `oldOp` with the given `val` and for subview uses +/// propagate the type change. Changing the memref type may require propagating +/// it through subview ops so we cannot just do a replaceAllUse but need to +/// propagate the type change and erase old subview ops. +static void replaceUsesAndPropagateType(RewriterBase& rewriter, + Operation* oldOp, Value val) { + SmallVector opsToDelete; + SmallVector operandsToReplace; + + // Save the operand to replace / delete later (avoid iterator invalidation). + // TODO: can we use an early_inc iterator? + for (OpOperand& use : oldOp->getUses()) { + // Non-subview ops will be replaced by `val`. + auto subviewUse = dyn_cast(use.getOwner()); + if (!subviewUse) { + operandsToReplace.push_back(&use); + continue; + } + + // `subview(old_op)` is replaced by a new `subview(val)`. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(subviewUse); + Type newType = memref::SubViewOp::inferRankReducedResultType( + subviewUse.getType().getShape(), cast(val.getType()), + subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), + subviewUse.getStaticStrides()); + Value newSubview = rewriter.create( + subviewUse->getLoc(), cast(newType), val, + subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), + subviewUse.getMixedStrides()); + + // Ouch recursion ... is this really necessary? + replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); + + opsToDelete.push_back(use.getOwner()); + } + + // Perform late replacement. + // TODO: can we use an early_inc iterator? + for (OpOperand* operand : operandsToReplace) { + Operation* op = operand->getOwner(); + rewriter.startRootUpdate(op); + operand->set(val); + rewriter.finalizeRootUpdate(op); + } + + // Perform late op erasure. + // TODO: can we use an early_inc iterator? + for (Operation* op : opsToDelete) rewriter.eraseOp(op); +} + +//===----------------------------------------------------------------------===// +// DISCMultiBufferingOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform_dialect::DISCMultiBuffering::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + MLIRContext* ctx = target->getContext(); + IRRewriter rewriter(ctx); + Location loc = target->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + auto funcOp = cast(target); + // Get ctaKLoop + scf::ForOp ctaKLoop; + target->walk([&](scf::ForOp loop) { + auto reductionTy = loop->getAttrOfType("loop-type"); + if (reductionTy && reductionTy.getValue().equals("cta-k-loop")) { + ctaKLoop = loop; + WalkResult::interrupt(); + } + }); + scf::ForOp mmaKLoop; + target->walk([&](nvgpu::MmaSyncOp mmaSyncOp) { + auto loop = dyn_cast_or_null(mmaSyncOp->getParentOp()); + if (loop) { + mmaKLoop = loop; + WalkResult::interrupt(); + } + }); + + DominanceInfo dom(ctaKLoop); + SmallVector allocs; + // Collect all the candidate alloc operations + // 1. shared memory + // 2. dominate ctaKLoop + // 3. all users inside ctaKLoop or inside mmaKLoop + funcOp.walk([&](memref::AllocOp allocOp) { + if (!hasSharedMemoryAddressSpace(allocOp.getType()) || + !dom.properlyDominates(allocOp.getOperation(), ctaKLoop)) { + return WalkResult::advance(); + } + for (Operation* user : allocOp->getUsers()) { + if (isa(user)) continue; + auto loop = dyn_cast_or_null(user->getParentOp()); + if (!loop && loop != ctaKLoop && loop != mmaKLoop) + return WalkResult::advance(); + } + allocOp.dump(); + allocs.push_back(allocOp); + return WalkResult::advance(); + }); + + // Try to apply multi-buffering to all of them. + for (memref::AllocOp allocOp : allocs) { + DominanceInfo dom(allocOp->getParentOp()); + LoopLikeOpInterface candidateLoop = ctaKLoop; + + std::optional inductionVar = candidateLoop.getSingleInductionVar(); + std::optional lowerBound = + candidateLoop.getSingleLowerBound(); + std::optional singleStep = candidateLoop.getSingleStep(); + if (!inductionVar || !lowerBound || !singleStep) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + + // Start multibuffering loop + // 1. Construct the multi-buffered memref type. + ArrayRef originalShape = allocOp.getType().getShape(); + int64_t multiBufferingFactor = getMultiBufferingFactor(); + SmallVector multiBufferedShape{multiBufferingFactor}; + llvm::append_range(multiBufferedShape, originalShape); + MemRefType mbMemRefType = MemRefType::Builder(allocOp.getType()) + .setShape(multiBufferedShape) + .setLayout(MemRefLayoutAttrInterface()); + + // 2. Create the multi-buffered alloc. + Location loc = allocOp->getLoc(); + OpBuilder::InsertionGuard g(rewriter); + b.setInsertionPoint(allocOp); + auto mbAlloc = b.create(loc, mbMemRefType, ValueRange{}, + allocOp->getAttrs()); + + // 3. Within the loop, build the modular leading index (i.e. each loop + // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor). + b.setInsertionPointToStart(&candidateLoop.getLoopBody().front()); + Value ivVal = *inductionVar; + Value lbVal = getValueOrCreateConstantIndexOp(b, loc, *lowerBound); + Value stepVal = getValueOrCreateConstantIndexOp(b, loc, *singleStep); + AffineExpr iv, lb, step; + bindDims(b.getContext(), iv, lb, step); + Value bufferIndex = affine::makeComposedAffineApply( + b, loc, ((iv - lb).floorDiv(step)) % multiBufferingFactor, + {ivVal, lbVal, stepVal}); + + // 4. Build the subview accessing the particular slice, + // taking modular rotation into account. + int64_t mbMemRefTypeRank = mbMemRefType.getRank(); + IntegerAttr zero = b.getIndexAttr(0); + IntegerAttr one = b.getIndexAttr(1); + SmallVector offsets(mbMemRefTypeRank, zero); + SmallVector sizes(mbMemRefTypeRank, one); + SmallVector strides(mbMemRefTypeRank, one); + // Offset is [bufferIndex, 0 ... 0 ]. + offsets.front() = bufferIndex; + // Sizes is [1, original_size_0 ... original_size_n ]. + for (int64_t i = 0, e = originalShape.size(); i != e; ++i) + sizes[1 + i] = b.getIndexAttr(originalShape[i]); + // Strides is [1, 1 ... 1 ]. + auto dstMemref = + cast(memref::SubViewOp::inferRankReducedResultType( + originalShape, mbMemRefType, offsets, sizes, strides)); + Value subview = b.create(dstMemref, mbAlloc, offsets, + sizes, strides); + + // 5. Due to the recursive nature of replaceUsesAndPropagateType, + // we need to handle dealloc uses separately. + for (OpOperand& use : llvm::make_early_inc_range(allocOp->getUses())) { + auto deallocOp = dyn_cast(use.getOwner()); + if (!deallocOp) continue; + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(deallocOp); + auto newDeallocOp = + b.create(deallocOp->getLoc(), mbAlloc); + (void)newDeallocOp; + deallocOp.erase(); + } + + // 6. RAUW with the particular slice, taking modular rotation into account. + replaceUsesAndPropagateType(rewriter, allocOp, subview); + + // 7. Finally, erase the old allocOp. + allocOp.erase(); + } + return DiagnosedSilenceableFailure::success(); +} + +void transform_dialect::DISCMultiBuffering::getEffects( + SmallVectorImpl& effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + +//===----------------------------------------------------------------------===// +// DISCSwizzleShareMemoryOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform_dialect::DISCSwizzleShareMemoryOp::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + MLIRContext* ctx = target->getContext(); + IRRewriter rewriter(ctx); + Location loc = target->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + auto funcOp = cast(target); + SmallVector shmAllocOps; + RewritePatternSet pattern(funcOp.getContext()); + memref::populateFoldMemRefAliasOpPatterns(pattern); + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(pattern)))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + funcOp->walk([&](memref::AllocOp allocOp) { + // Only apply swizzling to input shared memory. + if (hasSharedMemoryAddressSpace(allocOp.getType())) { + auto memoryTy = allocOp->getAttrOfType("memory-type"); + if (!memoryTy) { + shmAllocOps.push_back(allocOp); + } + } + }); + LogicalResult result(failure()); + for (auto allocOp : shmAllocOps) { + result = disc_ral::transform_dialect::optimizeSharedMemoryReadsAndWrites( + funcOp, allocOp.getMemref()); + if (failed(result)) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + } + return DiagnosedSilenceableFailure::success(); +} + +static int64_t getAllocSize(Operation* op, DataLayout& dataLayout) { + auto allocOp = cast(op); + int64_t numElements = allocOp.getType().getNumElements(); + return (dataLayout.getTypeSizeInBits(allocOp.getType().getElementType()) * + numElements) / + 8; +} + +//===----------------------------------------------------------------------===// +// DISCPackSharedMemoryAllocOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform_dialect::DISCPackSharedMemoryAllocOp::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + MLIRContext* ctx = target->getContext(); + IRRewriter rewriter(ctx); + Location loc = target->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + auto parallelOp = cast(target); + DominanceInfo dominators(parallelOp); + SmallVector inputAllocs; + SmallVector outputAllocs; + parallelOp.walk([&](memref::AllocOp alloc) { + if (hasSharedMemoryAddressSpace(alloc.getType())) { + auto memoryType = alloc->getAttrOfType("memory-type"); + if (memoryType && memoryType.getValue().equals("output")) { + outputAllocs.push_back(alloc); + } else { + inputAllocs.push_back(alloc); + } + } + }); + + DataLayout dataLayout = DataLayout::closest(parallelOp); + int64_t maxAlloc = 0; + int64_t inputAllocSize = 0, outputAllocSize = 0; + for (auto alloc : inputAllocs) { + inputAllocSize += getAllocSize(alloc, dataLayout); + maxAlloc = std::max(maxAlloc, inputAllocSize); + } + for (auto alloc : outputAllocs) { + outputAllocSize += getAllocSize(alloc, dataLayout); + maxAlloc = std::max(maxAlloc, outputAllocSize); + } + b.setInsertionPointToStart(¶llelOp.getRegion().front()); + Attribute memorySpace = gpu::AddressSpaceAttr::get( + ctx, gpu::GPUDialect::getWorkgroupAddressSpace()); + MemRefType allocType = + MemRefType::get({maxAlloc}, b.getI8Type(), AffineMap(), memorySpace); + Value packedAlloc = b.create(allocType); + + int64_t inputOffset = 0, outputOffset = 0; + for (auto alloc : inputAllocs) { + b.setInsertionPoint(alloc); + Value offsetValue = b.create(inputOffset); + Value newAlloc = b.create(alloc.getType(), packedAlloc, + offsetValue, ArrayRef({})); + inputOffset += getAllocSize(alloc, dataLayout); + alloc.replaceAllUsesWith(newAlloc); + } + for (auto alloc : outputAllocs) { + b.setInsertionPoint(alloc); + Value offsetValue = b.create(outputOffset); + Value newAlloc = b.create(alloc.getType(), packedAlloc, + offsetValue, ArrayRef({})); + outputOffset += getAllocSize(alloc, dataLayout); + alloc.replaceAllUsesWith(newAlloc); + } + for (auto alloc : inputAllocs) { + alloc.erase(); + } + for (auto alloc : outputAllocs) { + alloc.erase(); + } + return DiagnosedSilenceableFailure::success(); +} + +void transform_dialect::DISCPackSharedMemoryAllocOp::getEffects( + SmallVectorImpl& effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform_dialect::DISCMoveDataToRegister::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + MLIRContext* ctx = target->getContext(); + IRRewriter rewriter(ctx); + Location loc = target->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + auto mma = cast(target); + if (!mma) { + return mlir::emitDefiniteFailure( + target, "DISCUseRegForAccumalation expect mma operation"); + } + auto ldMatrixA = mma.getMatrixA().getDefiningOp(); + auto ldMatrixB = mma.getMatrixB().getDefiningOp(); + if (!ldMatrixA || !ldMatrixB) { + return mlir::emitDefiniteFailure(target, "expect ldmatrixA and ldmatrixB"); + } + // mma shape + std::array mmaShapeArray = mma.getMmaShapeAsArray(); + int64_t mmaMShape = mmaShapeArray[0]; + int64_t mmaNShape = mmaShapeArray[1]; + int64_t mmaKShape = mmaShapeArray[2]; + + // get mma's M, N, K loop + auto mmaKLoop = dyn_cast_or_null(mma->getParentOp()); + if (!mmaKLoop) { + return mlir::emitDefiniteFailure(target, "expect mma inner loop"); + } + auto mmaNLoop = dyn_cast_or_null(mmaKLoop->getParentOp()); + if (!mmaNLoop) { + return mlir::emitDefiniteFailure(target, "expect mma mid loop"); + } + auto mmaMLoop = dyn_cast_or_null(mmaNLoop->getParentOp()); + if (!mmaMLoop) { + return mlir::emitDefiniteFailure(target, "expect mma outer loop"); + } + auto mmaMLoopUB = + mmaMLoop.getUpperBound().getDefiningOp(); + auto mmaNLoopUB = + mmaNLoop.getUpperBound().getDefiningOp(); + auto mmaKLoopUB = + mmaKLoop.getUpperBound().getDefiningOp(); + if (!mmaMLoopUB || !mmaNLoopUB || !mmaKLoopUB) { + return mlir::emitDefiniteFailure(target, "expect constant warp shape"); + } + // warp shape + int64_t warpMShape = mmaMLoopUB.value(); + int64_t warpNShape = mmaNLoopUB.value(); + int64_t warpKShape = mmaKLoopUB.value(); + // block shape + const ArrayRef blockShape = getBlockMnShape(); + if (blockShape.size() != 2) { + return mlir::emitDefiniteFailure(target, "expect block M, N shape"); + } + int64_t blockMShape = blockShape[0]; + int64_t blockNShape = blockShape[1]; + if (blockMShape % warpMShape != 0 || blockNShape % warpNShape != 0 || + warpMShape % mmaMShape != 0 || warpNShape % mmaNShape != 0 || + warpKShape % mmaKShape != 0) { + return mlir::emitDefiniteFailure(target, "invalid shape of block or warp"); + } + auto parallelOp = mmaMLoop->getParentOfType(); + if (!parallelOp) { + return mlir::emitDefiniteFailure(target, + "It should have a parent parallel op"); + } + + scf::ForOp ctaKLoop; + parallelOp->walk([&](scf::ForOp forOp) { + auto reductionTy = forOp->getAttrOfType("loop-type"); + if (reductionTy && reductionTy.getValue().equals("cta-k-loop")) { + ctaKLoop = forOp; + WalkResult::interrupt(); + } + }); + if (!ctaKLoop) { + return mlir::emitDefiniteFailure(target, "expect cta K reduction loop"); + } + + auto privateAS = gpu::GPUDialect::getPrivateAddressSpace(); + auto shareAS = gpu::GPUDialect::getWorkgroupAddressSpace(); + auto privateASAttr = gpu::AddressSpaceAttr::get(ctx, privateAS); + auto shareASAttr = gpu::AddressSpaceAttr::get(ctx, shareAS); + Type elementType = mma.getType().cast().getElementType(); + + b.setInsertionPointToStart(¶llelOp.getRegion().front()); + + auto zeroAttr = b.getZeroAttr(elementType); + Value zeroCst = b.create(zeroAttr); + Value zeroCstVec = b.create(DenseElementsAttr::get( + VectorType::get({2}, elementType), FloatAttr::get(elementType, 0.0))); + auto d0 = b.getAffineDimExpr(0); + auto s0 = b.getAffineSymbolExpr(0); + Value zeroIndex = b.create(0); + Value oneIndex = b.create(1); + + Value threadId = parallelOp.getInductionVars()[1]; + Value warpSize = b.create(kWarpSize); + Value warpId = b.create(threadId, warpSize); + Value laneId = b.create(threadId, warpSize); + // A. alloc register for input matrix + MemRefType allocTypeA = + MemRefType::get({warpMShape / mmaMShape, warpKShape / mmaKShape, 4, 2}, + elementType, AffineMap(), privateASAttr); + MemRefType allocTypeB = + MemRefType::get({warpKShape / mmaKShape, warpNShape / mmaNShape, 2, 2}, + elementType, AffineMap(), privateASAttr); + MemRefType allocTypeC = + MemRefType::get({warpMShape / mmaMShape, warpNShape / mmaNShape, 2, 2}, + elementType, AffineMap(), privateASAttr); + Value aWarpRegAlloc = b.create(allocTypeA); + Value bWarpRegAlloc = b.create(allocTypeB); + Value cWarpRegAlloc = b.create(allocTypeC); + for (int64_t i = 0; i < warpMShape / mmaMShape; ++i) { + for (int64_t j = 0; j < warpNShape / mmaNShape; ++j) + for (int64_t m = 0; m < 2; ++m) + b.create( + zeroCstVec, cWarpRegAlloc, + ValueRange{b.create(i), + b.create(j), + b.create(m), zeroIndex}); + } + + // create vector for a, b, c + Value matrixAReg = b.create(DenseElementsAttr::get( + ldMatrixA.getRes().getType(), FloatAttr::get(elementType, 0.0))); + Value matrixBReg = b.create(DenseElementsAttr::get( + ldMatrixB.getRes().getType(), FloatAttr::get(elementType, 0.0))); + Value matrixCReg = b.create(DenseElementsAttr::get( + mma.getRes().getType(), FloatAttr::get(elementType, 0.0))); + + SmallVector outputTypes(2, b.getIndexType()); + SmallVector shape; + shape.push_back(b.create(blockMShape / warpMShape)); + shape.push_back(b.create(blockNShape / warpNShape)); + auto delinearizeOp = b.create( + outputTypes, b.create(threadId, warpSize), shape); + SmallVector delinearizeRes; + for (Value result : delinearizeOp->getResults()) + delinearizeRes.push_back(result); + Value matrixAWarpOffset = b.create( + delinearizeRes[0], b.create(warpMShape)); + Value matrixBWarpOffset = b.create( + delinearizeRes[1], b.create(warpNShape)); + + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(mmaMLoop); + auto ldAMOffsetMap = + AffineMap::get(1, 1, + // {d0 + s0 - (s0.floorDiv(mmaMShape)) * mmaMShape}, ctx); + {d0 + s0 % mmaMShape}, ctx); + auto ldAKOffsetMap = + AffineMap::get(1, 1, {d0 + (s0.floorDiv(mmaKShape)) * 8}, ctx); + // create M/K nested loop to cache ldmatrix A's result + auto ldAMLoop = b.create( + zeroIndex, b.create(warpMShape), + b.create(mmaMShape)); + b.setInsertionPoint(ldAMLoop.getBody(), ldAMLoop.getBody()->begin()); + auto ldAMLoopIV = ldAMLoop.getInductionVar(); + auto ldAMOffset = b.create( + ldAMOffsetMap, + ValueRange{b.create(ldAMLoopIV, matrixAWarpOffset), + laneId}); + auto ldAKLoop = b.create( + zeroIndex, b.create(warpKShape), + b.create(mmaKShape)); + b.setInsertionPoint(ldAKLoop.getBody(), ldAKLoop.getBody()->begin()); + auto ldAKLoopIV = ldAKLoop.getInductionVar(); + auto ldAKOffset = b.create( + ldAKOffsetMap, ValueRange{ldAKLoopIV, laneId}); + auto ldARes = b.create( + ldMatrixA.getRes().getType(), ldMatrixA.getSrcMemref(), + ValueRange{ldAMOffset, ldAKOffset}, ldMatrixA.getTransposeAttr(), + ldMatrixA.getNumTilesAttr()); + for (unsigned i = 0; i < 4; ++i) { + Value extractOp = b.create( + ldARes.getRes(), ValueRange{b.create(i)}); + b.create( + extractOp, aWarpRegAlloc, + ValueRange{b.create( + ldAMLoopIV, b.create(mmaMShape)), + b.create( + ldAKLoopIV, b.create(mmaKShape)), + b.create(i), zeroIndex}); + } + b.setInsertionPointAfter(ldAMLoop); + (void)mlir::loopUnrollByFactor(ldAKLoop, warpKShape / mmaKShape); + (void)mlir::loopUnrollByFactor(ldAMLoop, warpMShape / mmaMShape); + + auto ldBKOffsetMap = AffineMap::get(1, 1, {d0 + s0 % mmaKShape}, ctx); + auto ldBNOffsetMap = + AffineMap::get(1, 1, {d0 + (s0.floorDiv(mmaKShape)) * 8}, ctx); + // create K/N nested loop to cache ldmartixB's result + auto ldBKLoop = b.create( + zeroIndex, b.create(warpKShape), + b.create(mmaKShape)); + b.setInsertionPoint(ldBKLoop.getBody(), ldBKLoop.getBody()->begin()); + auto ldBKLoopIV = ldBKLoop.getInductionVar(); + auto ldBKOffset = b.create( + ldBKOffsetMap, ValueRange{ldBKLoopIV, laneId}); + auto ldBNLoop = b.create( + zeroIndex, b.create(warpNShape), + b.create(mmaNShape)); + b.setInsertionPoint(ldBNLoop.getBody(), ldBNLoop.getBody()->begin()); + auto ldBNLoopIV = ldBNLoop.getInductionVar(); + auto ldBNOffset = b.create( + ldBNOffsetMap, + ValueRange{b.create(ldBNLoopIV, matrixBWarpOffset), + laneId}); + auto ldBRes = b.create( + ldMatrixB.getRes().getType(), ldMatrixB.getSrcMemref(), + ValueRange{ldBKOffset, ldBNOffset}, ldMatrixB.getTransposeAttr(), + ldMatrixB.getNumTilesAttr()); + for (unsigned i = 0; i < 2; ++i) { + Value extractOp = b.create( + ldBRes.getRes(), ValueRange{b.create(i)}); + b.create( + extractOp, bWarpRegAlloc, + ValueRange{b.create( + ldBKLoopIV, b.create(mmaKShape)), + b.create( + ldBNLoopIV, b.create(mmaNShape)), + b.create(i), zeroIndex}); + } + b.setInsertionPointAfter(ldBKLoop); + (void)mlir::loopUnrollByFactor(ldBNLoop, warpNShape / mmaNShape); + (void)mlir::loopUnrollByFactor(ldBKLoop, warpKShape / mmaKShape); + + // create M/N/K nested loop to compute mma + int64_t vector_width = 2; + auto newMmaKLoop = b.create( + zeroIndex, b.create(warpKShape), + b.create(mmaKShape)); + b.setInsertionPoint(newMmaKLoop.getBody(), newMmaKLoop.getBody()->begin()); + auto mmaKLoopIV = newMmaKLoop.getInductionVar(); + auto mmaKShapeValue = b.create(mmaKShape); + auto kIndexValue = b.create(mmaKLoopIV, mmaKShapeValue); + auto newMmaMLoop = b.create( + zeroIndex, b.create(warpMShape), + b.create(mmaMShape)); + b.setInsertionPoint(newMmaMLoop.getBody(), newMmaMLoop.getBody()->begin()); + auto mmaMLoopIV = newMmaMLoop.getInductionVar(); + auto mmaMShapeValue = b.create(mmaMShape); + auto mIndexValue = b.create(mmaMLoopIV, mmaMShapeValue); + // load A from reg + for (unsigned i = 0; i < 4; ++i) { + matrixAReg = b.create( + b.create( + VectorType::get({vector_width}, elementType), aWarpRegAlloc, + ValueRange{mIndexValue, kIndexValue, + b.create(i), zeroIndex}), + matrixAReg, (int64_t[]){i}); + } + auto newMmaNLoop = b.create( + zeroIndex, b.create(warpNShape), + b.create(mmaNShape)); + b.setInsertionPoint(newMmaNLoop.getBody(), newMmaNLoop.getBody()->begin()); + auto mmaNLoopIV = newMmaNLoop.getInductionVar(); + auto mmaNShapeValue = b.create(mmaNShape); + auto nIndexValue = b.create(mmaNLoopIV, mmaNShapeValue); + // load C form reg + for (unsigned i = 0; i < 2; ++i) { + matrixCReg = b.create( + b.create( + VectorType::get({vector_width}, elementType), cWarpRegAlloc, + ValueRange{mIndexValue, nIndexValue, + b.create(i), zeroIndex}), + matrixCReg, (int64_t[]){i}); + } + // load B from reg + for (unsigned i = 0; i < 2; ++i) { + matrixBReg = b.create( + b.create( + VectorType::get({vector_width}, elementType), bWarpRegAlloc, + ValueRange{kIndexValue, nIndexValue, + b.create(i), zeroIndex}), + matrixBReg, (int64_t[]){i}); + } + // compute mma + matrixCReg = b.create(matrixAReg, matrixBReg, matrixCReg, + mma.getMmaShapeAttr()); + // store c to reg + for (unsigned i = 0; i < 2; ++i) { + b.create( + b.create(matrixCReg, (int64_t[]){i}), cWarpRegAlloc, + ValueRange{mIndexValue, nIndexValue, + b.create(i), zeroIndex}); + } + b.setInsertionPointAfter(newMmaKLoop); + (void)mlir::loopUnrollByFactor(newMmaNLoop, warpNShape / mmaNShape); + (void)mlir::loopUnrollByFactor(newMmaMLoop, warpMShape / mmaMShape); + (void)mlir::loopUnrollByFactor(newMmaKLoop, warpKShape / mmaKShape); + + // no longer need origin mmaMLoop + for (auto user : mmaMLoop->getUsers()) user->erase(); + mmaMLoop.erase(); + + // 4. Move data from cWarpReg to cWarpSmem + b.setInsertionPointAfter(ctaKLoop); + b.create(); + const int64_t kThreadsPerBlock = 128; + const int64_t kWarpsPersBlock = kThreadsPerBlock / kWarpSize; + const int64_t kWarpsPerRowInBlock = blockNShape / warpNShape; + if ((blockMShape / warpMShape) * (blockNShape / warpNShape) != + kWarpsPersBlock) { + return mlir::emitDefiniteFailure(target, + "warps should be equal to thread groups"); + } + // Use padding to avoid output bank conflict + int64_t smemPadding = getSmemPadding(); + auto cBlockSmemAlloc = b.create( + MemRefType::get({blockMShape, blockNShape + smemPadding}, elementType, + AffineMap(), shareASAttr)); + cBlockSmemAlloc->setAttr("memory-type", StringAttr::get(ctx, "output")); + // Get cWarpSmem of cBlockSmem by subviewing + auto cWarpSmemRowOffsetMap = AffineMap::get( + 1, 0, {(d0.floorDiv(kWarpsPerRowInBlock)) * warpMShape}, ctx); + auto cWarpSmemColOffsetMap = + AffineMap::get(1, 0, {(d0 % kWarpsPerRowInBlock) * warpNShape}, ctx); + auto cWarpSmemRowOffset = + b.create(cWarpSmemRowOffsetMap, warpId); + auto cWarpSmemColOffset = + b.create(cWarpSmemColOffsetMap, warpId); + + auto cWarpSmemColMap = AffineMap::get(1, 1, {d0 + (s0 % 4) * 2}, ctx); + Value warpBaseRow = b.create( + cWarpSmemRowOffset, + b.create(laneId, b.create(2))); + mmaMLoop = b.create(zeroIndex, + b.create(warpMShape), + b.create(mmaMShape)); + mmaMLoop->setAttr("loop-type", StringAttr::get(ctx, "reg-to-smem-loop")); + b.setInsertionPoint(mmaMLoop.getBody(), mmaMLoop.getBody()->begin()); + mmaMLoopIV = mmaMLoop.getInductionVar(); + mmaNLoop = b.create(zeroIndex, + b.create(warpNShape), + b.create(mmaNShape)); + b.setInsertionPoint(mmaNLoop.getBody(), mmaNLoop.getBody()->begin()); + mmaNLoopIV = mmaNLoop.getInductionVar(); + auto mIndex = b.create( + mmaMLoopIV, b.create(mmaMShape)); + auto nIndex = b.create( + mmaNLoopIV, b.create(mmaNShape)); + + Value cWarpSmemRow0 = b.create(warpBaseRow, mmaMLoopIV); + Value cWarpSmemRow1 = b.create( + cWarpSmemRow0, b.create(8)); + Value cWarpSmemCol = b.create( + cWarpSmemColOffset, b.create( + cWarpSmemColMap, ValueRange{mmaNLoopIV, laneId})); + // TODO: store 4xf16 rather than 2x2xf16 + b.create( + b.create( + VectorType::get({vector_width}, elementType), cWarpRegAlloc, + ValueRange{mIndex, nIndex, zeroIndex, zeroIndex}), + cBlockSmemAlloc, ValueRange{cWarpSmemRow0, cWarpSmemCol}); + b.create( + b.create(VectorType::get({vector_width}, elementType), + cWarpRegAlloc, + ValueRange{mIndex, nIndex, oneIndex, zeroIndex}), + cBlockSmemAlloc, ValueRange{cWarpSmemRow1, cWarpSmemCol}); + b.setInsertionPointAfter(mmaMLoop); + b.create(); + (void)mlir::loopUnrollByFactor(mmaNLoop, warpNShape / mmaNShape); + (void)mlir::loopUnrollByFactor(mmaMLoop, warpMShape / mmaMShape); + + // 5. Move data from smem to gmem + linalg::GenericOp genericOp; + parallelOp->walk([&](linalg::GenericOp generic) { + // TODO: check input is shared memory, output is global + genericOp = generic; + WalkResult::interrupt(); + }); + if (!genericOp) return mlir::emitDefiniteFailure(target, "expect generic op"); + Value input = genericOp.getInputs()[0]; + Value output = genericOp.getOutputs()[0]; + auto inputSubView = + dyn_cast_or_null(input.getDefiningOp()); + auto outputSubView = + dyn_cast_or_null(output.getDefiningOp()); + if (!inputSubView || !outputSubView) { + return mlir::emitDefiniteFailure(target, + "expect generic op's operand is subview"); + } + + // expand genericOp to loop + b.setInsertionPoint(genericOp); + int64_t cpBytesPerThread = 16; // 128 bits + int64_t cpElementsPerThread = cpBytesPerThread / 2; + int64_t cpThreadsPerRow = blockNShape / cpElementsPerThread; + int64_t cpRowsPerBlock = 128 * cpElementsPerThread / blockNShape; + auto smemToGmemLoop = + b.create(b.create(0), + b.create(blockMShape), + b.create(cpRowsPerBlock)); + smemToGmemLoop->setAttr("loop-type", + StringAttr::get(ctx, "smem-to-gmem-loop")); + + b.setInsertionPoint(smemToGmemLoop.getBody(), + smemToGmemLoop.getBody()->begin()); + auto iv = smemToGmemLoop.getInductionVar(); + auto offsetY = b.create( + b.create(cpElementsPerThread), + b.create( + threadId, b.create(cpThreadsPerRow))); + auto offsetX = b.create( + iv, b.create( + threadId, b.create(cpThreadsPerRow))); + auto dimM = b.create(outputSubView.getSource(), + b.create(0)); + auto dimN = b.create(outputSubView.getSource(), + b.create(1)); + // TODO: enable mask vector store + auto vec8xf16 = b.create( + VectorType::get({cpElementsPerThread}, elementType), cBlockSmemAlloc, + ValueRange{offsetX, offsetY}); + auto vecStore = b.create(vec8xf16, outputSubView, + ValueRange{offsetX, offsetY}); + vecStore->setAttr("alignment", IntegerAttr::get(b.getI32Type(), 16)); + b.setInsertionPointAfter(smemToGmemLoop); + b.create(); + (void)mlir::loopUnrollByFactor(smemToGmemLoop, blockMShape / cpRowsPerBlock); + genericOp->erase(); + + RewritePatternSet patterns(ctx); + addAllRegisteredCanonicalizationPatterns(patterns); + GreedyRewriteConfig config; + if (failed(applyPatternsAndFoldGreedily(parallelOp, std::move(patterns), + config))) { + return mlir::emitDefiniteFailure(target, + "greedy pattern applicatin failed"); + } + + // Delete any remain transfer_write op + Operation* transferWrite = nullptr; + parallelOp->walk([&](memref::AllocOp alloc) { + if (llvm::hasSingleElement(alloc->getUsers())) { + Operation* op = *(alloc->getUsers().begin()); + auto xWrite = dyn_cast_or_null(op); + if (xWrite) { + transferWrite = op; + } + } + }); + if (transferWrite != nullptr) transferWrite->erase(); + + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// DISCGPUSoftwarePipelineOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform_dialect::DISCGPUSoftwarePipeline::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + MLIRContext* ctx = target->getContext(); + IRRewriter rewriter(ctx); + Location loc = target->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + auto funcOp = cast(target); + scf::ForOp ctaKLoop; + target->walk([&](scf::ForOp forOp) { + auto reductionTy = forOp->getAttrOfType("loop-type"); + if (reductionTy && reductionTy.getValue().equals("cta-k-loop") && + isa(forOp->getParentOp())) { + ctaKLoop = forOp; + WalkResult::interrupt(); + } + }); + auto [diag, pipelined] = applyPipelining(ctaKLoop, getDepth(), true); + + if (diag.succeeded()) { + return DiagnosedSilenceableFailure::success(); + } + if (diag.isDefiniteFailure()) { + auto diag = emitDefiniteFailure("irreversible pipelining failure"); + return diag; + } + + return std::move(diag); +} + +//===----------------------------------------------------------------------===// +// DISCConvertNVGPUAsyncCpTONVVMAsyncCp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform_dialect::DISCConvertNVGPUAsyncCpTONVVMAsyncCp::applyToOne( + Operation* target, transform::ApplyToEachResultList& results, + transform::TransformState& state) { + MLIRContext* ctx = target->getContext(); + IRRewriter rewriter(ctx); + Location loc = target->getLoc(); + ImplicitLocOpBuilder b(loc, rewriter); + SmallVector waitGroups; + SmallVector commitGroups; + + target->walk([&](Operation* op) { + if (isa(op)) { + rewriter.setInsertionPoint(op); + auto commitGroup = + rewriter.create(op->getLoc()); + commitGroups.push_back(op); + } else if (isa(op)) { + rewriter.setInsertionPoint(op); + auto waitGroup = rewriter.create( + op->getLoc(), + cast(op).getNumGroups().value_or(0)); + rewriter.setInsertionPointAfter(waitGroup); + waitGroups.push_back(op); + } + }); + for (auto op : waitGroups) { + rewriter.eraseOp(op); + } + for (auto op : commitGroups) { + rewriter.eraseOp(op); + } + return DiagnosedSilenceableFailure::success(); +} + } // namespace transform_dialect void registerTransformDialectCommonExtension(DialectRegistry& registry) { diff --git a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.h b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.h index 5c6e45ec261..f177a9fa84f 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.h +++ b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.h @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" @@ -36,6 +37,14 @@ class CommonExtensions public: CommonExtensions(); }; + +/// Pipeline copy to shared memory for matmul op +std::tuple applyPipelining( + scf::ForOp forOp, int64_t depth, bool epiloguePeeling); + +LogicalResult optimizeSharedMemoryReadsAndWrites(Operation* parentOp, + Value memrefValue); + } // namespace transform_dialect } // namespace disc_ral } // namespace mlir diff --git a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.td b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.td index 697e083a672..6803877c122 100644 --- a/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.td +++ b/tao_compiler/mlir/disc/tools/disc-transform/TransformOps/TransformOpsExt.td @@ -1098,13 +1098,15 @@ def DISCSplitReductionSerialOp : // TODO: support mixed static-dynamic (see TileToForallOp). let arguments = (ins TransformHandleTypeInterface:$target, - DefaultValuedAttr:$tile_sizes); + DefaultValuedAttr:$tile_sizes, + OptionalAttr:$loop_type); let results = (outs TransformHandleTypeInterface:$for_op, TransformHandleTypeInterface:$splitted_op); let assemblyFormat = [{ $target `by` `tile_sizes` `=` $tile_sizes + oilist (`loop_type` `=` $loop_type) attr-dict `:` functional-type(operands, results) }]; @@ -1415,4 +1417,331 @@ def ApplyLoopIndependentCodeMotionOp : Op { + let description = [{ + For a matmul of 'C += A * B', padding along the dimension M and N. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$padding_values, + DefaultValuedAttr:$tile_sizes); + let results = (outs TransformHandleTypeInterface:$padding_dot); + + let assemblyFormat = [{ + $target + `padding_values` $padding_values + `tile_sizes` $tile_sizes + attr-dict + `:` functional-type(operands, results) + }]; + + let cppNamespace = "::mlir::disc_ral::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def DISCPaddingK : Op { + let description = [{ + For a matmul of 'C += A * B', padding along the dimension K. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$padding_values, + DefaultValuedAttr:$tile_sizes); + let results = (outs TransformHandleTypeInterface:$padding_dot); + + let assemblyFormat = [{ + $target + `padding_values` $padding_values + `tile_sizes` $tile_sizes + attr-dict + `:` functional-type(operands, results) + }]; + + let cppNamespace = "::mlir::disc_ral::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def DISCSwapAllocTensor : Op { + let description = [{ + Swaps bufferization.alloc_tensor with the copied vector.transfer_write op + when the destination of this write is an empty op + + Example: + + ``` + %empty = tensor.empty() + %val = vector.transfer_read + %write = vector.transfer_write %val, %empty[%c0, %c0] + %alloc = bufferization.alloc_tensor() copy(%write) + ... = ... alloc + + is transformed to: + + ``` + %alloc = bufferization.alloc_tensor() + %read = vector.transfer_read + %write = vector.transfer_write %read, %alloc + ... = ... %write + ``` + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type(operands, results) + }]; + + let cppNamespace = "::mlir::disc_ral::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def DISCExpandTransferRWToMemrefCopy : Op { + let description = [{ + Expand vetor::Transfer_Read/Write op to a sequence of linalg::FillOp and + Memref::Copy ops. + + Example: + + ``` + %alloc = bufferization.alloc_tensor() + %val = memref.subview ... + %read = vector.transfer_read ... %padding ... %val + vector.transfer_write %read, %alloc + + is transformed to: + + ``` + %alloc = bufferization.alloc_tensor() + linalg.fill %padding, %alloc + %val = memref.subview ... + memref.copy %val, %alloc + ``` + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type(operands, results) + }]; + + let cppNamespace = "::mlir::disc_ral::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def DISCMultiBuffering : Op, + TransformEachOpTrait, + TransformOpInterface]> { + let description = [{ + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$multi_buffering_factor); + let results = (outs); + + let assemblyFormat = [{ + $target + attr-dict + `by` `multi_buffering_factor` `=` $multi_buffering_factor + `:` functional-type(operands, results) + }]; + + let cppNamespace = "::mlir::disc_ral::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def DISCSwizzleShareMemoryOp : Op { + let description = [{ + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type(operands, results) + }]; + + let cppNamespace = "::mlir::disc_ral::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def DISCPackSharedMemoryAllocOp : Op, + TransformEachOpTrait, + TransformOpInterface]> { + let description = [{ + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type(operands, results) + }]; + + let cppNamespace = "::mlir::disc_ral::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def DISCMoveDataToRegister : Op { + let description = [{ + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DenseI64ArrayAttr:$block_mn_shape, + I64Attr:$smem_padding); + let results = (outs); + + let assemblyFormat = [{ + $target + attr-dict + `by` `block_mn_shape` `=` $block_mn_shape + `smem_padding` `=` $smem_padding + `:` functional-type(operands, results) + }]; + + let cppNamespace = "::mlir::disc_ral::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def DISCGPUSoftwarePipeline : Op { + let description = [{ + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$depth); + let results = (outs); + + let assemblyFormat = [{ + $target + attr-dict + `by` `depth` `=` $depth + `:` functional-type(operands, results) + }]; + + let cppNamespace = "::mlir::disc_ral::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + +def DISCConvertNVGPUAsyncCpTONVVMAsyncCp : Op { + let description = [{ + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target + attr-dict + `:` functional-type(operands, results) + }]; + + let cppNamespace = "::mlir::disc_ral::transform_dialect"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} +#endif // DISC_TRANSFORM_OPS_EXT \ No newline at end of file diff --git a/tao_compiler/mlir/disc/transforms/disc_lower_gpu_ops_to_nvvm_ops.cc b/tao_compiler/mlir/disc/transforms/disc_lower_gpu_ops_to_nvvm_ops.cc index 9bf3aa597b2..8847331e0e1 100644 --- a/tao_compiler/mlir/disc/transforms/disc_lower_gpu_ops_to_nvvm_ops.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lower_gpu_ops_to_nvvm_ops.cc @@ -43,6 +43,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Transforms/DialectConversion.h" @@ -62,6 +63,47 @@ namespace { /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" +/// Conversion vector.store with align attribute to llvm.store +class VectorStoreWithAlignToLLVMPattern + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult matchAndRewrite( + vector::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // Only 1-D vectors can be lowered to LLVM. + VectorType vectorTy = storeOp.getVectorType(); + if (vectorTy.getRank() > 1) return failure(); + auto alignAttr = storeOp->getAttrOfType("alignment"); + if (!alignAttr) return failure(); + unsigned align = alignAttr.getInt(); + + auto loc = storeOp->getLoc(); + MemRefType memRefTy = storeOp.getMemRefType(); + + // Resolve address. + auto vtype = cast( + this->typeConverter->convertType(storeOp.getVectorType())); + Value dataPtr = this->getStridedElementPtr(loc, memRefTy, adaptor.getBase(), + adaptor.getIndices(), rewriter); + // Casts a strided element pointer to a vector pointer. The vector pointer + // will be in the same address space as the incoming memref type. + Value ptr; + if ((*this->getTypeConverter()).useOpaquePointers()) { + ptr = dataPtr; + } else { + unsigned addressSpace = + *(*this->getTypeConverter()).getMemRefAddressSpace(memRefTy); + auto pType = LLVM::LLVMPointerType::get(vtype, addressSpace); + ptr = rewriter.create(loc, pType, dataPtr); + } + + rewriter.replaceOpWithNewOp( + storeOp, adaptor.getValueToStore(), ptr, align); + return success(); + } +}; + /// A pass that replaces all occurrences of GPU device operations with their /// corresponding NVVM equivalent. /// @@ -125,6 +167,8 @@ struct DiscLowerGpuOpsToNVVMOpsPass llvmPatterns.add( converter, /* PatternBenefit */ 3); llvmPatterns.add(converter); + llvmPatterns.add(converter, + /* PatternBenefit */ 3); arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); populateVectorToLLVMConversionPatterns(converter, llvmPatterns); diff --git a/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc index 95cad8156a1..120fde4ba78 100644 --- a/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc +++ b/tao_compiler/mlir/disc/transforms/disc_transform_schedule.cc @@ -394,11 +394,18 @@ transform_dialect::DISCPromoteDotOperandsOp buildPromoteDotOperandsOp( } transform_dialect::DISCSplitReductionSerialOp buildSplitReductionSerialOp( - OpBuilder& b, Location& loc, Value target, ArrayRef tileSizes) { + OpBuilder& b, Location& loc, Value target, ArrayRef tileSizes, + StringAttr loopType = nullptr) { SmallVector transformOpTypes(2, transform::AnyOpType::get(b.getContext())); - return b.create( - loc, transformOpTypes, target, tileSizes); + if (!loopType) { + return b.create( + loc, transformOpTypes, target, tileSizes, + StringAttr::get(b.getContext(), "")); + } else { + return b.create( + loc, transformOpTypes, target, tileSizes, loopType); + } } transform_dialect::DISCVectorToMMAConversionOp buildVectorToMMAConversionOp( diff --git a/tao_compiler/mlir/disc/transforms/revise_kernel_outlining.cc b/tao_compiler/mlir/disc/transforms/revise_kernel_outlining.cc index 4cd9dba69fb..6a47f3827d7 100644 --- a/tao_compiler/mlir/disc/transforms/revise_kernel_outlining.cc +++ b/tao_compiler/mlir/disc/transforms/revise_kernel_outlining.cc @@ -362,6 +362,13 @@ void convertWorkgroupBuffer(gpu::GPUFuncOp gpu_func_op, AllocOp alloc) { alloc.erase(); } +void convertPrivateBuffer(gpu::GPUFuncOp gpu_func_op, AllocOp alloc) { + auto memref_type = alloc.getResult().getType().cast(); + auto buffer = gpu_func_op.addPrivateAttribution(memref_type, alloc.getLoc()); + alloc.replaceAllUsesWith(buffer); + alloc.erase(); +} + /* This pass revises the kernel outlining: * * 1, For a MemRef resides in host memory, which always means that the MemRef @@ -418,7 +425,7 @@ class ReviseGpuKernelOutliningPass } } - // convert for shared buffer + // convert for shared/private buffer module.walk([&](gpu::LaunchFuncOp launch_func_op) { auto gpu_module = module.lookupSymbol( launch_func_op.getKernelModuleName()); @@ -428,12 +435,16 @@ class ReviseGpuKernelOutliningPass assert(gpu_func_op && "gpu_func_op is empty"); gpu_func_op.walk([&](AllocOp alloc) { auto memref_type = alloc.getResult().getType().cast(); - assert(memref_type.getMemorySpace() - .dyn_cast() - .getValue() == - gpu::GPUDialect::getWorkgroupAddressSpace() && - "unexpected alloc op in gpu_func_op"); - convertWorkgroupBuffer(gpu_func_op, alloc); + gpu::AddressSpace addressSpace = memref_type.getMemorySpace() + .dyn_cast() + .getValue(); + if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) { + convertWorkgroupBuffer(gpu_func_op, alloc); + } else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) { + convertPrivateBuffer(gpu_func_op, alloc); + } else { + llvm_unreachable("unexpected alloc op in gpu_func_op"); + } }); }); }