Skip to content

Commit

Permalink
[VectorExt] Add support for masking for toLayout vectorization
Browse files Browse the repository at this point in the history
Currently, when we vectorize toLayout ops it expects the
source to be statically shaped. However, if we are to vectorize
dynamic shaped tensor to static shaped vector, toLayout op
needs to perform masking.

This commit adds support to add masking in the process of
vectorization of toLayout op.

Signed-off-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
manupak committed Jan 23, 2025
1 parent 6aedfd3 commit 00b7184
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ LogicalResult NestedLayoutAttr::isValidLayout(ShapedType shapeTy,
int64_t expectedShape = getSubgroupTile()[i] * getBatchTile()[i] *
getOuterTile()[i] * getThreadTile()[i] *
getElementTile()[i];
if (expectedShape != shape[i]) {
if (!ShapedType::isDynamic(shape[i]) && expectedShape != shape[i]) {
std::string shapeStr;
llvm::raw_string_ostream shapeOs(shapeStr);
llvm::interleaveComma(shape, shapeOs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,45 +22,84 @@ struct VectorizeToLayoutOpPattern final
: OpRewritePattern<IREE::VectorExt::ToLayoutOp> {
using OpRewritePattern::OpRewritePattern;

vector::TransferReadOp
createReadOp(PatternRewriter &rewriter,
IREE::VectorExt::ToLayoutOp toLayoutOp) const {
Location loc = toLayoutOp.getLoc();
ShapedType inputTy = toLayoutOp.getType();
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto identityMap = rewriter.getMultiDimIdentityMap(inputTy.getRank());
SmallVector<int64_t> readShape =
toLayoutOp.getLayout().getUndistributedShape();
Value mask = nullptr;
if (!toLayoutOp.getType().hasStaticShape()) {
SmallVector<OpFoldResult> mixedSourceDims =
tensor::getMixedSizes(rewriter, loc, toLayoutOp.getInput());
auto maskType = VectorType::get(readShape, rewriter.getI1Type());
mask =
rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
}
VectorType vectorType =
VectorType::get(readShape, inputTy.getElementType());
auto inBounds = rewriter.getBoolArrayAttr(
SmallVector<bool>(vectorType.getRank(), true));
auto padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(inputTy.getElementType()));
auto read = rewriter.create<vector::TransferReadOp>(
loc,
/*type=*/vectorType,
/*source=*/toLayoutOp.getInput(),
/*indices=*/ValueRange{SmallVector<Value>(readShape.size(), zero)},
/*permutation_map=*/identityMap,
/*padding=*/padValue,
/*mask=*/mask,
/*in_bounds=*/inBounds);
return read;
}

vector::TransferWriteOp
createWriteOp(PatternRewriter &rewriter,
IREE::VectorExt::ToLayoutOp tensorLayoutOp,
Value vectorLayoutOp, Value mask) const {
Location loc = tensorLayoutOp.getLoc();
ShapedType tensorTy = tensorLayoutOp.getType();
auto resType =
RankedTensorType::get(tensorTy.getShape(), tensorTy.getElementType());
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
int64_t rank = tensorTy.getShape().size();
auto inBounds = rewriter.getBoolArrayAttr(SmallVector<bool>(rank, true));
auto identityMap = rewriter.getMultiDimIdentityMap(tensorTy.getRank());
auto empty = rewriter.create<tensor::EmptyOp>(
loc, tensor::getMixedSizes(rewriter, loc, tensorLayoutOp.getInput()),
tensorTy.getElementType());
return rewriter.create<vector::TransferWriteOp>(
loc,
/*result=*/resType,
/*vector=*/vectorLayoutOp,
/*source=*/empty,
/*indices=*/ValueRange{SmallVector<Value>(rank, zero)},
/*permutation_map=*/identityMap,
/*mask=*/mask,
/*inBounds=*/inBounds);
}

LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp,
PatternRewriter &rewriter) const override {
if (!toLayoutOp.hasTensorSemantics()) {
return failure();
}
if (!toLayoutOp.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(toLayoutOp,
"non-static shape for vectorization");
}

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(toLayoutOp);

Location loc = toLayoutOp.getLoc();
ShapedType inputTy = toLayoutOp.getType();

// Construct the (never used) zero padding value for input.
auto padValue = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(inputTy.getElementType()));

auto newInput = vector::createReadOrMaskedRead(
rewriter, loc, toLayoutOp.getInput(), inputTy.getShape(), padValue,
/*useInBoundsInsteadOfMasking=*/true);

vector::TransferReadOp readOp = createReadOp(rewriter, toLayoutOp);
// Create the toLayout operation but with vector types instead.
auto newLayoutOp = rewriter.create<IREE::VectorExt::ToLayoutOp>(
loc, newInput, toLayoutOp.getLayout(), toLayoutOp.getMmaKindAttr(),
loc, readOp, toLayoutOp.getLayout(), toLayoutOp.getMmaKindAttr(),
toLayoutOp.getSharedMemoryConversion());

// Create the write back to a tensor.
int64_t rank = inputTy.getRank();
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto empty = rewriter.create<tensor::EmptyOp>(loc, inputTy, ValueRange());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
toLayoutOp,
/*vector=*/newLayoutOp,
/*source=*/empty,
/*indices=*/SmallVector<Value>(rank, zero),
/*inBounds=*/SmallVector<bool>(rank, true));
vector::TransferWriteOp writeOp =
createWriteOp(rewriter, toLayoutOp, newLayoutOp, readOp.getMask());
rewriter.replaceOp(toLayoutOp, writeOp);
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Tests for common transforms.

load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob")
load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")

package(
features = ["layering_check"],
licenses = ["notice"], # Apache 2.0
)

iree_lit_test_suite(
name = "lit",
srcs = enforce_glob(
[
"vectorize_vector_ext_ops.mlir",
],
include = ["*.mlir"],
),
cfg = "//compiler:lit.cfg.py",
tools = [
"//tools:iree-opt",
"@llvm-project//llvm:FileCheck",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
################################################################################
# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from #
# compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel#
# #
# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary #
# CMake-only content. #
# #
# To disable autogeneration for this file entirely, delete this header. #
################################################################################

iree_add_all_subdirs()

iree_lit_test_suite(
NAME
lit
SRCS
"vectorize_vector_ext_ops.mlir"
TOOLS
FileCheck
iree-opt
)

### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ###
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: iree-opt %s -pass-pipeline='builtin.module(func.func(iree-vector-ext-vectorize-ops, iree-codegen-generic-vectorization))' | FileCheck %s
// RUN: iree-opt %s -pass-pipeline='builtin.module(func.func(iree-vector-ext-vectorize-ops, iree-codegen-generic-vectorization{enable-vector-masking=true}),canonicalize,cse,canonicalize)' --split-input-file --mlir-print-local-scope | FileCheck %s

#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
Expand All @@ -15,9 +15,9 @@ func.func @vectorize_matmul_layout(%A: tensor<64x64xf32>,
%B: tensor<64x64xf32>,
%C: tensor<64x64xf32>)
-> tensor<64x64xf32> {
%AL = iree_vector_ext.to_layout %A to #layout : tensor<64x64xf32>
%BL = iree_vector_ext.to_layout %B to #layout : tensor<64x64xf32>
%CL = iree_vector_ext.to_layout %C to #layout : tensor<64x64xf32>
%AL = iree_vector_ext.to_layout %A to layout(#layout) : tensor<64x64xf32>
%BL = iree_vector_ext.to_layout %B to layout(#layout) : tensor<64x64xf32>
%CL = iree_vector_ext.to_layout %C to layout(#layout) : tensor<64x64xf32>
%matmul = linalg.matmul ins(%AL, %BL : tensor<64x64xf32>, tensor<64x64xf32>)
outs(%CL: tensor<64x64xf32>) -> tensor<64x64xf32>
return %matmul : tensor<64x64xf32>
Expand All @@ -35,3 +35,42 @@ func.func @vectorize_matmul_layout(%A: tensor<64x64xf32>,

// CHECK: vector.contract
// CHECK-SAME: %[[A]], %[[B]], %[[C]]


// -----

#layout = #iree_vector_ext.nested_layout<
subgroup_tile = [1, 1],
batch_tile = [1, 1],
outer_tile = [1, 1],
thread_tile = [1, 1],
element_tile = [64, 64],

subgroup_strides = [0, 0],
thread_strides = [0, 0]
>

func.func @vectorize_matmul_dyn_parallel(%A: tensor<?x64xf32>,
%B: tensor<64x?xf32>,
%C: tensor<?x?xf32>)
-> tensor<?x?xf32> {
%AL = iree_vector_ext.to_layout %A to layout(#layout) : tensor<?x64xf32>
%BL = iree_vector_ext.to_layout %B to layout(#layout) : tensor<64x?xf32>
%CL = iree_vector_ext.to_layout %C to layout(#layout) : tensor<?x?xf32>
%matmul = linalg.matmul ins(%AL, %BL : tensor<?x64xf32>, tensor<64x?xf32>)
outs(%CL: tensor<?x?xf32>) {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[0, 0, 0], [64, 64, 0], [0, 0, 64]]>}
-> tensor<?x?xf32>
return %matmul : tensor<?x?xf32>
}

// CHECK-LABEL: func.func @vectorize_matmul_dyn_parallel
// CHECK-SAME: %[[AT:.+]]: tensor<?x64xf32>, %[[BT:.+]]: tensor<64x?xf32>, %[[CT:.+]]: tensor<?x?xf32>
// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[ADIM:.+]] = tensor.dim %arg0, %c0 : tensor<?x64xf32>
// CHECK-DAG: %[[BDIM:.+]] = tensor.dim %arg1, %c1 : tensor<64x?xf32>
// CHECK-DAG: %[[AMASK:.+]] = vector.create_mask %[[ADIM]], %c64 : vector<64x64xi1>
// CHECK-DAG: %[[AV:.+]] = vector.transfer_read %arg0[%c0, %c0], %[[PAD]], %[[AMASK]]
// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_layout %[[AV]] to layout(#iree_vector_ext.nested_layout<subgroup_tile = [1, 1], batch_tile = [1, 1], outer_tile = [1, 1], thread_tile = [1, 1], element_tile = [64, 64], subgroup_strides = [0, 0], thread_strides = [0, 0]>) : vector<64x64xf32>

// CHECK-DAG: %[[OPMASK:.+]] = vector.create_mask %[[ADIM]], %[[BDIM]], %c64 : vector<64x64x64xi1>
// CHECK-DAG: vector.mask %[[OPMASK]] { vector.contract {{.*}} %[[A]]

0 comments on commit 00b7184

Please sign in to comment.