Skip to content

Commit

Permalink
Add squeeze operator test
Browse files Browse the repository at this point in the history
  • Loading branch information
vobojevicTT committed Dec 25, 2024
1 parent 40997ac commit 038649d
Showing 1 changed file with 133 additions and 0 deletions.
133 changes: 133 additions & 0 deletions forge/test/operators/pytorch/tm/test_squeeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0

import forge
import math
import torch
import pytest
import random
import os

from typing import List, Dict
from loguru import logger

from forge.verify.config import VerifyConfig

from forge.verify.value_checkers import AllCloseValueChecker
from forge.verify.verify import verify as forge_verify

from test.operators.utils import InputSourceFlags, VerifyUtils
from test.operators.utils import InputSource
from test.operators.utils import TestVector
from test.operators.utils import TestPlan
from test.operators.utils import TestPlanUtils
from test.operators.utils import FailingReasons
from test.operators.utils.compat import TestDevice
from test.operators.utils import TestCollection
from test.operators.utils import TestCollectionCommon
from test.operators.utils import ValueRanges

from test.operators.pytorch.eltwise_unary import ModelFromAnotherOp, ModelDirect, ModelConstEvalPass

Check failure on line 31 in forge/test/operators/pytorch/tm/test_squeeze.py

View workflow job for this annotation

GitHub Actions / TT-Forge-FE Tests

test_squeeze.forge.test.operators.pytorch.tm.test_squeeze

collection failure
Raw output
ImportError while importing test module '/__w/tt-forge-fe/tt-forge-fe/forge/test/operators/pytorch/tm/test_squeeze.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
/opt/ttforge-toolchain/venv/lib/python3.10/site-packages/_pytest/python.py:578: in _importtestmodule
    mod = import_path(self.fspath, mode=importmode)
/opt/ttforge-toolchain/venv/lib/python3.10/site-packages/_pytest/pathlib.py:524: in import_path
    importlib.import_module(module_name)
/usr/lib/python3.10/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
<frozen importlib._bootstrap>:1050: in _gcd_import
    ???
<frozen importlib._bootstrap>:1027: in _find_and_load
    ???
<frozen importlib._bootstrap>:1006: in _find_and_load_unlocked
    ???
<frozen importlib._bootstrap>:688: in _load_unlocked
    ???
/opt/ttforge-toolchain/venv/lib/python3.10/site-packages/_pytest/assertion/rewrite.py:170: in exec_module
    exec(co, module.__dict__)
forge/test/operators/pytorch/tm/test_squeeze.py:31: in <module>
    from test.operators.pytorch.eltwise_unary import ModelFromAnotherOp, ModelDirect, ModelConstEvalPass
E   ImportError: cannot import name 'ModelFromAnotherOp' from 'test.operators.pytorch.eltwise_unary' (/__w/tt-forge-fe/tt-forge-fe/forge/test/operators/pytorch/eltwise_unary/__init__.py)


class TestVerification:

MODEL_TYPES = {
InputSource.FROM_ANOTHER_OP: ModelFromAnotherOp,
InputSource.FROM_HOST: ModelDirect,
InputSource.FROM_DRAM_QUEUE: ModelDirect,
InputSource.CONST_EVAL_PASS: ModelConstEvalPass,
}

@classmethod
def verify(
cls,
test_device: TestDevice,
test_vector: TestVector,
input_params: List[Dict] = [],
warm_reset: bool = False,
):

input_source_flag: InputSourceFlags = None
if test_vector.input_source in (InputSource.FROM_DRAM_QUEUE,):
input_source_flag = InputSourceFlags.FROM_DRAM

operator = getattr(torch, test_vector.operator)
kwargs = test_vector.kwargs if test_vector.kwargs else {}

model_type = cls.MODEL_TYPES[test_vector.input_source]
pytorch_model = (
model_type(operator, test_vector.input_shape, kwargs)
if test_vector.input_source in (InputSource.CONST_EVAL_PASS,)
else model_type(operator, kwargs)
)

input_shapes = tuple([test_vector.input_shape])

logger.trace(f"***input_shapes: {input_shapes}")

VerifyUtils.verify(
model=pytorch_model,
test_device=test_device,
input_shapes=input_shapes,
input_params=input_params,
input_source_flag=input_source_flag,
dev_data_format=test_vector.dev_data_format,
math_fidelity=test_vector.math_fidelity,
warm_reset=warm_reset,
value_range=ValueRanges.SMALL,
deprecated_verification=False,
verify_config=VerifyConfig(value_checker=AllCloseValueChecker()),
)


class TestParamsData:

__test__ = False

test_plan: TestPlan = None

@classmethod
def generate_kwargs(cls, test_vector: TestVector):
return {}


TestParamsData.test_plan = TestPlan(
verify=lambda test_device, test_vector: TestVerification.verify(
test_device,
test_vector,
),
collections=[
# Test operators with all shapes and input sources collection:
TestCollection(
operators=["squeeze"],
input_sources=TestCollectionCommon.all.input_sources,
input_shapes=TestCollectionCommon.all.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
),
# Test Data formats collection:
TestCollection(
operators=["squeeze"],
input_sources=TestCollectionCommon.single.input_sources,
input_shapes=TestCollectionCommon.single.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
dev_data_formats=[
item
for item in TestCollectionCommon.all.dev_data_formats
if item not in TestCollectionCommon.single.dev_data_formats
],
math_fidelities=TestCollectionCommon.single.math_fidelities,
),
# Test Math fidelities collection:
TestCollection(
operators=["squeeze"],
input_sources=TestCollectionCommon.single.input_sources,
input_shapes=TestCollectionCommon.single.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector),
dev_data_formats=TestCollectionCommon.single.dev_data_formats,
math_fidelities=TestCollectionCommon.all.math_fidelities,
),
],
failing_rules=[], # No failing rules for this test plan
)

0 comments on commit 038649d

Please sign in to comment.