From d0b0da80d61b76c5bf4d864ba49293e4b9cf01d6 Mon Sep 17 00:00:00 2001 From: Muhammad Asif Manzoor Date: Thu, 21 Nov 2024 16:51:12 +0000 Subject: [PATCH] Update testing infrastructure to generate boolean inputs. --- tt_torch/tools/verify.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tt_torch/tools/verify.py b/tt_torch/tools/verify.py index b444ae16..5b0f5589 100644 --- a/tt_torch/tools/verify.py +++ b/tt_torch/tools/verify.py @@ -18,6 +18,7 @@ def _verify_torch_module( required_pcc, required_atol, input_range, + input_range_int, compiler_config, do_assert, ): @@ -35,9 +36,14 @@ def _verify_torch_module( (low - high) * torch.rand(shape, dtype=dtype) + high for shape, dtype in zip(input_shapes, input_data_types) ] + elif all([dtype == torch.bool for dtype in input_data_types]): + inputs = [ + torch.randint(0, 2, shape, dtype=torch.bool) for shape in input_shapes + ] else: + low, high = input_range_int inputs = [ - torch.randint(0, 1000, shape, dtype=torch.int32) + torch.randint(low, high, shape, dtype=torch.int32) for shape in input_shapes ] @@ -84,6 +90,7 @@ def _verify_onnx_module( required_pcc, required_atol, input_range, + input_range_int, compiler_config, do_assert, ): @@ -103,8 +110,9 @@ def _verify_onnx_module( for shape, dtype in zip(input_shapes, input_data_types) ] else: + low, high = input_range_int inputs = [ - torch.randint(0, 1000, shape, dtype=torch.int64) + torch.randint(low, high, shape, dtype=torch.int64) for shape in input_shapes ] @@ -163,6 +171,7 @@ def verify_module( required_pcc=0.99, required_atol=1e-2, input_range=(-0.5, 0.5), + input_range_int=(0, 1000), compiler_config=None, do_assert=True, ): @@ -179,6 +188,7 @@ def verify_module( required_pcc, required_atol, input_range, + input_range_int, compiler_config, do_assert, ) @@ -193,6 +203,7 @@ def verify_module( required_pcc, required_atol, input_range, + input_range_int, compiler_config, do_assert, )