Skip to content

Commit

Permalink
Add test for logical operations
Browse files Browse the repository at this point in the history
* Logical and
* Logical or
* Logical not
* Logical xor
  • Loading branch information
mmanzoorTT committed Nov 21, 2024
1 parent d0b0da8 commit 8ed2f62
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions tests/torch/test_logical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import torch
from torch import nn
import pytest

import tt_torch
from tt_torch.tools.verify import verify_module


def test_and():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.logical_and(x, y)

verify_module(
Basic(),
input_shapes=[(64, 64), (64, 64)],
input_data_types=[torch.bool],
)


def test_not():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.logical_not(x)

verify_module(
Basic(),
input_shapes=[(64, 64)],
input_data_types=[torch.bool],
)


def test_or():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.logical_or(x, y)

verify_module(
Basic(),
input_shapes=[(64, 64), (64, 64)],
input_data_types=[torch.bool],
)


def test_xor():
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.logical_xor(x, y)

verify_module(
Basic(),
input_shapes=[(64, 64), (64, 64)],
input_data_types=[torch.bool],
)

0 comments on commit 8ed2f62

Please sign in to comment.