-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Different types of indexing in pytorch (#976)
Exploring different types of indexing in pytorch and looking at what is supported through `tvm`/`forge-fe`/`mlir`/`ttnn`. More info about results of testing can be found [here](https://docs.google.com/document/d/1SoMZWKsplNIRXx01HDeWQWJ2NXhyQ0VHYZ9o7Pxr2h8/edit?usp=sharing).
- Loading branch information
1 parent
1db1f92
commit 64a3200
Showing
17 changed files
with
4,356 additions
and
2,080 deletions.
There are no files selected for viewing
252 changes: 252 additions & 0 deletions
252
forge/test/mlir/operators/eltwise_binary/test_eltwise_binary.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
import forge | ||
from forge.verify.verify import verify | ||
from forge.verify.config import VerifyConfig | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape_x, shape_y", | ||
[ | ||
((1, 128, 28, 28), (1, 128, 28, 28)), | ||
((1, 64, 28, 28), (1, 64, 28, 28)), | ||
((1, 256, 28, 28), (1, 256, 28, 28)), | ||
((1, 128, 14, 14), (1, 128, 14, 14)), | ||
((1, 128, 56, 56), (1, 128, 56, 56)), | ||
((1, 32, 64, 64), (1, 32, 64, 64)), | ||
((1, 512, 7, 7), (1, 512, 7, 7)), | ||
((1, 32, 32, 32), (1, 32, 32, 32)), | ||
], | ||
) | ||
@pytest.mark.push | ||
def test_less(shape_x, shape_y): | ||
class Less(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x, y): | ||
return torch.less(x, y) | ||
|
||
x = torch.rand(shape_x) | ||
y = torch.rand(shape_y) | ||
|
||
inputs = [x, y] | ||
|
||
framework_model = Less() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape_x, shape_y", | ||
[ | ||
((1, 128, 28, 28), (1, 128, 28, 28)), | ||
((1, 64, 28, 28), (1, 64, 28, 28)), | ||
((1, 256, 28, 28), (1, 256, 28, 28)), | ||
((1, 128, 14, 14), (1, 128, 14, 14)), | ||
((1, 128, 56, 56), (1, 128, 56, 56)), | ||
((1, 32, 64, 64), (1, 32, 64, 64)), | ||
((1, 512, 7, 7), (1, 512, 7, 7)), | ||
((1, 32, 32, 32), (1, 32, 32, 32)), | ||
], | ||
) | ||
@pytest.mark.push | ||
def test_greater(shape_x, shape_y): | ||
class Greater(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x, y): | ||
return torch.greater(x, y) | ||
|
||
x = torch.rand(shape_x) | ||
y = torch.rand(shape_y) | ||
|
||
inputs = [x, y] | ||
|
||
framework_model = Greater() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape_x, shape_y", | ||
[ | ||
((1, 128, 28, 28), (1, 128, 28, 28)), | ||
((1, 64, 28, 28), (1, 64, 28, 28)), | ||
((1, 256, 28, 28), (1, 256, 28, 28)), | ||
((1, 128, 14, 14), (1, 128, 14, 14)), | ||
((1, 128, 56, 56), (1, 128, 56, 56)), | ||
((1, 32, 64, 64), (1, 32, 64, 64)), | ||
((1, 512, 7, 7), (1, 512, 7, 7)), | ||
((1, 32, 32, 32), (1, 32, 32, 32)), | ||
], | ||
) | ||
@pytest.mark.push | ||
def test_not_equal(shape_x, shape_y): | ||
class NotEqual(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x, y): | ||
return torch.ne(x, y) | ||
|
||
x = torch.rand(shape_x) | ||
y = torch.rand(shape_y) | ||
|
||
inputs = [x, y] | ||
|
||
framework_model = NotEqual() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape", | ||
[ | ||
(1, 128, 28, 28), | ||
(1, 64, 28, 28), | ||
(1, 256, 28, 28), | ||
(1, 128, 14, 14), | ||
(1, 128, 56, 56), | ||
(1, 32, 64, 64), | ||
(1, 512, 7, 7), | ||
(1, 32, 32, 32), | ||
(128, 28, 28), | ||
(64, 28, 28), | ||
(256, 28, 28), | ||
(128, 14, 14), | ||
(128, 56, 56), | ||
(32, 64, 64), | ||
(512, 7, 7), | ||
(32, 32, 32), | ||
(128, 28), | ||
(64, 28), | ||
(256, 28), | ||
(128, 14), | ||
(128, 56), | ||
(32, 64), | ||
(512, 7), | ||
(32, 32), | ||
], | ||
) | ||
@pytest.mark.push | ||
def test_equal(shape): | ||
class Equal(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x, y): | ||
return torch.eq(x, y) | ||
|
||
x = torch.rand(shape) | ||
y = x * 2.0 | ||
|
||
inputs = [x, y] | ||
|
||
framework_model = Equal() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False)) | ||
|
||
|
||
@pytest.mark.push | ||
def test_add(): | ||
class Add(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b): | ||
return a + b | ||
|
||
inputs = [torch.rand(2, 32, 32), torch.rand(2, 32, 32)] | ||
|
||
framework_model = Add() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model) | ||
|
||
|
||
@pytest.mark.parametrize("dims", [(1, 32, 64), (6, 33), (4, 16, 17)]) | ||
@pytest.mark.push | ||
def test_greater_equal(dims): | ||
class GreaterEqual(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b): | ||
return torch.greater_equal(a, b) | ||
|
||
inputs = [torch.rand(dims), torch.rand(dims)] | ||
|
||
framework_model = GreaterEqual() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False)) | ||
|
||
|
||
@pytest.mark.push | ||
def test_subtract(): | ||
class Subtract(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b): | ||
return a - b | ||
|
||
inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32)] | ||
|
||
framework_model = Subtract() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model, VerifyConfig(verify_dtype=False)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"shape", | ||
[ | ||
(1, 32, 32), | ||
(12, 8640), | ||
], | ||
) | ||
@pytest.mark.push | ||
def test_multiply(shape): | ||
class Multiply(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b): | ||
return a * b | ||
|
||
inputs = [torch.rand(shape), torch.rand(shape)] | ||
|
||
framework_model = Multiply() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model) | ||
|
||
|
||
@pytest.mark.push | ||
def test_remainder(): | ||
class Remainder(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b): | ||
return a % b | ||
|
||
inputs = [torch.rand(2, 32, 32), torch.rand(2, 32, 32)] | ||
|
||
framework_model = Remainder() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model) |
74 changes: 74 additions & 0 deletions
74
forge/test/mlir/operators/eltwise_nary/test_eltwise_nary.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
import forge | ||
from forge.verify.verify import verify | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"condition, input, other", | ||
[ | ||
( | ||
[[1, 0], [0, 1]], | ||
[[1, 2], [3, 4]], | ||
[[10, 20], [30, 40]], | ||
), | ||
], | ||
) | ||
@pytest.mark.xfail(reason="Unsupported data format during lowering from TTForge to TTIR: Bfp2_b") | ||
@pytest.mark.push | ||
def test_where(condition, input, other): | ||
class Where(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, condition, input1, input2): | ||
return torch.where(condition, input1, input2) | ||
|
||
condition = torch.tensor(condition, dtype=torch.bool) | ||
input = torch.tensor(input) | ||
other = torch.tensor(other) | ||
|
||
inputs = [condition, input, other] | ||
|
||
framework_model = Where() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"inputs_and_dim", | ||
[ | ||
((2, 2, 32, 32), (2, 2, 32, 32), 0), | ||
((2, 2, 32, 32), (2, 2, 32, 32), 1), | ||
((2, 2, 32, 32), (2, 2, 32, 32), 2), | ||
((2, 2, 32, 32), (2, 2, 32, 32), 3), | ||
((2, 2, 32, 32), (2, 2, 32, 32), -1), | ||
((2, 2, 32, 32), (2, 2, 32, 32), -2), | ||
((2, 2, 32, 32), (2, 2, 32, 32), -3), | ||
((2, 2, 32, 32), (2, 2, 32, 32), -4), | ||
], | ||
ids=["0", "1", "2", "3", "-1", "-2", "-3", "-4"], | ||
) | ||
@pytest.mark.push | ||
def test_concat(inputs_and_dim): | ||
in_shape1, in_shape2, dim = inputs_and_dim | ||
|
||
class Concat(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b): | ||
return torch.cat((a, b), dim) | ||
|
||
inputs = [torch.rand(in_shape1), torch.rand(in_shape2)] | ||
|
||
framework_model = Concat() | ||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
|
||
verify(inputs, framework_model, compiled_model) |
Oops, something went wrong.