Skip to content

Commit

Permalink
Ensure all torch inputs are contiguous for ttrt
Browse files Browse the repository at this point in the history
Modify input tensors to be contgiguous in-place inside `tt_mlir.run`
  • Loading branch information
LPanosTT committed Jan 7, 2025
1 parent 3bdfb8b commit 98afd89
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion tt_torch/csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ static torch::Tensor create_torch_tensor(const tt::runtime::Tensor &tensor,
return torch_tensor;
}

std::vector<at::Tensor> run(const std::vector<at::Tensor> &inputs,
std::vector<at::Tensor> run(std::vector<at::Tensor> &inputs,
py::bytes byte_stream) {

std::string data_str = byte_stream;
Expand All @@ -127,6 +127,15 @@ std::vector<at::Tensor> run(const std::vector<at::Tensor> &inputs,
int program_idx = 0;
auto input_descs = binary.getProgramInputs(program_idx);

for (int idx = 0; idx < inputs.size(); idx++) {
if (!inputs[idx].is_contiguous()) {
std::cout << "WARINING: Input " << idx
<< " is not contiguous. Converting to contiguous in-place."
<< std::endl;
inputs[idx].set_(inputs[idx].contiguous());
}
}

std::vector<tt::runtime::Tensor> rt_inputs;
for (auto const &input : inputs) {
rt_inputs.emplace_back(create_tensor(input));
Expand Down

0 comments on commit 98afd89

Please sign in to comment.