diff --git a/forge/test/operators/pytorch/tm/test_squeeze.py b/forge/test/operators/pytorch/tm/test_squeeze.py new file mode 100644 index 000000000..f3ae4b14e --- /dev/null +++ b/forge/test/operators/pytorch/tm/test_squeeze.py @@ -0,0 +1,133 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + +import forge +import math +import torch +import pytest +import random +import os + +from typing import List, Dict +from loguru import logger + +from forge.verify.config import VerifyConfig + +from forge.verify.value_checkers import AllCloseValueChecker +from forge.verify.verify import verify as forge_verify + +from test.operators.utils import InputSourceFlags, VerifyUtils +from test.operators.utils import InputSource +from test.operators.utils import TestVector +from test.operators.utils import TestPlan +from test.operators.utils import TestPlanUtils +from test.operators.utils import FailingReasons +from test.operators.utils.compat import TestDevice +from test.operators.utils import TestCollection +from test.operators.utils import TestCollectionCommon +from test.operators.utils import ValueRanges + +from test.operators.pytorch.eltwise_unary import ModelFromAnotherOp, ModelDirect, ModelConstEvalPass + + +class TestVerification: + + MODEL_TYPES = { + InputSource.FROM_ANOTHER_OP: ModelFromAnotherOp, + InputSource.FROM_HOST: ModelDirect, + InputSource.FROM_DRAM_QUEUE: ModelDirect, + InputSource.CONST_EVAL_PASS: ModelConstEvalPass, + } + + @classmethod + def verify( + cls, + test_device: TestDevice, + test_vector: TestVector, + input_params: List[Dict] = [], + warm_reset: bool = False, + ): + + input_source_flag: InputSourceFlags = None + if test_vector.input_source in (InputSource.FROM_DRAM_QUEUE,): + input_source_flag = InputSourceFlags.FROM_DRAM + + operator = getattr(torch, test_vector.operator) + kwargs = test_vector.kwargs if test_vector.kwargs else {} + + model_type = cls.MODEL_TYPES[test_vector.input_source] + pytorch_model = ( + model_type(operator, test_vector.input_shape, kwargs) + if test_vector.input_source in (InputSource.CONST_EVAL_PASS,) + else model_type(operator, kwargs) + ) + + input_shapes = tuple([test_vector.input_shape]) + + logger.trace(f"***input_shapes: {input_shapes}") + + VerifyUtils.verify( + model=pytorch_model, + test_device=test_device, + input_shapes=input_shapes, + input_params=input_params, + input_source_flag=input_source_flag, + dev_data_format=test_vector.dev_data_format, + math_fidelity=test_vector.math_fidelity, + warm_reset=warm_reset, + value_range=ValueRanges.SMALL, + deprecated_verification=False, + verify_config=VerifyConfig(value_checker=AllCloseValueChecker()), + ) + + +class TestParamsData: + + __test__ = False + + test_plan: TestPlan = None + + @classmethod + def generate_kwargs(cls, test_vector: TestVector): + return {} + + +TestParamsData.test_plan = TestPlan( + verify=lambda test_device, test_vector: TestVerification.verify( + test_device, + test_vector, + ), + collections=[ + # Test operators with all shapes and input sources collection: + TestCollection( + operators=["squeeze"], + input_sources=TestCollectionCommon.all.input_sources, + input_shapes=TestCollectionCommon.all.input_shapes, + kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + ), + # Test Data formats collection: + TestCollection( + operators=["squeeze"], + input_sources=TestCollectionCommon.single.input_sources, + input_shapes=TestCollectionCommon.single.input_shapes, + kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + dev_data_formats=[ + item + for item in TestCollectionCommon.all.dev_data_formats + if item not in TestCollectionCommon.single.dev_data_formats + ], + math_fidelities=TestCollectionCommon.single.math_fidelities, + ), + # Test Math fidelities collection: + TestCollection( + operators=["squeeze"], + input_sources=TestCollectionCommon.single.input_sources, + input_shapes=TestCollectionCommon.single.input_shapes, + kwargs=lambda test_vector: TestParamsData.generate_kwargs(test_vector), + dev_data_formats=TestCollectionCommon.single.dev_data_formats, + math_fidelities=TestCollectionCommon.all.math_fidelities, + ), + ], + failing_rules=[], # No failing rules for this test plan +) \ No newline at end of file