diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp index e1bc3cc4d424..732282c3f6db 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtAttrs.cpp @@ -147,7 +147,7 @@ LogicalResult NestedLayoutAttr::isValidLayout(ShapedType shapeTy, int64_t expectedShape = getSubgroupTile()[i] * getBatchTile()[i] * getOuterTile()[i] * getThreadTile()[i] * getElementTile()[i]; - if (expectedShape != shape[i]) { + if (!ShapedType::isDynamic(shape[i]) && expectedShape != shape[i]) { std::string shapeStr; llvm::raw_string_ostream shapeOs(shapeStr); llvm::interleaveComma(shape, shapeOs); diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp index 10341122e5c9..59b96ce933aa 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/VectorizeIREEVectorExtOps.cpp @@ -22,45 +22,84 @@ struct VectorizeToLayoutOpPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; + vector::TransferReadOp + createReadOp(PatternRewriter &rewriter, + IREE::VectorExt::ToLayoutOp toLayoutOp) const { + Location loc = toLayoutOp.getLoc(); + ShapedType inputTy = toLayoutOp.getType(); + auto zero = rewriter.create(loc, 0); + auto identityMap = rewriter.getMultiDimIdentityMap(inputTy.getRank()); + SmallVector readShape = + toLayoutOp.getLayout().getUndistributedShape(); + Value mask = nullptr; + if (!toLayoutOp.getType().hasStaticShape()) { + SmallVector mixedSourceDims = + tensor::getMixedSizes(rewriter, loc, toLayoutOp.getInput()); + auto maskType = VectorType::get(readShape, rewriter.getI1Type()); + mask = + rewriter.create(loc, maskType, mixedSourceDims); + } + VectorType vectorType = + VectorType::get(readShape, inputTy.getElementType()); + auto inBounds = rewriter.getBoolArrayAttr( + SmallVector(vectorType.getRank(), true)); + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(inputTy.getElementType())); + auto read = rewriter.create( + loc, + /*type=*/vectorType, + /*source=*/toLayoutOp.getInput(), + /*indices=*/ValueRange{SmallVector(readShape.size(), zero)}, + /*permutation_map=*/identityMap, + /*padding=*/padValue, + /*mask=*/mask, + /*in_bounds=*/inBounds); + return read; + } + + vector::TransferWriteOp + createWriteOp(PatternRewriter &rewriter, + IREE::VectorExt::ToLayoutOp tensorLayoutOp, + Value vectorLayoutOp, Value mask) const { + Location loc = tensorLayoutOp.getLoc(); + ShapedType tensorTy = tensorLayoutOp.getType(); + auto resType = + RankedTensorType::get(tensorTy.getShape(), tensorTy.getElementType()); + auto zero = rewriter.create(loc, 0); + int64_t rank = tensorTy.getShape().size(); + auto inBounds = rewriter.getBoolArrayAttr(SmallVector(rank, true)); + auto identityMap = rewriter.getMultiDimIdentityMap(tensorTy.getRank()); + auto empty = rewriter.create( + loc, tensor::getMixedSizes(rewriter, loc, tensorLayoutOp.getInput()), + tensorTy.getElementType()); + return rewriter.create( + loc, + /*result=*/resType, + /*vector=*/vectorLayoutOp, + /*source=*/empty, + /*indices=*/ValueRange{SmallVector(rank, zero)}, + /*permutation_map=*/identityMap, + /*mask=*/mask, + /*inBounds=*/inBounds); + } + LogicalResult matchAndRewrite(IREE::VectorExt::ToLayoutOp toLayoutOp, PatternRewriter &rewriter) const override { if (!toLayoutOp.hasTensorSemantics()) { return failure(); } - if (!toLayoutOp.getType().hasStaticShape()) { - return rewriter.notifyMatchFailure(toLayoutOp, - "non-static shape for vectorization"); - } - OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(toLayoutOp); - Location loc = toLayoutOp.getLoc(); - ShapedType inputTy = toLayoutOp.getType(); - - // Construct the (never used) zero padding value for input. - auto padValue = rewriter.create( - loc, rewriter.getZeroAttr(inputTy.getElementType())); - - auto newInput = vector::createReadOrMaskedRead( - rewriter, loc, toLayoutOp.getInput(), inputTy.getShape(), padValue, - /*useInBoundsInsteadOfMasking=*/true); - + vector::TransferReadOp readOp = createReadOp(rewriter, toLayoutOp); // Create the toLayout operation but with vector types instead. auto newLayoutOp = rewriter.create( - loc, newInput, toLayoutOp.getLayout(), toLayoutOp.getMmaKindAttr(), + loc, readOp, toLayoutOp.getLayout(), toLayoutOp.getMmaKindAttr(), toLayoutOp.getSharedMemoryConversion()); - // Create the write back to a tensor. - int64_t rank = inputTy.getRank(); - auto zero = rewriter.create(loc, 0); - auto empty = rewriter.create(loc, inputTy, ValueRange()); - rewriter.replaceOpWithNewOp( - toLayoutOp, - /*vector=*/newLayoutOp, - /*source=*/empty, - /*indices=*/SmallVector(rank, zero), - /*inBounds=*/SmallVector(rank, true)); + vector::TransferWriteOp writeOp = + createWriteOp(rewriter, toLayoutOp, newLayoutOp, readOp.getMask()); + rewriter.replaceOp(toLayoutOp, writeOp); return success(); } }; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel new file mode 100644 index 000000000000..fe13c88e8b85 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel @@ -0,0 +1,30 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Tests for common transforms. + +load("//build_tools/bazel:enforce_glob.bzl", "enforce_glob") +load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite") + +package( + features = ["layering_check"], + licenses = ["notice"], # Apache 2.0 +) + +iree_lit_test_suite( + name = "lit", + srcs = enforce_glob( + [ + "vectorize_vector_ext_ops.mlir", + ], + include = ["*.mlir"], + ), + cfg = "//compiler:lit.cfg.py", + tools = [ + "//tools:iree-opt", + "@llvm-project//llvm:FileCheck", + ], +) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt new file mode 100644 index 000000000000..4de340372c2c --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/CMakeLists.txt @@ -0,0 +1,23 @@ +################################################################################ +# Autogenerated by build_tools/bazel_to_cmake/bazel_to_cmake.py from # +# compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/BUILD.bazel# +# # +# Use iree_cmake_extra_content from iree/build_defs.oss.bzl to add arbitrary # +# CMake-only content. # +# # +# To disable autogeneration for this file entirely, delete this header. # +################################################################################ + +iree_add_all_subdirs() + +iree_lit_test_suite( + NAME + lit + SRCS + "vectorize_vector_ext_ops.mlir" + TOOLS + FileCheck + iree-opt +) + +### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir index 76735e3cac1f..fbeada9d1727 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/VectorExt/Transforms/test/vectorize_vector_ext_ops.mlir @@ -1,4 +1,4 @@ -// RUN: iree-opt %s -pass-pipeline='builtin.module(func.func(iree-vector-ext-vectorize-ops, iree-codegen-generic-vectorization))' | FileCheck %s +// RUN: iree-opt %s -pass-pipeline='builtin.module(func.func(iree-vector-ext-vectorize-ops, iree-codegen-generic-vectorization{enable-vector-masking=true}),canonicalize,cse,canonicalize)' --split-input-file --mlir-print-local-scope | FileCheck %s #layout = #iree_vector_ext.nested_layout< subgroup_tile = [1, 1], @@ -15,9 +15,9 @@ func.func @vectorize_matmul_layout(%A: tensor<64x64xf32>, %B: tensor<64x64xf32>, %C: tensor<64x64xf32>) -> tensor<64x64xf32> { - %AL = iree_vector_ext.to_layout %A to #layout : tensor<64x64xf32> - %BL = iree_vector_ext.to_layout %B to #layout : tensor<64x64xf32> - %CL = iree_vector_ext.to_layout %C to #layout : tensor<64x64xf32> + %AL = iree_vector_ext.to_layout %A to layout(#layout) : tensor<64x64xf32> + %BL = iree_vector_ext.to_layout %B to layout(#layout) : tensor<64x64xf32> + %CL = iree_vector_ext.to_layout %C to layout(#layout) : tensor<64x64xf32> %matmul = linalg.matmul ins(%AL, %BL : tensor<64x64xf32>, tensor<64x64xf32>) outs(%CL: tensor<64x64xf32>) -> tensor<64x64xf32> return %matmul : tensor<64x64xf32> @@ -35,3 +35,42 @@ func.func @vectorize_matmul_layout(%A: tensor<64x64xf32>, // CHECK: vector.contract // CHECK-SAME: %[[A]], %[[B]], %[[C]] + + +// ----- + +#layout = #iree_vector_ext.nested_layout< + subgroup_tile = [1, 1], + batch_tile = [1, 1], + outer_tile = [1, 1], + thread_tile = [1, 1], + element_tile = [64, 64], + + subgroup_strides = [0, 0], + thread_strides = [0, 0] +> + +func.func @vectorize_matmul_dyn_parallel(%A: tensor, + %B: tensor<64x?xf32>, + %C: tensor) + -> tensor { + %AL = iree_vector_ext.to_layout %A to layout(#layout) : tensor + %BL = iree_vector_ext.to_layout %B to layout(#layout) : tensor<64x?xf32> + %CL = iree_vector_ext.to_layout %C to layout(#layout) : tensor + %matmul = linalg.matmul ins(%AL, %BL : tensor, tensor<64x?xf32>) + outs(%CL: tensor) {lowering_config = #iree_codegen.lowering_config} + -> tensor + return %matmul : tensor +} + +// CHECK-LABEL: func.func @vectorize_matmul_dyn_parallel +// CHECK-SAME: %[[AT:.+]]: tensor, %[[BT:.+]]: tensor<64x?xf32>, %[[CT:.+]]: tensor +// CHECK-DAG: %[[PAD:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[ADIM:.+]] = tensor.dim %arg0, %c0 : tensor +// CHECK-DAG: %[[BDIM:.+]] = tensor.dim %arg1, %c1 : tensor<64x?xf32> +// CHECK-DAG: %[[AMASK:.+]] = vector.create_mask %[[ADIM]], %c64 : vector<64x64xi1> +// CHECK-DAG: %[[AV:.+]] = vector.transfer_read %arg0[%c0, %c0], %[[PAD]], %[[AMASK]] +// CHECK-DAG: %[[A:.+]] = iree_vector_ext.to_layout %[[AV]] to layout(#iree_vector_ext.nested_layout) : vector<64x64xf32> + +// CHECK-DAG: %[[OPMASK:.+]] = vector.create_mask %[[ADIM]], %[[BDIM]], %c64 : vector<64x64x64xi1> +// CHECK-DAG: vector.mask %[[OPMASK]] { vector.contract {{.*}} %[[A]]