-
Notifications
You must be signed in to change notification settings - Fork 163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support ScatterOp Codegen #1270
Conversation
0664a4e
to
91d3a3f
Compare
91d3a3f
to
520e5a9
Compare
3505b6d
to
72b23e5
Compare
e0fad5d
to
6e778a0
Compare
a83f009
to
dadfe54
Compare
b8b5e2a
to
422fc9a
Compare
auto sliceOpResultType = RankedTensorType::get(indexType.getShape(), srcType.getElementType()); | ||
src = rewriter.create<mhlo::RealDynamicSliceOp>(loc, getTypeConverter()->convertType(sliceOpResultType), src, baseIndicesValue, limitIndicesValue, stridesValue); | ||
|
||
// Construct ScatterDimensionNumbersAttr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate comments with L1595 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
||
// Reference implementation: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_lower_util.cpp#L139 | ||
template <> | ||
LogicalResult ConvertAtenOp<AtenScatterSrcOp>::matchAndRewrite( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this rewriter function is too long, can you add some pesucode to describe how to lower the AtenScatterOp briefly?
@@ -295,5 +295,26 @@ def LHLO_ArgsMutationOp : LHLODISC_Op<"args_mutation", []> { | |||
); | |||
} | |||
|
|||
def LHLODISC_InplaceScatterOp: LHLODISC_Op<"inplace_scatter", []> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need a new operator to define a INPLACE op? Maybe we simple way is if (op.operand[0] == op.operand[-1])
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah you are right. Fixed.
output_placements = "cpu", | ||
outputs = "output0" | ||
}} { | ||
%2 = "mhlo.scatter"(%arg0, %arg1, %arg2) ({ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this ut is too simple, please check the IR of scatter op lowing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
c7c9589
to
d1eece4
Compare
bb5c6e6
to
799b5b7
Compare
799b5b7
to
002f0bd
Compare
002f0bd
to
20cb101
Compare
No description provided.