diff --git a/requirements.txt b/requirements.txt index e43b5376..75563cd7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,5 @@ matplotlib mlp_mixer_pytorch opencv-python xlsxwriter +onnx +onnxruntime diff --git a/tests/test_basic.py b/tests/test_basic.py index c9e36ea7..1657a6af 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -26,10 +26,10 @@ class Basic(nn.Module): def __init__(self): super().__init__() - def forward(self, x): - return x + x + def forward(self, x, y): + return torch.add(x + y) - verify_module(Basic(), [(256, 256)]) + verify_module(Basic(), [(1, 1, 32, 32)] * 2) def test_concat_dim0(): @@ -353,3 +353,30 @@ def forward(self, x): cc = CompilerConfig() cc.compile_depth = tt_torch.tools.utils.CompileDepth.COMPILE_OP_BY_OP verify_module(Basic(), [(256, 256)], compiler_config=cc, do_assert=False) + + +def test_onnx(): + + import onnx + + import torch_mlir.dialects.torch + import tt_mlir + + from torch_mlir.compiler_utils import ( + OutputType, + run_pipeline_with_repro_report, + lower_mlir_module, + ) + + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.add(x, y) + + torch_model = Basic() + torch_input = (torch.randn(1, 1, 32, 32), torch.randn(1, 1, 32, 32)) + torch.onnx.export(torch_model, torch_input, "add.onnx") + + verify_module("add.onnx") diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 8ad9f749..89151e78 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 import torch +import onnx + from torch._dynamo.backends.common import aot_autograd from torch.fx.experimental.proxy_tensor import make_fx from torch._functorch.compile_utils import strip_overloads @@ -13,6 +15,8 @@ import tt_mlir from torch_mlir.ir import Context from torch_mlir.extras.fx_importer import FxImporter +from torch_mlir.extras import onnx_importer + from torch_mlir.dialects import torch as torch_dialect from torch_mlir.compiler_utils import ( @@ -92,6 +96,7 @@ def lower_to_stable_hlo(module, op=None): f"builtin.module(torchdynamo-export-to-torch-backend-pipeline)", "Lowering TorchFX IR -> Torch Backend IR", ) + module.dump() if op is not None: op.compilation_status = OpCompilationStatus.CONVERTED_TO_TORCH_BACKEND_IR lower_mlir_module(False, OutputType.STABLEHLO, module) @@ -378,3 +383,22 @@ def backend(gm, example_inputs, options=None): # backend = aot_autograd(fw_compiler=_base_backend) + + +def compile_onnx_module_to_stablehlo(module: onnx.ModelProto): + context = Context() + torch_dialect.register_dialect(context) + module_info = onnx_importer.ModelInfo(module) + module = module_info.create_module(context=context).operation + imp = onnx_importer.NodeImporter.define_function( + module_info.main_graph, module + ) + imp.import_all() + + run_pipeline_with_repro_report( + module, + "builtin.module(torch-onnx-to-torch-backend-pipeline)", + "Lowering Torch Onnx IR -> Torch Backend IR", + ) + lower_mlir_module(False, OutputType.STABLEHLO, module) + return module diff --git a/tt_torch/tools/verify.py b/tt_torch/tools/verify.py index 93bde8be..d2f7f409 100644 --- a/tt_torch/tools/verify.py +++ b/tt_torch/tools/verify.py @@ -2,19 +2,22 @@ # # SPDX-License-Identifier: Apache-2.0 import torch +import onnx +from onnxruntime import InferenceSession import numpy as np -from tt_torch.dynamo.backend import backend +import tt_mlir +from tt_torch.dynamo.backend import backend, compile_onnx_module_to_stablehlo -def verify_module( +def _verify_torch_module( mod, input_shapes, - input_data_types=[torch.float32], - required_pcc=0.99, - required_atol=1e-2, - input_range=(-0.5, 0.5), - compiler_config=None, - do_assert=True, + input_data_types, + required_pcc, + required_atol, + input_range, + compiler_config, + do_assert, ): tt_mod = torch.compile(mod, backend=backend, options=compiler_config) @@ -43,3 +46,105 @@ def verify_module( ) ) assert (do_assert and pcc) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}" + + +def _verify_onnx_module( + filename, + input_data_types, + required_pcc, + required_atol, + input_range, + compiler_config, + do_assert, +): + + sess = InferenceSession(filename) + input_shapes = [nodearg.shape for nodearg in sess.get_inputs()] + + if all([dtype.is_floating_point for dtype in input_data_types]): + low, high = input_range + # Uniformly distribute random numbers within the input_range + inputs = [(low - high) * torch.rand(shape) + high for shape in input_shapes] + else: + inputs = [ + torch.randint(0, 1000, shape, dtype=torch.int32) for shape in input_shapes + ] + + inputs_dict = { + nodearg.name: input.numpy().astype(np.float32) + if input.dtype == torch.bfloat16 + else input.numpy() + for nodearg, input in zip(sess.get_inputs(), inputs) + } + golden = sess.run(None, inputs_dict) + + for i in range(len(golden)): + golden[i] = torch.tensor(golden[i]) + + mod = onnx.load(filename) + stablehlo_mod = compile_onnx_module_to_stablehlo(mod) + binary = tt_mlir.compile(stablehlo_mod.operation.get_asm()) + + ret = tt_mlir.run(inputs, binary) + assert len(golden) == len( + ret + ), f"Number of outputs mismatch between golden and compiled: {len(golden)} vs {len(ret)}" + + for golden_out, tt_out in zip(golden, ret): + atol = torch.max(torch.abs(golden_out - tt_out)).item() + assert ( + do_assert and atol + ) <= required_atol, f"ATOL too high: {atol} vs {required_atol}" + + if np.prod(golden_out.shape) == 1: + return + pcc = np.min( + np.ma.corrcoef( + np.ma.masked_invalid(torch.squeeze(tt_out).detach().numpy()).flatten(), + np.ma.masked_invalid( + torch.squeeze(golden_out).detach().numpy() + ).flatten(), + ) + ) + assert ( + do_assert and pcc + ) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}" + + +def verify_module( + mod, + input_shapes=None, + input_data_types=[torch.float32], + required_pcc=0.99, + required_atol=1e-2, + input_range=(-0.5, 0.5), + compiler_config=None, + do_assert=True, +): + if isinstance(mod, torch.nn.Module): + assert ( + input_shapes is not None + ), "Verifying a torch module requires that you provide input_shapes" + _verify_torch_module( + mod, + input_shapes, + input_data_types, + required_pcc, + required_atol, + input_range, + compiler_config, + do_assert, + ) + elif isinstance(mod, str) and mod.endswith(".onnx"): + assert ( + input_shapes is None + ), "When verifying an ONNX module, input_shapes must be None as they are inferred from the ONNX model" + _verify_onnx_module( + mod, + input_data_types, + required_pcc, + required_atol, + input_range, + compiler_config, + do_assert, + )