Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLVMGPUVectorDistribute] Add support for inter-subgroup multi_reduction #19596

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,9 @@ static int64_t getShuffleWidth(NestedLayoutAttr layout, int64_t dim) {
/// by doing a butterfly shuffle.
/// 3. Accumulator Reduce: Each thread reduces it's intermediate reduced
/// results with the accumulator it holds.
/// Currently, reduction across warps is not supported, but it would just add
/// another step, Warp Reduce, where threads do an atomic addition on a buffer.
/// 4. Subgroup reduce : each subgroup will store the partial reductions
/// to shared memory and will be reloaded into a layout where partial
/// reductions will be placed inside threads.
struct DistributeMultiReduction final
: OpDistributionPattern<vector::MultiDimReductionOp> {
using OpDistributionPattern::OpDistributionPattern;
Expand Down Expand Up @@ -460,7 +461,6 @@ struct DistributeMultiReduction final
}

Location loc = multiReduceOp.getLoc();

SmallVector<bool> reducedDims = multiReduceOp.getReductionMask();
int64_t rank = srcVector.getType().getRank();

Expand Down Expand Up @@ -492,47 +492,219 @@ struct DistributeMultiReduction final
assert(locallyReduced && "result should have been a vector");

// Flatten the locally reduced value.
VectorValue threadReduced;
VectorType shaped = locallyReduced.getType();
int64_t numElements = shaped.getNumElements();
SmallVector<int64_t> flatShape(1, numElements);
VectorType flatVecType = VectorType::get(flatShape, elemTy);
VectorValue flat =
rewriter.create<vector::ShapeCastOp>(loc, flatVecType, locallyReduced);
bool hasThreadReductions =
llvm::any_of(multiReduceOp.getReductionDims(), [&](int64_t rDim) {
return srcLayout.getThreadTile()[rDim] > 1;
});
if (hasThreadReductions) {
int64_t numElements = shaped.getNumElements();
SmallVector<int64_t> flatShape(1, numElements);
VectorType flatVecType = VectorType::get(flatShape, elemTy);
VectorValue flat = rewriter.create<vector::ShapeCastOp>(loc, flatVecType,
locallyReduced);

// Do inter-thread/warp reduce.
FailureOr<VectorValue> threadReducedFlat = doThreadReduction(
rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims);
if (failed(threadReducedFlat)) {
return failure();
}

// Do inter-thread/warp reduce.
FailureOr<VectorValue> threadReduced = doThreadReduction(
rewriter, srcLayout, flat, multiReduceOp.getKind(), reducedDims);
if (failed(threadReduced)) {
return failure();
// Do reduction against accumulator, which needs to be done after thread
// reduction.
threadReduced = rewriter.create<vector::ShapeCastOp>(
loc, shaped, threadReducedFlat.value());
} else {
threadReduced = locallyReduced;
}

// Do reduction against accumulator, which needs to be done after thread
// reduction.
VectorValue unflattened = rewriter.create<vector::ShapeCastOp>(
loc, shaped, threadReduced.value());

if (!accVector) {
// Broadcast the scalar (e.g., f32) to a vector type (e.g., vector<f32>)
// because the following implementation requires the operand to be a
// vector.
disAcc = rewriter.create<vector::BroadcastOp>(loc, shaped, disAcc);
}

Value accReduction = vector::makeArithReduction(
rewriter, loc, multiReduceOp.getKind(), unflattened, disAcc);
auto accReduced = dyn_cast<VectorValue>(accReduction);
if (!accReduced) {
return failure();
bool hasSubgroupReductions =
llvm::any_of(multiReduceOp.getReductionDims(), [&](int64_t rDim) {
return srcLayout.getSubgroupTile()[rDim] > 1;
});
// We can exit here if its just a subgroup reduction.
if (!hasSubgroupReductions) {
Value accReduction = vector::makeArithReduction(
rewriter, loc, multiReduceOp.getKind(), threadReduced, disAcc);
auto accReduced = dyn_cast<VectorValue>(accReduction);
if (!accReduced) {
return failure();
}
if (resVector) {
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
} else {
Value accReducedVal = rewriter.create<vector::ExtractOp>(
loc, accReduction, ArrayRef{int64_t(0)});
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal);
}
return success();
}

if (resVector) {
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReduced);
} else {
Value accReducedVal = rewriter.create<vector::ExtractOp>(
loc, accReduction, ArrayRef{int64_t(0)});
replaceOpWithDistributedValues(rewriter, multiReduceOp, accReducedVal);
// If there is reduction across subgroups
// then we need relayout the partial reductions
// to be within a subgroup for further reduction.

// Shapecast to re-insert reduced rank as unit dimensions.
SmallVector<int64_t> partialReducedDistributedShape =
srcLayout.getDistributedShape();
for (int64_t tileGroupIdx : llvm::seq<int64_t>(3)) {
int64_t tileGroupOffset = tileGroupIdx * rank;
for (int64_t rDim : multiReduceOp.getReductionDims()) {
partialReducedDistributedShape[tileGroupOffset + rDim] = 1;
}
}
VectorType partialReducedDistributedType = VectorType::get(
partialReducedDistributedShape, srcVector.getType().getElementType());
Value isoRankThreadReduced = rewriter.create<vector::ShapeCastOp>(
loc, partialReducedDistributedType, threadReduced);

SmallVector<int64_t> preDistrShape =
srcLayout.getUndistributedPackedShape();
SmallVector<int64_t> partialReductionShape =
llvm::to_vector(srcVector.getType().getShape());
for (int64_t rDim : multiReduceOp.getReductionDims()) {
// The first #rank elements will form the subgroup tile
// Here we replace the input shape with subgroup tile
// because every other tile is reduced except the subgroup
// tile.
partialReductionShape[rDim] = preDistrShape[rDim];
}
auto workgroupMemoryAddressSpace = Attribute(gpu::AddressSpaceAttr::get(
rewriter.getContext(), gpu::AddressSpace::Workgroup));
MemRefType allocType = MemRefType::get(
partialReductionShape, srcVector.getType().getElementType(),
AffineMap(), workgroupMemoryAddressSpace);
auto alloc = rewriter.create<memref::AllocOp>(loc, allocType);
VectorType unDistributedType = VectorType::get(
partialReductionShape, srcVector.getType().getElementType());
Value undistrWrite = rewriter.create<IREE::VectorExt::ToSIMDOp>(
loc, unDistributedType, isoRankThreadReduced);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> indices(unDistributedType.getRank(), c0);
SmallVector<bool> inBounds(unDistributedType.getRank(), true);
// Insert gpu.barrier to make sure previuos iteration
// of batch loop has fully read the subgroup partial
// reductions.
rewriter.create<gpu::BarrierOp>(multiReduceOp.getLoc());
auto write = rewriter.create<vector::TransferWriteOp>(
loc, undistrWrite, alloc, indices, inBounds);
// Set layouts signature for write.
// We need to set the layout on the srcVector/first operand.
auto unitAttr = UnitAttr::get(rewriter.getContext());
{
SmallVector<int64_t> subgroupTileLens =
llvm::to_vector(srcLayout.getSubgroupTile());
SmallVector<int64_t> batchTileLens =
llvm::to_vector(srcLayout.getBatchTile());
SmallVector<int64_t> outerTileLens =
llvm::to_vector(srcLayout.getOuterTile());
SmallVector<int64_t> threadTileLens =
llvm::to_vector(srcLayout.getThreadTile());
SmallVector<int64_t> elementTileLens =
llvm::to_vector(srcLayout.getElementTile());
SmallVector<int64_t> subgroupStrides =
llvm::to_vector(srcLayout.getSubgroupStrides());
SmallVector<int64_t> threadStrides =
llvm::to_vector(srcLayout.getThreadStrides());
// Replace the reduced tiles with unit dimension.
for (int64_t rDim : multiReduceOp.getReductionDims()) {
batchTileLens[rDim] = 1;
outerTileLens[rDim] = 1;
threadTileLens[rDim] = 1;
elementTileLens[rDim] = 1;
threadStrides[rDim] = 0;
}
auto interSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get(
multiReduceOp.getContext(), subgroupTileLens, batchTileLens,
outerTileLens, threadTileLens, elementTileLens, subgroupStrides,
threadStrides);
auto writeAttrs =
SmallVector<Attribute>(write->getNumOperands(), unitAttr);
writeAttrs[0] = interSubGroupLayout;
ArrayAttr writeOperandsAttr =
ArrayAttr::get(rewriter.getContext(), writeAttrs);
ArrayAttr writeResultsAttr = ArrayAttr::get(rewriter.getContext(), {});
setSignatureForRedistribution(rewriter, write.getOperation(),
writeOperandsAttr, writeResultsAttr);
}
// Insert gpu.barrier
rewriter.create<gpu::BarrierOp>(write.getLoc());
auto read = rewriter.create<vector::TransferReadOp>(loc, unDistributedType,
alloc, indices);
// Create new layout where subgroup dims are squashed to
// element tile
IREE::VectorExt::NestedLayoutAttr intraSubGroupLayout;
{
// We intentionally make the subgroup tile to be 1
SmallVector<int64_t> subgroupTileLens =
llvm::to_vector(srcLayout.getSubgroupTile());
SmallVector<int64_t> batchTileLens =
llvm::to_vector(srcLayout.getBatchTile());
SmallVector<int64_t> outerTileLens =
llvm::to_vector(srcLayout.getOuterTile());
SmallVector<int64_t> threadTileLens =
llvm::to_vector(srcLayout.getThreadTile());
SmallVector<int64_t> elementTileLens =
llvm::to_vector(srcLayout.getElementTile());
SmallVector<int64_t> subgroupStrides =
llvm::to_vector(srcLayout.getSubgroupStrides());
SmallVector<int64_t> threadStrides =
llvm::to_vector(srcLayout.getThreadStrides());
for (int64_t rDim : multiReduceOp.getReductionDims()) {
subgroupTileLens[rDim] = 1;
batchTileLens[rDim] = 1;
outerTileLens[rDim] = 1;
threadTileLens[rDim] = 1;
// the partial reductions that was across subgroups will
// will be loaded as element tile. We can revisit if this
// need to be something else such as thread tile.
elementTileLens[rDim] = srcLayout.getSubgroupTile()[rDim];
subgroupStrides[rDim] = 0;
threadStrides[rDim] = 0;
}
intraSubGroupLayout = IREE::VectorExt::NestedLayoutAttr::get(
multiReduceOp.getContext(), subgroupTileLens, batchTileLens,
outerTileLens, threadTileLens, elementTileLens, subgroupStrides,
threadStrides);
auto readAttrs = SmallVector<Attribute>(read->getNumOperands(), unitAttr);
ArrayAttr readOperandsAttr =
ArrayAttr::get(rewriter.getContext(), readAttrs);
ArrayAttr readResultsAttr =
ArrayAttr::get(rewriter.getContext(), {intraSubGroupLayout});
setSignatureForRedistribution(rewriter, read.getOperation(),
readOperandsAttr, readResultsAttr);
}

// A newly created reduction to complete the reduction
// that reduces the data that was otherwise was on
// different subgroups.
auto secondReduction = rewriter.create<vector::MultiDimReductionOp>(
loc, read, acc, reducedDims, multiReduceOp.getKind());
{
auto reduceAttrs =
SmallVector<Attribute>(secondReduction->getNumOperands(), unitAttr);
reduceAttrs[0] = intraSubGroupLayout;
ArrayAttr reduceResultsAttr =
ArrayAttr::get(rewriter.getContext(), {unitAttr});
if (auto dstLayout =
dyn_cast_or_null<NestedLayoutAttr>(signature[resVector])) {
reduceAttrs[1] = dstLayout;
reduceResultsAttr = ArrayAttr::get(rewriter.getContext(), {dstLayout});
}
ArrayAttr reduceOperandsAttr =
ArrayAttr::get(rewriter.getContext(), reduceAttrs);
setSignatureForRedistribution(rewriter, secondReduction.getOperation(),
reduceOperandsAttr, reduceResultsAttr);
}
rewriter.replaceOp(multiReduceOp, {secondReduction.getResult()});
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ struct VectorDistributionListener : public RewriterBase::Listener {
void notifyOperationModified(Operation *op) override {
if (op->hasAttr(kVectorLayoutRedistributeAttrName) &&
op->hasAttrOfType<ArrayAttr>(kVectorLayoutFetcherStorageAttrName)) {
op->removeAttr(kVectorLayoutRedistributeAttrName);
toBeDistributed.push_back(op);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ iree_lit_test_suite(
"gpu_lower_to_ukernels.mlir",
"gpu_nested_layout_contract_amdgpu.mlir",
"gpu_nested_layout_vector_distribution.mlir",
"gpu_nested_layout_vector_distribution_multi_reduce.mlir",
"gpu_nested_layout_vector_distribution_step.mlir",
"gpu_pad_operands.mlir",
"gpu_pipeline.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_lit_test_suite(
"gpu_lower_to_ukernels.mlir"
"gpu_nested_layout_contract_amdgpu.mlir"
"gpu_nested_layout_vector_distribution.mlir"
"gpu_nested_layout_vector_distribution_multi_reduce.mlir"
"gpu_nested_layout_vector_distribution_step.mlir"
"gpu_pack_to_instrinsics.mlir"
"gpu_pad_operands.mlir"
Expand Down
Loading
Loading