Skip to content

Commit

Permalink
[NFC][Vectorization] Refactor vector size inference out of the pass (#…
Browse files Browse the repository at this point in the history
…19768)

Currently, the vector length inference lives inside the generic
vectorization pass. However, we need to infer vector lengths when
setting layouts for LLVMGPUVectorDistribute. This currently happens
prior to generic vectorization.

Therefore, this commit refactors inferSizesfromIR API into codegen utils
to be able to generally use it where its needed.

Signed-off-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
manupak authored Jan 23, 2025
1 parent d6b2b0d commit 6aedfd3
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 220 deletions.
221 changes: 1 addition & 220 deletions compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Common/TileSizeSelection.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
Expand All @@ -28,226 +29,6 @@ namespace mlir::iree_compiler {

namespace {

struct VectorizationTileSizes {
SmallVector<int64_t> destShape;
SmallVector<int64_t> vectorSizes;
SmallVector<bool> vectorScalableFlags;
};

/// Returns a VectorizationTileSizes which contains the inferred bounded result
/// shape and vector input sizes. This is useful to infer the sizes from a
/// chain.
static std::optional<VectorizationTileSizes> inferSizesFromIR(Value val);

/// Tries to infer the vector sizes from an IR using ValueBounds analysis. If
/// `opResult` is provided, it stores the bounded result shapes to destShape.
/// Returns std::nullopt if vector sizes can't be inferred.
static std::optional<VectorizationTileSizes>
inferSizesFromIR(linalg::LinalgOp linalgOp, std::optional<OpResult> opResult) {
LLVM_DEBUG({
VEC_DBGS() << "Inferring sizes for:\n" << linalgOp;
if (opResult) {
VEC_DBGS() << " with OpResult.resultNumber="
<< opResult->getResultNumber();
}
VEC_DBGS() << '\n';
});

std::optional<vector::VscaleRange> vscaleRange;
if (!opResult) {
// Note: Inferring scalable sizes is not supported is `opResult` is set
// (which is used to compute sizes for tensor.pack/unpack).
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(linalgOp);
vscaleRange = getDefaultVscaleRange(targetAttr);
}

VectorizationTileSizes result;
unsigned numDims = linalgOp.getNumLoops();
for (int dim = 0; dim < numDims; ++dim) {
// Map dimension `dim` to an operand dimension that we will use to
// traverse the U-D chain to get `dim` vector size information.
SmallVector<std::pair<Value, unsigned>> operandDimPairs;
linalgOp.mapIterationSpaceDimToAllOperandDims(dim, operandDimPairs);
if (operandDimPairs.empty()) {
return std::nullopt;
}

Value firstOperand = operandDimPairs[0].first;
unsigned firstOperandDim = operandDimPairs[0].second;

// Trivial case: `dim` size is available in the operand type.
int64_t dimSize = llvm::cast<ShapedType>(firstOperand.getType())
.getShape()[firstOperandDim];
bool dimScalable = false;
if (!ShapedType::isDynamic(dimSize)) {
result.vectorSizes.push_back(dimSize);
result.vectorScalableFlags.push_back(dimScalable);
LLVM_DEBUG(VEC_DBGS() << "Inferred iteration size '" << dimSize
<< "' for dimension '" << dim << "'\n");
continue;
}

// Use ValueBounds analysis to infer `dim` size upper bound.
FailureOr<DimBoundSize> maybeDimBound;
for (auto operandDimPair : operandDimPairs) {
Value operand = operandDimPair.first;
unsigned operandDim = operandDimPair.second;
maybeDimBound = computeDimUpperBound(operand, operandDim, vscaleRange,
RoundUpVscaleMultiple::Yes);
if (succeeded(maybeDimBound)) {
break;
}
}

if (failed(maybeDimBound)) {
return std::nullopt;
}

dimSize = maybeDimBound->baseSize;
dimScalable = maybeDimBound->scalable;
result.vectorSizes.push_back(dimSize);
result.vectorScalableFlags.push_back(dimScalable);

LLVM_DEBUG(VEC_DBGS() << "Inferred iteration size '" << dimSize
<< (dimScalable ? " x vscale" : "")
<< "' for dimension '" << dim << "'\n");
}

if (opResult) {
assert(!llvm::is_contained(result.vectorScalableFlags, true) &&
"inferring scalable bounds with `opResult` not supported!");
result.destShape = linalgOp.getIndexingMapMatchingResult(opResult.value())
.compose(result.vectorSizes);
}

return result;
}

/// Returns the result sizes and vector input sizes of the tensor.pack op. The
/// inferred bounding size is returned if it is dynamic shape. Returns
/// std::nullopt if the shape inference failed.
static std::optional<VectorizationTileSizes>
inferSizesFromIR(tensor::PackOp op) {
LLVM_DEBUG(VEC_DBGS() << "Inferring dest sizes for:\n" << op << "\n");

if (llvm::any_of(op.getInnerTiles(), [](OpFoldResult v) {
return !getConstantIntValue(v).has_value();
})) {
LLVM_DEBUG(VEC_DBGS() << "skip, because inner_tiles are not all constant");
return std::nullopt;
}

VectorizationTileSizes result;
std::optional<VectorizationTileSizes> inferred =
inferSizesFromIR(op.getSource());
if (!inferred) {
return std::nullopt;
}
result.vectorSizes = inferred.value().destShape;

for (auto [dimPos, tileSize] :
llvm::zip_equal(op.getInnerDimsPos(), op.getStaticInnerTiles())) {
if (result.vectorSizes[dimPos] % tileSize != 0) {
return std::nullopt;
}
result.vectorSizes[dimPos] /= tileSize;
}
auto outerDimsPerm = op.getOuterDimsPerm();
if (!outerDimsPerm.empty()) {
applyPermutationToVector(result.vectorSizes, outerDimsPerm);
}

LLVM_DEBUG({
VEC_DBGS() << "After adjustment with inner tiles and "
"outer_dims_perm:\n";
for (auto [idx, val] : llvm::enumerate(result.vectorSizes)) {
llvm::dbgs() << "Dim #" << idx << ": " << val << "\n";
}
});
result.destShape = result.vectorSizes;

return result;
}

/// Returns the result sizes and vector input sizes of the tensor.unpack op. The
/// inferred bounding size is returned if it is dynamic shape. Returns
/// std::nullopt if the shape inference failed.
static std::optional<VectorizationTileSizes>
inferSizesFromIR(tensor::UnPackOp op) {
LLVM_DEBUG(VEC_DBGS() << "Inferring dest sizes for:\n" << op << "\n");

if (llvm::any_of(op.getInnerTiles(), [](OpFoldResult v) {
return !getConstantIntValue(v).has_value();
})) {
LLVM_DEBUG(
VEC_DBGS()
<< "failed on inference because inner_tiles are not all constant");
return std::nullopt;
}

VectorizationTileSizes result;
std::optional<VectorizationTileSizes> inferred =
inferSizesFromIR(op.getSource());
if (!inferred) {
return std::nullopt;
}
result.vectorSizes = inferred.value().destShape;

result.vectorSizes.resize(op.getDestType().getRank());
auto outerDimsPerm = op.getOuterDimsPerm();
if (!outerDimsPerm.empty()) {
applyPermutationToVector(result.vectorSizes,
invertPermutationVector(outerDimsPerm));
}
for (auto [dimPos, tileSize] :
llvm::zip_equal(op.getInnerDimsPos(), op.getStaticInnerTiles())) {
result.vectorSizes[dimPos] *= tileSize;
}

LLVM_DEBUG({
VEC_DBGS() << "After adjustment with inner tiles and "
"outer_dims_perm:\n";
for (auto [idx, val] : llvm::enumerate(result.vectorSizes)) {
llvm::dbgs() << "Dim #" << idx << ": " << val << "\n";
}
});
result.destShape = result.vectorSizes;

return result;
}

/// See the documentation in the above function declaration.
static std::optional<VectorizationTileSizes> inferSizesFromIR(Value val) {
std::optional<VectorizationTileSizes> result;
TypeSwitch<Operation *, void>(val.getDefiningOp())
.Case<linalg::LinalgOp>(
[&](auto op) { result = inferSizesFromIR(op, cast<OpResult>(val)); })
.Case<tensor::PackOp>([&](auto op) { result = inferSizesFromIR(op); })
.Case<tensor::ExtractSliceOp>([&](tensor::ExtractSliceOp op) {
// tensor::ExtractSliceOp is not vectorizable, so only `destShape` has
// the values.
result = VectorizationTileSizes();
LLVM_DEBUG(VEC_DBGS() << "Inferring sizes for:\n" << op << "\n");
int64_t destRank = op.getResult().getType().getRank();
for (int dim = 0; dim < destRank; ++dim) {
LLVM_DEBUG(VEC_DBGS() << "Dim #" << dim << ": ");
FailureOr<int64_t> maybeDimBound =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::UB, {op, dim},
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(maybeDimBound)) {
LLVM_DEBUG(llvm::dbgs() << "failed\n");
result = std::nullopt;
return;
}
LLVM_DEBUG(llvm::dbgs() << maybeDimBound.value() << "\n");
result->destShape.push_back(maybeDimBound.value());
}
})
.Default([&](Operation *) {});
return result;
}

// Returns the vector sizes from the local lowering config or try to infer them
// from the tensor shapes and tiled loops in the IR.
static std::optional<SizesAndScalableFlags>
Expand Down
Loading

0 comments on commit 6aedfd3

Please sign in to comment.