From 76de60dbdcdbd589fbaf27ca67c1f418bb5a7850 Mon Sep 17 00:00:00 2001 From: OTABI Tomoya Date: Wed, 8 Nov 2023 21:08:32 +0900 Subject: [PATCH] Fix import error when torch>=2.0.1 and torch.distributed is disabled (#2121) --- src/accelerate/utils/other.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/accelerate/utils/other.py b/src/accelerate/utils/other.py index 970b2969478..a1e2c3bb2ce 100644 --- a/src/accelerate/utils/other.py +++ b/src/accelerate/utils/other.py @@ -28,7 +28,7 @@ from ..state import PartialState from .constants import FSDP_PYTORCH_VERSION from .dataclasses import DistributedType -from .imports import is_deepspeed_available, is_safetensors_available, is_tpu_available +from .imports import is_deepspeed_available, is_safetensors_available, is_torch_distributed_available, is_tpu_available from .transformer_engine import convert_model from .versions import is_torch_version @@ -77,7 +77,7 @@ def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True): options += (DeepSpeedEngine,) - if is_torch_version(">=", FSDP_PYTORCH_VERSION): + if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available(): from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP options += (FSDP,)