From 6b64fca85f7be6cb29b7b20d15f1c13d298f7491 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Mon, 6 Jan 2025 09:20:50 -0800 Subject: [PATCH] [LinalgExt] Fusion support for LinalgExt ScatterOp 1/3 (#19560) These changes are needed to be able to propagate reshapes and fold unit dimensions. This essentially changes `scatter` to be more closely in line with [tf.tensor_scatter_nd_update](https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) except with a `dimension_map` (side note: the linked tensorflow docs have a really good explanation of the op). This also removes support for non-contiguous scatters because the slice must be right justified (along the innermost dimensions of `updates` and `original`) to prevent ambiguity around how to index `original` and how to scatter `updates`. #### Overview: - Update verifier to handle multiple batch dimensions. Restrict `dimension_map` to allow indexing only of the outermost dimensions, ensuring slices are inserted contiguously. - Fix `TilingInterfaceImpl` to support multiple "batch" dimensions and added test cases to `convert_to_loops.mlir` and `tiling.mlir` - Fix `ScatterOp` description to align with verifier - Add new test cases for `ScatterOp` and remove a few that are no longer supported. --------- Signed-off-by: Ian Wood --- .../test/stablehlo_to_linalg_ext.mlir | 21 --- .../Dialect/LinalgExt/IR/LinalgExtOps.cpp | 128 ++++++++++------- .../Dialect/LinalgExt/IR/LinalgExtOps.td | 108 +++++++++++--- .../LinalgExt/IR/TilingInterfaceImpl.cpp | 29 ++-- .../Dialect/LinalgExt/IR/test/invalid.mlir | 136 +++++++++++++++--- .../Dialect/LinalgExt/IR/test/roundtrip.mlir | 104 ++++++++++++++ .../Transforms/test/convert_to_loops.mlir | 35 ++++- .../LinalgExt/Transforms/test/tiling.mlir | 50 +++++++ tests/e2e/linalg_ext_ops/scatter.mlir | 32 +++++ tests/e2e/stablehlo_ops/scatter.mlir | 59 -------- 10 files changed, 515 insertions(+), 187 deletions(-) diff --git a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir index 713fb05c61e4..ca6343c73d1f 100644 --- a/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir +++ b/compiler/plugins/input/StableHLO/Conversion/test/stablehlo_to_linalg_ext.mlir @@ -344,27 +344,6 @@ func.func @scatter_add_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>, // ----- -// CHECK-LABEL: func.func @scatter_partial -func.func @scatter_partial(%arg0: tensor<10x5xf32>, %arg1: tensor<3x1xi32>, %arg2: tensor<3x3xf32>) -> tensor<10x5xf32> { - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - %1 = stablehlo.add %arg3, %arg4 : tensor - "stablehlo.return"(%1) : (tensor) -> () - }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<10x5xf32>, tensor<3x1xi32>, tensor<3x3xf32>) -> tensor<10x5xf32> - return %0 : tensor<10x5xf32> -} - -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter -// CHECK-SAME: unique_indices(false) -// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<3x3xf32>, tensor<3x1xi32>) -// CHECK-SAME: outs(%[[ARG0]] : tensor<10x5xf32>) -// CHECK: return %[[SCATTER]] - -// ----- - // CHECK-LABEL: func.func @scatter_ui32 func.func @scatter_ui32(%arg0: tensor<1xui32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xui32>) -> tensor<1xui32> { %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp index d66c5f2162f6..c1f452e04ca2 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp @@ -24,7 +24,6 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -97,6 +96,26 @@ static bool isInvalid(ArrayRef dimsPos, int64_t rank) { dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; }); } +/// Emit an error and return failure when `seq` is invalid. It is only valid +/// when it is a permutation of the sequence 0...length(seq) - 1. +static LogicalResult +isPermSequence(function_ref emitError, + ArrayRef seq) { + BitVector seen(seq.size(), false); + for (auto [idx, dim] : llvm::enumerate(seq)) { + if (dim < 0 || dim >= seq.size()) { + return emitError().attachNote() << "element (" << dim << ") at index#" + << idx << " is out of bounds"; + } + if (seen.test(dim)) { + return emitError().attachNote() + << "element (" << dim << ") at index#" << idx << " is a duplicate"; + } + seen.set(dim); + } + return success(); +} + /// Returns true if the dimension of `sourceShape` is smaller than the dimension /// of the `limitShape`. static bool isSmallerThan(ArrayRef sourceShape, @@ -126,15 +145,12 @@ LogicalResult ScatterOp::verify() { if (getOutputs().size() != 1) { return op->emitOpError("expected one output operand"); } - auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) { - return t1.getShape()[dim] == t2.getShape()[dim]; - }; auto indicesType = getIndicesType(); - if (indicesType.getRank() != 2 || + if (indicesType.getRank() < 2 || !isa(indicesType.getElementType())) { - return op->emitOpError( - "expected indices to be of rank 2 of integer element type"); + return op->emitOpError("expected indices to be of rank 2 or greater and of " + "integer element type"); } auto indexDepth = getIndexDepth(); if (ShapedType::isDynamic(indexDepth)) { @@ -143,67 +159,81 @@ LogicalResult ScatterOp::verify() { ArrayRef dimMap = getDimensionMap(); if (dimMap.size() != indexDepth) { - return op->emitOpError("invalid number of dimension map entries "); + return op->emitOpError("invalid number of dimension map entries"); } auto originalType = getOriginalType(); - if (isInvalid(dimMap, originalType.getRank())) { - return op->emitOpError("dimension map is invalid"); + if (failed(isPermSequence( + [&]() { return this->emitOpError("dimension map is invalid."); }, + dimMap))) { + return failure(); } - // The first dimension of the indices should match the first dimension of the - // output. They indicate to the number of updates. - auto updateType = getUpdateType(); - if (updateType.getRank() < 1) { - return op->emitOpError("expected update value to be at least rank 1"); - } - if (!checkDimensionsMatch(indicesType, updateType, 0)) { + if (indexDepth > originalType.getShape().size()) { return op->emitOpError( - "mismatch in shape of indices and update value at dim#0"); + "index depth is greater than the rank of the original value"); } - if (updateType.getRank() - 1 > originalType.getRank()) { + + auto updateType = getUpdateType(); + auto batchRank = indicesType.getRank() - 1; + if (updateType.getRank() < batchRank) { + return op->emitOpError("expected update value to be of rank greater than " + "or equal to rank(indices) - 1") + << batchRank; + } + + // Validate the shape of indices and update value match for the first + // `batchRank` dims. + auto [indicesIt, updateIt] = + llvm::mismatch(indicesType.getShape().take_front(batchRank), + updateType.getShape().take_front(batchRank)); + if (indicesIt != indicesType.getShape().take_front(batchRank).end()) { return op->emitOpError( - "update value rank exceeds the rank of the original value"); + "mismatch in shape of indices and update value at dim#") + << (indicesIt - indicesType.getShape().begin()); } - // indexDepth + update dims should cover the original dims. The first dim of - // update is the number of updates. - if (originalType.getRank() > indexDepth + updateType.getRank() - 1) { + if (updateType.getRank() - batchRank > originalType.getRank()) { + return op->emitOpError("update operand's slice rank (") + << updateType.getRank() - batchRank + << " = rank(updates) - batch rank) exceeds the rank of the original " + "value (" + << originalType.getRank() << ")"; + } + + // TODO: make it illegal for `numImplicitDims` to be non-zero. + auto numImplicitDims = originalType.getRank() - getUpdateSliceRank(); + if (numImplicitDims > indexDepth) { return op->emitOpError( - "index depth and update value does not cover rank of original value"); - } - - // Validate the non-indexed update dims cover the full slice size of the - // original tensor. - int64_t fullSliceDims = originalType.getRank() - indexDepth; - for (auto it : - llvm::zip_equal(llvm::seq(indexDepth, originalType.getRank()), - llvm::seq(updateType.getRank() - fullSliceDims, - updateType.getRank()))) { - int64_t originalDim = std::get<0>(it); - int64_t updateDim = std::get<1>(it); + "update and index depth does not fully index original"); + } + + // updateSlice[0..indexDepth] <= original[0..indexDepth] + // updateSlice[indexDepth..] == original[indexDepth..] + auto updateSliceShape = getUpdateSliceShape(); + for (uint64_t fullSliceIdx : + llvm::seq(numImplicitDims, indexDepth)) { + int64_t originalDim = fullSliceIdx; + int64_t updateSliceDim = fullSliceIdx - numImplicitDims; if (!originalType.isDynamicDim(originalDim) && - updateType.getDimSize(updateDim) > + updateSliceShape[updateSliceDim] > originalType.getDimSize(originalDim)) { return op->emitOpError("shape of update value dim#") - << updateDim << " exceeds original value at dim#" << originalDim; + << updateSliceDim + batchRank << " exceeds original value at dim#" + << originalDim; } } - // Check that the remaining update indices do not exceed the update length. - int64_t insertDims = originalType.getRank() - updateType.getRank() + 1; - for (auto it : llvm::zip_equal( - llvm::seq(insertDims, indexDepth), - llvm::seq(1, updateType.getRank() - fullSliceDims))) { - int64_t originalDim = std::get<0>(it); - int64_t updateDim = std::get<1>(it); + for (auto fullSliceIdx : + llvm::seq(indexDepth, originalType.getRank())) { + int64_t originalDim = fullSliceIdx; + int64_t updateSliceDim = fullSliceIdx - numImplicitDims; if (!originalType.isDynamicDim(originalDim) && - updateType.getDimSize(updateDim) > + updateSliceShape[updateSliceDim] != originalType.getDimSize(originalDim)) { - return op->emitOpError("indexed shape of update value dim#") - << updateDim << " exceeds original value at dim#" << originalDim - << " " << updateType.getDimSize(updateDim) << " " - << originalType.getDimSize(originalDim); + return op->emitOpError("shape of update value dim#") + << updateSliceDim + batchRank + << " must match original value at dim#" << originalDim; } } diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td index 137899125cb7..067c54d412e8 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td @@ -113,34 +113,52 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter", current value with the value in `updates` using the computation specified in `region`. The `region` specifies a binary operation of signature (T, T) -> T, where `T` is the element-type of - `updates` (and `original`). The first argument correspond the - value to be updated (i.e. from `updates`), and the second the - current value (i.e. value from `original`). + `updates` (and `original`). The first argument is from `updates`, + and the second is from `original`. - The `indices` is a 2D tensor/memref type. The first dim is the number of - updates, and the second dim is index depth. The index depth should always be - static. + The operand `indices` is a N-D tensor/memref type that is composed + of two logical parts: + - The first `N-1` dimensions represent the batch of updates. + - The last dim (at index `N-1`) is the `index_depth`, which should + always be static. - The first dim of `updates` and `indices` is identical, since they represent - the number of updates. + For example, given `indices` of shape `[4, 3, 2]`, the batch dimensions + are `[4, 3]` and the `index_depth` is `2`. - The rank of the `original`/`result` is at least - `index_depth + rank(%updates) - 1`. The first `index_depth` indices are - derived from `indices` and the shape of update value has the last - rank(%original) - index_depth values match %(originals) last dimensions, - with the previous dims extending from the index offsets. + The operand `update` is a M-D tensor/memref type and similarly + consists of two parts: + - The first `N-1` dimensions represent the batch of updates. This + must exactly match to the first `N-1` dimensions in `indices` + (from the example above: `indices` must start with `[4, 3]`) + - Dimensions `N..M-1` represent the slice scattered into `original`. + The first part of this tensor represents the dimensions indexed + by `indices`. This must be no larger than `index_depth` but can be + less if unit dimensions are omitted. + + The second part represents a contiguous slice to be inserted into + `original`. + + The operand `original` is a DPS init representing the destination that + `update` gets scattered to. + + + The rank of the `original` is at least `rank(%updates) - batch_rank`. + The first `index_depth` indices are derived from `indices` and the + shape of update value has the last rank(%original) - index_depth values + match %(originals) last dimensions, with the previous dims extending + from the index offsets. The dimension_map attributes describes which index value maps to which - dimension in the destionation. It cannot contain duplicate values, must - have as many entries as index depth, and values must be within the rank of - the destination. + dimension in the destionation. It's rank must equal `index_depth` as + represents a permutation of the indices before indexing into `original``. - The unique_indices attribute carries the information whether all the indices - are unique. If there are repeated indices, the first iteration loop will be - marked as reduction. + The unique_indices attribute carries the information whether all the + indices are unique. If `unique_indices` is `true` and two or more updates + scatter to the same location in `original` the final value in `original` is + not guaranteed. If `unique_indices` is set to false, the first + `batch_rank` iteration loops will be marked as reduction. - The shapes definition follows tensorflow operations execept that it force - batch dims to be 1D. See more information in + The shapes definition follows tensorflow operations. See more information in https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update }]; let arguments = (ins @@ -159,6 +177,9 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter", $region (`->` type($results)^)? }]; let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{ + static constexpr unsigned kUpdatesOpNum = 0; + static constexpr unsigned kIndicesOpNum = 1; + static constexpr unsigned kOriginalOpNum = 2; int64_t getIndexDepth() { return cast(getDpsInputOperand(1)->get().getType()) @@ -190,8 +211,51 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter", return cast(getOriginal().getType()); } + /// Utility to get the rank of the portion of `indices` that + /// represents the batch dimensions + int64_t getBatchRank() { + return getIndicesType().getRank() - 1; + } + + /// Utility to get the shape of the portion of `indices` that + /// represents the batch dimensions. + ArrayRef getBatchShape() { + return getIndicesType().getShape().slice(0, getBatchRank()); + } + + /// Utility to get the rank of the portion of `updates` that + /// is scattered into `original`. int64_t getUpdateSliceRank() { - return cast(getUpdates().getType()).getRank() - 1; + return getUpdateType().getRank() - getBatchRank(); + } + + /// Utility to get the shape of the portion of `updates` that + /// is scattered into `original`. + ArrayRef getUpdateSliceShape() { + return getUpdateType().getShape().slice(getBatchRank(), + getUpdateSliceRank()); + } + + /// Utility to get the dimension in `updates` the corresponds + /// to the given dimension in `original` + int64_t convertOriginalDimToUpdatesDim(uint64_t dim) { + assert(dim >= 0 && dim < getOriginalType().getRank() && + "expected dimension to be within original rank"); + int64_t updateDim = + getUpdateType().getRank() - getOriginalType().getRank() + dim; + assert(updateDim >= getBatchRank() && + "dim doesn't map to a dim in updates"); + return updateDim; + } + + /// Get the dimension in `original` that corresponds to the given + /// dimension in `original`. + int64_t convertUpdatesDimToOriginalDim(uint64_t dim) { + assert(dim >= getBatchRank() && + "update batch dim doesn't map to original"); + assert(dim < getUpdateType().getRank() && + "expected dimension to be within updates rank"); + return getOriginalType().getRank() - getUpdateType().getRank() + dim; } bool isScalarUpdate() { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp index 9c88f93778a4..4b21177e3287 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp @@ -73,7 +73,10 @@ SmallVector ScatterOp::getLoopIteratorTypes() { SmallVector iteratorTypes(getUpdateType().getRank(), utils::IteratorType::parallel); if (!getUniqueIndices()) { - iteratorTypes[0] = utils::IteratorType::reduction; + int64_t batchRank = getBatchRank(); + for (auto i : llvm::seq(0, batchRank)) { + iteratorTypes[i] = utils::IteratorType::reduction; + } } return iteratorTypes; } @@ -84,7 +87,7 @@ SmallVector ScatterOp::getIterationDomain(OpBuilder &builder) { Value one = builder.create(loc, 1); SmallVector ranges; for (auto dim : llvm::seq(0, getUpdateType().getRank())) { - Value ub = getDimValue(builder, loc, getUpdates(), dim); + OpFoldResult ub = getDim(builder, loc, getUpdates(), dim); ranges.emplace_back(Range{zero, ub, one}); } return ranges; @@ -113,14 +116,12 @@ ScatterOp::getTiledImplementation(OpBuilder &builder, // Slice of indices. auto indicesRank = getIndicesType().getRank(); - SmallVector indicesOffsets(indicesRank, zeroAttr); - SmallVector indicesSizes(indicesRank); - indicesOffsets[0] = offsets[0]; - indicesSizes[0] = sizes[0]; - for (auto dim : llvm::seq(1, indicesRank)) { - indicesSizes[dim] = getDim(builder, loc, getIndices(), dim); - } + SmallVector indicesOffsets(offsets.take_front(getBatchRank())); + indicesOffsets.push_back(zeroAttr); + SmallVector indicesSizes(sizes.take_front(getBatchRank())); + indicesSizes.push_back(builder.getIndexAttr(getIndexDepth())); SmallVector indicesStrides(indicesRank, oneAttr); + Operation *indicesSlice = getSlice(builder, loc, getIndices(), indicesOffsets, indicesSizes, indicesStrides); if (!indicesSlice) { @@ -170,11 +171,11 @@ LogicalResult ScatterOp::getResultTilePosition( auto updateRank = getUpdateType().getRank(); Location loc = getLoc(); - for (auto dim : llvm::seq(0, originalRank - updateRank + 1)) { + for (auto dim : llvm::seq(0, originalRank - getUpdateSliceRank())) { resultSizes[dim] = getDim(builder, loc, getOriginal(), dim); } for (auto dim : - llvm::seq(originalRank - updateRank + 1, originalRank)) { + llvm::seq(originalRank - getUpdateSliceRank(), originalRank)) { resultOffsets[dim] = offsets[dim - (originalRank - updateRank)]; resultSizes[dim] = sizes[dim - (originalRank - updateRank)]; } @@ -226,13 +227,13 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, Value update = b.create(loc, getUpdates(), ivs); SmallVector starts; SmallVector loadIndices; - loadIndices.push_back(ivs.front()); + append_range(loadIndices, ivs.take_front(getBatchRank())); loadIndices.push_back(Value()); // Populate with empty values. - auto originalTy = cast(getOriginal().getType()); + auto originalTy = getOriginalType(); starts.resize(originalTy.getRank(), Value()); - auto updateIvs = ivs.drop_front(1); + auto updateIvs = ivs.drop_front(getBatchRank()); int64_t offset = starts.size() - updateIvs.size(); for (auto [idx, iv] : llvm::enumerate(updateIvs)) { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir index 32ed87e1311d..93581a1cbe1b 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir @@ -44,7 +44,7 @@ func.func @sort_mismatch_shape(%arg0: tensor, %arg1: tensor<42xf32>) func.func @scatter_extra_outputs( %update : tensor, %indices : tensor, %original : tensor) -> (tensor, tensor) { - // expected-error @+1 {{expected the number of tensor results (2) to be equal to the number of output tensors (1)}} + // expected-error @below {{'iree_linalg_ext.scatter' op expected the number of tensor results (2) to be equal to the number of output tensors (1)}} %0, %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { @@ -76,7 +76,8 @@ func.func @scatter_mistmatch_dim_map_entries( func.func @scatter_duplicate_dim_map_entries( %update : tensor, %indices : tensor, %original : tensor) -> tensor { - // expected-error @+1 {{dimension map is invalid}} + // expected-error @below {{'iree_linalg_ext.scatter' op dimension map is invalid.}} + // expected-note @below {{element (1) at index#1 is a duplicate}} %0 = iree_linalg_ext.scatter dimension_map = [1, 1] unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { @@ -92,7 +93,8 @@ func.func @scatter_duplicate_dim_map_entries( func.func @scatter_invalid_dim_map_entries( %update : tensor, %indices : tensor, %original : tensor) -> tensor { - // expected-error @+1 {{dimension map is invalid}} + // expected-error @below {{'iree_linalg_ext.scatter' op dimension map is invalid.}} + // expected-note @below {{element (2) at index#0 is out of bounds}} %0 = iree_linalg_ext.scatter dimension_map = [2] unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { @@ -105,6 +107,23 @@ func.func @scatter_invalid_dim_map_entries( // ----- +func.func @scatter_invalid_dim_map_entries( + %update : tensor, %indices : tensor, + %original : tensor) -> tensor { + // expected-error @below {{'iree_linalg_ext.scatter' op dimension map is invalid.}} + // expected-note @below {{element (1) at index#0 is out of bounds}} + %0 = iree_linalg_ext.scatter dimension_map = [1] unique_indices(true) + ins(%update, %indices : tensor, tensor) + outs(%original : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + func.func @scatter_output_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor<4x?xf32> { @@ -124,7 +143,7 @@ func.func @scatter_output_type_mismatch( func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor<48x1xi32>, %original : tensor) -> tensor { - // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} + // expected-error @below {{'iree_linalg_ext.scatter' op mismatch in shape of indices and update value at dim#0}} %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%update, %indices : tensor, tensor<48x1xi32>) outs(%original : tensor) { @@ -137,6 +156,22 @@ func.func @scatter_dim_mismatch( // ----- +func.func @scatter_dim_mismatch( + %update : tensor<48x?x?xf32>, %indices : tensor<48x10x1xi32>, + %original : tensor) -> tensor { + // expected-error @below {{'iree_linalg_ext.scatter' op mismatch in shape of indices and update value at dim#1}} + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) + ins(%update, %indices : tensor<48x?x?xf32>, tensor<48x10x1xi32>) + outs(%original : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + func.func @scatter_dim_mismatch( %update : tensor<64x?xf32>, %indices : tensor<48x1xi32>, %original : tensor) -> tensor { @@ -154,9 +189,41 @@ func.func @scatter_dim_mismatch( // ----- func.func @scatter_dim_mismatch( + %update : tensor<48x?x2x11xf32>, %indices : tensor<48x?x1xi32>, + %original : tensor) -> tensor { + // expected-error @below {{'iree_linalg_ext.scatter' op shape of update value dim#3 must match original value at dim#1}} + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) + ins(%update, %indices : tensor<48x?x2x11xf32>, tensor<48x?x1xi32>) + outs(%original : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + +func.func @scatter_rank_mismatch( + %update : tensor, %indices : tensor, + %original : tensor) -> tensor { + // expected-error @below {{'iree_linalg_ext.scatter' op update operand's slice rank (3 = rank(updates) - batch rank) exceeds the rank of the original value (2)}} + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) + ins(%update, %indices : tensor, tensor) + outs(%original : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + +func.func @scatter_rank_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { - // expected-error @+1 {{op update value rank exceeds the rank of the original value}} + // expected-error @below {{'iree_linalg_ext.scatter' op update operand's slice rank (3 = rank(updates) - batch rank) exceeds the rank of the original value (2)}} %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { @@ -169,10 +236,26 @@ func.func @scatter_dim_mismatch( // ----- +func.func @scatter_rank_mismatch( + %update : tensor, %indices : tensor, + %original : tensor) -> tensor { + // expected-error @below {{'iree_linalg_ext.scatter' op update operand's slice rank (3 = rank(updates) - batch rank) exceeds the rank of the original value (2)}} + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) + ins(%update, %indices : tensor, tensor) + outs(%original : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { - // expected-error @+1 {{op shape of update value dim#1 exceeds original value at dim#1}} + // expected-error @below {{'iree_linalg_ext.scatter' op shape of update value dim#1 must match original value at dim#1}} %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { @@ -321,19 +404,36 @@ func.func @scatter_index_depth_dynamic( // ----- -func.func @scatter_original_rank_mismatch( - %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{op index depth and update value does not cover rank of original value}} +func.func @scatter_index_depth_too_large( + %original: tensor, %indices: tensor, + %update: tensor) -> tensor { + // expected-error @below {{'iree_linalg_ext.scatter' op index depth is greater than the rank of the original value}} + %0 = iree_linalg_ext.scatter + dimension_map = [0, 1, 2] + unique_indices(true) + ins(%update, %indices : tensor, tensor) + outs(%original: tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +// ----- + +func.func @scatter_index_depth_too_small( + %update : tensor, %indices : tensor, + %original : tensor) -> tensor { + // expected-error @below {{'iree_linalg_ext.scatter' op update and index depth does not fully index original}} %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) - ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { - ^bb0(%arg1: i64, %arg2: i64): - %1 = arith.addi %arg1, %arg2 : i64 - %2 = arith.trunci %1 : i64 to i32 - iree_linalg_ext.yield %1, %2 : i64, i32 - } -> tensor - return %0 : tensor + ins(%update, %indices : tensor, tensor) + outs(%original : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor } // ----- diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir index 7394e09a3f17..1bc505f7584d 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/roundtrip.mlir @@ -377,6 +377,110 @@ func.func @scatter_update_slice_2D( // ----- +func.func @scatter_batch_2D_dynamic( + %update : tensor<48x?x?xf32>, %indices : tensor<48x?x1xi32>, + %original : tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) + ins(%update, %indices : tensor<48x?x?xf32>, tensor<48x?x1xi32>) + outs(%original : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor +} +// CHECK-LABEL: func.func @scatter_batch_2D_dynamic( +// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]] +// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: dimension_map = [0] +// CHECK-SAME: unique_indices(true) +// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] +// CHECK-SAME: outs(%[[ORIGINAL]] +// CHECK: iree_linalg_ext.yield %{{.+}} : f32 +// CHECK: return %[[RESULT]] + +// ----- + +func.func @scatter_batch_2D_static( + %update : tensor<48x?x1x10xf32>, %indices : tensor<48x?x1xi32>, + %original : tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) + ins(%update, %indices : tensor<48x?x1x10xf32>, tensor<48x?x1xi32>) + outs(%original : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor +} +// CHECK-LABEL: func.func @scatter_batch_2D_static( +// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]] +// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: dimension_map = [0] +// CHECK-SAME: unique_indices(true) +// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] +// CHECK-SAME: outs(%[[ORIGINAL]] +// CHECK: iree_linalg_ext.yield %{{.+}} : f32 +// CHECK: return %[[RESULT]] + +// ----- + +func.func @scatter_rank_reduced( + %update : tensor<48x10xf32>, %indices : tensor<48x1xi32>, + %original : tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) + ins(%update, %indices : tensor<48x10xf32>, tensor<48x1xi32>) + outs(%original : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor +} +// CHECK-LABEL: func.func @scatter_rank_reduced( +// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]] +// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: dimension_map = [0] +// CHECK-SAME: unique_indices(true) +// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] +// CHECK-SAME: outs(%[[ORIGINAL]] +// CHECK: iree_linalg_ext.yield %{{.+}} : f32 +// CHECK: return %[[RESULT]] + +// ----- + +func.func @scatter_batch_2D_rank_reduced( + %update : tensor<48x?x10xf32>, %indices : tensor<48x?x1xi32>, + %original : tensor) -> tensor { + %0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) + ins(%update, %indices : tensor<48x?x10xf32>, tensor<48x?x1xi32>) + outs(%original : tensor) { + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + iree_linalg_ext.yield %1 : f32 + } -> tensor + return %0 : tensor +} +// CHECK-LABEL: func.func @scatter_batch_2D_rank_reduced( +// CHECK-SAME: %[[UPDATE:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]] +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]] +// CHECK: %[[RESULT:.+]] = iree_linalg_ext.scatter +// CHECK-SAME: dimension_map = [0] +// CHECK-SAME: unique_indices(true) +// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] +// CHECK-SAME: outs(%[[ORIGINAL]] +// CHECK: iree_linalg_ext.yield %{{.+}} : f32 +// CHECK: return %[[RESULT]] + +// ----- + func.func @scatter_update_slice_2D( %original: tensor<4x?xi32>, %indices: tensor<1x1xi32>, %updates: tensor<1x3xi32>) -> tensor<4x?xi32> { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir index f136ab511342..c5b624b73bbd 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/convert_to_loops.mlir @@ -124,6 +124,33 @@ func.func @scatter_update_scalar_1D( // ----- +func.func @scatter_batch_2D( + %original: memref<8xi32>, %indices: memref<1x3x1xi32>, + %updates: memref<1x3xi32>) { + iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) + ins(%updates, %indices : memref<1x3xi32>, memref<1x3x1xi32>) + outs(%original : memref<8xi32>) { + ^bb0(%arg0: i32, %arg1: i32): // no predecessors + iree_linalg_ext.yield %arg0 : i32 + } + return +} +// CHECK-LABEL: func.func @scatter_batch_2D +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index +// CHECK: scf.for %[[I0:.+]] = %[[C0]] to %[[C1]] step %[[C1]] { +// CHECK: scf.for %[[I1:.+]] = %[[C0]] to %[[C3]] step %[[C1]] { +// CHECK: %[[T1:.+]] = memref.load %[[UPDATES]][%[[I0]], %[[I1]]] : memref<1x3xi32> +// CHECK: %[[T2:.+]] = memref.load %[[INDICES]][%[[I0]], %[[I1]], %[[C0]]] : memref<1x3x1xi32> +// CHECK: %[[IDX:.+]] = arith.index_cast %[[T2]] : i32 to index +// CHECK: memref.store %[[T1]], %[[ORIGINAL]][%[[IDX]]] + +// ----- + func.func @scatter_add_scalar_2D( %original: memref<4x3xi32>, %indices: memref<3x2xi32>, %updates: memref<3xi32>) { @@ -345,10 +372,10 @@ func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3x // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK-DAG: %[[C0:.+]] = arith.constant -// CHECK-DAG: %[[C1:.+]] = arith.constant -// CHECK-DAG: %[[C2:.+]] = arith.constant -// CHECK-DAG: %[[C12:.+]] = arith.constant +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C12:.+]] = arith.constant 12 : index // CHECK: scf.for %[[ARG3:.+]] = %[[C0]] to %[[C2]] step %[[C1]] { // CHECK-NEXT: scf.for %[[ARG4:.+]] = %[[C0]] to %[[C1]] step %[[C1]] { // CHECK-NEXT: scf.for %[[ARG5:.+]] = %[[C0]] to %[[C12]] step %[[C1]] { diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir index dfe59f9b6b25..a6baca15a877 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Transforms/test/tiling.mlir @@ -196,6 +196,56 @@ module attributes { transform.with_named_sequence } { // ----- +func.func @scatter_batch_2D( + %original: memref, %indices: memref, + %updates: memref) { + iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) + ins(%updates, %indices : memref, memref) + outs(%original : memref) { + ^bb0(%arg0: i32, %arg1: i32): // no predecessors + iree_linalg_ext.yield %arg0 : i32 + } + return +} +module attributes { transform.with_named_sequence } { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["iree_linalg_ext.scatter"]} in %module_op : (!transform.any_op) -> !transform.any_op + %1, %loops = transform.structured.tile_using_for %0 tile_sizes [0, 20] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)> +// CHECK: func.func @scatter_batch_2D +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[UPDATES]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[UPDATES]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = memref.dim %[[UPDATES]], %[[C2]] +// CHECK: scf.for %[[I:.+]] = %[[C0]] to %[[D1]] step %[[C20]] +// CHECK: %[[SZ:.+]] = affine.min #[[MAP]](%[[I]])[%[[D1]]] +// CHECK: %[[UPDATES_TILE:.+]] = memref.subview +// CHECK-SAME: %[[UPDATES]][0, %[[I]], 0] +// CHECK-SAME: [%[[D0]], %[[SZ]], %[[D2]]] +// CHECK: %[[INDICES_TILE:.+]] = memref.subview +// CHECK-SAME: %[[INDICES]][0, %[[I]], 0] +// CHECK-SAME: [%[[D0]], %[[SZ]], 1] +// CHECK: %[[ORIGINAL_TILE:.+]] = memref.subview +// CHECK-SAME: %[[ORIGINAL]][0] +// CHECK-SAME: [%[[D2]]] +// CHECK: %[[ORIG_CAST:.+]] = memref.cast %[[ORIGINAL_TILE]] +// CHECK: iree_linalg_ext.scatter +// CHECK-SAME: unique_indices(true) +// CHECK-SAME: ins(%[[UPDATES_TILE]], %[[INDICES_TILE]] +// CHECK-SAME: outs(%[[ORIG_CAST]] + +// ----- + + func.func @sort_1d(%arg0: tensor) -> tensor { %0 = iree_linalg_ext.sort dimension(0) diff --git a/tests/e2e/linalg_ext_ops/scatter.mlir b/tests/e2e/linalg_ext_ops/scatter.mlir index 786764d5684d..808475a678dc 100644 --- a/tests/e2e/linalg_ext_ops/scatter.mlir +++ b/tests/e2e/linalg_ext_ops/scatter.mlir @@ -96,6 +96,38 @@ func.func @scatter_2d_multiple() { return } +func.func @scatter_2d_unit_batch() { + %original = util.unfoldable_constant dense<0> : tensor<2x2xi32> + %update = util.unfoldable_constant dense<1> : tensor<1x2xi32> + %indices = util.unfoldable_constant dense<[[[0, 0], [1, 1]]]> : tensor<1x2x2xi32> + %result = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) + ins(%update, %indices : tensor<1x2xi32>, tensor<1x2x2xi32>) + outs(%original : tensor<2x2xi32>) { + ^bb0(%arg0: i32, %arg1: i32): + iree_linalg_ext.yield %arg0 : i32 + } -> tensor<2x2xi32> + + check.expect_eq_const(%result, dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>) : tensor<2x2xi32> + + return +} + +func.func @scatter_2d_batch() { + %original = util.unfoldable_constant dense<0> : tensor<2x2xi32> + %update = util.unfoldable_constant dense<1> : tensor<2x2xi32> + %indices = util.unfoldable_constant dense<[[[0, 0], [1, 1]], [[1, 0], [0, 1]]]> : tensor<2x2x2xi32> + %result = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) + ins(%update, %indices : tensor<2x2xi32>, tensor<2x2x2xi32>) + outs(%original : tensor<2x2xi32>) { + ^bb0(%arg0: i32, %arg1: i32): + iree_linalg_ext.yield %arg0 : i32 + } -> tensor<2x2xi32> + + check.expect_eq_const(%result, dense<[[1, 1], [1, 1]]> : tensor<2x2xi32>) : tensor<2x2xi32> + + return +} + func.func @scatter_2d_multiple_slice() { %original = util.unfoldable_constant dense<0> : tensor<3x3xi32> %update = util.unfoldable_constant dense<1> : tensor<2x2xi32> diff --git a/tests/e2e/stablehlo_ops/scatter.mlir b/tests/e2e/stablehlo_ops/scatter.mlir index 0a9fdec5f7ab..07a8ec34ecdf 100644 --- a/tests/e2e/stablehlo_ops/scatter.mlir +++ b/tests/e2e/stablehlo_ops/scatter.mlir @@ -87,33 +87,6 @@ func.func @scatter_update_slice_2D() { return } -func.func @scatter_update_slice_partial_2D() { - %arg0 = util.unfoldable_constant dense<0> : tensor<6x3xi32> - %arg1 = util.unfoldable_constant dense<[[2], [4]]> : tensor<2x1xi32> - %arg2 = util.unfoldable_constant dense<[[1, 2], - [4, 5]]> : tensor<2x2xi32> - %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ( { - ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors - "stablehlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #stablehlo.scatter< - update_window_dims = [1], - inserted_window_dims = [0], - scatter_dims_to_operand_dims = [0], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<6x3xi32>, tensor<2x1xi32>, tensor<2x2xi32>) -> tensor<6x3xi32> - check.expect_eq_const(%0, dense<[[0, 0, 0], - [0, 0, 0], - [1, 2, 0], - [0, 0, 0], - [4, 5, 0], - [0, 0, 0]]> : tensor<6x3xi32>) : tensor<6x3xi32> - return -} - func.func @scatter_add_slice_2D() { %arg0 = util.unfoldable_constant dense<1> : tensor<6x3xi32> %arg1 = util.unfoldable_constant dense<[[2], [4]]> : tensor<2x1xi32> @@ -204,35 +177,3 @@ func.func @scatter_2D_large() { check.expect_eq_const(%result, dense<2> : tensor<200x300xi32>) : tensor<200x300xi32> return } - -func.func @scatter_2D_large_permuted() { - %original = util.unfoldable_constant dense<1> : tensor<200x300xi32> - %update = util.unfoldable_constant dense<2> : tensor<300x200xi32> - %init = tensor.empty() : tensor<300xi32> - %indices = linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - outs(%init : tensor<300xi32>) { - ^bb0(%arg0: i32): - %0 = linalg.index 0 : index - %1 = arith.index_cast %0 : index to i32 - linalg.yield %1 : i32 - } -> tensor<300xi32> - %indices_reshaped = tensor.expand_shape %indices [[0, 1]] output_shape [300, 1] : - tensor<300xi32> into tensor<300x1xi32> - %result = "stablehlo.scatter"(%original, %indices_reshaped, %update)({ - ^bb0(%arg3 : tensor, %arg4 : tensor): - "stablehlo.return"(%arg4) : (tensor) -> () - }) { - indices_are_sorted = false, - scatter_dimension_numbers = #stablehlo.scatter< - update_window_dims = [1], - inserted_window_dims = [1], - scatter_dims_to_operand_dims = [1], - index_vector_dim = 1, - >, - unique_indices = true - } : (tensor<200x300xi32>, tensor<300x1xi32>, tensor<300x200xi32>) -> tensor<200x300xi32> - check.expect_eq_const(%result, dense<2> : tensor<200x300xi32>) : tensor<200x300xi32> - return -}