diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index ad68082d..ebf138d3 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -135,10 +135,18 @@ cur_iter=-1, ) +# cpu_offloading = dict( +# enable=True, +# num_layers=3, +# ) +# selective_checkpoint = True +# selective_checkpoint_offload = False + use_fp32_norm = False model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, embed_grad_scale=1, diff --git a/internlm/core/parallel/comm/__init__.py b/internlm/core/parallel/comm/__init__.py index e69de29b..be170f28 100644 --- a/internlm/core/parallel/comm/__init__.py +++ b/internlm/core/parallel/comm/__init__.py @@ -0,0 +1,3 @@ +from .attn_offload import get_offload_manager, initialize_offload_manager + +__all__ = ["initialize_offload_manager", "get_offload_manager"] diff --git a/internlm/core/parallel/comm/attn_offload.py b/internlm/core/parallel/comm/attn_offload.py new file mode 100644 index 00000000..da23f3ae --- /dev/null +++ b/internlm/core/parallel/comm/attn_offload.py @@ -0,0 +1,127 @@ +import torch + +from internlm.utils.common import get_current_device + +global_attn_offload = None + + +class AttnOffloadManager: + """ + A manager for attention output CPU offloading and GPU prefetch loading. + """ + + def __init__(self, enable_cpu_offload: bool = False) -> None: + # cpu offload overlapping + self.cpu_offload = enable_cpu_offload + # layer id mapping to flash attn output + self.fa_output_mapping = {} + self.fa_stream = torch.cuda.Stream() + self.d2h_final_event = torch.cuda.Event() + self.h2d_final_event = torch.cuda.Event() + # prepare for tensor buffer + self.tensor_id_to_tensor_bufs = {} + + def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id): + """Get tensor buffer for offloaded tensor.""" + layer_id = layer_id % 2 + if layer_id not in self.tensor_id_to_tensor_bufs: + self.tensor_id_to_tensor_bufs[layer_id] = {} + + if tensor_id not in self.tensor_id_to_tensor_bufs[layer_id]: + allocate_new_buf = True + else: + tensor_buf = self.tensor_id_to_tensor_bufs[layer_id][tensor_id] + allocate_new_buf = tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype + + if allocate_new_buf: + # supposed to only execute once + buffer = torch.empty( + tensor.size(), + dtype=tensor.dtype, + layout=tensor.layout, + device=tensor.device, + ) + + self.tensor_id_to_tensor_bufs[layer_id][tensor_id] = buffer + + return self.tensor_id_to_tensor_bufs[layer_id][tensor_id] + + def insert_fa_output_with_layer(self, layer_idx, output): + assert layer_idx not in self.fa_output_mapping + if self.cpu_offload is False: + self.fa_output_mapping[layer_idx] = output + return + + tensors = [] + for tensor_id, tensor in enumerate(output): + if tensor is None: + tensors.append(None) + continue + tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, layer_idx, tensor_id) + tensor_buf.copy_(tensor) + tensors.append(tensor_buf) + self.fa_output_mapping[layer_idx] = tensors + + def get_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + return self.fa_output_mapping.pop(layer_idx) + + def offload_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + + self.fa_stream.wait_stream(torch.cuda.current_stream()) + self.fa_stream.wait_event(self.d2h_final_event) + + with torch.cuda.stream(self.fa_stream): + _gpu_tensors = self.fa_output_mapping.pop(layer_idx) + _cpu_tensors = [] + for _tensor in _gpu_tensors: + if _tensor is None: + _cpu_tensors.append(_tensor) + continue + + _cpu_backup = torch.empty( + _tensor.size(), + dtype=_tensor.dtype, + layout=_tensor.layout, + device="cpu", + pin_memory=True, + ) + _cpu_backup.copy_(_tensor, non_blocking=True) + _cpu_tensors.append(_cpu_backup) + + # _cpu_tensors.append(_tensor.to("cpu", non_blocking=False)) + + self.fa_output_mapping[layer_idx] = _cpu_tensors + + self.fa_stream.record_event(self.d2h_final_event) + + def preload_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + + self.fa_stream.wait_stream(torch.cuda.current_stream()) + self.fa_stream.wait_event(self.h2d_final_event) + + # Important: get device before with stream, in stream get device is error + _device = get_current_device() + with torch.cuda.stream(self.fa_stream): + _cpu_tensors = self.fa_output_mapping.pop(layer_idx) + self.fa_output_mapping[layer_idx] = [ + _tensor.to(device=_device, non_blocking=True) if _tensor is not None else _tensor + for _tensor in _cpu_tensors + ] + + self.fa_stream.record_event(self.h2d_final_event) + + +def initialize_offload_manager(enable_cpu_offload: bool = False): + global global_attn_offload + if global_attn_offload is None: + global_attn_offload = AttnOffloadManager(enable_cpu_offload) + + return global_attn_offload + + +def get_offload_manager(): + assert global_attn_offload is not None + return global_attn_offload diff --git a/internlm/core/parallel/comm/cpu_offload.py b/internlm/core/parallel/comm/cpu_offload.py new file mode 100644 index 00000000..89e5912b --- /dev/null +++ b/internlm/core/parallel/comm/cpu_offload.py @@ -0,0 +1,505 @@ +# Adapted from https://github.com/NVIDIA/TransformerEngine/blob/v1.12/transformer_engine/pytorch/cpu_offload.py +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Functionality for CPU offloading of tensors saved for backward pass.""" +from __future__ import annotations + +from contextlib import nullcontext +from typing import Any, Dict, Optional + +import torch + +__all__ = ["get_cpu_offload_context"] + +CPUOffloadEnabled = False + + +def is_cpu_offload_enabled() -> bool: + """Check if CPU offloading is currently enabled.""" + return CPUOffloadEnabled + + +class CpuOffloadSavedTensorHook: + """Contex-manager that executes a pair of pack/unpack hooks for saved tensors. + In this context, the ``on_save_for_backward`` method will be called every time + a tensor is saved for backward (this includes intermediary results saved using + :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but + also those recorded by a PyTorch-defined operation). + The ``on_get_saved_tensors`` method will be called when the backward function + of this op attempts to retrieve the saved tensor from context (this includes + :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the + as input the return value of the ``on_save_for_backward``, and is meant to return + an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of + size, device and element values. + Example: + >>> import torch + >>> from typing import Any + >>> + >>> class DummyHook(CpuOffloadSavedTensorHook): + ... + ... def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + ... logging.info("On save", tensor) + ... return (tensor,) + ... + ... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + ... logging.info("On get", saved_state) + ... tensor, = saved_state + ... return tensor + ... + >>> a = torch.ones(5, requires_grad=True) + >>> b = torch.ones(5, requires_grad=True) * 2 + >>> with DummyHook(): + ... y = a * b + ... + On save tensor([1., 1., 1., 1., 1.], requires_grad=True) + On save tensor([2., 2., 2., 2., 2.], grad_fn=) + >>> y.sum().backward() + On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),) + On get (tensor([2., 2., 2., 2., 2.], grad_fn=),) + """ + + def __init__(self) -> None: + self.inside_context = False + + def __enter__(self): + global CPUOffloadEnabled + CPUOffloadEnabled = True + + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks(self.on_save_for_backward, self.on_get_saved_tensor) + + def __exit__(self, *args: Any): + global CPUOffloadEnabled + CPUOffloadEnabled = False + + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + """On save for backward.""" + raise NotImplementedError( + "`on_save_for_backward: Callable[[torch.Tensor], Any]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks" + ) + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + """On get saved tensor.""" + raise NotImplementedError( + "`on_get_saved_tensors: Callable[[Any], torch.Tensor]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks" + ) + + +class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): + """Context-manager that offloads/recovers tensors through an offload hander. + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[Dict[str, Any]] = None, + debug: bool = False, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.debug: bool = debug + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs + super().__init__() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError( + "`tensor_push is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_push." + ) + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError( + "`tensor_pop is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_pop." + ) + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + + def __init__(self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + self.debug = debug + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) + + cpu_backup.copy_(src_tensor, non_blocking=pin_memory) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_model_group, + tensor_need_offloading_checker=(lambda t: True), + debug=False, + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + debug=debug, + ) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant + + if torch.distributed.get_rank() == 0: + print( + f"Offloading {self.num_offload_group} layers' activations with " + f"layer_window_map:{self.layer_window_map}", + flush=True, + ) + + # allocate streams and events for synchronization + self.d2h_stream = torch.cuda.Stream() + self.h2d_stream = torch.cuda.Stream() + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + torch_stray_tensor = False + + # torch2.4 + # torch_stray_tensor = isinstance( + # tensor, + # ( + # torch._subclasses.fake_tensor.FakeTensor, + # torch._subclasses.functional_tensor.FunctionalTensor, + # ), + # ) + + if not torch_stray_tensor: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + + self.tensor_tag_to_state[tensor_tag] = tensor + + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(tensor): + self.tensor_tag_to_buf[tensor_tag] = tensor + else: + tensor_tag = (-1, self.torch_tensor_count) + self.torch_tensor_count += 1 + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + with torch.cuda.stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + tensor_on_device = state + + # if offload, return the reference to cpu copy + if self.tensor_need_offloading_checker(tensor_on_device): + state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) + self.tensor_tag_to_state[tensor_tag] = state + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + self.bulk_offload_group(current_group) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + # e.g. layer_window_map={0: 10, 1: 21, 2: 31} + if self.layer_window_map[self.offloaded_group_count] == current_group: + + # Stream synchronization both ways + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, _ in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + super().on_group_commit_forward() + + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + + with torch.cuda.stream(self.h2d_stream): + # move back tensors + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload: + if isinstance(state, tuple): + recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) + self.tensor_tag_to_state[tensor_label] = recovered_tensor + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # Layer window data structure helps us to reload at right times + # e.g. layer_window_map={0: 10, 1: 21, 2: 31} + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: + + # Stream synchronization both ways + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.h2d_stream) + + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) + + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 + + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + torch.cuda.current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 + + +def get_cpu_offload_context( + enabled: bool = False, + num_layers: int = 1, + model_layers: int = 1, + offload_activations: bool = False, + offload_weights: bool = False, +): + """ + This function returns the CPU Offload context and the synchronizer function that needs to be + used after every transformer layer. Returns `nullcontext()` if offloading is not enabled. + Usage: + .. code-block:: python + cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True) + with cpu_offload_context: + te_layer.forward(inp_tensor) + cpu_offload_synchronizer() + Parameters + ---------- + enabled: bool, default = `False` + When set to True, CPU Offloading functionality is enabled. + num_layers: int, default = 1 + Determines the number of transformer layers + you want to offload activations/weights for. + model_layers: int, default = 1 + Number of layers in the model that will be used under this context. + offload_activations: bool, default = `False` + When set to `True`, offloads the tensors with attribute 'activation_offloading' for the layer. + offload_weights: bool, default = `False` + When set to `True`, offloads the weights with attribute 'weight_offloading' for the layer. + """ + + def tensor_need_offloading_checker_base(tensor): # pylint: disable=W0613 + return True + + def tensor_need_offloading_checker_activations(tensor): + return hasattr(tensor, "activation_offloading") + + # This includes the Gradient Accumulation Buffer + def tensor_need_offloading_checker_weights(tensor): + return hasattr(tensor, "weight_offloading") + + def tensor_need_offloading_checker_all(tensor): + return hasattr(tensor, "activation_offloading") or hasattr(tensor, "weight_offloading") + + if offload_activations and offload_weights: + tensor_need_offloading_checker = tensor_need_offloading_checker_all + elif offload_activations: + tensor_need_offloading_checker = tensor_need_offloading_checker_activations + elif offload_weights: + tensor_need_offloading_checker = tensor_need_offloading_checker_weights + else: + tensor_need_offloading_checker = tensor_need_offloading_checker_base + + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_model_group=model_layers, + tensor_need_offloading_checker=tensor_need_offloading_checker, + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor, cpu_offload_handler) + + if enabled: + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + return nullcontext(), group_prefetch_offload_commit_async diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 7e722c2f..24677c09 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -37,6 +37,8 @@ params_dispatch_with_condition, ) +from .attn_offload import get_offload_manager + # not really useful, only for code hint. class WPCommunicator(ABC): @@ -306,6 +308,7 @@ def __init__( overlap: bool = False, process_group: dist.ProcessGroup = None, is_moe: bool = False, + selective_ckpt_offload: bool = False, ) -> None: self.process_group = process_group self.overlap = overlap @@ -316,6 +319,14 @@ def __init__( self._forward_prefetch_prerequisites = [] self._forward_overlap_per = self._get_forward_overlap_granularity() self._launch_before_module = self._get_launch_before_module() + # As an optimization, do not release weight after forward for the last + # transformer block since wp would prefetch it immediately + self.layers_wp_not_release = [] # [gpc.config.isp_num_layers - 1] + self.layers_fa_not_release = [ + gpc.config.isp_num_layers - 1, + int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) - 1, + ] + self.sc_offload = selective_ckpt_offload # real overlap state for each chunk. self._overlap_states: Dict[int, ISPOverlapState] = {} @@ -411,6 +422,7 @@ def is_allgather_launch_module(name, module): self._overlap_states[cid].index_to_isp_modules[idx].append(child) setattr(child, "isp_name", name) + setattr(child, "isp_layer_idx", idx) full_name = f"{cid}.{idx}.{name}" setattr( @@ -506,6 +518,25 @@ def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args) if block_index + 1 < self._num_blocks: self._all_gather_block_weight(block_index + 1) + # register offload and prefetch hook for selective ckpt with wo linear + if self.sc_offload is True: + # move current layer's attn output from GPU to CPU asynchronizely + if ( + self.is_forward is True + and gpc.config.selective_checkpoint + and block_index not in self.layers_fa_not_release + and block_index < self._ckpt_block_num + ): + get_offload_manager().offload_fa_output_with_layer(layer_idx=block_index) + + # load previous layer's attn output from CPU to GPU asynchronizely + if ( + self.is_forward is False + and gpc.config.selective_checkpoint + and (0 <= (block_index - 1) < self._ckpt_block_num) + ): + get_offload_manager().preload_fa_output_with_layer(layer_idx=block_index - 1) + def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 if module not in self._weight_global_handle: self._all_gather_module_weight(module) @@ -539,6 +570,9 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis self._all_gather_module_weight(next_module) def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 + if int(module.isp_layer_idx) in self.layers_wp_not_release: + # print(f"the layer {module.isp_layer_idx} after forward not clear weight") + return if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False): self._clear_handle(module) self._clear_weight(module) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index d0ef284d..cb45be72 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -11,6 +11,7 @@ from internlm.checkpoint.checkpoint_manager import CheckpointManager from internlm.core.context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode +from internlm.core.parallel.comm import initialize_offload_manager from internlm.core.trainer import Trainer from internlm.data.streaming.utils import streaming_simple_resume from internlm.data.train_state import get_train_state @@ -118,6 +119,9 @@ def __init__( # initialize isp communicator isp_communicator = initialize_parallel_communicator(model) + # initialize cpu offload manager for selective checkpoint + initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) + # initialize train state train_state = get_train_state(train_dl) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 35b3d646..5819c435 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -66,6 +66,8 @@ def get_default_parser(): def args_sanity_check(): assert gpc.config is not None, "config is not load!" + gpc.is_forward = True + if "JOB_NAME" not in gpc.config: gpc.config._add_item("JOB_NAME", "AnonymousJob") @@ -73,6 +75,18 @@ def args_sanity_check(): if "model_type" not in gpc.config: gpc.config._add_item("model_type", ModelType.INTERNLM.name) + if gpc.config.model_type == "InternLM3_M": + # TODO: need check for isp overlap + num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers + else: + num_layers = gpc.config.model.num_layers + gpc.config.isp_num_layers = num_layers + + if "cpu_offloading" not in gpc.config: + gpc.config._add_item("cpu_offloading", dict(enable=False, num_layers=0)) + if gpc.config.cpu_offloading.enable is False: + assert gpc.config.cpu_offloading.num_layers == 0, "num_layers should be 0 when cpu_offloading is disabled." + if "use_apex_adam" not in gpc.config: gpc.config._add_item("use_apex_adam", False) @@ -399,17 +413,18 @@ def args_sanity_check(): gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name) if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name - assert ( - gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0 - ), "VOCAB_SIZE must be integer multiple of tensor parallel size" if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name: assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp" assert ( torch.__version__ >= "2.1.0" ), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}" - assert ( - gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0 - ), "VOCAB_SIZE must be integer multiple of wp size" + + assert ( + gpc.config.model.vocab_size % gpc.config.parallel.weight.size == 0 + ), "model.vocab_size must be integer multiple of weight parallel size" + assert ( + gpc.config.model.vocab_size % gpc.config.parallel.tensor.size == 0 + ), "model.vocab_size must be integer multiple of tensor parallel size" assert gpc.config.parallel["tensor"].get("mode", None) in [ TensorParallelMode.mtp.name, @@ -532,7 +547,20 @@ def args_sanity_check(): gpc.config.loss._add_item("moe_loss_coeff", 1.0) if "selective_checkpoint" not in gpc.config: - gpc.config._add_item("selective_checkpoint", False) + gpc.config.selective_checkpoint = False + if "selective_checkpoint_offload" not in gpc.config: + gpc.config.selective_checkpoint_offload = False + if gpc.config.selective_checkpoint is True: + assert ( + gpc.config.parallel["tensor"]["mode"] == "isp" + ), "When using selective_checkpoint, tensor parallel mode must be isp" + if gpc.config.selective_checkpoint_offload is True: + assert ( + gpc.config.selective_checkpoint is True + ), "When using selective_checkpoint_offload, selective_checkpoint must be True" + assert ( + gpc.config.parallel.weight.launch_allgather_before == "wo" + ), "When using selective_checkpoint_offload, wp launch allgather communication should be set before 'wo' module" # moe not support overlap and zero1.5 for now if gpc.config.model.get("num_experts", 1) > 1: diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 69da0837..fc573da3 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -1,6 +1,7 @@ # Copyright (c) InternLM. All rights reserved. import math import os +from contextlib import nullcontext from functools import reduce from typing import Optional @@ -12,6 +13,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.comm.cpu_offload import get_cpu_offload_context from internlm.core.parallel.shard import partition_uniform from internlm.initialize.initialize_tensor import ( normal_, @@ -384,6 +386,16 @@ def __init__( checkpoint_layer_num = int(num_layers * checkpoint) self.embed_grad_scale = embed_grad_scale self.parallel_output = parallel_output + self.enable_cpu_offloading = gpc.config.cpu_offloading.enable + + if self.enable_cpu_offloading: + (self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context( + gpc.config.cpu_offloading.enable, + gpc.config.cpu_offloading.num_layers, + gpc.config.model.num_layers, + ) + else: + self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None if first: self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) @@ -406,7 +418,7 @@ def __init__( max_position_embeddings=max_position_embeddings, dtype=dtype, layer_norm_epsilon=layer_norm_epsilon, - checkpoint=lid < checkpoint_layer_num, + checkpoint=gpc.config.cpu_offloading.num_layers <= lid < checkpoint_layer_num, layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation use_dynamic_ntk_rope=use_dynamic_ntk_rope, residual_in_fp32=residual_in_fp32, @@ -463,7 +475,15 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): ) for _, block in enumerate(self.layers): - hidden_states = block(hidden_states, residual=None, **kwargs) + with self.offload_context: + hidden_states = block(hidden_states, residual=None, **kwargs) + + if ( + torch.is_grad_enabled() + and self.enable_cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index f40d35f3..71ed5ef7 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- - import math +from contextlib import nullcontext from typing import Optional import torch @@ -9,6 +9,7 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.comm.cpu_offload import get_cpu_offload_context from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.base_model import BaseModel from internlm.model.modules.embedding import Embedding1D @@ -319,6 +320,16 @@ def __init__( super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) + self.enable_cpu_offloading = gpc.config.cpu_offloading.enable + + if self.enable_cpu_offloading: + (self.offload_context, self.group_prefetch_offload_commit_async) = get_cpu_offload_context( + gpc.config.cpu_offloading.enable, + gpc.config.cpu_offloading.num_layers, + gpc.config.model.num_layers, + ) + else: + self.offload_context, self.group_prefetch_offload_commit_async = nullcontext(), None if first: self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) @@ -337,7 +348,7 @@ def __init__( max_position_embeddings=max_position_embeddings, dtype=dtype, layer_norm_epsilon=layer_norm_epsilon, - checkpoint=lid < checkpoint_layer_num, + checkpoint=gpc.config.cpu_offloading.num_layers <= lid < checkpoint_layer_num, layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation use_dynamic_ntk_rope=use_dynamic_ntk_rope, residual_in_fp32=residual_in_fp32, @@ -386,8 +397,16 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): moe_losses = [] for _, block in enumerate(self.blocks): - hidden_states, mos_loss = block(hidden_states, **kwargs) - moe_losses.append(mos_loss) + with self.offload_context: + hidden_states, mos_loss = block(hidden_states, **kwargs) + moe_losses.append(mos_loss) + + if ( + torch.is_grad_enabled() + and self.enable_cpu_offloading + and self.group_prefetch_offload_commit_async is not None + ): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) diff --git a/internlm/model/ops/_flash_attn.py b/internlm/model/ops/_flash_attn.py new file mode 100644 index 00000000..1d1416d9 --- /dev/null +++ b/internlm/model/ops/_flash_attn.py @@ -0,0 +1,333 @@ +# Copyright (c) InternLM. All rights reserved. +import torch + +from internlm.accelerator import get_accelerator +from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm import get_offload_manager + +try: + import flash_attn + from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_backward, + _flash_attn_varlen_forward, + ) + + gpu_flash_attn_impl = True +except (ModuleNotFoundError, ImportError): + gpu_flash_attn_impl = False + +internlm_accelerator = get_accelerator() +device_backend = internlm_accelerator.get_accelerator_backend() + + +class FlashAttnVarlenKVPackedFunc_V263(torch.autograd.Function): + """ + Varlen KVPacked Func from Flash Attn v2.6.3. + """ + + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + layer_idx, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + k, v = kv[:, 0], kv[:, 1] + + _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + _is_ckpt_layer = gpc.config.cpu_offloading.num_layers <= layer_idx < _ckpt_block_num + + if gpc.is_forward is False and gpc.config.selective_checkpoint and _is_ckpt_layer: + out, out_padded, softmax_lse, S_dmask, rng_state = get_offload_manager().get_fa_output_with_layer(layer_idx) + else: + ( + out, + q, + k, + v, + out_padded, + softmax_lse, + S_dmask, + rng_state, + ) = _flash_attn_varlen_forward( # pylint: disable=E1123 + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + block_table=None, + ) + + # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled + if gpc.is_forward and gpc.config.selective_checkpoint and _is_ckpt_layer: + get_offload_manager().insert_fa_output_with_layer( + layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) + ) + + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): # pylint: disable=W0613 + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( # pylint: disable=E1121,E1124 + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenKVPackedFunc_V221(torch.autograd.Function): + """ + Varlen KVPacked Func from Flash Attn v2.2.1. + """ + + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, + layer_idx, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + k, v = kv[:, 0], kv[:, 1] + + _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + _is_ckpt_layer = gpc.config.cpu_offloading.num_layers <= layer_idx < _ckpt_block_num + + if gpc.is_forward is False and gpc.config.selective_checkpoint and _is_ckpt_layer: + out, out_padded, softmax_lse, S_dmask, rng_state = get_offload_manager().get_fa_output_with_layer(layer_idx) + else: + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + return_softmax=return_softmax and dropout_p > 0, + ) + + # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled + if gpc.is_forward and gpc.config.selective_checkpoint and _is_ckpt_layer: + get_offload_manager().insert_fa_output_with_layer( + layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): # pylint: disable=W0613 + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + layer_idx=0, +): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + + assert gpu_flash_attn_impl is True and flash_attn.__version__ in [ + "2.2.1", + "2.6.3", + ], "flash-attn should be installed and version must be v2.2.1 or v2.6.3" + + if flash_attn.__version__ == "2.2.1": + return FlashAttnVarlenKVPackedFunc_V221.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_attn_probs, + layer_idx, + ) + + return FlashAttnVarlenKVPackedFunc_V263.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + layer_idx, + ) diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index 604ea77a..3aec51f5 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -93,13 +93,14 @@ from flash_attn.flash_attn_interface import ( flash_attn_varlen_func as _flash_varlen_qkvsplited_func, ) - from flash_attn.flash_attn_interface import ( - flash_attn_varlen_kvpacked_func as _flash_varlen_kvpacked_func, - ) from flash_attn.flash_attn_interface import ( flash_attn_varlen_qkvpacked_func as _flash_varlen_qkvpacked_func, ) + from ._flash_attn import ( + flash_attn_varlen_kvpacked_func as _flash_varlen_kvpacked_func, + ) + gpu_flash_attn_impl = True except (ModuleNotFoundError, ImportError): gpu_flash_attn_impl = False @@ -187,6 +188,7 @@ def _flash_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, ): # compatible data format: [1, packelen, 3, n_head, headim] q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) @@ -204,6 +206,7 @@ def _flash_varlen_kvpacked_attn( dropout_p, softmax_scale, causal, + layer_idx=layer_idx, ) return output.unsqueeze(dim=0) @@ -521,6 +524,7 @@ def _npu_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, # pylint: disable=W0613 ): # TODO: support npu native varlen flash attention k, v = kv.unbind(dim=2) @@ -579,6 +583,7 @@ def _deeplink_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, # pylint: disable=W0613 ): # compatible data format: [1, packelen, 3, n_head, headim] q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) @@ -1012,7 +1017,17 @@ def _q_kv_with_cu_seqlens( extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( - q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout, + softmax_scale, + causal, + *extra_args, + layer_idx=self.layer_idx, ) @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.With))) diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 79e9caf4..ca11e689 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -363,6 +363,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.weight.overlap, gpc.get_group(ParallelMode.WEIGHT), is_moe=False, + selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False), ) # register communicator for isp column parallel linear. ColumnParallelLinear.register_cls_communicator(isp_communicator) diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index 2fd8ad4c..ff32160d 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -471,16 +471,16 @@ def test_training_with_isp(): global CONFIG_FILE_PATH, BASELINE_LOSS_LIST CONFIG_FILE_PATH = "./configs/7B_isp_sft.py" BASELINE_LOSS_LIST = [ - 12.225811004638672, - 12.103824615478516, - 12.223844528198242, - 11.87704849243164, - 11.651590347290039, - 11.629219055175781, - 10.242591857910156, - 9.768388748168945, - 9.330610275268555, - 5.505439758300781, + 12.159960746765137, + 12.22106647491455, + 12.106496810913086, + 11.951896667480469, + 11.644429206848145, + 11.459924697875977, + 10.127229690551758, + 9.795705795288086, + 9.255647659301758, + 5.301709175109863, ] # model training