diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index f8df049c88d..46fb02f47f4 100755 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -329,8 +329,8 @@ def __init__( if compare_versions("deepspeed-mlu", "<", "0.10.1"): raise ImportError("DeepSpeed MLU version must be >= 0.10.1. Please update DeepSpeed MLU.") elif is_musa_available(): - if compare_versions("deepspeed", ">", "0.14.3"): - raise ImportError("DeepSpeed MUSA version must be <= 0.14.3. Please downgrade DeepSpeed.") + if compare_versions("deepspeed", "<", "0.14.3"): + raise ImportError("DeepSpeed MUSA version must be >= 0.14.3. Please update DeepSpeed.") elif compare_versions("deepspeed", "<", "0.9.3"): raise ImportError("DeepSpeed version must be >= 0.9.3. Please update DeepSpeed.") diff --git a/src/accelerate/checkpointing.py b/src/accelerate/checkpointing.py index 432c0994fc0..e83a29ea575 100644 --- a/src/accelerate/checkpointing.py +++ b/src/accelerate/checkpointing.py @@ -33,6 +33,7 @@ WEIGHTS_NAME, get_pretty_name, is_mlu_available, + is_musa_available, is_torch_xla_available, is_xpu_available, load, @@ -152,6 +153,8 @@ def save_accelerator_state( states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all() if is_mlu_available(): states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all() + if is_musa_available(): + states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all() else: states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() if is_torch_xla_available(): @@ -275,6 +278,8 @@ def load_accelerator_state( torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"]) if is_mlu_available(): torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"]) + if is_musa_available(): + torch.musa.set_rng_state_all(states["torch_musa_manual_seed"]) else: torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"]) if is_torch_xla_available(): diff --git a/src/accelerate/commands/env.py b/src/accelerate/commands/env.py index 7dd5995f6b4..841e3c7e60c 100644 --- a/src/accelerate/commands/env.py +++ b/src/accelerate/commands/env.py @@ -83,6 +83,8 @@ def env_command(args): info["GPU type"] = torch.cuda.get_device_name() if pt_mlu_available: info["MLU type"] = torch.mlu.get_device_name() + if pt_musa_available: + info["MUSA type"] = torch.musa.get_device_name() if pt_npu_available: info["CANN version"] = torch.version.cann diff --git a/src/accelerate/hooks.py b/src/accelerate/hooks.py index 2f19293271d..7aa41788017 100644 --- a/src/accelerate/hooks.py +++ b/src/accelerate/hooks.py @@ -28,6 +28,7 @@ ) from .utils.imports import ( is_mlu_available, + is_musa_available, is_npu_available, is_xpu_available, ) @@ -391,6 +392,8 @@ def post_forward(self, module, output): device = f"npu:{device}" elif is_mlu_available(): device = f"mlu:{device}" + elif is_musa_available(): + device = f"musa:{device}" elif is_xpu_available(): device = f"xpu:{device}" del self.tied_params_map[value_pointer][device] diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 404cd377e3c..cf66ed56a6e 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -41,6 +41,7 @@ is_cuda_available, is_mlu_available, is_msamp_available, + is_musa_available, is_npu_available, is_transformer_engine_available, is_xpu_available, @@ -1686,6 +1687,8 @@ def __post_init__(self): device = torch.npu.current_device() elif is_mlu_available(): device = torch.mlu.current_device() + elif is_musa_available(): + device = torch.musa.current_device() elif is_cuda_available(): device = torch.cuda.current_device() elif is_xpu_available():