Skip to content

Commit

Permalink
Add onnx compile
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Nov 8, 2024
1 parent 0ea020e commit 7edf47c
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 12 deletions.
11 changes: 9 additions & 2 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,16 @@ jobs:
cmake --build ${{ steps.strings.outputs.build-output-dir }}
cmake --install ${{ steps.strings.outputs.build-output-dir }}
- name: Run tests
- name: Run PyTorch tests
shell: bash
run: |
export LD_LIBRARY_PATH="/opt/ttmlir-toolchain/lib/:${{ steps.strings.outputs.install-output-dir }}/lib:${LD_LIBRARY_PATH}"
source env/activate
pytest -v tests --ignore=tests/models
pytest -v tests/torch
- name: Run ONNX tests
shell: bash
run: |
export LD_LIBRARY_PATH="/opt/ttmlir-toolchain/lib/:${{ steps.strings.outputs.install-output-dir }}/lib:${LD_LIBRARY_PATH}"
source env/activate
pytest -v tests/onnx
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ matplotlib
mlp_mixer_pytorch
opencv-python
xlsxwriter
onnx
onnxruntime
25 changes: 25 additions & 0 deletions tests/onnx/test_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import torch
from torch import nn
import pytest

import tt_torch
from tt_torch.tools.verify import verify_module
from tt_torch.tools.utils import CompilerConfig


def test_add():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.add(x, y)

torch_model = Basic()
torch_input = (torch.randn(256, 256), torch.randn(256, 256))
torch.onnx.export(torch_model, torch_input, "add.onnx")

verify_module("add.onnx")
6 changes: 3 additions & 3 deletions tests/test_basic.py → tests/torch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x + x
def forward(self, x, y):
return torch.add(x, y)

verify_module(Basic(), [(256, 256)])
verify_module(Basic(), [(256, 256)] * 2)


def test_concat_dim0():
Expand Down
2 changes: 2 additions & 0 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import torch

from torch._dynamo.backends.common import aot_autograd
from torch.fx.experimental.proxy_tensor import make_fx
from torch._functorch.compile_utils import strip_overloads
Expand All @@ -13,6 +14,7 @@
import tt_mlir
from torch_mlir.ir import Context
from torch_mlir.extras.fx_importer import FxImporter

from torch_mlir.dialects import torch as torch_dialect

from torch_mlir.compiler_utils import (
Expand Down
4 changes: 4 additions & 0 deletions tt_torch/onnx_compile/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
from .onnx_compile import compile_onnx
31 changes: 31 additions & 0 deletions tt_torch/onnx_compile/onnx_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import onnx
from torch_mlir.extras import onnx_importer
import tt_mlir
from torch_mlir.ir import Context
from torch_mlir.dialects import torch as torch_dialect

from torch_mlir.compiler_utils import (
OutputType,
run_pipeline_with_repro_report,
lower_mlir_module,
)


def compile_onnx(module: onnx.ModelProto):
context = Context()
torch_dialect.register_dialect(context)
module_info = onnx_importer.ModelInfo(module)
module = module_info.create_module(context=context).operation
imp = onnx_importer.NodeImporter.define_function(module_info.main_graph, module)
imp.import_all()

run_pipeline_with_repro_report(
module,
"builtin.module(torch-onnx-to-torch-backend-pipeline)",
"Lowering Torch Onnx IR -> Torch Backend IR",
)
lower_mlir_module(False, OutputType.STABLEHLO, module)
return tt_mlir.compile(module.operation.get_asm())
121 changes: 114 additions & 7 deletions tt_torch/tools/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@
#
# SPDX-License-Identifier: Apache-2.0
import torch
import onnx
from onnxruntime import InferenceSession
import numpy as np
import tt_mlir
from tt_torch.onnx_compile import compile_onnx
from tt_torch.dynamo.backend import backend


def verify_module(
def _verify_torch_module(
mod,
input_shapes,
input_data_types=[torch.float32],
required_pcc=0.99,
required_atol=1e-2,
input_range=(-0.5, 0.5),
compiler_config=None,
do_assert=True,
input_data_types,
required_pcc,
required_atol,
input_range,
compiler_config,
do_assert,
):
tt_mod = torch.compile(mod, backend=backend, options=compiler_config)

Expand Down Expand Up @@ -43,3 +47,106 @@ def verify_module(
)
)
assert (do_assert and pcc) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}"


def _verify_onnx_module(
filename,
input_data_types,
required_pcc,
required_atol,
input_range,
compiler_config,
do_assert,
):

sess = InferenceSession(filename)
input_shapes = [nodearg.shape for nodearg in sess.get_inputs()]

if all([dtype.is_floating_point for dtype in input_data_types]):
low, high = input_range
# Uniformly distribute random numbers within the input_range
inputs = [(low - high) * torch.rand(shape) + high for shape in input_shapes]
else:
inputs = [
torch.randint(0, 1000, shape, dtype=torch.int32) for shape in input_shapes
]

inputs_dict = {
nodearg.name: input.numpy().astype(np.float32)
if input.dtype == torch.bfloat16
else input.numpy()
for nodearg, input in zip(sess.get_inputs(), inputs)
}
golden = sess.run(None, inputs_dict)

for i in range(len(golden)):
golden[i] = torch.tensor(golden[i])

mod = onnx.load(filename)
binary = compile_onnx(mod)

ret = tt_mlir.run(inputs, binary)
assert len(golden) == len(
ret
), f"Number of outputs mismatch between golden and compiled: {len(golden)} vs {len(ret)}"

for golden_out, tt_out in zip(golden, ret):
atol = torch.max(torch.abs(golden_out - tt_out)).item()
assert (
do_assert and atol
) <= required_atol, f"ATOL too high: {atol} vs {required_atol}"

if np.prod(golden_out.shape) == 1:
return
pcc = np.min(
np.ma.corrcoef(
np.ma.masked_invalid(torch.squeeze(tt_out).detach().numpy()).flatten(),
np.ma.masked_invalid(
torch.squeeze(golden_out).detach().numpy()
).flatten(),
)
)
assert (
do_assert and pcc
) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}"


def verify_module(
mod,
input_shapes=None,
input_data_types=[torch.float32],
required_pcc=0.99,
required_atol=1e-2,
input_range=(-0.5, 0.5),
compiler_config=None,
do_assert=True,
):
if isinstance(mod, torch.nn.Module):
assert (
input_shapes is not None
), "Verifying a torch module requires that you provide input_shapes"
_verify_torch_module(
mod,
input_shapes,
input_data_types,
required_pcc,
required_atol,
input_range,
compiler_config,
do_assert,
)
elif isinstance(mod, str) and mod.endswith(".onnx"):
assert (
input_shapes is None
), "When verifying an ONNX module, input_shapes must be None as they are inferred from the ONNX model"
_verify_onnx_module(
mod,
input_data_types,
required_pcc,
required_atol,
input_range,
compiler_config,
do_assert,
)
else:
raise ValueError("Invalid module type")

0 comments on commit 7edf47c

Please sign in to comment.