Skip to content

Commit

Permalink
[Fix] Fix optim_wrapper unittest for pytorch <= 1.10.0 (#975)
Browse files Browse the repository at this point in the history
  • Loading branch information
C1rN09 authored Mar 2, 2023
1 parent 2ed8e34 commit b1b1f53
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/test_optim/test_optimizer/test_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from mmengine.optim import AmpOptimWrapper, ApexOptimWrapper, OptimWrapper
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase
from mmengine.utils.dl_utils import TORCH_VERSION
from mmengine.utils.version_utils import digit_version

is_apex_available = False
try:
Expand Down Expand Up @@ -438,6 +440,10 @@ def test_init(self):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_step(self, dtype):
if dtype is not None and (digit_version(TORCH_VERSION) <
digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
raise unittest.SkipTest('bfloat16 not supported by device')
optimizer = MagicMock(spec=Optimizer)
Expand All @@ -454,6 +460,10 @@ def test_step(self, dtype):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_backward(self, dtype):
if dtype is not None and (digit_version(TORCH_VERSION) <
digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
raise unittest.SkipTest('bfloat16 not supported by device')
amp_optim_wrapper = AmpOptimWrapper(
Expand Down Expand Up @@ -512,6 +522,10 @@ def test_load_state_dict(self):
not torch.cuda.is_available(),
reason='`torch.cuda.amp` is only available when pytorch-gpu installed')
def test_optim_context(self, dtype, target_dtype):
if dtype is not None and (digit_version(TORCH_VERSION) <
digit_version('1.10.0')):
raise unittest.SkipTest('Require PyTorch version >= 1.10.0 to '
'support `dtype` argument in autocast')
if dtype == 'bfloat16' and not bf16_supported():
raise unittest.SkipTest('bfloat16 not supported by device')
amp_optim_wrapper = AmpOptimWrapper(
Expand Down

0 comments on commit b1b1f53

Please sign in to comment.