diff --git a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp index ebe4ef2aff1e..0175f19750da 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GenericVectorization.cpp @@ -325,6 +325,14 @@ class GenericVectorizationPass final void runOnOperation() override; }; +/// Converts from iree_compiler::VscaleRange to vector::VscaleRange. +static std::optional +toVectorVscaleRange(std::optional vscaleRange) { + if (!vscaleRange.has_value()) + return std::nullopt; + return vector::VscaleRange{vscaleRange->min, vscaleRange->max}; +} + void GenericVectorizationPass::runOnOperation() { MLIRContext *context = &getContext(); auto funcOp = getOperation(); @@ -377,6 +385,17 @@ void GenericVectorizationPass::runOnOperation() { vectorizeGatherAccesses); }; + { + // Eliminate (all-true) vector masks as early as possible (to avoid missing + // optimizations/folds). This is particularly beneficial for scalable + // vectors that use dynamic tensor shapes. + auto targetAttr = + iree_compiler::IREE::HAL::ExecutableTargetAttr::lookup(funcOp); + auto vscaleRange = iree_compiler::getDefaultVscaleRange(targetAttr); + vector::eliminateVectorMasks(rewriter, funcOp, + toVectorVscaleRange(vscaleRange)); + } + { // Canonicalize mask related ops before we lower them. RewritePatternSet maskCanonPatterns(funcOp.getContext());