Skip to content

Commit

Permalink
Update testing infrastructure to generate boolean inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmanzoorTT committed Nov 21, 2024
1 parent af95f51 commit d0b0da8
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions tt_torch/tools/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def _verify_torch_module(
required_pcc,
required_atol,
input_range,
input_range_int,
compiler_config,
do_assert,
):
Expand All @@ -35,9 +36,14 @@ def _verify_torch_module(
(low - high) * torch.rand(shape, dtype=dtype) + high
for shape, dtype in zip(input_shapes, input_data_types)
]
elif all([dtype == torch.bool for dtype in input_data_types]):
inputs = [
torch.randint(0, 2, shape, dtype=torch.bool) for shape in input_shapes
]
else:
low, high = input_range_int
inputs = [
torch.randint(0, 1000, shape, dtype=torch.int32)
torch.randint(low, high, shape, dtype=torch.int32)
for shape in input_shapes
]

Expand Down Expand Up @@ -84,6 +90,7 @@ def _verify_onnx_module(
required_pcc,
required_atol,
input_range,
input_range_int,
compiler_config,
do_assert,
):
Expand All @@ -103,8 +110,9 @@ def _verify_onnx_module(
for shape, dtype in zip(input_shapes, input_data_types)
]
else:
low, high = input_range_int
inputs = [
torch.randint(0, 1000, shape, dtype=torch.int64)
torch.randint(low, high, shape, dtype=torch.int64)
for shape in input_shapes
]

Expand Down Expand Up @@ -163,6 +171,7 @@ def verify_module(
required_pcc=0.99,
required_atol=1e-2,
input_range=(-0.5, 0.5),
input_range_int=(0, 1000),
compiler_config=None,
do_assert=True,
):
Expand All @@ -179,6 +188,7 @@ def verify_module(
required_pcc,
required_atol,
input_range,
input_range_int,
compiler_config,
do_assert,
)
Expand All @@ -193,6 +203,7 @@ def verify_module(
required_pcc,
required_atol,
input_range,
input_range_int,
compiler_config,
do_assert,
)
Expand Down

0 comments on commit d0b0da8

Please sign in to comment.