Skip to content

Commit

Permalink
[LinalgExt] Fusion support for LinalgExt ScatterOp 1/3 (#19560)
Browse files Browse the repository at this point in the history
These changes are needed to be able to propagate reshapes and fold unit
dimensions. This essentially changes `scatter` to be more closely in
line with
[tf.tensor_scatter_nd_update](https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update)
except with a `dimension_map` (side note: the linked tensorflow docs
have a really good explanation of the op).

This also removes support for non-contiguous scatters because the slice
must be right justified (along the innermost dimensions of `updates` and
`original`) to prevent ambiguity around how to index `original` and how
to scatter `updates`.

#### Overview:
- Update verifier to handle multiple batch dimensions. Restrict
`dimension_map` to allow indexing only of the outermost
  dimensions, ensuring slices are inserted contiguously.
- Fix `TilingInterfaceImpl` to support multiple "batch" dimensions
  and added test cases to `convert_to_loops.mlir` and `tiling.mlir`
- Fix `ScatterOp` description to align with verifier
- Add new test cases for `ScatterOp` and remove a few that are no longer
supported.

---------

Signed-off-by: Ian Wood <[email protected]>
  • Loading branch information
IanWood1 authored Jan 6, 2025
1 parent 9cb984f commit 6b64fca
Show file tree
Hide file tree
Showing 10 changed files with 515 additions and 187 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,27 +344,6 @@ func.func @scatter_add_slice_2D(%arg0: tensor<6x3xi32>, %arg1: tensor<2x1xi32>,

// -----

// CHECK-LABEL: func.func @scatter_partial
func.func @scatter_partial(%arg0: tensor<10x5xf32>, %arg1: tensor<3x1xi32>, %arg2: tensor<3x3xf32>) -> tensor<10x5xf32> {
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ( {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>): // no predecessors
%1 = stablehlo.add %arg3, %arg4 : tensor<f32>
"stablehlo.return"(%1) : (tensor<f32>) -> ()
}) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<10x5xf32>, tensor<3x1xi32>, tensor<3x3xf32>) -> tensor<10x5xf32>
return %0 : tensor<10x5xf32>
}

// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: unique_indices(false)
// CHECK-SAME: ins(%[[ARG2]], %[[ARG1]] : tensor<3x3xf32>, tensor<3x1xi32>)
// CHECK-SAME: outs(%[[ARG0]] : tensor<10x5xf32>)
// CHECK: return %[[SCATTER]]

// -----

