Skip to content

Commit

Permalink
Reshape operator test (#862)
Browse files Browse the repository at this point in the history
  • Loading branch information
vobojevicTT authored Jan 10, 2025
1 parent d9f3cfb commit 5002141
Show file tree
Hide file tree
Showing 8 changed files with 787 additions and 50 deletions.
8 changes: 8 additions & 0 deletions forge/test/operators/pytorch/eltwise_unary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

from .models import ModelFromAnotherOp, ModelDirect, ModelConstEvalPass

__all__ = [
"ModelFromAnotherOp",
"ModelDirect",
"ModelConstEvalPass",
]
45 changes: 45 additions & 0 deletions forge/test/operators/pytorch/eltwise_unary/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn as nn

from forge.op_repo import TensorShape


class ModelFromAnotherOp(nn.Module):
def __init__(self, operator, kwargs):
super().__init__()
self.testname = "Element_wise_unary_operators_test_op_src_from_another_op"
self.operator = operator
self.kwargs = kwargs

def forward(self, x):
xx = torch.add(x, x)
return self.operator(xx, **self.kwargs)


class ModelDirect(nn.Module):
def __init__(self, operator, kwargs):
super().__init__()
self.testname = "Element_wise_unary_operators_test_op_src_from_host"
self.operator = operator
self.kwargs = kwargs

def forward(self, x):
return self.operator(x, **self.kwargs)


class ModelConstEvalPass(nn.Module):
def __init__(self, operator, shape: TensorShape, kwargs):
super().__init__()
self.testname = "Element_wise_unary_operators_test_op_src_const_eval_pass"
self.operator = operator
self.kwargs = kwargs
self.c = (torch.rand(shape, requires_grad=False) - 0.5).detach()

def forward(self, x):
cc = self.operator(self.c, **self.kwargs)
xx = self.operator(x, **self.kwargs)
return torch.add(xx, cc)
41 changes: 1 addition & 40 deletions forge/test/operators/pytorch/eltwise_unary/test_unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@
# (/) Reuse inputs for selected operators


import pytest
import torch
import torch.nn as nn
import forge
from forge.op_repo import TensorShape

from typing import List, Dict
from loguru import logger
Expand All @@ -69,42 +65,7 @@
from test.operators.utils import TestCollectionCommon
from test.operators.utils import ValueRanges


class ModelFromAnotherOp(nn.Module):
def __init__(self, operator, kwargs):
super().__init__()
self.testname = "Element_wise_unary_operators_test_op_src_from_another_op"
self.operator = operator
self.kwargs = kwargs

def forward(self, x):
xx = torch.add(x, x)
return self.operator(xx, **self.kwargs)


class ModelDirect(nn.Module):
def __init__(self, operator, kwargs):
super().__init__()
self.testname = "Element_wise_unary_operators_test_op_src_from_host"
self.operator = operator
self.kwargs = kwargs

def forward(self, x):
return self.operator(x, **self.kwargs)


class ModelConstEvalPass(nn.Module):
def __init__(self, operator, shape: TensorShape, kwargs):
super().__init__()
self.testname = "Element_wise_unary_operators_test_op_src_const_eval_pass"
self.operator = operator
self.kwargs = kwargs
self.c = (torch.rand(shape, requires_grad=False) - 0.5).detach()

def forward(self, x):
cc = self.operator(self.c, **self.kwargs)
xx = self.operator(x, **self.kwargs)
return torch.add(xx, cc)
from .models import ModelFromAnotherOp, ModelDirect, ModelConstEvalPass


class TestVerification:
Expand Down
Loading

0 comments on commit 5002141

Please sign in to comment.