diff --git a/tests/test_basic.py b/tests/test_basic.py index 6885f430..1657a6af 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -29,7 +29,7 @@ def __init__(self): def forward(self, x, y): return torch.add(x + y) - verify_module(Basic(), [(1, 1, 32, 32)]*2) + verify_module(Basic(), [(1, 1, 32, 32)] * 2) def test_concat_dim0(): @@ -354,10 +354,11 @@ def forward(self, x): cc.compile_depth = tt_torch.tools.utils.CompileDepth.COMPILE_OP_BY_OP verify_module(Basic(), [(256, 256)], compiler_config=cc, do_assert=False) + def test_onnx(): import onnx - + import torch_mlir.dialects.torch import tt_mlir @@ -366,7 +367,7 @@ def test_onnx(): run_pipeline_with_repro_report, lower_mlir_module, ) - + class Basic(nn.Module): def __init__(self): super().__init__() @@ -374,7 +375,6 @@ def __init__(self): def forward(self, x, y): return torch.add(x, y) - torch_model = Basic() torch_input = (torch.randn(1, 1, 32, 32), torch.randn(1, 1, 32, 32)) torch.onnx.export(torch_model, torch_input, "add.onnx") diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 3a280c4b..07d81a92 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -384,12 +384,15 @@ def backend(gm, example_inputs, options=None): # backend = aot_autograd(fw_compiler=_base_backend) + def generate_torch_onnx_module(module: onnx.ModelProto): context = Context() torch_dialect.register_dialect(context) module_info = onnx_importer.ModelInfo(module) torch_onnx_mod = module_info.create_module(context=context).operation - imp = onnx_importer.NodeImporter.define_function(module_info.main_graph, torch_onnx_mod) + imp = onnx_importer.NodeImporter.define_function( + module_info.main_graph, torch_onnx_mod + ) imp.import_all() run_pipeline_with_repro_report( @@ -399,6 +402,3 @@ def generate_torch_onnx_module(module: onnx.ModelProto): ) lower_mlir_module(False, OutputType.STABLEHLO, torch_onnx_mod) return torch_onnx_mod - - - diff --git a/tt_torch/tools/verify.py b/tt_torch/tools/verify.py index 9ae4f73e..99c7ddc6 100644 --- a/tt_torch/tools/verify.py +++ b/tt_torch/tools/verify.py @@ -8,6 +8,7 @@ import tt_mlir from tt_torch.dynamo.backend import backend, generate_torch_onnx_module + def _verify_torch_module( mod, input_shapes, @@ -46,6 +47,7 @@ def _verify_torch_module( ) assert (do_assert and pcc) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}" + def _verify_onnx_module( filename, input_data_types, @@ -54,11 +56,11 @@ def _verify_onnx_module( input_range, compiler_config, do_assert, -): - +): + sess = InferenceSession(filename) input_shapes = [nodearg.shape for nodearg in sess.get_inputs()] - + if all([dtype.is_floating_point for dtype in input_data_types]): low, high = input_range # Uniformly distribute random numbers within the input_range @@ -67,9 +69,13 @@ def _verify_onnx_module( inputs = [ torch.randint(0, 1000, shape, dtype=torch.int32) for shape in input_shapes ] - - - inputs_dict = {nodearg.name: input.numpy().astype(np.float32) if input.dtype == torch.bfloat16 else input.numpy() for nodearg, input in zip(sess.get_inputs(), inputs)} + + inputs_dict = { + nodearg.name: input.numpy().astype(np.float32) + if input.dtype == torch.bfloat16 + else input.numpy() + for nodearg, input in zip(sess.get_inputs(), inputs) + } golden = sess.run(None, inputs_dict) for i in range(len(golden)): @@ -80,7 +86,9 @@ def _verify_onnx_module( binary = tt_mlir.compile(stablehlo_mod.operation.get_asm()) ret = tt_mlir.run(inputs, binary) - assert len(golden) == len(ret), f"Number of outputs mismatch between golden and compiled: {len(golden)} vs {len(ret)}" + assert len(golden) == len( + ret + ), f"Number of outputs mismatch between golden and compiled: {len(golden)} vs {len(ret)}" for golden_out, tt_out in zip(golden, ret): atol = torch.max(torch.abs(golden_out - tt_out)).item() @@ -93,11 +101,14 @@ def _verify_onnx_module( pcc = np.min( np.ma.corrcoef( np.ma.masked_invalid(torch.squeeze(tt_out).detach().numpy()).flatten(), - np.ma.masked_invalid(torch.squeeze(golden_out).detach().numpy()).flatten(), + np.ma.masked_invalid( + torch.squeeze(golden_out).detach().numpy() + ).flatten(), ) ) - assert (do_assert and pcc) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}" - + assert ( + do_assert and pcc + ) >= required_pcc, f"PCC too low: {pcc} vs {required_pcc}" def verify_module( @@ -111,7 +122,9 @@ def verify_module( do_assert=True, ): if isinstance(mod, torch.nn.Module): - assert input_shapes is not None, "Verifying a torch module requires that you provide input_shapes" + assert ( + input_shapes is not None + ), "Verifying a torch module requires that you provide input_shapes" _verify_torch_module( mod, input_shapes, @@ -123,7 +136,9 @@ def verify_module( do_assert, ) elif isinstance(mod, str) and mod.endswith(".onnx"): - assert input_shapes is None, "When verifying an ONNX module, input_shapes must be None as they are inferred from the ONNX model" + assert ( + input_shapes is None + ), "When verifying an ONNX module, input_shapes must be None as they are inferred from the ONNX model" _verify_onnx_module( mod, input_data_types, @@ -132,4 +147,4 @@ def verify_module( input_range, compiler_config, do_assert, - ) \ No newline at end of file + )