From 7edf47caa3efb4f402c399db7cf83ba8699c42cd Mon Sep 17 00:00:00 2001 From: Lewis Panos Date: Thu, 7 Nov 2024 20:16:59 +0000 Subject: [PATCH] Add onnx compile --- .github/workflows/build-and-test.yml | 11 ++- requirements.txt | 2 + tests/onnx/test_basic.py | 25 ++++++ tests/{ => torch}/test_basic.py | 6 +- tt_torch/dynamo/backend.py | 2 + tt_torch/onnx_compile/__init__.py | 4 + tt_torch/onnx_compile/onnx_compile.py | 31 +++++++ tt_torch/tools/verify.py | 121 ++++++++++++++++++++++++-- 8 files changed, 190 insertions(+), 12 deletions(-) create mode 100644 tests/onnx/test_basic.py rename tests/{ => torch}/test_basic.py (98%) create mode 100644 tt_torch/onnx_compile/__init__.py create mode 100644 tt_torch/onnx_compile/onnx_compile.py diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 629a5048..18dd6d61 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -144,9 +144,16 @@ jobs: cmake --build ${{ steps.strings.outputs.build-output-dir }} cmake --install ${{ steps.strings.outputs.build-output-dir }} - - name: Run tests + - name: Run PyTorch tests shell: bash run: | export LD_LIBRARY_PATH="/opt/ttmlir-toolchain/lib/:${{ steps.strings.outputs.install-output-dir }}/lib:${LD_LIBRARY_PATH}" source env/activate - pytest -v tests --ignore=tests/models + pytest -v tests/torch + + - name: Run ONNX tests + shell: bash + run: | + export LD_LIBRARY_PATH="/opt/ttmlir-toolchain/lib/:${{ steps.strings.outputs.install-output-dir }}/lib:${LD_LIBRARY_PATH}" + source env/activate + pytest -v tests/onnx diff --git a/requirements.txt b/requirements.txt index e43b5376..75563cd7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,5 @@ matplotlib mlp_mixer_pytorch opencv-python xlsxwriter +onnx +onnxruntime diff --git a/tests/onnx/test_basic.py b/tests/onnx/test_basic.py new file mode 100644 index 00000000..b4fededd --- /dev/null +++ b/tests/onnx/test_basic.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import torch +from torch import nn +import pytest + +import tt_torch +from tt_torch.tools.verify import verify_module +from tt_torch.tools.utils import CompilerConfig + + +def test_add(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.add(x, y) + + torch_model = Basic() + torch_input = (torch.randn(256, 256), torch.randn(256, 256)) + torch.onnx.export(torch_model, torch_input, "add.onnx") + + verify_module("add.onnx") diff --git a/tests/test_basic.py b/tests/torch/test_basic.py similarity index 98% rename from tests/test_basic.py rename to tests/torch/test_basic.py index c9e36ea7..1bec508d 100644 --- a/tests/test_basic.py +++ b/tests/torch/test_basic.py @@ -26,10 +26,10 @@ class Basic(nn.Module): def __init__(self): super().__init__() - def forward(self, x): - return x + x + def forward(self, x, y): + return torch.add(x, y) - verify_module(Basic(), [(256, 256)]) + verify_module(Basic(), [(256, 256)] * 2) def test_concat_dim0(): diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 8ad9f749..5fdbd5db 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import torch + from torch._dynamo.backends.common import aot_autograd from torch.fx.experimental.proxy_tensor import make_fx from torch._functorch.compile_utils import strip_overloads @@ -13,6 +14,7 @@ import tt_mlir from torch_mlir.ir import Context from torch_mlir.extras.fx_importer import FxImporter + from torch_mlir.dialects import torch as torch_dialect from torch_mlir.compiler_utils import ( diff --git a/tt_torch/onnx_compile/__init__.py b/tt_torch/onnx_compile/__init__.py new file mode 100644 index 00000000..cfa2cc48 --- /dev/null +++ b/tt_torch/onnx_compile/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +from .onnx_compile import compile_onnx diff --git a/tt_torch/onnx_compile/onnx_compile.py b/tt_torch/onnx_compile/onnx_compile.py new file mode 100644 index 00000000..2ecc195c --- /dev/null +++ b/tt_torch/onnx_compile/onnx_compile.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import onnx +from torch_mlir.extras import onnx_importer +import tt_mlir +from torch_mlir.ir import Context +from torch_mlir.dialects import torch as torch_dialect + +from torch_mlir.compiler_utils import ( + OutputType, + run_pipeline_with_repro_report, + lower_mlir_module, +) + + +def compile_onnx(module: onnx.ModelProto): + context = Context() + torch_dialect.register_dialect(context) + module_info = onnx_importer.ModelInfo(module) + module = module_info.create_module(context=context).operation + imp = onnx_importer.NodeImporter.define_function(module_info.main_graph, module) + imp.import_all() + + run_pipeline_with_repro_report( + module, + "builtin.module(torch-onnx-to-torch-backend-pipeline)", + "Lowering Torch Onnx IR -> Torch Backend IR", + ) + lower_mlir_module(False, OutputType.STABLEHLO, module) + return tt_mlir.compile(module.operation.get_asm()) diff --git a/tt_torch/tools/verify.py b/tt_torch/tools/verify.py index 93bde8be..fc275143 100644 --- a/tt_torch/tools/verify.py +++ b/tt_torch/tools/verify.py @@ -2,19 +2,23 @@ # # SPDX-License-Identifier: Apache-2.0 import torch +import onnx +from onnxruntime import InferenceSession import numpy as np +import tt_mlir +from tt_torch.onnx_compile import compile_onnx from tt_torch.dynamo.backend import backend -def verify_module( +def _verify_torch_module( mod, input_shapes, - input_data_types=[torch.float32], - required_pcc=0.99, - required_atol=1e-2, - input_range=(-0.5, 0.5), - compiler_config=None, - do_assert=True, + input_data_types, + required_pcc, + required_atol, + input_range, + compiler_config, + do_assert, ): tt_mod = torch.compile(mod, backend=backend, options=compiler_config) @@ -43,3 +47,106 @@ def verify_module( ) ) assert (do_assert and pcc) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}" + + +def _verify_onnx_module( + filename, + input_data_types, + required_pcc, + required_atol, + input_range, + compiler_config, + do_assert, +): + + sess = InferenceSession(filename) + input_shapes = [nodearg.shape for nodearg in sess.get_inputs()] + + if all([dtype.is_floating_point for dtype in input_data_types]): + low, high = input_range + # Uniformly distribute random numbers within the input_range + inputs = [(low - high) * torch.rand(shape) + high for shape in input_shapes] + else: + inputs = [ + torch.randint(0, 1000, shape, dtype=torch.int32) for shape in input_shapes + ] + + inputs_dict = { + nodearg.name: input.numpy().astype(np.float32) + if input.dtype == torch.bfloat16 + else input.numpy() + for nodearg, input in zip(sess.get_inputs(), inputs) + } + golden = sess.run(None, inputs_dict) + + for i in range(len(golden)): + golden[i] = torch.tensor(golden[i]) + + mod = onnx.load(filename) + binary = compile_onnx(mod) + + ret = tt_mlir.run(inputs, binary) + assert len(golden) == len( + ret + ), f"Number of outputs mismatch between golden and compiled: {len(golden)} vs {len(ret)}" + + for golden_out, tt_out in zip(golden, ret): + atol = torch.max(torch.abs(golden_out - tt_out)).item() + assert ( + do_assert and atol + ) <= required_atol, f"ATOL too high: {atol} vs {required_atol}" + + if np.prod(golden_out.shape) == 1: + return + pcc = np.min( + np.ma.corrcoef( + np.ma.masked_invalid(torch.squeeze(tt_out).detach().numpy()).flatten(), + np.ma.masked_invalid( + torch.squeeze(golden_out).detach().numpy() + ).flatten(), + ) + ) + assert ( + do_assert and pcc + ) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}" + + +def verify_module( + mod, + input_shapes=None, + input_data_types=[torch.float32], + required_pcc=0.99, + required_atol=1e-2, + input_range=(-0.5, 0.5), + compiler_config=None, + do_assert=True, +): + if isinstance(mod, torch.nn.Module): + assert ( + input_shapes is not None + ), "Verifying a torch module requires that you provide input_shapes" + _verify_torch_module( + mod, + input_shapes, + input_data_types, + required_pcc, + required_atol, + input_range, + compiler_config, + do_assert, + ) + elif isinstance(mod, str) and mod.endswith(".onnx"): + assert ( + input_shapes is None + ), "When verifying an ONNX module, input_shapes must be None as they are inferred from the ONNX model" + _verify_onnx_module( + mod, + input_data_types, + required_pcc, + required_atol, + input_range, + compiler_config, + do_assert, + ) + else: + raise ValueError("Invalid module type")