// CHECK-LABEL: func.func @scatter_ui32
func.func @scatter_ui32(%arg0: tensor<1xui32>, %arg1: tensor<1x1xi32>, %arg2: tensor<1xui32>) -> tensor<1xui32> {
%0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({
Expand Down
128 changes: 79 additions & 49 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -97,6 +96,26 @@ static bool isInvalid(ArrayRef<int64_t> dimsPos, int64_t rank) {
dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; });
}

/// Emit an error and return failure when `seq` is invalid. It is only valid
/// when it is a permutation of the sequence 0...length(seq) - 1.
static LogicalResult
isPermSequence(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> seq) {
BitVector seen(seq.size(), false);
for (auto [idx, dim] : llvm::enumerate(seq)) {
if (dim < 0 || dim >= seq.size()) {
return emitError().attachNote() << "element (" << dim << ") at index#"
<< idx << " is out of bounds";
}
if (seen.test(dim)) {
return emitError().attachNote()
<< "element (" << dim << ") at index#" << idx << " is a duplicate";
}
seen.set(dim);
}
return success();
}

/// Returns true if the dimension of `sourceShape` is smaller than the dimension
/// of the `limitShape`.
static bool isSmallerThan(ArrayRef<int64_t> sourceShape,
Expand Down Expand Up @@ -126,15 +145,12 @@ LogicalResult ScatterOp::verify() {
if (getOutputs().size() != 1) {
return op->emitOpError("expected one output operand");
}
auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) {
return t1.getShape()[dim] == t2.getShape()[dim];
};

auto indicesType = getIndicesType();
if (indicesType.getRank() != 2 ||
if (indicesType.getRank() < 2 ||
!isa<IntegerType>(indicesType.getElementType())) {
return op->emitOpError(
"expected indices to be of rank 2 of integer element type");
return op->emitOpError("expected indices to be of rank 2 or greater and of "
"integer element type");
}
auto indexDepth = getIndexDepth();
if (ShapedType::isDynamic(indexDepth)) {
Expand All @@ -143,67 +159,81 @@ LogicalResult ScatterOp::verify() {

ArrayRef<int64_t> dimMap = getDimensionMap();
if (dimMap.size() != indexDepth) {
return op->emitOpError("invalid number of dimension map entries ");
return op->emitOpError("invalid number of dimension map entries");
}

auto originalType = getOriginalType();
if (isInvalid(dimMap, originalType.getRank())) {
return op->emitOpError("dimension map is invalid");
if (failed(isPermSequence(
[&]() { return this->emitOpError("dimension map is invalid."); },
dimMap))) {
return failure();
}

// The first dimension of the indices should match the first dimension of the
// output. They indicate to the number of updates.
auto updateType = getUpdateType();
if (updateType.getRank() < 1) {
return op->emitOpError("expected update value to be at least rank 1");
}
if (!checkDimensionsMatch(indicesType, updateType, 0)) {
if (indexDepth > originalType.getShape().size()) {
return op->emitOpError(
"mismatch in shape of indices and update value at dim#0");
"index depth is greater than the rank of the original value");
}
if (updateType.getRank() - 1 > originalType.getRank()) {

auto updateType = getUpdateType();
auto batchRank = indicesType.getRank() - 1;
if (updateType.getRank() < batchRank) {
return op->emitOpError("expected update value to be of rank greater than "
"or equal to rank(indices) - 1")
<< batchRank;
}

// Validate the shape of indices and update value match for the first
// `batchRank` dims.
auto [indicesIt, updateIt] =
llvm::mismatch(indicesType.getShape().take_front(batchRank),
updateType.getShape().take_front(batchRank));
if (indicesIt != indicesType.getShape().take_front(batchRank).end()) {
return op->emitOpError(
"update value rank exceeds the rank of the original value");
"mismatch in shape of indices and update value at dim#")
<< (indicesIt - indicesType.getShape().begin());
}

// indexDepth + update dims should cover the original dims. The first dim of
// update is the number of updates.
if (originalType.getRank() > indexDepth + updateType.getRank() - 1) {
if (updateType.getRank() - batchRank > originalType.getRank()) {
return op->emitOpError("update operand's slice rank (")
<< updateType.getRank() - batchRank
<< " = rank(updates) - batch rank) exceeds the rank of the original "
"value ("
<< originalType.getRank() << ")";
}

// TODO: make it illegal for `numImplicitDims` to be non-zero.
auto numImplicitDims = originalType.getRank() - getUpdateSliceRank();
if (numImplicitDims > indexDepth) {
return op->emitOpError(
"index depth and update value does not cover rank of original value");
}

// Validate the non-indexed update dims cover the full slice size of the
// original tensor.
int64_t fullSliceDims = originalType.getRank() - indexDepth;
for (auto it :
llvm::zip_equal(llvm::seq<unsigned>(indexDepth, originalType.getRank()),
llvm::seq<unsigned>(updateType.getRank() - fullSliceDims,
updateType.getRank()))) {
int64_t originalDim = std::get<0>(it);
int64_t updateDim = std::get<1>(it);
"update and index depth does not fully index original");
}

// updateSlice[0..indexDepth] <= original[0..indexDepth]
// updateSlice[indexDepth..] == original[indexDepth..]
auto updateSliceShape = getUpdateSliceShape();
for (uint64_t fullSliceIdx :
llvm::seq<uint64_t>(numImplicitDims, indexDepth)) {
int64_t originalDim = fullSliceIdx;
int64_t updateSliceDim = fullSliceIdx - numImplicitDims;
if (!originalType.isDynamicDim(originalDim) &&
updateType.getDimSize(updateDim) >
updateSliceShape[updateSliceDim] >
originalType.getDimSize(originalDim)) {
return op->emitOpError("shape of update value dim#")
<< updateDim << " exceeds original value at dim#" << originalDim;
<< updateSliceDim + batchRank << " exceeds original value at dim#"
<< originalDim;
}
}

// Check that the remaining update indices do not exceed the update length.
int64_t insertDims = originalType.getRank() - updateType.getRank() + 1;
for (auto it : llvm::zip_equal(
llvm::seq<unsigned>(insertDims, indexDepth),
llvm::seq<unsigned>(1, updateType.getRank() - fullSliceDims))) {
int64_t originalDim = std::get<0>(it);
int64_t updateDim = std::get<1>(it);
for (auto fullSliceIdx :
llvm::seq<int64_t>(indexDepth, originalType.getRank())) {
int64_t originalDim = fullSliceIdx;
int64_t updateSliceDim = fullSliceIdx - numImplicitDims;
if (!originalType.isDynamicDim(originalDim) &&
updateType.getDimSize(updateDim) >
updateSliceShape[updateSliceDim] !=
originalType.getDimSize(originalDim)) {
return op->emitOpError("indexed shape of update value dim#")
<< updateDim << " exceeds original value at dim#" << originalDim
<< " " << updateType.getDimSize(updateDim) << " "
<< originalType.getDimSize(originalDim);
return op->emitOpError("shape of update value dim#")
<< updateSliceDim + batchRank
<< " must match original value at dim#" << originalDim;
}
}

Expand Down
108 changes: 86 additions & 22 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -113,34 +113,52 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
current value with the value in `updates` using the computation
specified in `region`. The `region` specifies a binary operation
of signature (T, T) -> T, where `T` is the element-type of
`updates` (and `original`). The first argument correspond the
value to be updated (i.e. from `updates`), and the second the
current value (i.e. value from `original`).
`updates` (and `original`). The first argument is from `updates`,
and the second is from `original`.

The `indices` is a 2D tensor/memref type. The first dim is the number of
updates, and the second dim is index depth. The index depth should always be
static.
The operand `indices` is a N-D tensor/memref type that is composed
of two logical parts:
- The first `N-1` dimensions represent the batch of updates.
- The last dim (at index `N-1`) is the `index_depth`, which should
always be static.

The first dim of `updates` and `indices` is identical, since they represent
the number of updates.
For example, given `indices` of shape `[4, 3, 2]`, the batch dimensions
are `[4, 3]` and the `index_depth` is `2`.

The rank of the `original`/`result` is at least
`index_depth + rank(%updates) - 1`. The first `index_depth` indices are
derived from `indices` and the shape of update value has the last
rank(%original) - index_depth values match %(originals) last dimensions,
with the previous dims extending from the index offsets.
The operand `update` is a M-D tensor/memref type and similarly
consists of two parts:
- The first `N-1` dimensions represent the batch of updates. This
must exactly match to the first `N-1` dimensions in `indices`
(from the example above: `indices` must start with `[4, 3]`)
- Dimensions `N..M-1` represent the slice scattered into `original`.
The first part of this tensor represents the dimensions indexed
by `indices`. This must be no larger than `index_depth` but can be
less if unit dimensions are omitted.

The second part represents a contiguous slice to be inserted into
`original`.

The operand `original` is a DPS init representing the destination that
`update` gets scattered to.


The rank of the `original` is at least `rank(%updates) - batch_rank`.
The first `index_depth` indices are derived from `indices` and the
shape of update value has the last rank(%original) - index_depth values
match %(originals) last dimensions, with the previous dims extending
from the index offsets.

The dimension_map attributes describes which index value maps to which
dimension in the destionation. It cannot contain duplicate values, must
have as many entries as index depth, and values must be within the rank of
the destination.
dimension in the destionation. It's rank must equal `index_depth` as
represents a permutation of the indices before indexing into `original``.

The unique_indices attribute carries the information whether all the indices
are unique. If there are repeated indices, the first iteration loop will be
marked as reduction.
The unique_indices attribute carries the information whether all the
indices are unique. If `unique_indices` is `true` and two or more updates
scatter to the same location in `original` the final value in `original` is
not guaranteed. If `unique_indices` is set to false, the first
`batch_rank` iteration loops will be marked as reduction.

The shapes definition follows tensorflow operations execept that it force
batch dims to be 1D. See more information in
The shapes definition follows tensorflow operations. See more information in
https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update
}];
let arguments = (ins
Expand All @@ -159,6 +177,9 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
$region (`->` type($results)^)?
}];
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
static constexpr unsigned kUpdatesOpNum = 0;
static constexpr unsigned kIndicesOpNum = 1;
static constexpr unsigned kOriginalOpNum = 2;

int64_t getIndexDepth() {
return cast<ShapedType>(getDpsInputOperand(1)->get().getType())
Expand Down Expand Up @@ -190,8 +211,51 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
return cast<ShapedType>(getOriginal().getType());
}

/// Utility to get the rank of the portion of `indices` that
/// represents the batch dimensions
int64_t getBatchRank() {
return getIndicesType().getRank() - 1;
}

/// Utility to get the shape of the portion of `indices` that
/// represents the batch dimensions.
ArrayRef<int64_t> getBatchShape() {
return getIndicesType().getShape().slice(0, getBatchRank());
}

/// Utility to get the rank of the portion of `updates` that
/// is scattered into `original`.
int64_t getUpdateSliceRank() {
return cast<ShapedType>(getUpdates().getType()).getRank() - 1;
return getUpdateType().getRank() - getBatchRank();
}

/// Utility to get the shape of the portion of `updates` that
/// is scattered into `original`.
ArrayRef<int64_t> getUpdateSliceShape() {
return getUpdateType().getShape().slice(getBatchRank(),
getUpdateSliceRank());
}

/// Utility to get the dimension in `updates` the corresponds
/// to the given dimension in `original`
int64_t convertOriginalDimToUpdatesDim(uint64_t dim) {
assert(dim >= 0 && dim < getOriginalType().getRank() &&
"expected dimension to be within original rank");
int64_t updateDim =
getUpdateType().getRank() - getOriginalType().getRank() + dim;
assert(updateDim >= getBatchRank() &&
"dim doesn't map to a dim in updates");
return updateDim;
}

/// Get the dimension in `original` that corresponds to the given
/// dimension in `original`.
int64_t convertUpdatesDimToOriginalDim(uint64_t dim) {
assert(dim >= getBatchRank() &&
"update batch dim doesn't map to original");
assert(dim < getUpdateType().getRank() &&
"expected dimension to be within updates rank");
return getOriginalType().getRank() - getUpdateType().getRank() + dim;
}

bool isScalarUpdate() {
Expand Down
Loading

0 comments on commit 6b64fca

Please sign in to comment.