Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Oct 29, 2024
1 parent 7fb08e8 commit 57a76c4
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,23 +117,49 @@ def forward(self, x):

verify_module(Basic(), [(32, 32)], required_atol=3e-2, input_range=(0.1, 1))

@pytest.mark.parametrize("begin_W", torch.arange(64).tolist())
@pytest.mark.parametrize("end_W", torch.arange(64, 128).tolist())
@pytest.mark.parametrize("dim", [0, 1, 2, 3])
def test_slice(begin_W, end_W, dim):
dim0_cases = []
for begin in torch.arange(10).tolist():
for end in torch.arange(90, 100).tolist():
dim0_cases.append((begin, end, 0))

dim1_cases = []
for begin in torch.arange(10).tolist():
for end in torch.arange(90, 100).tolist():
dim1_cases.append((begin, end, 1))

dim2_cases = []
for begin in torch.arange(0, 64, 32).tolist():
for end in torch.arange(64, 128, 32).tolist():
dim2_cases.append((begin, end, 2))

dim3_cases = []
for begin in torch.arange(0, 64, 32).tolist():
for end in torch.arange(64, 128, 32).tolist():
dim3_cases.append((begin, end, 3))

@pytest.mark.parametrize(
"begin, end, dim",
[
*dim2_cases,
*dim3_cases,
*dim0_cases,
*dim1_cases
]
)
def test_slice(begin, end, dim):
class Basic(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a):
if dim == 0:
return a[begin_W:end_W, :, :, :]
return a[begin:end, :, :, :]
elif dim == 1:
return a[:, begin_W:end_W, :, :]
return a[:, begin:end, :, :]
elif dim == 2:
return a[:, :, begin_W:end_W, :]
return a[:, :, begin:end, :]
else:
return a[:, :, :, begin_W:end_W]
return a[:, :, :, begin:end]

shape = [10, 10, 10, 10]
shape[dim] = 128
Expand Down

0 comments on commit 57a76c4

Please sign in to comment.