-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for nan_to_num, atan2 & bitwise_or op #57
base: main
Are you sure you want to change the base?
Conversation
b02be0c
to
9939beb
Compare
e66c611
to
2924f46
Compare
python/tvm/relay/frontend/pytorch.py
Outdated
|
||
dtype = input_types[0] | ||
|
||
assert dtype == "float32", f"Expected dtype to be float32, but got {dtype}. Support for {dtype} is not added yet." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems implementation supports all float dtypes. can we also add support for int types too? And remove this assert.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Since NaN is inherently a concept in floating-point data types, the occurrence of NaN values in integer tensors is highly improbable. Integer tensors do not support NaN values, and thus, it is unnecessary to check for NaN in such tensors during model inference.
-
for float16 & float64, I faced these errors nan_to_num_float16.log , nan_to_num_float64.log
-
Due to above reasons & to push PETR model to next stage , support for float32 dtype alone is added now . will add support for other datatypes if it is needed in future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For float16, seems issue in ttnn, can we raise an issue in ttnn and enable the support here?
For float64, can we enable the support in Forge?
BTW, This is not blocker for this PR, We can add support to this latter. What do you think @nvukobratTT ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even as TTNN doesn't support it, I don't see why we should limit our compiler. Eventually, TTNN should add support for this.
That said, let's test out different data formats, and mark the ones that are unsupported on TTNN side as xfailed. No need to open issues for this right now. Let's just use xfail for tracking these at the moment..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Tests added for float32 & float16
-
support for this float64 data format need to be added from front end itself it seems. nan_to_num_float64.log
-
As adding support for this dtype is not proirity now, will a create a seperate PR for it if needed in future.
2917d2e
to
d951607
Compare
@@ -4620,6 +4620,67 @@ def scaled_dot_product_attention(self, inputs, input_types): | |||
attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2]) | |||
|
|||
return attn_weight | |||
|
|||
|
|||
def nan_to_num(self, inputs, input_types): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add function docs.
Same for other functions added in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As explanations are added as comments, I just added function docs like below. Let me know if any specific format should be followed.
python/tvm/relay/frontend/pytorch.py
Outdated
|
||
dtype = input_types[0] | ||
|
||
assert dtype == "float32", f"Expected dtype to be float32, but got {dtype}. Support for {dtype} is not added yet." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even as TTNN doesn't support it, I don't see why we should limit our compiler. Eventually, TTNN should add support for this.
That said, let's test out different data formats, and mark the ones that are unsupported on TTNN side as xfailed. No need to open issues for this right now. Let's just use xfail for tracking these at the moment..
d951607
to
5d4e3f4
Compare
…nan decomposition logic
5d4e3f4
to
39215fb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Summary
Logs