Skip to content

Commit

Permalink
add scatter op lowering unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Pokemons386 committed Nov 9, 2023
1 parent 2d91798 commit 520e5a9
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
32 changes: 31 additions & 1 deletion tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,10 @@ Value elementalLower<lmhlo::RealDynamicSliceOp>(OpBuilder* b, Location loc,
mayCreateStore(b, loc, op.getOperation(), result, output_index, lower_config);
return result;
}

namespace {



template <typename T>
Value elementalLowerImplForBroadcastInDimOps(OpBuilder* b, Location loc,
T broadcast_in_dim,
Expand Down Expand Up @@ -504,6 +505,8 @@ Value elementalLowerImplForBroadcastInDimOps(OpBuilder* b, Location loc,

} // namespace



template <>
Value elementalLower<lmhlo::DynamicBroadcastInDimOp>(
OpBuilder* b, Location loc, lmhlo::DynamicBroadcastInDimOp op,
Expand All @@ -514,6 +517,33 @@ Value elementalLower<lmhlo::DynamicBroadcastInDimOp>(
return result;
}

template<>
Value elementalLower<lmhlo::ScatterOp>(OpBuilder* b, Location loc,
lmhlo::ScatterOp op, ValueRange output_index,
bool check_cache,
LowerConfig* lower_config) {

int rank = output_index.size();
SmallVector<Value, 4> input_index;
for (int dim = 0; dim < rank; ++dim) {
input_index.push_back(output_index[dim]);
}

Value operand_memref = *(op->getOperands().begin());

Value result;
if (!check_cache) {
result = createMaySpecificLoad(*b, loc, op.getOperation(), operand_memref,
input_index, lower_config);
} else {
result = createLoadOrUseCachedValue(loc, b, op.getOperation(),
operand_memref, input_index,
b->saveInsertionPoint(), lower_config);
}
mayCreateStore(b, loc, op.getOperation(), result, output_index, lower_config);
return result;
}

template <>
Value elementalLower<lmhlo::BroadcastInDimOp>(OpBuilder* b, Location loc,
lmhlo::BroadcastInDimOp op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ LogicalResult lowerHelper(OpBuilder& b, Location loc, Operation* op,
succeeded(miscLowerHelper<lmhlo::ReverseOp>(
b, loc, op, output_linear_index, shape_analysis, vector_size, lower_config)) ||
succeeded(miscLowerHelper<lmhlo::DynamicUpdateSliceOp>(
b, loc, op, output_linear_index, shape_analysis, vector_size, lower_config)) ||
succeeded(miscLowerHelper<lmhlo::ScatterOp>(
b, loc, op, output_linear_index, shape_analysis, vector_size, lower_config))
) {
return success();
Expand Down Expand Up @@ -5749,6 +5751,7 @@ struct DiscLhloLegalizeRootsToParallelLoops
// TODO(disc): single nodes with non kLoop schedule like ReduceOp
// is not implemented yet. Currently ReduceOp is lowered with loop
// schedule, which means for poor performance.

if (failed(lowerWithScheduleLoop({op}, op, nullptr,
/*non_fusion=*/true,
/*parallel_loop=*/true))) {
Expand Down
24 changes: 24 additions & 0 deletions tao_compiler/mlir/disc/transforms/tests/scatter-legalize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// RUN: disc-opt -split-input-file -disc-hlo-legalize-to-lhlo -hlo-legalize-to-lhlo \
// RUN: -canonicalize -split-input-file %s -o - | FileCheck %s

// -----


// CHECK-LABEL: @test_scatterop_lowering
// CHECK-SAME: %[[ARG0:.*]]: memref<32000x4096xf32>, %[[ARG1:.*]]: memref<8193x1xi64>, %[[ARG2:.*]]: memref<8192x4096xf32>
func.func @test_scatterop_lowering(%arg0: tensor<32000x4096xf32>, %arg1: tensor<8192x1xi64>, %arg2: tensor<8192x4096xf32>) -> tensor<32000x4096xf32>
attributes {
tf.entry_function = {
input_placements = "cpu, cpu, cpu",
inputs = "input0, input1, input2",
output_placements = "cpu",
outputs = "output0"
}} {
%2 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg143: tensor<f32>, %arg144: tensor<f32>):
%1 = mhlo.add %arg143, %arg144 : tensor<f32>
mhlo.return %1 : tensor<f32>
}) {indices_are_sorted = false, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = false} : (tensor<32000x4096xf32>, tensor<8192x1xi64>, tensor<8192x4096xf32>) -> tensor<32000x4096xf32>

return %2 : tensor<32000x4096xf32>
}

0 comments on commit 520e5a9

Please sign in to comment.