Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalid Runtime inputs to embedding #952

Open
dgolubovicTT opened this issue Dec 23, 2024 · 10 comments
Open

Invalid Runtime inputs to embedding #952

dgolubovicTT opened this issue Dec 23, 2024 · 10 comments
Assignees

Comments

@dgolubovicTT
Copy link
Contributor

I am getting data mismatch for embedding op when input indices are int64. Turns out that input indices are invalid right before call to ttnn::embedding in runtime: third_party/tt-mlir/runtime/lib/ttnn/operations/embedding/embedding.cpp

Namely, input indices in torch look like this:

16044 8239 2933 13760 16963 16379 31427 6503 31497 9683 14101 26866

In third_party/tt-mlir/runtime/lib/ttnn/operations/embedding/embedding.cpp run method they look like this:

16044 0 8239 0 2933 0 13760 0 16963 0 16379 0 , data type is UINT32

So every other index is now 0, which causes embedding to return invalid rows from embedding matrix.

I've verified that input tensor is still valid in forge/csrc/runtime/runtime.cpp::run_binary, so the issue is somewhere between.

To repro:

checkout: dgolubovic/repro-embedding-input-indices-issue
run: pytest -svv forge/test/mlir/test_ops.py::test_embedding

One additional note: If the input indices are int64:

    inputs = [
        torch.randint(0, vocab_size, (1, token_num)),
    ]

This issue occurs. However, If input indices is int32, everything works.

    inputs = [
        torch.randint(0, vocab_size, (1, token_num),dtype=torch.int32),
    ]

On the other hand, ttnn::embedding op accepts dtype=DataType::UINT32, so runtime tensor has to have that datatype. @jnie-TT do you know where do we do this datatype conversion from pytorch tensors to ttnn.Tensor?

@jnie-TT
Copy link
Contributor

jnie-TT commented Dec 23, 2024

@dgolubovicTT can you also include the dumped mlir graph here as well in the ttnn dialect?

@dgolubovicTT
Copy link
Contributor Author

Ofc. It is the same in both cases (input indices int32 and int64), which is expected.

Embedding_test_ops_ttnn.txt

@jnie-TT
Copy link
Contributor

jnie-TT commented Dec 30, 2024

The issue is probably coming from this function:

static target::DataType torch_scalar_type_to_dt(torch::ScalarType st)
{
    switch (st)
    {
        case torch::ScalarType::Byte: return target::DataType::UInt8;
        case torch::ScalarType::Char: return target::DataType::UInt8;
        case torch::ScalarType::Short: return target::DataType::UInt16;
        case torch::ScalarType::Int: return target::DataType::UInt32;
        case torch::ScalarType::Long: return target::DataType::UInt32;
        case torch::ScalarType::Half: return target::DataType::Float16;
        case torch::ScalarType::Float: return target::DataType::Float32;
        // case torch::ScalarType::Double:
        // case torch::ScalarType::ComplexHalf:
        // case torch::ScalarType::ComplexFloat:
        // case torch::ScalarType::ComplexDouble:
        // case torch::ScalarType::Bool:
        case torch::ScalarType::BFloat16: return target::DataType::BFloat16;
        default: break;
    }

    log_fatal(LogTTDevice, "Unhandled dtype {}", st);
}

For int64 (case torch::ScalarType::Long), it's returning DataType::UInt32 which gets passed as the data type into create_tensor. When runtime creates the borrowed runtime tensor, it'll interpret every 32 bit chunk as an independent value, it doesn't truncate the original int64 (it doesn't know it was int64 in the first place), rather every int64 is being viewed as 2 separate uint32 values. That's why you're seeing 0s in between the values, those are the upper 32 bits of the original int64 being viewed as an independent uint32 value.

@dgolubovicTT
Copy link
Contributor Author

Makes sense. What do you suggest as the solution? We need to be able to handle int64 as input dtype.

@jnie-TT
Copy link
Contributor

jnie-TT commented Jan 8, 2025

@dgolubovicTT ttnn runtime doesn't support 64-bit values. We'll just have to cast them to 32 bit values before running anything through runtime.

@dgolubovicTT
Copy link
Contributor Author

So it is up to user to cast any int64 to int32. @nvukobratTT fyi.

@jnie-TT jnie-TT assigned dgolubovicTT and unassigned jnie-TT Jan 8, 2025
@dgolubovicTT
Copy link
Contributor Author

It doesn't seem right to let this be user responsibility. @jnie-TT since ttnn doesn't support int64, can we cast torch tensor from int64 to int32 before we create ttnn runtime tensors?

@jnie-TT
Copy link
Contributor

jnie-TT commented Jan 9, 2025

@dgolubovicTT yeah you can cast it before creating runtime tensors, as long as when it gets to runtime it's not in 64 bit dataformat.

@nvukobratTT
Copy link
Contributor

@dgolubovicTT yeah you can cast it before creating runtime tensors, as long as when it gets to runtime it's not in 64 bit data format.

@jnie-TT Is this a valid approach? Shouldn't runtime do a casting if forwarded activation, and data format definition in IR don't match? The main reason is that the front end can't always accurately track how the graph is changing.

Here are 2 main cases:

  1. Frontend does graph optimization and changes the input data format from the original one
  • In this case, we can track if DF has changed before generating TTIR, and therefore do the case in frontend runtime (before pushing it to the device for example).
  1. MLIR does graph optimization and changes the input data format from the original one
  • In this case, frontends can't easily track if something changed from the original graph we generated. For example, if MLIR changes the input data type during TTIR => TTNN conversion, we will not know that and therefore we'll still send the tensor in it's original.

Having 2. case in mind, the valid approach to cover both seems as follows:

  • MLIR runtime should check forwarded tensor dtype, and the one defined in IR level, and do a casting if requested by IR.

Let us know your thoughts.

@jnie-TT
Copy link
Contributor

jnie-TT commented Jan 10, 2025

@nvukobratTT I don't think I'm following point 2. In my understanding FE and runtime see the same thing which is the final compiled flatbuffer which will contain the final resulting input/output data format. Frontend can extract the input output dataformat directly from the flatbuffer and create the tensors. Any recompile/reoptimization will create a new flatbuffer containing any updated data formats. The final data type for any tensors on the IR level should be equally visible between runtime and FE. This is also what's happening in ttrt.

One problem with runtime doing typecast implicitly is that FE needs to create the output tensor beforehand, and then memcpy the runtime output tensor into the pre-created output tensor. If runtime were to implicitly typecast to int32, the output data format may also end up changing to int32 as well. In this case, FE would need to override the data format of the pre-created output tensor to ensure the memcpy is valid. Therefore for better predictability, it might be better for the FE to explicitly do the typecast on the input as well, this way there are no surprises and the input/output flow are consistent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants