Skip to content

Commit

Permalink
Add onnx compile
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Nov 7, 2024
1 parent 6503c70 commit 564c0f4
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 11 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ matplotlib
mlp_mixer_pytorch
opencv-python
xlsxwriter
onnx
onnxruntime
33 changes: 30 additions & 3 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x + x
def forward(self, x, y):
return torch.add(x, y)

verify_module(Basic(), [(256, 256)])
verify_module(Basic(), [(1, 1, 32, 32)] * 2)


def test_concat_dim0():
Expand Down Expand Up @@ -353,3 +353,30 @@ def forward(self, x):
cc = CompilerConfig()
cc.compile_depth = tt_torch.tools.utils.CompileDepth.COMPILE_OP_BY_OP
verify_module(Basic(), [(256, 256)], compiler_config=cc, do_assert=False)


def test_onnx():

import onnx

import torch_mlir.dialects.torch
import tt_mlir

from torch_mlir.compiler_utils import (
OutputType,
run_pipeline_with_repro_report,
lower_mlir_module,
)

class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.add(x, y)

torch_model = Basic()
torch_input = (torch.randn(1, 1, 32, 32), torch.randn(1, 1, 32, 32))
torch.onnx.export(torch_model, torch_input, "add.onnx")

verify_module("add.onnx")
22 changes: 22 additions & 0 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0
import torch
import onnx

from torch._dynamo.backends.common import aot_autograd
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads
Expand All @@ -13,6 +15,8 @@
import tt_mlir
from torch_mlir.ir import Context
from torch_mlir.extras.fx_importer import FxImporter
from torch_mlir.extras import onnx_importer

from torch_mlir.dialects import torch as torch_dialect

from torch_mlir.compiler_utils import (
Expand Down Expand Up @@ -92,6 +96,7 @@ def lower_to_stable_hlo(module, op=None):
f"builtin.module(torchdynamo-export-to-torch-backend-pipeline)",
"Lowering TorchFX IR -> Torch Backend IR",
)
module.dump()
if op is not None:
op.compilation_status = OpCompilationStatus.CONVERTED_TO_TORCH_BACKEND_IR
lower_mlir_module(False, OutputType.STABLEHLO, module)
Expand Down Expand Up @@ -378,3 +383,20 @@ def backend(gm, example_inputs, options=None):


# backend = aot_autograd(fw_compiler=_base_backend)


def compile_onnx_module_to_stablehlo(module: onnx.ModelProto):
context = Context()
torch_dialect.register_dialect(context)
module_info = onnx_importer.ModelInfo(module)
module = module_info.create_module(context=context).operation
imp = onnx_importer.NodeImporter.define_function(module_info.main_graph, module)
imp.import_all()

run_pipeline_with_repro_report(
module,
"builtin.module(torch-onnx-to-torch-backend-pipeline)",
"Lowering Torch Onnx IR -> Torch Backend IR",
)
lower_mlir_module(False, OutputType.STABLEHLO, module)
return module
121 changes: 113 additions & 8 deletions tt_torch/tools/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
#
# SPDX-License-Identifier: Apache-2.0
import torch
import onnx
from onnxruntime import InferenceSession
import numpy as np
from tt_torch.dynamo.backend import backend
import tt_mlir
from tt_torch.dynamo.backend import backend, compile_onnx_module_to_stablehlo


def verify_module(
def _verify_torch_module(
mod,
input_shapes,
input_data_types=[torch.float32],
required_pcc=0.99,
required_atol=1e-2,
input_range=(-0.5, 0.5),
compiler_config=None,
do_assert=True,
input_data_types,
required_pcc,
required_atol,
input_range,
compiler_config,
do_assert,
):
tt_mod = torch.compile(mod, backend=backend, options=compiler_config)

Expand Down Expand Up @@ -43,3 +46,105 @@ def verify_module(
)
)
assert (do_assert and pcc) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}"


def _verify_onnx_module(
filename,
input_data_types,
required_pcc,
required_atol,
input_range,
compiler_config,
do_assert,
):

sess = InferenceSession(filename)
input_shapes = [nodearg.shape for nodearg in sess.get_inputs()]

if all([dtype.is_floating_point for dtype in input_data_types]):
low, high = input_range
# Uniformly distribute random numbers within the input_range
inputs = [(low - high) * torch.rand(shape) + high for shape in input_shapes]
else:
inputs = [
torch.randint(0, 1000, shape, dtype=torch.int32) for shape in input_shapes
]

inputs_dict = {
nodearg.name: input.numpy().astype(np.float32)
if input.dtype == torch.bfloat16
else input.numpy()
for nodearg, input in zip(sess.get_inputs(), inputs)
}
golden = sess.run(None, inputs_dict)

for i in range(len(golden)):
golden[i] = torch.tensor(golden[i])

mod = onnx.load(filename)
stablehlo_mod = compile_onnx_module_to_stablehlo(mod)
binary = tt_mlir.compile(stablehlo_mod.operation.get_asm())

ret = tt_mlir.run(inputs, binary)
assert len(golden) == len(
ret
), f"Number of outputs mismatch between golden and compiled: {len(golden)} vs {len(ret)}"

for golden_out, tt_out in zip(golden, ret):
atol = torch.max(torch.abs(golden_out - tt_out)).item()
assert (
do_assert and atol
) <= required_atol, f"ATOL too high: {atol} vs {required_atol}"

if np.prod(golden_out.shape) == 1:
return
pcc = np.min(
np.ma.corrcoef(
np.ma.masked_invalid(torch.squeeze(tt_out).detach().numpy()).flatten(),
np.ma.masked_invalid(
torch.squeeze(golden_out).detach().numpy()
).flatten(),
)
)
assert (
do_assert and pcc
) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}"


def verify_module(
mod,
input_shapes=None,
input_data_types=[torch.float32],
required_pcc=0.99,
required_atol=1e-2,
input_range=(-0.5, 0.5),
compiler_config=None,
do_assert=True,
):
if isinstance(mod, torch.nn.Module):
assert (
input_shapes is not None
), "Verifying a torch module requires that you provide input_shapes"
_verify_torch_module(
mod,
input_shapes,
input_data_types,
required_pcc,
required_atol,
input_range,
compiler_config,
do_assert,
)
elif isinstance(mod, str) and mod.endswith(".onnx"):
assert (
input_shapes is None
), "When verifying an ONNX module, input_shapes must be None as they are inferred from the ONNX model"
_verify_onnx_module(
mod,
input_data_types,
required_pcc,
required_atol,
input_range,
compiler_config,
do_assert,
)

0 comments on commit 564c0f4

Please sign in to comment.