Skip to content

Commit

Permalink
feat(deeplink): add deeplink as new backend (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
caikun-pjlab authored Apr 1, 2024
1 parent c1a1936 commit 85f6b7d
Show file tree
Hide file tree
Showing 13 changed files with 505 additions and 39 deletions.
26 changes: 24 additions & 2 deletions internlm/accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ class AcceleratorType(enum.Enum):
GPU = 1
NPU = 2
CPU = 3
OTHER = 4
DIPU = 4
OTHER = 5


internlm_accelerator = None
Expand Down Expand Up @@ -80,7 +81,7 @@ def get_accelerator():

accelerator_name = None
# 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
intern_accelerator_LIST = ["cuda", "npu"]
intern_accelerator_LIST = ["cuda", "npu", "dipu"]
if "INTERNLM_ACCELERATOR" in os.environ:
accelerator_name = os.environ["INTERNLM_ACCELERATOR"]
if accelerator_name == "npu":
Expand All @@ -89,13 +90,30 @@ def get_accelerator():
except (ImportError, ModuleNotFoundError):
raise ValueError("NPU_Accelerator requires torch_npu, which is not installed on this system.")
pass
elif accelerator_name == "dipu":
try:
import deeplink_ext # noqa # pylint: disable=W0611
import torch_dipu # noqa # pylint: disable=W0611
except (ImportError, ModuleNotFoundError):
raise ValueError(
"DIPU_Accelerator requires torch_dipu and deeplink_ext, which is not installed on this system."
)
pass
elif accelerator_name != "cuda":
raise ValueError(
f"accelerator_name must be one of {intern_accelerator_LIST}."
+ " Value '{accelerator_name}' is not supported"
)

# 2. If no override, detect which accelerator to use automatically
if accelerator_name is None:
try:
import deeplink_ext # noqa: F401,F811 # type: ignore
import torch_dipu # noqa: F401,F811 # type: ignore

accelerator_name = "dipu"
except (ImportError, ModuleNotFoundError):
pass
if accelerator_name is None:
try:
import torch_npu # noqa: F401,F811 # type: ignore
Expand All @@ -115,5 +133,9 @@ def get_accelerator():
from .npu_accelerator import ASCEND_Accelerator

internlm_accelerator = ASCEND_Accelerator()
elif accelerator_name == "dipu":
from .dipu_accelerator import DIPU_Accelerator

internlm_accelerator = DIPU_Accelerator()

return internlm_accelerator
Loading

0 comments on commit 85f6b7d

Please sign in to comment.