From 85f6b7d0def444b5a81dd3a0170999543058e03d Mon Sep 17 00:00:00 2001 From: caikun-pjlab <116071181+caikun-pjlab@users.noreply.github.com> Date: Mon, 1 Apr 2024 20:58:36 +0800 Subject: [PATCH] feat(deeplink): add deeplink as new backend (#168) --- internlm/accelerator/abstract_accelerator.py | 26 +- internlm/accelerator/dipu_accelerator.py | 388 ++++++++++++++++++ internlm/core/communication/isp.py | 20 +- internlm/initialize/launch.py | 9 +- internlm/model/metrics.py | 8 +- internlm/model/modeling_internlm2.py | 4 +- internlm/model/modeling_llama.py | 4 +- internlm/model/modules/embedding.py | 16 +- .../model/modules/multi_head_attention.py | 19 +- internlm/model/utils.py | 30 +- internlm/train/pipeline.py | 3 +- internlm/utils/common.py | 4 +- internlm/utils/gputest.py | 13 +- 13 files changed, 505 insertions(+), 39 deletions(-) create mode 100644 internlm/accelerator/dipu_accelerator.py diff --git a/internlm/accelerator/abstract_accelerator.py b/internlm/accelerator/abstract_accelerator.py index 104a5176..943a2eb3 100644 --- a/internlm/accelerator/abstract_accelerator.py +++ b/internlm/accelerator/abstract_accelerator.py @@ -9,7 +9,8 @@ class AcceleratorType(enum.Enum): GPU = 1 NPU = 2 CPU = 3 - OTHER = 4 + DIPU = 4 + OTHER = 5 internlm_accelerator = None @@ -80,7 +81,7 @@ def get_accelerator(): accelerator_name = None # 1. Detect whether there is override of DeepSpeed accelerators from environment variable. - intern_accelerator_LIST = ["cuda", "npu"] + intern_accelerator_LIST = ["cuda", "npu", "dipu"] if "INTERNLM_ACCELERATOR" in os.environ: accelerator_name = os.environ["INTERNLM_ACCELERATOR"] if accelerator_name == "npu": @@ -89,6 +90,15 @@ def get_accelerator(): except (ImportError, ModuleNotFoundError): raise ValueError("NPU_Accelerator requires torch_npu, which is not installed on this system.") pass + elif accelerator_name == "dipu": + try: + import deeplink_ext # noqa # pylint: disable=W0611 + import torch_dipu # noqa # pylint: disable=W0611 + except (ImportError, ModuleNotFoundError): + raise ValueError( + "DIPU_Accelerator requires torch_dipu and deeplink_ext, which is not installed on this system." + ) + pass elif accelerator_name != "cuda": raise ValueError( f"accelerator_name must be one of {intern_accelerator_LIST}." @@ -96,6 +106,14 @@ def get_accelerator(): ) # 2. If no override, detect which accelerator to use automatically + if accelerator_name is None: + try: + import deeplink_ext # noqa: F401,F811 # type: ignore + import torch_dipu # noqa: F401,F811 # type: ignore + + accelerator_name = "dipu" + except (ImportError, ModuleNotFoundError): + pass if accelerator_name is None: try: import torch_npu # noqa: F401,F811 # type: ignore @@ -115,5 +133,9 @@ def get_accelerator(): from .npu_accelerator import ASCEND_Accelerator internlm_accelerator = ASCEND_Accelerator() + elif accelerator_name == "dipu": + from .dipu_accelerator import DIPU_Accelerator + + internlm_accelerator = DIPU_Accelerator() return internlm_accelerator diff --git a/internlm/accelerator/dipu_accelerator.py b/internlm/accelerator/dipu_accelerator.py new file mode 100644 index 00000000..7943b4c7 --- /dev/null +++ b/internlm/accelerator/dipu_accelerator.py @@ -0,0 +1,388 @@ +from .abstract_accelerator import Accelerator, AcceleratorType + +try: + import torch.cuda +except ImportError: + pass + + +class DIPU_Accelerator(Accelerator): + """Accelerator for CUDA device. + + Args: + Accelerator (Accelerator): _description_ + """ + + def __init__(self) -> None: + self._name_str = "cuda" + self._communication_backend_name = "nccl" + self.amp = self.get_amp() + self.memory = torch.cuda.memory + self._find_or_mock_module("flash_attn_2_cuda") + + def _find_or_mock_module(self, module_name) -> bool: + import importlib.util + import sys + import types + + """Find or mock a module. Return True if the module is found, False otherwise.""" + module_spec = importlib.util.find_spec(module_name) + if module_spec is None: + sys.modules[module_name] = types.SimpleNamespace() # type: ignore + return module_spec is not None + + def get_backend_name(self): + """ + Return the name of the accelerator. + """ + return self._name_str + + def get_accelerator_backend(self): + """ + Return the name of the backend. + """ + return AcceleratorType.DIPU + + # Device APIs + def device_name(self, device_index=None): + """ + Return the name of the device. + """ + if device_index is None: + return "cuda" + return "cuda:{}".format(device_index) + + def set_device(self, device_index): + """ + Bind the current process to a device. + """ + torch.cuda.set_device(device_index) + + def get_device_id(self): + """ + Return the current device index. + """ + return torch.cuda.current_device() + + def current_device_name(self): + """ + Return the name of the current device. + """ + return "cuda:{}".format(torch.cuda.current_device()) + + def device_count(self): + """ + Return the number of devices on the machine. + """ + return torch.cuda.device_count() + + def synchronize(self, device_index=None): + """ + Synchronize the current process. + """ + return torch.cuda.synchronize(device_index) + + # RNG APIs + def random(self): + """ + Get random number. + """ + return torch.random + + def set_rng_state(self, new_state, device_index=None): + """ + Sets the random number generator state of the specified GPU. + """ + if device_index is None: + return torch.cuda.set_rng_state(new_state) + + return torch.cuda.set_rng_state(new_state, device_index) + + def get_rng_state(self, device_index=None): + """ + Returns the random number generator state of the specified GPU as a ByteTensor. + """ + if device_index is None: + return torch.cuda.get_rng_state() + + return torch.cuda.get_rng_state(device_index) + + def manual_seed(self, seed): + """ + Sets the seed for generating random numbers for the current GPU. + """ + return torch.cuda.manual_seed(seed) + + def manual_seed_all(self, seed): + """ + Set the random seed for the all processes. + """ + return torch.cuda.manual_seed_all(seed) + + def initial_seed(self): + """ + Returns the current random seed of the current GPU. + """ + return torch.cuda.initial_seed() + + def default_generator(self, device_index): + """ + Returns the default generators according to device index + """ + return torch.cuda.default_generators[device_index] + + # Streams/Events + @property + def Stream(self): + """ + A CUDA stream is a linear sequence of execution that belongs to + a specific device, independent from other streams. + See cuda-semantics for details. + """ + return torch.cuda.Stream + + def stream(self, _stream): + """ + Wrapper around the Context-manager StreamContext that selects a given stream. + """ + return torch.cuda.stream(_stream) + + def current_stream(self, device_index=None): + """ + Returns the currently selected Stream for a given device. + """ + return torch.cuda.current_stream(device_index) + + def default_stream(self, device_index=None): + """ + Returns the default Stream for a given device. + """ + return torch.cuda.default_stream(device_index) + + @property + def Event(self): + """ + CUDA events are synchronization markers that can be used + to monitor the device's progress, to accurately measure timing, + and to synchronize CUDA streams. + """ + return torch.cuda.Event + + # Memory management + def empty_cache(self): + """ + Releases all unoccupied cached memory currently held by the caching allocator + so that those can be used in other GPU application and visible in nvidia-smi. + """ + return torch.cuda.empty_cache() + + def memory_allocated(self, device_index=None): + """ + Returns the current GPU memory occupied by tensors in bytes for a given device. + """ + return torch.cuda.memory_allocated(device_index) + + def max_memory_allocated(self, device_index=None): + """ + Returns the maximum GPU memory occupied by tensors in bytes for a given device. + """ + return torch.cuda.max_memory_allocated(device_index) + + def reset_max_memory_allocated(self, device_index=None): + """ + Resets the starting point in tracking maximum GPU memory occupied by + tensors for a given device. + """ + return torch.cuda.reset_max_memory_allocated(device_index) + + def memory_cached(self, device_index=None): + """ + Returns the cached memory + """ + return torch.cuda.memory_cached(device_index) + + def max_memory_cached(self, device_index=None): + """ + Returns the maximum cached memory + """ + return torch.cuda.max_memory_cached(device_index) + + def reset_max_memory_cached(self, device_index=None): + """ + Resets the starting point in tracking maximum GPU memory managed by + the caching allocator for a given device. + """ + return torch.cuda.reset_max_memory_cached(device_index) + + def memory_stats(self, device_index=None): + """ + Returns the memory stats + """ + if hasattr(torch.cuda, "memory_stats"): + return torch.cuda.memory_stats(device_index) + + def reset_peak_memory_stats(self, device_index=None): + """ + Resets the "peak" stats tracked by the CUDA memory allocator. + """ + if hasattr(torch.cuda, "reset_peak_memory_stats"): + return torch.cuda.reset_peak_memory_stats(device_index) + + def memory_reserved(self, device_index=None): + """ + Returns the current GPU memory managed by the caching allocator + in bytes for a given device. + """ + if hasattr(torch.cuda, "memory_reserved"): + return torch.cuda.memory_reserved(device_index) + + def max_memory_reserved(self, device_index=None): + """ + Returns the maximum GPU memory managed by the caching allocator + in bytes for a given device. + """ + if hasattr(torch.cuda, "max_memory_reserved"): + return torch.cuda.max_memory_reserved(device_index) + + def total_memory(self, device_index=None): + """ + Returns the total memory + """ + return torch.cuda.get_device_properties(device_index).total_memory + + # Data types + def is_bf16_supported(self): + """ + Returns true if bf16 is supported. Otherwise, returns false + """ + return torch.cuda.is_bf16_supported() + + def is_fp16_supported(self): + """ + Returns true if fp16 is supported. Otherwise, returns false + """ + major, _ = torch.cuda.get_device_capability() + return bool(major >= 7) + + # Misc + def get_amp(self): + """ + Returns the 'amp' module from torch.cuda if available, else returns None. + """ + if hasattr(torch.cuda, "amp"): + return torch.cuda.amp + return None + + def is_available(self): + """ + Checks and returns True if CUDA is available, False otherwise. + """ + return torch.cuda.is_available() + + def range_push(self, msg): + """ + If available, pushes a range with the given message for profiling using NVTX. + """ + if hasattr(torch.cuda.nvtx, "range_push"): + return torch.cuda.nvtx.range_push(msg) + + def range_pop(self): + """ + If available, pops the most recent range pushed using NVTX. + """ + if hasattr(torch.cuda.nvtx, "range_pop"): + return torch.cuda.nvtx.range_pop() + + def lazy_call(self, callback): + """ + Executes the given callback with lazy propagation if available. + """ + return torch.cuda._lazy_call(callback) + + def communication_backend_name(self): + """ + Returns the name of the current communication backend. + """ + return self._communication_backend_name + + # Tensor operations + + @property + def BFloat16Tensor(self): + """ + Property to get the BFloat16Tensor class from torch.cuda. + """ + return torch.cuda.BFloat16Tensor + + @property + def ByteTensor(self): + """ + Property to get the ByteTensor class from torch.cuda. + """ + return torch.cuda.ByteTensor + + @property + def DoubleTensor(self): + """ + Property to get the DoubleTensor class from torch.cuda. + """ + return torch.cuda.DoubleTensor + + @property + def FloatTensor(self): + """ + Property to get the FloatTensor class from torch.cuda. + """ + return torch.cuda.FloatTensor + + @property + def HalfTensor(self): + """ + Property to get the HalfTensor class from torch.cuda. + """ + return torch.cuda.HalfTensor + + @property + def IntTensor(self): + """ + Property to get the IntTensor class from torch.cuda. + """ + return torch.cuda.IntTensor + + @property + def LongTensor(self): + """ + Property to get the LongTensor class from torch.cuda. + """ + return torch.cuda.LongTensor + + def pin_memory(self, tensor): + """ + Pins the memory of the given tensor, if it's a CUDA tensor. + """ + return tensor.pin_memory() + + def on_accelerator(self, tensor): + """ + Checks and returns True if the given tensor is on an accelerator (CUDA device), False otherwise. + """ + device_str = str(tensor.device) + return bool(device_str.startswith("cuda:")) + + def set_allow_tf32(self, enable: bool): + """ + Sets the `allow_tf32` flag in cuDNN and CUDA matrix multiplication to the given boolean value. + """ + print(f"Not support tf32 for DIPU, {enable}!") + + def return_custom_bwd(self): + """ + Returns the custom backward hook function from torch.cuda.amp, if available. + """ + return torch.cuda.amp.custom_bwd + + def return_custom_fwd(self): + """ + Returns the custom forward hook function from torch.cuda.amp, if available. + """ + return torch.cuda.amp.custom_fwd diff --git a/internlm/core/communication/isp.py b/internlm/core/communication/isp.py index b821e994..2976899a 100644 --- a/internlm/core/communication/isp.py +++ b/internlm/core/communication/isp.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, List, Union @@ -16,16 +15,25 @@ from internlm.utils.common import SchedulerHook, get_current_device -@dataclass class ISPCommModelConfig: """ model config for isp communicator. """ - dtype: torch.dtype = torch.half - device: torch.device = get_current_device() - activation_checkpointing: float = 0.0 - module_shapes: Dict[str, torch.Size] = None + def __init__( + self, + dtype: torch.dtype = torch.half, + device: torch.device = None, + activation_checkpointing: float = 0.0, + module_shapes: Dict[str, torch.Size] = None, + ) -> None: + self.dtype = dtype + if device is None: + self.device = get_current_device() + else: + self.device = device + self.activation_checkpointing = activation_checkpointing + self.module_shapes = module_shapes class MemoryPool: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index a4ec30d7..bd211947 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -323,13 +323,16 @@ def args_sanity_check(): gpc.config.model._add_item("use_flash_attn", True) gpc.config["use_cuda_flash_attn"] = False - if gpc.config.model.use_flash_attn and internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + if gpc.config.model.use_flash_attn and ( + internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU] + ): gpc.config["use_cuda_flash_attn"] = True # for NPU accelerator supports: 1)FA-True + Packed-False 2) FA-False + Packed-False + # for DIPU accelerator supports: 1)FA-True + Packed-False 2) FA-False + Packed-False # for GPU accelerator supports: 1)FA-True + Packed-True 2) FA-False + Packed-False - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: - assert gpc.config.data.use_packed_dataset is False, "packed data is not supported for NPU accelerator" + if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: + assert gpc.config.data.use_packed_dataset is False, "packed data is not supported for NPU/DIPU accelerator" else: assert ( gpc.config.use_cuda_flash_attn == gpc.config.data.use_packed_dataset diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 96d0fbff..72777c59 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -90,7 +90,7 @@ def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types) - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU]: self.scatter_sum = cuda_scatter else: self.scatter_sum = vanilla_scatter @@ -156,7 +156,7 @@ def update(self, logits, labels, type_ids=None): acc = corrects.sum() torch.distributed.all_reduce(acc, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg) # The synchronization here is to prevent unpredictable HANG when the NPU is running. - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: + if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: internlm_accelerator.current_stream().synchronize() self.right += acc # Masked_fill is not needed here because -100 is not available anyway self.total += mask.sum() @@ -262,7 +262,7 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None: self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device) self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device) - if gpc.config.use_cuda_flash_attn: + if gpc.config.use_cuda_flash_attn and AcceleratorType.GPU == get_accelerator(): from flash_attn.losses.cross_entropy import ( CrossEntropyLoss as FlashCrossEntropyLoss, ) @@ -273,7 +273,7 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None: else: self.loss_fn = nn.CrossEntropyLoss(reduction="none") - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU]: self.scatter_sum = cuda_scatter else: self.scatter_sum = vanilla_scatter diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 5e9ab5ee..4208d5a8 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -419,7 +419,7 @@ def _packed_forward(self, x, inference_params=None, **kwargs): if inference_params is None: kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) # for packed data, batch dimension with a size of 1 should be directly squeezed off. - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU]: q = q.squeeze(0) kv = kv.squeeze(0) @@ -457,7 +457,7 @@ def _packed_forward(self, x, inference_params=None, **kwargs): context = rearrange(context, "b h d -> b (h d)") # recover shape # restore bsz dimension - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU]: context = context.unsqueeze(0) out = self.wo(context) diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index 3e752112..660dd192 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -414,7 +414,7 @@ def _packed_forward(self, x, inference_params=None, **kwargs): if inference_params is None: kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) # for packed data, batch dimension with a size of 1 should be directly squeezed off. - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU]: q = q.squeeze(0) kv = kv.squeeze(0) @@ -452,7 +452,7 @@ def _packed_forward(self, x, inference_params=None, **kwargs): context = rearrange(context, "b h d -> b (h d)") # recover shape # restore bsz dimension - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU]: context = context.unsqueeze(0) out = self.wo(context) diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index d155a97e..0e68e847 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -8,7 +8,7 @@ from einops import rearrange from torch import Tensor, nn -from internlm.accelerator import get_accelerator +from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc @@ -170,7 +170,12 @@ def backward(ctx, do): return dx, None, None, None, None -apply_rotary_emb = ApplyRotaryEmb.apply +if AcceleratorType.DIPU == get_accelerator().get_accelerator_backend(): + from deeplink_ext.internlm_ops.rotary.deeplink import DeeplinkApplyRotaryEmb + + apply_rotary_emb = DeeplinkApplyRotaryEmb.apply +else: + apply_rotary_emb = ApplyRotaryEmb.apply class ApplyRotaryEmbQKV_(torch.autograd.Function): @@ -253,7 +258,12 @@ def backward(ctx, dqkv): return dqkv, None, None, None, None, None -apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply +if AcceleratorType.DIPU == get_accelerator().get_accelerator_backend(): + from deeplink_ext.internlm_ops.rotary.deeplink import DeeplinkApplyRotaryEmbQKV_ + + apply_rotary_emb_qkv_ = DeeplinkApplyRotaryEmbQKV_.apply +else: + apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply class RotaryEmbedding(torch.nn.Module): diff --git a/internlm/model/modules/multi_head_attention.py b/internlm/model/modules/multi_head_attention.py index 076becab..aa8509a6 100644 --- a/internlm/model/modules/multi_head_attention.py +++ b/internlm/model/modules/multi_head_attention.py @@ -44,6 +44,14 @@ def get_gqa_attn_cls(use_flash_attn, tp_mode, causal, softmax_scale, dropout, se inner_attn_cls, inner_cross_attn_cls = AscendFlashSelfAttention, AscendFlashSelfAttention inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + elif device_backend == AcceleratorType.DIPU: + from deeplink_ext.internlm_ops.mha import ( + DeepLinkCrossAttention, + DeepLinkSelfAttention, + ) + + inner_attn_cls, inner_cross_attn_cls = DeepLinkSelfAttention, DeepLinkCrossAttention + inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) else: raise NotImplementedError(f"Unsupport device type: {device_backend} for flash attention") else: @@ -554,6 +562,13 @@ def __init__( ) elif internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: FlashCrossAttention, FlashSelfAttention = AscendFlashSelfAttention, AscendFlashSelfAttention + elif internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU: + from deeplink_ext.internlm_ops.mha import ( + DeepLinkCrossAttention, + DeepLinkSelfAttention, + ) + + FlashCrossAttention, FlashSelfAttention = DeepLinkCrossAttention, DeepLinkSelfAttention inner_attn_cls = FlashSelfAttention inner_cross_attn_cls = FlashCrossAttention @@ -822,7 +837,7 @@ def _packed_forward(self, x, inference_params=None, **kwargs): kwargs.pop("indexes") # for packed data, batch dimension with a size of 1 should be directly squeezed off. - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU]: qkv = qkv.squeeze(0) if inference_params is None: if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: @@ -838,7 +853,7 @@ def _packed_forward(self, x, inference_params=None, **kwargs): context = rearrange(context, "b h d -> b (h d)") # recover the shape # restore bsz dimension - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU]: context = context.unsqueeze(0) out = self.out_proj(context) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 7396a2ae..6e3e0074 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -295,7 +295,7 @@ def backward(ctx, grad_output, *args): sequence_parallel = ctx.sequence_parallel gather_dim = ctx.gather_dim - if gpc.config.use_cuda_flash_attn: + if gpc.config.use_cuda_flash_attn and AcceleratorType.GPU == get_accelerator().get_accelerator_backend(): assert ctx.is_using_cuda, "CUDA Flash Attention only support GPU device" backward_func = fused_dense_cuda.linear_bias_wgrad else: @@ -416,7 +416,7 @@ def backward(ctx, grad_output, *args): process_group = ctx.process_group sequence_parallel = ctx.sequence_parallel - if gpc.config.use_cuda_flash_attn: + if gpc.config.use_cuda_flash_attn and AcceleratorType.GPU == get_accelerator().get_accelerator_backend(): assert ctx.is_using_cuda, "CUDA Flash Attention only support GPU device" backward_func = fused_dense_cuda.linear_bias_wgrad else: @@ -521,7 +521,7 @@ def backward(ctx, grad_output, *args): module = ctx.module communicator = ctx.communicator - if gpc.config.use_cuda_flash_attn: + if gpc.config.use_cuda_flash_attn and AcceleratorType.GPU == get_accelerator().get_accelerator_backend(): assert ctx.is_using_cuda, "CUDA Flash Attention only support GPU device" backward_func = fused_dense_cuda.linear_bias_wgrad else: @@ -597,7 +597,9 @@ def fused_dense_func( dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() ) - is_using_cuda = (internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU) and dtype_eligible + is_using_cuda = ( + internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU] + ) and dtype_eligible return FusedDenseFunc.apply( x, weight, @@ -622,7 +624,9 @@ def megatron_fused_dense_func( dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() ) - is_using_cuda = (internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU) and dtype_eligible + is_using_cuda = ( + internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU] + ) and dtype_eligible return MegatronFusedDenseFunc.apply( x, weight, @@ -646,7 +650,9 @@ def isp_fused_dense_func( dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( x.dtype == torch.float32 and torch.is_autocast_enabled() ) - is_using_cuda = (internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU) and dtype_eligible + is_using_cuda = ( + internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU] + ) and dtype_eligible return ISPFusedDenseFunc.apply( x, weight, @@ -664,9 +670,17 @@ def try_import_RMSNorm(): """ try: - from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm + device_backend = internlm_accelerator.get_accelerator_backend() + if device_backend == AcceleratorType.DIPU: + from deeplink_ext.internlm_ops.rms_norm import ( + DeepLinkRMSNormWithNormalizedShape as RMSNorm, + ) - return RMSNorm + return RMSNorm + else: + from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm + + return RMSNorm except (ModuleNotFoundError, ImportError): logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") from internlm.model.ops.norm import RMSNormTorch as RMSNorm diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index ae3a3b96..5475754c 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -299,11 +299,12 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato adam_extra_kwargs = {} # set fused=True to avoid nan grad norm when model size is larger and use_fp32_norm=True + # TODO(caikun): add DIPU backend adamw if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: internlm_adamw = torch_npu.optim.NpuFusedAdamW else: internlm_adamw = torch.optim.AdamW - if torch.__version__ >= "2.1.0": + if torch.__version__ >= "2.1.0" and internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: adam_extra_kwargs["fused"] = True naive_optimizer = internlm_adamw( diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 2f39a7a2..80150a6c 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -14,7 +14,7 @@ import torch import internlm -from internlm.accelerator import get_accelerator +from internlm.accelerator import AcceleratorType, get_accelerator from internlm.utils.logger import get_logger CURRENT_TIME = None @@ -239,7 +239,7 @@ def get_megatron_flops( def enable_pytorch_expandable_segments(): - if torch.__version__ >= "2.1.0" and "cuda" in internlm_accelerator.current_device_name(): + if torch.__version__ >= "2.1.0" and AcceleratorType.GPU == internlm_accelerator.get_accelerator_backend(): _alloc_setting = "expandable_segments:True" if os.getenv("PYTORCH_CUDA_ALLOC_CONF", None) is not None: _alloc_setting = os.getenv("PYTORCH_CUDA_ALLOC_CONF") + "," + _alloc_setting diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index a1b76c2e..f419e61d 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -9,7 +9,7 @@ import torch.distributed as dist from torch.utils import benchmark -from internlm.accelerator import get_accelerator +from internlm.accelerator import AcceleratorType, get_accelerator from internlm.model.modules.multi_head_attention import SelfAttention from internlm.monitor import send_alert_message from internlm.utils.common import get_current_device @@ -94,7 +94,7 @@ def get_gpu_temperature(): except AssertionError: gpu_id = -1 - if GPUtil is not None and gpu_id >= 0: + if GPUtil is not None and gpu_id >= 0 and internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: gpus = GPUtil.getGPUs() gpu_temperature = gpus[gpu_id].temperature else: @@ -236,9 +236,14 @@ def bench_gpu(use_flash_attn=True): nheads = dim // headdim if use_flash_attn: - from flash_attn.modules.mha import FlashSelfAttention + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + from flash_attn.modules.mha import FlashSelfAttention - inner_attn = FlashSelfAttention + inner_attn = FlashSelfAttention + elif internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU: + from deeplink_ext.internlm_ops.mha import DeepLinkSelfAttention + + inner_attn = DeepLinkSelfAttention else: inner_attn = SelfAttention