Skip to content

Commit

Permalink
Add delay between sending ttir and continuing compilation to allow pa…
Browse files Browse the repository at this point in the history
…rent process to read results before potential crash
  • Loading branch information
AleksKnezevic committed Nov 10, 2024
1 parent 4e71fd4 commit 4cef092
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
3 changes: 2 additions & 1 deletion tt_torch/csrc/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,15 @@ std::string compile_stable_hlo_to_ttir(std::string_view code) {
return ret;
}

std::tuple<py::bytes, std::string_view>
std::tuple<py::bytes, std::string>
compile_ttir_to_bytestream(std::string_view code) {
auto [binary, ttnn] = tt::torch::compileTTIRToTTNN(code);
auto size = ::flatbuffers::GetSizePrefixedBufferLength(
static_cast<const uint8_t *>(binary->get()));

std::string data_str(static_cast<const char *>(binary->get()), size);
delete binary;

return std::make_tuple(py::bytes(data_str), ttnn);
}

Expand Down
6 changes: 3 additions & 3 deletions tt_torch/csrc/tt-mlir-interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ std::string compileStableHLOToTTIR(std::string_view code) {
return buffer;
}

std::tuple<std::shared_ptr<void> *, std::string_view>
std::tuple<std::shared_ptr<void> *, std::string>
compileTTIRToTTNN(std::string_view code) {

mlir::MLIRContext context;
Expand Down Expand Up @@ -141,10 +141,10 @@ compileTTIRToTTNN(std::string_view code) {

std::string buffer;
llvm::raw_string_ostream os(buffer);
mlir_module.get()->print(os);
mlir_module->print(os);
os.flush();

return std::make_tuple(binary, std::string_view(buffer));
return std::make_tuple(binary, buffer);
}

std::shared_ptr<void> *Compile(std::string_view code) {
Expand Down
2 changes: 1 addition & 1 deletion tt_torch/csrc/tt-mlir-interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@
namespace tt::torch {
std::shared_ptr<void> *Compile(std::string_view code);
std::string compileStableHLOToTTIR(std::string_view code);
std::tuple<std::shared_ptr<void> *, std::string_view>
std::tuple<std::shared_ptr<void> *, std::string>
compileTTIRToTTNN(std::string_view code);
} // namespace tt::torch
8 changes: 5 additions & 3 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def compile_process(receiver, sender):
asm = obj["asm"]
ttir = tt_mlir.compile_stable_hlo_to_ttir(asm)
sender.put({"ttir": ttir})
time.sleep(0.1)
binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir)
sender.put({"binary": binary, "ttnn": ttnn})
time.sleep(0.1)
Expand Down Expand Up @@ -246,15 +247,16 @@ def compile_op(self, node, *inputs, **kwargs):
process.start()
sender.put(obj)
start = time.time()
result = {}
binary = None
while True:
try:
result = receiver.get_nowait()
if "ttir" in result:
op.compilation_status = OpCompilationStatus.CONVERTED_TO_TTIR
op.add_ttir_graph(result["ttir"])
if "binary" in result:
op.binary = result["binary"]
binary = result["binary"]
op.binary = binary
op.add_ttnn_graph(result["ttnn"])
op.compilation_status = OpCompilationStatus.CONVERTED_TO_TTNN
break
Expand All @@ -267,7 +269,7 @@ def compile_op(self, node, *inputs, **kwargs):
break
time.sleep(0.01)
process.join()
return result["binary"], op
return binary, op

def run_op(self, binary, *inputs):
sender = mp.Queue()
Expand Down

0 comments on commit 4cef092

Please sign in to comment.