Skip to content

Commit

Permalink
Different types of indexing in pytorch (#976)
Browse files Browse the repository at this point in the history
Exploring different types of indexing in pytorch and looking at what is
supported through `tvm`/`forge-fe`/`mlir`/`ttnn`. More info about
results of testing can be found
[here](https://docs.google.com/document/d/1SoMZWKsplNIRXx01HDeWQWJ2NXhyQ0VHYZ9o7Pxr2h8/edit?usp=sharing).
  • Loading branch information
vkovinicTT authored Jan 13, 2025
1 parent 1db1f92 commit 64a3200
Show file tree
Hide file tree
Showing 17 changed files with 4,356 additions and 2,080 deletions.
252 changes: 252 additions & 0 deletions forge/test/mlir/operators/eltwise_binary/test_eltwise_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from torch import nn

import forge
from forge.verify.verify import verify
from forge.verify.config import VerifyConfig


@pytest.mark.parametrize(
"shape_x, shape_y",
[
((1, 128, 28, 28), (1, 128, 28, 28)),
((1, 64, 28, 28), (1, 64, 28, 28)),
((1, 256, 28, 28), (1, 256, 28, 28)),
((1, 128, 14, 14), (1, 128, 14, 14)),
((1, 128, 56, 56), (1, 128, 56, 56)),
((1, 32, 64, 64), (1, 32, 64, 64)),
((1, 512, 7, 7), (1, 512, 7, 7)),
((1, 32, 32, 32), (1, 32, 32, 32)),
],
)
@pytest.mark.push
def test_less(shape_x, shape_y):
class Less(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.less(x, y)

x = torch.rand(shape_x)
y = torch.rand(shape_y)

inputs = [x, y]

framework_model = Less()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False))


@pytest.mark.parametrize(
"shape_x, shape_y",
[
((1, 128, 28, 28), (1, 128, 28, 28)),
((1, 64, 28, 28), (1, 64, 28, 28)),
((1, 256, 28, 28), (1, 256, 28, 28)),
((1, 128, 14, 14), (1, 128, 14, 14)),
((1, 128, 56, 56), (1, 128, 56, 56)),
((1, 32, 64, 64), (1, 32, 64, 64)),
((1, 512, 7, 7), (1, 512, 7, 7)),
((1, 32, 32, 32), (1, 32, 32, 32)),
],
)
@pytest.mark.push
def test_greater(shape_x, shape_y):
class Greater(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.greater(x, y)

x = torch.rand(shape_x)
y = torch.rand(shape_y)

inputs = [x, y]

framework_model = Greater()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False))


@pytest.mark.parametrize(
"shape_x, shape_y",
[
((1, 128, 28, 28), (1, 128, 28, 28)),
((1, 64, 28, 28), (1, 64, 28, 28)),
((1, 256, 28, 28), (1, 256, 28, 28)),
((1, 128, 14, 14), (1, 128, 14, 14)),
((1, 128, 56, 56), (1, 128, 56, 56)),
((1, 32, 64, 64), (1, 32, 64, 64)),
((1, 512, 7, 7), (1, 512, 7, 7)),
((1, 32, 32, 32), (1, 32, 32, 32)),
],
)
@pytest.mark.push
def test_not_equal(shape_x, shape_y):
class NotEqual(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.ne(x, y)

x = torch.rand(shape_x)
y = torch.rand(shape_y)

inputs = [x, y]

framework_model = NotEqual()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False))


@pytest.mark.parametrize(
"shape",
[
(1, 128, 28, 28),
(1, 64, 28, 28),
(1, 256, 28, 28),
(1, 128, 14, 14),
(1, 128, 56, 56),
(1, 32, 64, 64),
(1, 512, 7, 7),
(1, 32, 32, 32),
(128, 28, 28),
(64, 28, 28),
(256, 28, 28),
(128, 14, 14),
(128, 56, 56),
(32, 64, 64),
(512, 7, 7),
(32, 32, 32),
(128, 28),
(64, 28),
(256, 28),
(128, 14),
(128, 56),
(32, 64),
(512, 7),
(32, 32),
],
)
@pytest.mark.push
def test_equal(shape):
class Equal(nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
return torch.eq(x, y)

x = torch.rand(shape)
y = x * 2.0

inputs = [x, y]

framework_model = Equal()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False))


@pytest.mark.push
def test_add():
class Add(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b):
return a + b

inputs = [torch.rand(2, 32, 32), torch.rand(2, 32, 32)]

framework_model = Add()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize("dims", [(1, 32, 64), (6, 33), (4, 16, 17)])
@pytest.mark.push
def test_greater_equal(dims):
class GreaterEqual(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b):
return torch.greater_equal(a, b)

inputs = [torch.rand(dims), torch.rand(dims)]

framework_model = GreaterEqual()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False))


@pytest.mark.push
def test_subtract():
class Subtract(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b):
return a - b

inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32)]

framework_model = Subtract()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False))


@pytest.mark.parametrize(
"shape",
[
(1, 32, 32),
(12, 8640),
],
)
@pytest.mark.push
def test_multiply(shape):
class Multiply(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b):
return a * b

inputs = [torch.rand(shape), torch.rand(shape)]

framework_model = Multiply()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)


@pytest.mark.push
def test_remainder():
class Remainder(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b):
return a % b

inputs = [torch.rand(2, 32, 32), torch.rand(2, 32, 32)]

framework_model = Remainder()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)
74 changes: 74 additions & 0 deletions forge/test/mlir/operators/eltwise_nary/test_eltwise_nary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
from torch import nn

import forge
from forge.verify.verify import verify


@pytest.mark.parametrize(
"condition, input, other",
[
(
[[1, 0], [0, 1]],
[[1, 2], [3, 4]],
[[10, 20], [30, 40]],
),
],
)
@pytest.mark.xfail(reason="Unsupported data format during lowering from TTForge to TTIR: Bfp2_b")
@pytest.mark.push
def test_where(condition, input, other):
class Where(nn.Module):
def __init__(self):
super().__init__()

def forward(self, condition, input1, input2):
return torch.where(condition, input1, input2)

condition = torch.tensor(condition, dtype=torch.bool)
input = torch.tensor(input)
other = torch.tensor(other)

inputs = [condition, input, other]

framework_model = Where()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)


@pytest.mark.parametrize(
"inputs_and_dim",
[
((2, 2, 32, 32), (2, 2, 32, 32), 0),
((2, 2, 32, 32), (2, 2, 32, 32), 1),
((2, 2, 32, 32), (2, 2, 32, 32), 2),
((2, 2, 32, 32), (2, 2, 32, 32), 3),
((2, 2, 32, 32), (2, 2, 32, 32), -1),
((2, 2, 32, 32), (2, 2, 32, 32), -2),
((2, 2, 32, 32), (2, 2, 32, 32), -3),
((2, 2, 32, 32), (2, 2, 32, 32), -4),
],
ids=["0", "1", "2", "3", "-1", "-2", "-3", "-4"],
)
@pytest.mark.push
def test_concat(inputs_and_dim):
in_shape1, in_shape2, dim = inputs_and_dim

class Concat(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b):
return torch.cat((a, b), dim)

inputs = [torch.rand(in_shape1), torch.rand(in_shape2)]

framework_model = Concat()
compiled_model = forge.compile(framework_model, sample_inputs=inputs)

verify(inputs, framework_model, compiled_model)
Loading

0 comments on commit 64a3200

Please sign in to comment.