Skip to content

Commit

Permalink
Add mutexes to protect preemtive start/exit of compile thread
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksKnezevic committed Nov 14, 2024
1 parent 5e8317d commit 4cb37e4
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 29 deletions.
1 change: 0 additions & 1 deletion results/parse_op_by_op_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def process_json_files():
if result.returncode != 0:
error = result.stderr.split("\n")[0]
trace_dump = result.stderr
print(error)
row_data = [
name,
input_shapes,
Expand Down
6 changes: 3 additions & 3 deletions tests/models/codegen/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def set_model_eval(self, model):
"mode",
["eval"],
)
@pytest.mark.xfail(
reason="Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
def test_codegen(record_property, mode):
pytest.xfail(
"Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
model_name = "codegen"
record_property("model_name", model_name)
record_property("mode", mode)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/flan_t5/test_flan_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def set_model_eval(self, model):
"mode",
["eval"],
)
@pytest.mark.xfail(
reason="Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
def test_flan_t5(record_property, mode):
pytest.xfail(
"Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
model_name = "FLAN-T5"
record_property("model_name", model_name)
record_property("mode", mode)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/gpt_neo/test_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def set_model_eval(self, model):
"mode",
["eval"],
)
@pytest.mark.xfail(
reason="Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
def test_gpt_neo(record_property, mode):
pytest.xfail(
"Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
model_name = "GPTNeo"
record_property("model_name", model_name)
record_property("mode", mode)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/hand_landmark/test_hand_landmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def set_model_eval(self, model):
"mode",
["eval"],
)
@pytest.mark.xfail(reason="Need to debud")
def test_hand_landmark(record_property, mode):
pytest.xfail("Need to debug")
model_name = "Hand Landmark"
record_property("model_name", model_name)
record_property("mode", mode)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/mobilenet_ssd/test_mobilenet_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def _load_inputs(self):
"mode",
["eval"],
)
@pytest.mark.xfail(reason="Need to debug")
def test_mobilenet_ssd(record_property, mode):
pytest.xfail("Need to debug")
model_name = "MobileNetSSD"
record_property("model_name", model_name)
record_property("mode", mode)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/opt/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def set_model_eval(self, model):
"mode",
["eval"],
)
@pytest.mark.xfail(reason="Need to debug")
def test_opt(record_property, mode):
pytest.xfail("Need to debug")
model_name = "OPT"
record_property("model_name", model_name)
record_property("mode", mode)
Expand Down
6 changes: 3 additions & 3 deletions tests/models/t5/test_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def set_model_eval(self, model):
"mode",
["eval"],
)
@pytest.mark.xfail(
reason="Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
@pytest.mark.parametrize("model_name", ["t5-small", "t5-base", "t5-large"])
def test_t5(record_property, model_name, mode):
pytest.xfail(
"Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
record_property("model_name", model_name)
record_property("mode", mode)

Expand Down
6 changes: 3 additions & 3 deletions tests/models/whisper/test_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def set_model_eval(self, model):
"mode",
["eval"],
)
@pytest.mark.xfail(
reason="Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
def test_whisper(record_property, mode):
pytest.xfail(
"Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
model_name = "Whisper"
record_property("model_name", model_name)
record_property("mode", mode)
Expand Down
4 changes: 1 addition & 3 deletions tests/models/yolov5/test_yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ def _load_inputs(self):
"mode",
["eval"],
)
@pytest.mark.xfail(reason="Fails due to pt2 compile issue")
def test_yolov5(record_property, mode):
pytest.xfail(
"Fails due to pt2 compile issue when finishing generation, but we can still generate a graph"
)
model_name = "YOLOv5"
record_property("model_name", model_name)
record_property("mode", mode)
Expand Down
24 changes: 17 additions & 7 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,27 @@ def lower_to_stable_hlo(module, op=None):
op.compilation_status = OpCompilationStatus.CONVERTED_TO_STABLE_HLO


def compile_process(receiver, sender):
def compile_process(receiver, sender, ttir_envent, ttnn_event):
obj = receiver.get()
faulthandler.disable()
asm = obj["asm"]
ttir = tt_mlir.compile_stable_hlo_to_ttir(asm)
sender.put({"ttir": ttir})
time.sleep(0.1)
ttir_envent.wait()
binary, ttnn = tt_mlir.compile_ttir_to_bytestream(ttir)
sender.put({"binary": binary, "ttnn": ttnn})
time.sleep(0.1)
ttnn_event.wait()
sys.exit(0)


def execute_process(receiver, sender):
def execute_process(receiver, sender, exec_event):
obj = receiver.get()
faulthandler.disable()
binary = obj["binary"]
inputs = obj["inputs"]
outputs = tt_mlir.run(inputs, binary)
sender.put({"outputs": outputs})
time.sleep(0.1)
exec_event.wait()
sys.exit(0)


Expand Down Expand Up @@ -242,8 +242,12 @@ def compile_op(self, node, *inputs, **kwargs):

sender = mp.Queue()
receiver = mp.Queue()
ttir_event = mp.Event()
ttnn_event = mp.Event()
obj = {"asm": module.operation.get_asm()}
process = mp.Process(target=compile_process, args=(sender, receiver))
process = mp.Process(
target=compile_process, args=(sender, receiver, ttir_event, ttnn_event)
)
process.start()
sender.put(obj)
start = time.time()
Expand All @@ -254,10 +258,12 @@ def compile_op(self, node, *inputs, **kwargs):
if "ttir" in result:
op.compilation_status = OpCompilationStatus.CONVERTED_TO_TTIR
op.add_ttir_graph(result["ttir"])
ttir_event.set()
if "binary" in result:
binary = result["binary"]
op.binary = binary
op.add_ttnn_graph(result["ttnn"])
ttnn_event.set()
op.compilation_status = OpCompilationStatus.CONVERTED_TO_TTNN
break
except mp.queues.Empty:
Expand All @@ -276,7 +282,10 @@ def run_op(self, binary, *inputs):
receiver = mp.Queue()
obj = {"binary": binary, "inputs": inputs}

process = mp.Process(target=execute_process, args=(sender, receiver))
exec_event = mp.Event()
process = mp.Process(
target=execute_process, args=(sender, receiver, exec_event)
)
process.start()
sender.put(obj)
result = {}
Expand All @@ -286,6 +295,7 @@ def run_op(self, binary, *inputs):
break
try:
result = receiver.get_nowait()
exec_event.set()
break
except mp.queues.Empty:
pass
Expand Down

0 comments on commit 4cb37e4

Please sign in to comment.