From 993249f8f02b368bad165688715738c83646bae8 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Fri, 3 Jan 2025 03:55:15 -0800 Subject: [PATCH] [LLVMGPUVectorDistribute] Add support for inter-subgroup multi_reduction This commit adds support for distribute multi_reductions where the reduction dimension(s) is/are distributed across subgroups. We perform the existing reduction distribution, however, we are left with partial reductions accross subgroups. Thereafter, we insert tranfer_write / transfer_read to shared memory to achieve a layout change where we re-distribute reduction subgroup tiles into element tile. Finally, we do another multi_reduction to complete the reduction. Signed-off-by: Manupa Karunaratne --- .../GPUNestedLayoutDistributionPatterns.cpp | 232 +++++++++++++++--- .../Common/GPU/GPUVectorDistribution.cpp | 1 + .../Codegen/Common/GPU/test/BUILD.bazel | 1 + .../Codegen/Common/GPU/test/CMakeLists.txt | 1 + ...gpu_nested_layout_vector_distribution.mlir | 118 --------- ...yout_vector_distribution_multi_reduce.mlir | 180 ++++++++++++++ 6 files changed, 385 insertions(+), 148 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp index 0b4d812def17..d129e9ad2d5d 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp @@ -414,8 +414,9 @@ static int64_t getShuffleWidth(NestedLayoutAttr layout, int64_t dim) { /// by doing a butterfly shuffle. /// 3. Accumulator Reduce: Each thread reduces it's intermediate reduced /// results with the accumulator it holds. -/// Currently, reduction across warps is not supported, but it would just add -/// another step, Warp Reduce, where threads do an atomic addition on a buffer. +/// 4. Subgroup reduce : each subgroup will store the partial reductions +/// to shared memory and will be reloaded into a layout where partial +/// reductions will be placed inside threads. struct DistributeMultiReduction final : OpDistributionPattern { using OpDistributionPattern::OpDistributionPattern; @@ -460,7 +461,6 @@ struct DistributeMultiReduction final } Location loc = multiReduceOp.getLoc(); - SmallVector reducedDims = multiReduceOp.getReductionMask(); int64_t rank = srcVector.getType().getRank(); @@ -492,25 +492,34 @@ struct DistributeMultiReduction final assert(locallyReduced && "result should have been a vector"); // Flatten the locally reduced value. + VectorValue threadReduced; VectorType shaped = locallyReduced.getType(); - int64_t numElements = shaped.getNumElements(); - SmallVector flatShape(1, numElements); - VectorType flatVecType = VectorType::get(flatShape, elemTy); - VectorValue flat = - rewriter.create(loc, flatVecType, locallyReduced); + bool hasThreadReductions = + llvm::any_of(multiReduceOp.getReductionDims(), [&](int64_t rDim) { + return srcLayout.getThreadTile()[rDim] > 1; + }); + if (hasThreadReductions) { + int64_t numElements = shaped.getNumElements(); + SmallVector flatShape(1, numElements); + VectorType flatVecType = VectorType::get(flatShape, elemTy); + VectorValue flat = rewriter.create(loc, flatVecType, + locallyReduced); + + // Do inter-thread/warp reduce. + FailureOr threadReducedFlat = doThreadReduction( + rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims); + if (failed(threadReducedFlat)) { + return failure(); + } - // Do inter-thread/warp reduce. - FailureOr threadReduced = doThreadReduction( - rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims); - if (failed(threadReduced)) { - return failure(); + // Do reduction against accumulator, which needs to be done after thread + // reduction. + threadReduced = rewriter.create( + loc, shaped, threadReducedFlat.value()); + } else { + threadReduced = locallyReduced; } - // Do reduction against accumulator, which needs to be done after thread - // reduction. - VectorValue unflattened = rewriter.create( - loc, shaped, threadReduced.value()); - if (!accVector) { // Broadcast the scalar (e.g., f32) to a vector type (e.g., vector) // because the following implementation requires the operand to be a @@ -518,21 +527,184 @@ struct DistributeMultiReduction final disAcc = rewriter.create(loc, shaped, disAcc); } - Value accReduction = vector::makeArithReduction( - rewriter, loc, multiReduceOp.getKind(), unflattened, disAcc); - auto accReduced = dyn_cast(accReduction); - if (!accReduced) { - return failure(); + bool hasSubgroupReductions = + llvm::any_of(multiReduceOp.getReductionDims(), [&](int64_t rDim) { + return srcLayout.getSubgroupTile()[rDim] > 1; + }); + // We can exit here if its just a subgroup reduction. + if (!hasSubgroupReductions) { + Value accReduction = vector::makeArithReduction( + rewriter, loc, multiReduceOp.getKind(), threadReduced, disAcc); + auto accReduced = dyn_cast(accReduction); + if (!accReduced) { + return failure(); + } + if (resVector) { + replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced); + } else { + Value accReducedVal = rewriter.create( + loc, accReduction, ArrayRef{int64_t(0)}); + replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal); + } + return success(); } - - if (resVector) { - replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced); - } else { - Value accReducedVal = rewriter.create( - loc, accReduction, ArrayRef{int64_t(0)}); - replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal); + // If there is reduction across subgroups + // then we need relayout the partial reductions + // to be within a subgroup for further reduction. + + // Shapecast to re-insert reduced rank as unit dimensions. + SmallVector partialReducedDistributedShape = + srcLayout.getDistributedShape(); + for (int64_t tileGroupIdx : llvm::seq(3)) { + int64_t tileGroupOffset = tileGroupIdx * rank; + for (int64_t rDim : multiReduceOp.getReductionDims()) { + partialReducedDistributedShape[tileGroupOffset + rDim] = 1; + } + } + VectorType partialReducedDistributedType = VectorType::get( + partialReducedDistributedShape, srcVector.getType().getElementType()); + Value isoRankThreadReduced = rewriter.create( + loc, partialReducedDistributedType, threadReduced); + + SmallVector preDistrShape = + srcLayout.getUndistributedPackedShape(); + SmallVector partialReductionShape = + llvm::to_vector(srcVector.getType().getShape()); + for (int64_t rDim : multiReduceOp.getReductionDims()) { + // The first #rank elements will form the subgroup tile + // Here we replace the input shape with subgroup tile + // because every other tile is reduced except the subgroup + // tile. + partialReductionShape[rDim] = preDistrShape[rDim]; + } + auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get( + rewriter.getContext(), gpu::AddressSpace::Workgroup)); + MemRefType allocType = MemRefType::get( + partialReductionShape, srcVector.getType().getElementType(), + AffineMap(), workgroupMemoryAddressSpace); + auto alloc = rewriter.create(loc, allocType); + VectorType unDistributedType = VectorType::get( + partialReductionShape, srcVector.getType().getElementType()); + Value undistrWrite = rewriter.create( + loc, unDistributedType, isoRankThreadReduced); + Value c0 = rewriter.create(loc, 0); + SmallVector indices(unDistributedType.getRank(), c0); + SmallVector inBounds(unDistributedType.getRank(), true); + // Insert gpu.barrier to make sure previuos iteration + // of batch loop has fully read the subgroup partial + // reductions. + rewriter.create(multiReduceOp.getLoc()); + auto write = rewriter.create( + loc, undistrWrite, alloc, indices, inBounds); + // Set layouts signature for write. + // We need to set the layout on the srcVector/first operand. + auto unitAttr = UnitAttr::get(rewriter.getContext()); + { + SmallVector subgroupTileLens = + llvm::to_vector(srcLayout.getSubgroupTile()); + SmallVector batchTileLens = + llvm::to_vector(srcLayout.getBatchTile()); + SmallVector outerTileLens = + llvm::to_vector(srcLayout.getOuterTile()); + SmallVector threadTileLens = + llvm::to_vector(srcLayout.getThreadTile()); + SmallVector elementTileLens = + llvm::to_vector(srcLayout.getElementTile()); + SmallVector subgroupStrides = + llvm::to_vector(srcLayout.getSubgroupStrides()); + SmallVector threadStrides = + llvm::to_vector(srcLayout.getThreadStrides()); + // Replace the reduced tiles with unit dimension. + for (int64_t rDim : multiReduceOp.getReductionDims()) { + batchTileLens[rDim] = 1; + outerTileLens[rDim] = 1; + threadTileLens[rDim] = 1; + elementTileLens[rDim] = 1; + threadStrides[rDim] = 0; + } + auto interSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get( + multiReduceOp.getContext(), subgroupTileLens, batchTileLens, + outerTileLens, threadTileLens, elementTileLens, subgroupStrides, + threadStrides); + auto writeAttrs = + SmallVector(write->getNumOperands(), unitAttr); + writeAttrs[0] = interSubGroupLayout; + ArrayAttr writeOperandsAttr = + ArrayAttr::get(rewriter.getContext(), writeAttrs); + ArrayAttr writeResultsAttr = ArrayAttr::get(rewriter.getContext(), {}); + setSignatureForRedistribution(rewriter, write.getOperation(), + writeOperandsAttr, writeResultsAttr); + } + // Insert gpu.barrier + rewriter.create(write.getLoc()); + auto read = rewriter.create(loc, unDistributedType, + alloc, indices); + // Create new layout where subgroup dims are squashed to + // element tile + IREE::VectorExt::NestedLayoutAttr intraSubGroupLayout; + { + // We intentionally make the subgroup tile to be 1 + SmallVector subgroupTileLens = + llvm::to_vector(srcLayout.getSubgroupTile()); + SmallVector batchTileLens = + llvm::to_vector(srcLayout.getBatchTile()); + SmallVector outerTileLens = + llvm::to_vector(srcLayout.getOuterTile()); + SmallVector threadTileLens = + llvm::to_vector(srcLayout.getThreadTile()); + SmallVector elementTileLens = + llvm::to_vector(srcLayout.getElementTile()); + SmallVector subgroupStrides = + llvm::to_vector(srcLayout.getSubgroupStrides()); + SmallVector threadStrides = + llvm::to_vector(srcLayout.getThreadStrides()); + for (int64_t rDim : multiReduceOp.getReductionDims()) { + subgroupTileLens[rDim] = 1; + batchTileLens[rDim] = 1; + outerTileLens[rDim] = 1; + threadTileLens[rDim] = 1; + // the partial reductions that was across subgroups will + // will be loaded as element tile. We can revisit if this + // need to be something else such as thread tile. + elementTileLens[rDim] = srcLayout.getSubgroupTile()[rDim]; + subgroupStrides[rDim] = 0; + threadStrides[rDim] = 0; + } + intraSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get( + multiReduceOp.getContext(), subgroupTileLens, batchTileLens, + outerTileLens, threadTileLens, elementTileLens, subgroupStrides, + threadStrides); + auto readAttrs = SmallVector(read->getNumOperands(), unitAttr); + ArrayAttr readOperandsAttr = + ArrayAttr::get(rewriter.getContext(), readAttrs); + ArrayAttr readResultsAttr = + ArrayAttr::get(rewriter.getContext(), {intraSubGroupLayout}); + setSignatureForRedistribution(rewriter, read.getOperation(), + readOperandsAttr, readResultsAttr); } + // A newly created reduction to complete the reduction + // that reduces the data that was otherwise was on + // different subgroups. + auto secondReduction = rewriter.create( + loc, read, acc, reducedDims, multiReduceOp.getKind()); + { + auto reduceAttrs = + SmallVector(secondReduction->getNumOperands(), unitAttr); + reduceAttrs[0] = intraSubGroupLayout; + ArrayAttr reduceResultsAttr = + ArrayAttr::get(rewriter.getContext(), {unitAttr}); + if (auto dstLayout = + dyn_cast_or_null(signature[resVector])) { + reduceAttrs[1] = dstLayout; + reduceResultsAttr = ArrayAttr::get(rewriter.getContext(), {dstLayout}); + } + ArrayAttr reduceOperandsAttr = + ArrayAttr::get(rewriter.getContext(), reduceAttrs); + setSignatureForRedistribution(rewriter, secondReduction.getOperation(), + reduceOperandsAttr, reduceResultsAttr); + } + rewriter.replaceOp(multiReduceOp, {secondReduction.getResult()}); return success(); } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp index 2d9bfefd9248..f36a06f33af5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp @@ -217,6 +217,7 @@ struct VectorDistributionListener : public RewriterBase::Listener { void notifyOperationModified(Operation *op) override { if (op->hasAttr(kVectorLayoutRedistributeAttrName) && op->hasAttrOfType(kVectorLayoutFetcherStorageAttrName)) { + op->removeAttr(kVectorLayoutRedistributeAttrName); toBeDistributed.push_back(op); } } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel index 2f3b092d5676..1bddbd79eb8f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel @@ -34,6 +34,7 @@ iree_lit_test_suite( "gpu_lower_to_ukernels.mlir", "gpu_nested_layout_contract_amdgpu.mlir", "gpu_nested_layout_vector_distribution.mlir", + "gpu_nested_layout_vector_distribution_multi_reduce.mlir", "gpu_nested_layout_vector_distribution_step.mlir", "gpu_pad_operands.mlir", "gpu_pipeline.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt index 50be391693cc..512645298fd5 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt @@ -29,6 +29,7 @@ iree_lit_test_suite( "gpu_lower_to_ukernels.mlir" "gpu_nested_layout_contract_amdgpu.mlir" "gpu_nested_layout_vector_distribution.mlir" + "gpu_nested_layout_vector_distribution_multi_reduce.mlir" "gpu_nested_layout_vector_distribution_step.mlir" "gpu_pack_to_instrinsics.mlir" "gpu_pad_operands.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir index eb2a853c23ad..3e9f15513df2 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution.mlir @@ -1035,124 +1035,6 @@ builtin.module attributes { transform.with_named_sequence } { // ----- -#nested = #iree_vector_ext.nested_layout< - subgroup_tile = [1, 1], - // We are reducing along dim=1, so each thread will reduce - // 2 batches x 4 elements = 8 elements. - batch_tile = [2, 2], - outer_tile = [1, 1], - // We are reducing on dim=1, which is distributed over 4 threads. Based - // on the subgroup basis and thread order, the shuffle offset is 16. - thread_tile = [16, 4], - element_tile = [1, 4], - - subgroup_strides = [1, 1], - thread_strides = [1, 16] -> - -func.func @mfma_16x16x16_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> { - %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32> - %0 = vector.multi_reduction , %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32> - return %0 : vector<32xf32> -} - -builtin.module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { - %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op - transform.yield - } -} - -// CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1 -// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32> -// CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32> -// CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32> -// Local reduction -// CHECK: vector.multi_reduction , %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32> -// Global reduction -// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 -// Accumulator reduction -// CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32> -// CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32> - -// ----- - -#nested = #iree_vector_ext.nested_layout< - subgroup_tile = [1, 1], - // We are reducing along dim=1, so each thread will reduce - // 4 batches x 4 elements = 16 elements. - batch_tile = [1, 4], - outer_tile = [1, 1], - // We are reducing on dim=1, which is distributed over 2 threads. Based - // on the subgroup basis and thread order, the shuffle offset is 32. - thread_tile = [32, 2], - element_tile = [1, 4], - - subgroup_strides = [1, 1], - thread_strides = [1, 32] -> - -func.func @mfma_32x32x8_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> { - %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32> - %0 = vector.multi_reduction , %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32> - return %0 : vector<32xf32> -} - -builtin.module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { - %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op - transform.yield - } -} - -// CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1 -// Local reduction -// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32> -// Global reduction -// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32 -// Accumulator reduction -// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32> - -// ----- - -#nested = #iree_vector_ext.nested_layout< - subgroup_tile = [1, 1], - batch_tile = [2, 2], - outer_tile = [1, 1], - thread_tile = [16, 4], - element_tile = [1, 4], - - subgroup_strides = [1, 1], - thread_strides = [1, 16] -> - -func.func @mfma_16x16x16_out_reduced_alldims(%arg0: vector<32x32xf32>, %arg1: f32) -> f32 { - %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32> - %0 = vector.multi_reduction , %arg0l, %arg1 [0, 1] : vector<32x32xf32> to f32 - return %0 : f32 -} - -builtin.module attributes { transform.with_named_sequence } { - transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { - %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op - transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op - transform.yield - } -} - -// CHECK-LABEL: func @mfma_16x16x16_out_reduced_alldims -// Local reduction -// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32 -// Global reduction -// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 16) : (f32) -> f32 -// CHECK-NEXT: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 -// Accumulator reduction -// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1xf32> - -// ----- - #layout = #iree_vector_ext.nested_layout< subgroup_tile = [1, 1], batch_tile = [2, 2], diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir new file mode 100644 index 000000000000..2559f1042f90 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_nested_layout_vector_distribution_multi_reduce.mlir @@ -0,0 +1,180 @@ +// RUN: iree-opt --iree-transform-dialect-interpreter --split-input-file --canonicalize -mlir-print-local-scope --cse %s | FileCheck %s + +#nested = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + // We are reducing along dim=1, so each thread will reduce + // 2 batches x 4 elements = 8 elements. + batch_tile = [2, 2], + outer_tile = [1, 1], + // We are reducing on dim=1, which is distributed over 4 threads. Based + // on the subgroup basis and thread order, the shuffle offset is 16. + thread_tile = [16, 4], + element_tile = [1, 4], + + subgroup_strides = [1, 1], + thread_strides = [1, 16] +> + +func.func @mfma_16x16x16_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> { + %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32> + %0 = vector.multi_reduction , %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32> + return %0 : vector<32xf32> +} + +builtin.module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @mfma_16x16x16_out_reduced_dim1 +// CHECK-DAG: %[[IDENTITY:.*]] = arith.constant dense<0xFF800000> : vector<2x1x1xf32> +// CHECK-DAG: %[[DARG0:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32x32xf32> -> vector<2x2x1x1x1x4xf32> +// CHECK-DAG: %[[DARG1:.*]] = iree_vector_ext.to_simt %{{.*}} : vector<32xf32> -> vector<2x1x1xf32> +// Local reduction +// CHECK: vector.multi_reduction , %[[DARG0]], %[[IDENTITY]] [1, 3, 5] : vector<2x2x1x1x1x4xf32> to vector<2x1x1xf32> +// Global reduction +// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 +// Accumulator reduction +// CHECK: %[[ACC_REDUC:.+]] = arith.maximumf %{{.*}}, %[[DARG1]] : vector<2x1x1xf32> +// CHECK: iree_vector_ext.to_simd %[[ACC_REDUC]] : vector<2x1x1xf32> -> vector<32xf32> + +// ----- + +#nested = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + // We are reducing along dim=1, so each thread will reduce + // 4 batches x 4 elements = 16 elements. + batch_tile = [1, 4], + outer_tile = [1, 1], + // We are reducing on dim=1, which is distributed over 2 threads. Based + // on the subgroup basis and thread order, the shuffle offset is 32. + thread_tile = [32, 2], + element_tile = [1, 4], + + subgroup_strides = [1, 1], + thread_strides = [1, 32] +> + +func.func @mfma_32x32x8_out_reduced_dim1(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> { + %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32> + %0 = vector.multi_reduction , %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32> + return %0 : vector<32xf32> +} + +builtin.module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @mfma_32x32x8_out_reduced_dim1 +// Local reduction +// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3, 5] : vector<1x4x1x1x1x4xf32> to vector<1x1x1xf32> +// Global reduction +// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 2, stride = 32) : (f32) -> f32 +// Accumulator reduction +// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1x1x1xf32> + +// ----- + +#nested = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [2, 2], + outer_tile = [1, 1], + thread_tile = [16, 4], + element_tile = [1, 4], + + subgroup_strides = [1, 1], + thread_strides = [1, 16] +> + +func.func @mfma_16x16x16_out_reduced_alldims(%arg0: vector<32x32xf32>, %arg1: f32) -> f32 { + %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32> + %0 = vector.multi_reduction , %arg0l, %arg1 [0, 1] : vector<32x32xf32> to f32 + return %0 : f32 +} + +builtin.module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @mfma_16x16x16_out_reduced_alldims +// Local reduction +// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [0, 1, 2, 3, 4, 5] : vector<2x2x1x1x1x4xf32> to f32 +// Global reduction +// CHECK: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 16) : (f32) -> f32 +// CHECK-NEXT: gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 +// Accumulator reduction +// CHECK: arith.maximumf %{{.*}}, %{{.*}} : vector<1xf32> + +// ----- + +#nested = #iree_vector_ext.nested_layout< + // There will two partial reductions across + // two subgroups. + subgroup_tile = [1, 2], + // We are reducing along dim=1, so each thread will reduce + // 1 batches x 4 elements = 4 elements. + batch_tile = [2, 1], + outer_tile = [1, 1], + // We are reducing on dim=1, which is distributed over 4 threads. Based + // on the subgroup basis and thread order, the shuffle offset is 16. + thread_tile = [16, 4], + element_tile = [1, 4], + + subgroup_strides = [2, 1], + thread_strides = [1, 16] +> + +func.func @inter_subgroup_reduction(%arg0: vector<32x32xf32>, %arg1: vector<32xf32>) -> vector<32xf32> { + %arg0l = iree_vector_ext.to_layout %arg0 to layout(#nested) : vector<32x32xf32> + %0 = vector.multi_reduction , %arg0l, %arg1 [1] : vector<32x32xf32> to vector<32xf32> + return %0 : vector<32xf32> +} + +builtin.module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%variant_op: !transform.any_op {transform.readonly}) { + %top_level_func = transform.structured.match ops{["func.func"]} in %variant_op : (!transform.any_op) -> !transform.any_op + transform.iree.test_gpu_vector_distribution %top_level_func : !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @inter_subgroup_reduction +// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<2x1x1x1x1x2xf32> +// CHECK-DAG: %[[CST1:.+]] = arith.constant dense<0.000000e+00> : vector<2xf32> +// Local reduction +// CHECK: vector.multi_reduction , %{{.*}}, %{{.*}} [1, 3, 5] : vector<2x1x1x1x1x4xf32> to vector<2x1x1xf32> +// Thread reduction +// CHECK: %[[THREAD_RED0:.+]] = gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 +// CHECK: %[[THREAD_RED1:.+]] = vector.insert %[[THREAD_RED0]], %cst_1 [0] : f32 into vector<2xf32> +// CHECK: %[[THREAD_RED2:.+]] = gpu.subgroup_reduce maximumf %{{.*}} cluster(size = 4, stride = 16) : (f32) -> f32 +// CHECK: %[[THREAD_RED3:.+]] = vector.insert %[[THREAD_RED2]], %[[THREAD_RED1]] [1] : f32 into vector<2xf32> +// CHECK: %[[THREAD_RED4:.+]] = vector.shape_cast %[[THREAD_RED3]] : vector<2xf32> to vector<2x1x1xf32> +// Subgroup reduction +// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<32x2xf32, #gpu.address_space> +// CHECK: gpu.barrier +// CHECK-DAG: %[[TIDX0:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16)>()[%thread_id_x] +// CHECK-DAG: %[[TIDX1:.+]] = affine.apply affine_map<()[s0] -> (s0 mod 16 + 16)>()[%thread_id_x] +// CHECK-DAG: %[[SGIDX:.+]] = affine.apply affine_map<()[s0] -> ((s0 floordiv 64) mod 2)>()[%thread_id_x] +// CHECK-DAG: %[[EXTRACT0:.+]] = vector.extract %[[THREAD_RED4]][0] : vector<1x1xf32> from vector<2x1x1xf32> +// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[THREAD_RED4]][1] : vector<1x1xf32> from vector<2x1x1xf32> +// CHECK-DAG: vector.transfer_write %[[EXTRACT0]], %[[ALLOC]][%[[TIDX0]], %[[SGIDX]]] +// CHECK-DAG: vector.transfer_write %[[EXTRACT1]], %[[ALLOC]][%[[TIDX1]], %[[SGIDX]]] +// CHECK: gpu.barrier +// CHECK-DAG: %[[READ0:.+]] = vector.transfer_read %alloc[%[[TIDX0]], %c0], {{.*}} {in_bounds = [false, true]} : memref<32x2xf32, #gpu.address_space>, vector<1x2xf32> +// CHECK-DAG: %[[GATHER0:.+]] = vector.insert_strided_slice %[[READ0]], %[[CST]] {offsets = [0, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x2xf32> into vector<2x1x1x1x1x2xf32> +// CHECK-DAG: %[[READ1:.+]] = vector.transfer_read %alloc[%[[TIDX1]], %c0], %cst_0 {in_bounds = [false, true]} : memref<32x2xf32, #gpu.address_space>, vector<1x2xf32> +// CHECK-DAG: %[[GATHER1:.+]] = vector.insert_strided_slice %[[READ1]], %[[GATHER0]] {offsets = [1, 0, 0, 0, 0, 0], strides = [1, 1]} : vector<1x2xf32> into vector<2x1x1x1x1x2xf32> +// CHECK-DAG: %[[ACC:.+]] = iree_vector_ext.to_simt %arg1 : vector<32xf32> -> vector<2x1x1xf32> +// CHECK-DAG: %[[SGRED:.+]] = vector.multi_reduction , %[[GATHER1]], {{.*}} [1, 3, 5] : vector<2x1x1x1x1x2xf32> to vector<2x1x1xf32> +// CHECK-DAG: arith.maximumf %[[SGRED]], %[[ACC]] : vector<2x1x1xf32>