Skip to content

Commit

Permalink
Run tt-mlir-opt on generated graph before creating spreadsheet
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksKnezevic committed Nov 8, 2024
1 parent 8f08af4 commit da46355
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 6 deletions.
2 changes: 1 addition & 1 deletion env/activate
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ else
fi
export TTTORCH_ENV_ACTIVATED=1
export TTMLIR_ENV_ACTIVATED=1
export PATH=$TTMLIR_TOOLCHAIN_DIR/bin:$PATH
export PATH=$TT_TORCH_HOME/third_party/tt-mlir/src/tt-mlir-build/bin:$TTMLIR_TOOLCHAIN_DIR/bin:$PATH
export TOKENIZERS_PARALLELISM=false
if [ -n "$PROJECT_ROOT" ]; then
export TT_METAL_HOME="$PROJECT_ROOT/third_party/tt-mlir/src/tt-mlir/third_party/tt-metal/src/tt-metal"
Expand Down
57 changes: 52 additions & 5 deletions results/parse_op_by_op_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import csv
import xlsxwriter

import subprocess

# Script to parse the results of the unique ops json files and combine them into a spreadsheet
# This script parses models compiled into stable hlo / TTIR op by op
def find_json_files(directory="results"):
Expand Down Expand Up @@ -79,6 +81,10 @@ def process_json_files():
"Status",
"Ops",
"Raw SHLO",
"Raw TTIR",
"Raw TTNNIR",
"Compile Error",
"Trace dump",
)
worksheet.write_row(row, 0, header, bold)
row += 1
Expand All @@ -99,6 +105,9 @@ def process_json_files():
"status": value["compilation_status"],
"stable_hlo_graph": value["stable_hlo_graph"],
"ops": value["stable_hlo_ops"],
"ttir_graph": value["ttir_graph"],
"ttnn_graph": value["ttnn_graph"],
"key": key,
}
)
ops_per_model[model_name] = list(torch_ops.keys())
Expand All @@ -111,18 +120,44 @@ def process_json_files():
for torch_name, torch_op in sorted(torch_ops.items()):
stable_hlo_ops_per_torch_op[torch_name] = set()
name = torch_name
test_num = 0
for op in torch_op:
num_ops = op["num_ops"]
input_shapes = extract_shape(op["input_shapes"])
output_shapes = extract_shape(op["output_shapes"])
status = op["status"]
raw_shlo = op["stable_hlo_graph"]
ops = op["ops"]
row_data = [name, input_shapes, output_shapes, num_ops, status]
error = ""
trace_dump = ""
if status == 5 or status == 4:
if status == 5:
# Does not compile to TTNNIR, create unit test
test_name = f"{torch_name}_{test_num}.mlir"
test_num += 1
with open(f"results/mlir_tests/ttir/{test_name}", "w") as f:
f.write(op["ttir_graph"])

result = subprocess.run(["ttmlir-opt", "--ttir-to-ttnn-backend-pipeline", f"results/mlir_tests/ttir/{test_name}"], capture_output=True, text=True)
elif status == 4:
# Does not compile to TTIR, create unit test
test_name = f"{torch_name}_{test_num}.mlir"
test_num += 1
with open(f"results/mlir_tests/stable_hlo/{test_name}", "w") as f:
f.write(op["stable_hlo_graph"])

result = subprocess.run(["ttmlir-opt", "--stablehlo-to-ttir-pipeline=enable-remove-dead-values=true", f"results/mlir_tests/stable_hlo/{test_name}"], capture_output=True, text=True)
if result.returncode != 0:
error = result.stderr.split("\n")[0]
trace_dump = result.stderr
print(error)
row_data = [name, input_shapes, output_shapes, num_ops, status, "", raw_shlo, op["ttir_graph"], op["ttnn_graph"], error, trace_dump]
all_ops[op["key"]]["error"] = error
all_ops[op["key"]]["trace_dump"] = trace_dump
worksheet.write_row(row, 0, row_data)
name = ""
row += 1
row_data = ["", "", "", "", "", raw_shlo]
row_data = ["", "", "", "", "", "", raw_shlo]
worksheet.write_row(row, 0, row_data)
worksheet.set_row(row, None, None, {"hidden": True})
row += 1
Expand Down Expand Up @@ -150,6 +185,10 @@ def process_json_files():
"Status",
"Ops",
"Raw SHLO",
"Raw TTIR",
"Raw TTNNIR",
"Compile Error",
"Trace dump",
)
worksheet.write_row(row, 0, header, bold)
row += 1
Expand All @@ -170,6 +209,10 @@ def process_json_files():
"status": value["compilation_status"],
"stable_hlo_graph": value["stable_hlo_graph"],
"ops": value["stable_hlo_ops"],
"ttir_graph": value["ttir_graph"],
"ttnn_graph": value["ttnn_graph"],
"error": value["error"],
"trace_dump": value["trace_dump"],
}
)

Expand All @@ -182,11 +225,15 @@ def process_json_files():
status = op["status"]
raw_shlo = op["stable_hlo_graph"]
ops = op["ops"]
row_data = [name, input_shapes, output_shapes, num_ops, status]
ttir_graph = op["ttir_graph"]
ttnn_graph = op["ttnn_graph"]
error = op["error"]
trace_dump = op["trace_dump"]
row_data = [name, input_shapes, output_shapes, num_ops, status, "", raw_shlo, ttir_graph, ttnn_graph, error, trace_dump]
name = ""
worksheet.write_row(row, 0, row_data)
row += 1
row_data = ["", "", "", "", "", raw_shlo]
row_data = ["", "", "", "", "", "", raw_shlo, ttir_graph, ttnn_graph]
worksheet.write_row(row, 0, row_data)
worksheet.set_row(row, None, None, {"hidden": True})
row += 1
Expand All @@ -196,7 +243,7 @@ def process_json_files():
worksheet.set_row(row, None, None, {"hidden": True})
row += 1

worksheet.autofit()
# worksheet.autofit()

ops = list(models_per_op.keys())
ops.sort()
Expand Down

0 comments on commit da46355

Please sign in to comment.