Skip to content

Commit

Permalink
fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Nov 7, 2024
1 parent 39ddc2a commit d4bb9e7
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 21 deletions.
8 changes: 4 additions & 4 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):
def forward(self, x, y):
return torch.add(x + y)

verify_module(Basic(), [(1, 1, 32, 32)]*2)
verify_module(Basic(), [(1, 1, 32, 32)] * 2)


def test_concat_dim0():
Expand Down Expand Up @@ -354,10 +354,11 @@ def forward(self, x):
cc.compile_depth = tt_torch.tools.utils.CompileDepth.COMPILE_OP_BY_OP
verify_module(Basic(), [(256, 256)], compiler_config=cc, do_assert=False)


def test_onnx():

import onnx

import torch_mlir.dialects.torch
import tt_mlir

Expand All @@ -366,15 +367,14 @@ def test_onnx():
run_pipeline_with_repro_report,
lower_mlir_module,
)

class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.add(x, y)


torch_model = Basic()
torch_input = (torch.randn(1, 1, 32, 32), torch.randn(1, 1, 32, 32))
torch.onnx.export(torch_model, torch_input, "add.onnx")
Expand Down
8 changes: 4 additions & 4 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,15 @@ def backend(gm, example_inputs, options=None):

# backend = aot_autograd(fw_compiler=_base_backend)


def generate_torch_onnx_module(module: onnx.ModelProto):
context = Context()
torch_dialect.register_dialect(context)
module_info = onnx_importer.ModelInfo(module)
torch_onnx_mod = module_info.create_module(context=context).operation
imp = onnx_importer.NodeImporter.define_function(module_info.main_graph, torch_onnx_mod)
imp = onnx_importer.NodeImporter.define_function(
module_info.main_graph, torch_onnx_mod
)
imp.import_all()

run_pipeline_with_repro_report(
Expand All @@ -399,6 +402,3 @@ def generate_torch_onnx_module(module: onnx.ModelProto):
)
lower_mlir_module(False, OutputType.STABLEHLO, torch_onnx_mod)
return torch_onnx_mod



41 changes: 28 additions & 13 deletions tt_torch/tools/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tt_mlir
from tt_torch.dynamo.backend import backend, generate_torch_onnx_module


def _verify_torch_module(
mod,
input_shapes,
Expand Down Expand Up @@ -46,6 +47,7 @@ def _verify_torch_module(
)
assert (do_assert and pcc) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}"


def _verify_onnx_module(
filename,
input_data_types,
Expand All @@ -54,11 +56,11 @@ def _verify_onnx_module(
input_range,
compiler_config,
do_assert,
):
):

sess = InferenceSession(filename)
input_shapes = [nodearg.shape for nodearg in sess.get_inputs()]

if all([dtype.is_floating_point for dtype in input_data_types]):
low, high = input_range
# Uniformly distribute random numbers within the input_range
Expand All @@ -67,9 +69,13 @@ def _verify_onnx_module(
inputs = [
torch.randint(0, 1000, shape, dtype=torch.int32) for shape in input_shapes
]


inputs_dict = {nodearg.name: input.numpy().astype(np.float32) if input.dtype == torch.bfloat16 else input.numpy() for nodearg, input in zip(sess.get_inputs(), inputs)}

inputs_dict = {
nodearg.name: input.numpy().astype(np.float32)
if input.dtype == torch.bfloat16
else input.numpy()
for nodearg, input in zip(sess.get_inputs(), inputs)
}
golden = sess.run(None, inputs_dict)

for i in range(len(golden)):
Expand All @@ -80,7 +86,9 @@ def _verify_onnx_module(
binary = tt_mlir.compile(stablehlo_mod.operation.get_asm())

ret = tt_mlir.run(inputs, binary)
assert len(golden) == len(ret), f"Number of outputs mismatch between golden and compiled: {len(golden)} vs {len(ret)}"
assert len(golden) == len(
ret
), f"Number of outputs mismatch between golden and compiled: {len(golden)} vs {len(ret)}"

for golden_out, tt_out in zip(golden, ret):
atol = torch.max(torch.abs(golden_out - tt_out)).item()
Expand All @@ -93,11 +101,14 @@ def _verify_onnx_module(
pcc = np.min(
np.ma.corrcoef(
np.ma.masked_invalid(torch.squeeze(tt_out).detach().numpy()).flatten(),
np.ma.masked_invalid(torch.squeeze(golden_out).detach().numpy()).flatten(),
np.ma.masked_invalid(
torch.squeeze(golden_out).detach().numpy()
).flatten(),
)
)
assert (do_assert and pcc) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}"

assert (
do_assert and pcc
) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}"


def verify_module(
Expand All @@ -111,7 +122,9 @@ def verify_module(
do_assert=True,
):
if isinstance(mod, torch.nn.Module):
assert input_shapes is not None, "Verifying a torch module requires that you provide input_shapes"
assert (
input_shapes is not None
), "Verifying a torch module requires that you provide input_shapes"
_verify_torch_module(
mod,
input_shapes,
Expand All @@ -123,7 +136,9 @@ def verify_module(
do_assert,
)
elif isinstance(mod, str) and mod.endswith(".onnx"):
assert input_shapes is None, "When verifying an ONNX module, input_shapes must be None as they are inferred from the ONNX model"
assert (
input_shapes is None
), "When verifying an ONNX module, input_shapes must be None as they are inferred from the ONNX model"
_verify_onnx_module(
mod,
input_data_types,
Expand All @@ -132,4 +147,4 @@ def verify_module(
input_range,
compiler_config,
do_assert,
)
)

0 comments on commit d4bb9e7

Please sign in to comment.