From b1b1f53db2712904c04ab61356fe257eb8e9803a Mon Sep 17 00:00:00 2001 From: Qian Zhao <112053249+C1rN09@users.noreply.github.com> Date: Thu, 2 Mar 2023 14:14:23 +0800 Subject: [PATCH] [Fix] Fix optim_wrapper unittest for pytorch <= 1.10.0 (#975) --- .../test_optimizer/test_optimizer_wrapper.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 5ebebcb4c2..dc9727de63 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -17,6 +17,8 @@ from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase +from mmengine.utils.dl_utils import TORCH_VERSION +from mmengine.utils.version_utils import digit_version is_apex_available = False try: @@ -438,6 +440,10 @@ def test_init(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_step(self, dtype): + if dtype is not None and (digit_version(TORCH_VERSION) < + digit_version('1.10.0')): + raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' + 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): raise unittest.SkipTest('bfloat16 not supported by device') optimizer = MagicMock(spec=Optimizer) @@ -454,6 +460,10 @@ def test_step(self, dtype): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_backward(self, dtype): + if dtype is not None and (digit_version(TORCH_VERSION) < + digit_version('1.10.0')): + raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' + 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): raise unittest.SkipTest('bfloat16 not supported by device') amp_optim_wrapper = AmpOptimWrapper( @@ -512,6 +522,10 @@ def test_load_state_dict(self): not torch.cuda.is_available(), reason='`torch.cuda.amp` is only available when pytorch-gpu installed') def test_optim_context(self, dtype, target_dtype): + if dtype is not None and (digit_version(TORCH_VERSION) < + digit_version('1.10.0')): + raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to ' + 'support `dtype` argument in autocast') if dtype == 'bfloat16' and not bf16_supported(): raise unittest.SkipTest('bfloat16 not supported by device') amp_optim_wrapper = AmpOptimWrapper(