-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Invalid Runtime inputs to embedding #952
Comments
@dgolubovicTT can you also include the dumped mlir graph here as well in the ttnn dialect? |
Ofc. It is the same in both cases (input indices int32 and int64), which is expected. |
The issue is probably coming from this function:
For int64 ( |
Makes sense. What do you suggest as the solution? We need to be able to handle int64 as input dtype. |
@dgolubovicTT ttnn runtime doesn't support 64-bit values. We'll just have to cast them to 32 bit values before running anything through runtime. |
So it is up to user to cast any int64 to int32. @nvukobratTT fyi. |
It doesn't seem right to let this be user responsibility. @jnie-TT since ttnn doesn't support int64, can we cast torch tensor from int64 to int32 before we create ttnn runtime tensors? |
@dgolubovicTT yeah you can cast it before creating runtime tensors, as long as when it gets to runtime it's not in 64 bit dataformat. |
@jnie-TT Is this a valid approach? Shouldn't runtime do a casting if forwarded activation, and data format definition in IR don't match? The main reason is that the front end can't always accurately track how the graph is changing. Here are 2 main cases:
Having 2. case in mind, the valid approach to cover both seems as follows:
Let us know your thoughts. |
@nvukobratTT I don't think I'm following point 2. In my understanding FE and runtime see the same thing which is the final compiled flatbuffer which will contain the final resulting input/output data format. Frontend can extract the input output dataformat directly from the flatbuffer and create the tensors. Any recompile/reoptimization will create a new flatbuffer containing any updated data formats. The final data type for any tensors on the IR level should be equally visible between runtime and FE. This is also what's happening in ttrt. One problem with runtime doing typecast implicitly is that FE needs to create the output tensor beforehand, and then memcpy the runtime output tensor into the pre-created output tensor. If runtime were to implicitly typecast to int32, the output data format may also end up changing to int32 as well. In this case, FE would need to override the data format of the pre-created output tensor to ensure the memcpy is valid. Therefore for better predictability, it might be better for the FE to explicitly do the typecast on the input as well, this way there are no surprises and the input/output flow are consistent. |
I am getting data mismatch for embedding op when input indices are int64. Turns out that input indices are invalid right before call to ttnn::embedding in runtime:
third_party/tt-mlir/runtime/lib/ttnn/operations/embedding/embedding.cpp
Namely, input indices in torch look like this:
16044 8239 2933 13760 16963 16379 31427 6503 31497 9683 14101 26866
In
third_party/tt-mlir/runtime/lib/ttnn/operations/embedding/embedding.cpp
run method they look like this:16044 0 8239 0 2933 0 13760 0 16963 0 16379 0 , data type is UINT32
So every other index is now 0, which causes embedding to return invalid rows from embedding matrix.
I've verified that input tensor is still valid in
forge/csrc/runtime/runtime.cpp::run_binary
, so the issue is somewhere between.To repro:
checkout: dgolubovic/repro-embedding-input-indices-issue
run: pytest -svv forge/test/mlir/test_ops.py::test_embedding
One additional note: If the input indices are int64:
This issue occurs. However, If input indices is int32, everything works.
On the other hand, ttnn::embedding op accepts dtype=DataType::UINT32, so runtime tensor has to have that datatype. @jnie-TT do you know where do we do this datatype conversion from pytorch tensors to ttnn.Tensor?
The text was updated successfully, but these errors were encountered: