From 98afd89d141859053b8e91d8ba5ac63ba34beab7 Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Mon, 6 Jan 2025 20:14:21 +0000 Subject: [PATCH] Ensure all torch inputs are contiguous for ttrt Modify input tensors to be contgiguous in-place inside `tt_mlir.run` --- tt_torch/csrc/bindings.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tt_torch/csrc/bindings.cpp b/tt_torch/csrc/bindings.cpp index d79b80a..4ed62af 100644 --- a/tt_torch/csrc/bindings.cpp +++ b/tt_torch/csrc/bindings.cpp @@ -108,7 +108,7 @@ static torch::Tensor create_torch_tensor(const tt::runtime::Tensor &tensor, return torch_tensor; } -std::vector run(const std::vector &inputs, +std::vector run(std::vector &inputs, py::bytes byte_stream) { std::string data_str = byte_stream; @@ -127,6 +127,15 @@ std::vector run(const std::vector &inputs, int program_idx = 0; auto input_descs = binary.getProgramInputs(program_idx); + for (int idx = 0; idx < inputs.size(); idx++) { + if (!inputs[idx].is_contiguous()) { + std::cout << "WARINING: Input " << idx + << " is not contiguous. Converting to contiguous in-place." + << std::endl; + inputs[idx].set_(inputs[idx].contiguous()); + } + } + std::vector rt_inputs; for (auto const &input : inputs) { rt_inputs.emplace_back(create_tensor(input));