Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse transposition with linalg_ext.scatter #19091

Open
rsuderman opened this issue Nov 9, 2024 · 3 comments
Open

Fuse transposition with linalg_ext.scatter #19091

rsuderman opened this issue Nov 9, 2024 · 3 comments
Assignees

Comments

@rsuderman
Copy link
Contributor

Sample IR shown below produces two dispatches. For linalg_ext.scatter we need to support fusion with transposition and collapse / expand. We do this as arg0 could be created via a concatenation of contiguous memory.

func.func @main(%arg0: tensor<2x4x?x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {

    %c2 = arith.constant 2 : index
    %dim = tensor.dim %arg0, %c2 : tensor<2x4x?x16x4x128xf16>

    %empty = tensor.empty(%dim) : tensor<4x?x2x16x4x128xf16>
    %transpose = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d0, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x4x?x16x4x128xf16>) outs(%empty : tensor<4x?x2x16x4x128xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<4x?x2x16x4x128xf16>

    %collapsed = tensor.collapse_shape %transpose [[0, 1], [2], [3], [4], [5]] : tensor<4x?x2x16x4x128xf16> into tensor<?x2x16x4x128xf16>
    
    %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%collapsed, %arg1 : tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
    ^bb0(%arg7: f16, %arg8: f16):
         iree_linalg_ext.yield %arg7 : f16
    } -> tensor<?x2x16x4x128xf16>
    return %1 : tensor<?x2x16x4x128xf16>
}
@MaheshRavishankar
Copy link
Contributor

I was playing around with this a bit... This is possible with some tweaks, but need to figure out these two issues

  1. The tensor.collapse_shape needs to be propagated "down" past the scatter. Might be a simple thing to do, but not sure it could be hard since I dont understand scatter op semantics. I tried to do it by hand and I kept getting validation errors
  2. For the transpose operation, the input indexing map is identity and the output indexing map is not. We want to be able to make the output indexing map of this transpose identity. The issue is that we actually do the opposite here. Honestly that pattern is a bit of a hack cause havent really figure out where to do it. It has a spooky action at a distance feel to it. In any case that is an issue.

Lets say we fix the issue and do something like this

func.func @main(%arg0: tensor<2x?x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {

    %c1 = arith.constant 1 : index
    %dim = tensor.dim %arg0, %c1 : tensor<2x?x16x4x128xf16>

    %empty = tensor.empty(%dim) : tensor<?x2x16x4x128xf16>
    %transpose = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d1, d0, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x?x16x4x128xf16>) outs(%empty : tensor<?x2x16x4x128xf16>) {
    ^bb0(%in: f16, %out: f16):
      linalg.yield %in : f16
    } -> tensor<?x2x16x4x128xf16>

    %1 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%transpose, %arg1 : tensor<?x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
    ^bb0(%arg7: f16, %arg8: f16):
         iree_linalg_ext.yield %arg7 : f16
    } -> tensor<?x2x16x4x128xf16>
    return %1 : tensor<?x2x16x4x128xf16>
}

This gets fused, but only under aggressive fusion. So compiling this with --iree-dispatch-creation-aggressive-fusion=true makes this one dispatch.

@IanWood1 to start with maybe move scatter fusion out of aggressive fusion. As a side maybe we should just try turning aggressive fusion on by default and see where we stand.

@IanWood1
Copy link
Contributor

IanWood1 commented Dec 10, 2024

Here's a summary here of the changes we found were needed for scatter fusion. I have attached an example scatter. To simplify the explanation, I'll refer to the to the %update operand as a composition of two parts: a "batch" dimension (number of updates) of shape ? and portion "slice" of shape 1x16x8x128:

For clarity, the %update operand is considered as a composition of two parts:

  • "Batch" Dimension: Represents the number of updates, with a shape of ?.
  • "Slice" Portion: Represents the slice, with a shape of 1x16x8x128.
 %85 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) 
    ins(%update, %indicies : tensor<?x1x16x8x128xf16>, tensor<?x1xi32>) 
    outs(%output: tensor<?x16x8x128xf16>) 
    {...}

Required Changes:

(1) probably makes sense to add to FoldUnitExtentDims.cpp and (2) and (3) should be added to BubbleUpExpandShapes.cpp.

1) Drop the unit dimensions from "slice" portion of the %update operand of the scatter op

For example, change %update's type to tensor<?x1x16x8x128xf16> -> tensor<?x16x8x128xf16>.

2) Fuse with reshapes that modify the "slice" portion of the %update and/or %output operands

For example, fuse with a collapse shape that expands %update's second dim from 16 to 4x4:
tensor<?x16x8x128xf16> -> tensor<?x4x4x8x128xf16>.
This also requires updating the %output operand and introducing a tensor.collapse_shape

3) Fuse with reshapes that modify the "batch" portion of the %update and/or %output operands.

This will be more tricky since it will require changing verifiers and updating the operation documentation. If the "batch" portion of the operand gets expanded from ? to 4x? then %indicies will also have to be expanded to tensor<4x?x1xi32> to account for the change from a list of updates to a 2D tensor of updates.

@MaheshRavishankar @qedawkins

@IanWood1
Copy link
Contributor

Update: I opened 2 PRs. The first PR #19560 makes changes to scatter needed to do 1-3. The second #19450 is for dropping the unit dims on scatter ops. I'm working on reshape fusion and have changes on main...IanWood1:iree:scatter_organized.

IanWood1 added a commit that referenced this issue Jan 6, 2025
This change adds patterns to drop the unit dims of a
`iree_linalg_ext.scatter`'s `%updates` tensor. It only drops the leading
unit dimensions from the portion of `updates` that represents the
indexed dimensions.


See the main issue #19091

---------

Signed-off-by: Ian Wood <[email protected]>
IanWood1 added a commit that referenced this issue Jan 9, 2025
Implements fusion with reshapes by expansion for `LinalgExt::ScatterOp`.

See main issue #19091

---------

Signed-off-by: Ian Wood <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants