Skip to content

Commit

Permalink
Enable unit test for div and remainder op
Browse files Browse the repository at this point in the history
* Tests are working after tt-metal fix.
* tenstorrent/tt-metal#16250
  • Loading branch information
mmanzoorTT committed Jan 6, 2025
1 parent bad537d commit 8262f18
Showing 1 changed file with 24 additions and 41 deletions.
65 changes: 24 additions & 41 deletions tests/torch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,7 @@ def forward(self, x):
@pytest.mark.parametrize(
("input_range", "input_shapes", "input_type"),
[
pytest.param(
(-0.5, 0.5),
[(2, 2), (2, 2)],
[torch.float32, torch.float32],
marks=pytest.mark.xfail(
reason="Fails due to https://github.com/tenstorrent/tt-torch/issues/147"
),
),
((-0.5, 0.5), [(2, 2), (2, 2)], [torch.float32, torch.float32]),
((1, 10), [(32, 32), (32, 32)], [torch.bfloat16, torch.bfloat16]),
((1, 10), [(32, 32), (32, 32)], [torch.float32, torch.float32]),
],
Expand Down Expand Up @@ -422,40 +415,14 @@ def forward(self, x):
@pytest.mark.parametrize(
("input_range", "input_shapes", "input_type"),
[
pytest.param(
(1, 10),
[(32, 32), (32, 32)],
[torch.float32, torch.float32],
marks=pytest.mark.xfail(
reason="Fails due to https://github.com/tenstorrent/tt-torch/issues/147"
),
),
((1, 10), [(32, 32), (32, 32)], [torch.bfloat16, torch.bfloat16]),
((1, 10), [(32, 32), (32, 32)], [torch.float32, torch.float32]),
pytest.param(
(1, 10),
[(3, 3), (3, 3)],
[torch.float32, torch.float32],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-metal/issues/15131"
),
),
pytest.param(
(1, 100),
[(32, 32), (32, 32)],
[torch.float32, torch.float32],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-metal/issues/15130"
),
),
pytest.param(
(-100, 100),
[(32, 32), (32, 32)],
[torch.float32, torch.float32],
marks=pytest.mark.skip(
reason="Fails due to https://github.com/tenstorrent/tt-metal/issues/15130"
),
),
((1, 10), [(32, 32), (32, 32)], [torch.bfloat16, torch.bfloat16]),
((1, 10), [(3, 3), (3, 3)], [torch.float32, torch.float32]),
((1, 100), [(32, 32), (32, 32)], [torch.float32, torch.float32]),
# This set of parameter can fail when we generate a right hand operand
# which contains a 0. TTNN returns LHS operand instead of NaN in such
# case. Issue: https://github.com/tenstorrent/tt-metal/issues/16394
((-100, 100), [(32, 32), (32, 32)], [torch.float32, torch.float32]),
],
)
def test_remainder_op(input_range, input_shapes, input_type):
Expand All @@ -473,3 +440,19 @@ def forward(self, x, y):
input_range=input_range,
required_atol=1,
)


@pytest.mark.xfail(
reason="TTNN returns LHS operand instead of NaN if divisor is 0, see https://github.com/tenstorrent/tt-metal/issues/16394",
)
def test_remainder_op_zero():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.remainder(x, y)

input1 = torch.tensor([1, -2, 3, -4, 5, 6, -7, 18], dtype=torch.float32)
input2 = torch.tensor([1, 0, 3, 0, -9, 10, -7, 12], dtype=torch.float32)
verify_module(Basic(), inputs=[input1, input2])

0 comments on commit 8262f18

Please sign in to comment.