Skip to content

Commit

Permalink
Support more functionalities for MUSA backend (#3359)
Browse files Browse the repository at this point in the history
* Support more functionalities for MUSA backend

* fix lint
  • Loading branch information
fmo-mt authored Jan 23, 2025
1 parent 4c2c89e commit 8f2d31c
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ def __init__(
if compare_versions("deepspeed-mlu", "<", "0.10.1"):
raise ImportError("DeepSpeed MLU version must be >= 0.10.1. Please update DeepSpeed MLU.")
elif is_musa_available():
if compare_versions("deepspeed", ">", "0.14.3"):
raise ImportError("DeepSpeed MUSA version must be <= 0.14.3. Please downgrade DeepSpeed.")
if compare_versions("deepspeed", "<", "0.14.3"):
raise ImportError("DeepSpeed MUSA version must be >= 0.14.3. Please update DeepSpeed.")
elif compare_versions("deepspeed", "<", "0.9.3"):
raise ImportError("DeepSpeed version must be >= 0.9.3. Please update DeepSpeed.")

Expand Down
5 changes: 5 additions & 0 deletions src/accelerate/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
WEIGHTS_NAME,
get_pretty_name,
is_mlu_available,
is_musa_available,
is_torch_xla_available,
is_xpu_available,
load,
Expand Down Expand Up @@ -152,6 +153,8 @@ def save_accelerator_state(
states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all()
if is_mlu_available():
states["torch_mlu_manual_seed"] = torch.mlu.get_rng_state_all()
if is_musa_available():
states["torch_musa_manual_seed"] = torch.musa.get_rng_state_all()
else:
states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all()
if is_torch_xla_available():
Expand Down Expand Up @@ -275,6 +278,8 @@ def load_accelerator_state(
torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"])
if is_mlu_available():
torch.mlu.set_rng_state_all(states["torch_mlu_manual_seed"])
if is_musa_available():
torch.musa.set_rng_state_all(states["torch_musa_manual_seed"])
else:
torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"])
if is_torch_xla_available():
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/commands/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def env_command(args):
info["GPU type"] = torch.cuda.get_device_name()
if pt_mlu_available:
info["MLU type"] = torch.mlu.get_device_name()
if pt_musa_available:
info["MUSA type"] = torch.musa.get_device_name()
if pt_npu_available:
info["CANN version"] = torch.version.cann

Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from .utils.imports import (
is_mlu_available,
is_musa_available,
is_npu_available,
is_xpu_available,
)
Expand Down Expand Up @@ -391,6 +392,8 @@ def post_forward(self, module, output):
device = f"npu:{device}"
elif is_mlu_available():
device = f"mlu:{device}"
elif is_musa_available():
device = f"musa:{device}"
elif is_xpu_available():
device = f"xpu:{device}"
del self.tied_params_map[value_pointer][device]
Expand Down
3 changes: 3 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
is_cuda_available,
is_mlu_available,
is_msamp_available,
is_musa_available,
is_npu_available,
is_transformer_engine_available,
is_xpu_available,
Expand Down Expand Up @@ -1686,6 +1687,8 @@ def __post_init__(self):
device = torch.npu.current_device()
elif is_mlu_available():
device = torch.mlu.current_device()
elif is_musa_available():
device = torch.musa.current_device()
elif is_cuda_available():
device = torch.cuda.current_device()
elif is_xpu_available():
Expand Down

0 comments on commit 8f2d31c

Please sign in to comment.