diff --git a/forge/test/operators/pytorch/tm/test_squeeze.py b/forge/test/operators/pytorch/tm/test_squeeze.py new file mode 100644 index 000000000..ef0c9de41 --- /dev/null +++ b/forge/test/operators/pytorch/tm/test_squeeze.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import forge +import math +import torch +import pytest +import random +import os + +from typing import List, Dict +from loguru import logger + +from forge.verify.config import VerifyConfig + +from forge.verify.value_checkers import AllCloseValueChecker +from forge.verify.verify import verify as forge_verify + +from test.operators.utils import InputSourceFlags, VerifyUtils +from test.operators.utils import InputSource +from test.operators.utils import TestVector +from test.operators.utils import TestPlan +from test.operators.utils import TestPlanUtils +from test.operators.utils import FailingReasons +from test.operators.utils.compat import TestDevice +from test.operators.utils import TestCollection +from test.operators.utils import TestCollectionCommon +from test.operators.utils import ValueRanges + +from test.operators.pytorch.eltwise_unary import ModelFromAnotherOp, ModelDirect, ModelConstEvalPass + + +class TestVerification: + + MODEL_TYPES = { + InputSource.FROM_ANOTHER_OP: ModelFromAnotherOp, + InputSource.FROM_HOST: ModelDirect, + InputSource.FROM_DRAM_QUEUE: ModelDirect, + InputSource.CONST_EVAL_PASS: ModelConstEvalPass, + } + + @classmethod + def verify( + cls, + test_device: TestDevice, + test_vector: TestVector, + input_params: List[Dict] = [], + warm_reset: bool = False, + ): + + input_source_flag: InputSourceFlags = None + if test_vector.input_source in (InputSource.FROM_DRAM_QUEUE,): + input_source_flag = InputSourceFlags.FROM_DRAM + + operator = getattr(torch, test_vector.operator) + kwargs = test_vector.kwargs if test_vector.kwargs else {} + + model_type = cls.MODEL_TYPES[test_vector.input_source] + pytorch_model = ( + model_type(operator, test_vector.input_shape, kwargs) + if test_vector.input_source in (InputSource.CONST_EVAL_PASS,) + else model_type(operator, kwargs) + ) + + input_shapes = tuple([test_vector.input_shape]) + + logger.trace(f"***input_shapes: {input_shapes}") + + VerifyUtils.verify( + model=pytorch_model, + test_device=test_device, + input_shapes=input_shapes, + input_params=input_params, + input_source_flag=input_source_flag, + dev_data_format=test_vector.dev_data_format, + math_fidelity=test_vector.math_fidelity, + warm_reset=warm_reset, + value_range=ValueRanges.SMALL, + # deprecated_verification=False, + # verify_config=VerifyConfig(value_checker=AllCloseValueChecker()), + ) + + +class TestParamsData: + + __test__ = False + + test_plan: TestPlan = None + + # fmt: off + specific_squeezes = { + # input_shape: (squeeze_dims) # squeeze_dims=None means all dimensions are squeezed + (2, 1, 2, 1, 2): (None, 0, 1), + (100, 1, 100, 1): (None, 1, 2, 3), + (84, 25, 100, 1, 41): (None, 1, 2, 3), + } + # fmt: on + + @classmethod + def generate_kwargs(cls, test_vector: TestVector): + return {} + + +TestParamsData.test_plan = TestPlan( + verify=lambda test_device, test_vector: TestVerification.verify( + test_device, + test_vector, + ), + collections=[ + # Test operators with all shapes and input sources collection: + TestCollection( + operators=["squeeze"], + input_sources=TestCollectionCommon.all.input_sources, + input_shapes=TestCollectionCommon.all.input_shapes, + # kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + ), + # Test Data formats collection: + TestCollection( + operators=["squeeze"], + input_sources=TestCollectionCommon.single.input_sources, + input_shapes=TestCollectionCommon.single.input_shapes, + # kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + dev_data_formats=[ + item + for item in TestCollectionCommon.all.dev_data_formats + if item not in TestCollectionCommon.single.dev_data_formats + ], + math_fidelities=TestCollectionCommon.single.math_fidelities, + ), + # Test Math fidelities collection: + TestCollection( + operators=["squeeze"], + input_sources=TestCollectionCommon.single.input_sources, + input_shapes=TestCollectionCommon.single.input_shapes, + # kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + dev_data_formats=TestCollectionCommon.single.dev_data_formats, + math_fidelities=TestCollectionCommon.all.math_fidelities, + ), + # Test specific squeezes collection: + TestCollection( + operators=["squeeze"], + input_sources=TestCollectionCommon.all.input_sources, + input_shapes=TestParamsData.specific_squeezes.keys(), + # kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + dev_data_formats=TestCollectionCommon.single.dev_data_formats, + math_fidelities=TestCollectionCommon.all.math_fidelities, + ), + ], + failing_rules=[], # No failing rules for this test plan +) + + +def get_test_plans() -> List[TestPlan]: + return [TestParamsData.test_plan]