Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(checkpoint): support universal checkpoint #394

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/lint_check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
# lint check can be auto-executed by the workflow
lint-check:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3

Expand Down
6 changes: 6 additions & 0 deletions configs/7B_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
# INFO: Universal ckpt is not compatible with the original ckpt.
# Default is to use async_save and not use broadcast_load
# as broadcast_load may cause loading performance degradation.
# NOTE: If using aysnc_save, there is a risk of losing the latest ckpt
# when there is a sudden training interruption.
universal_ckpt=dict(enable=False, aysnc_save=True, broadcast_load=False),
)

TRAIN_FOLDER = None
Expand Down
56 changes: 42 additions & 14 deletions internlm/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from internlm.accelerator import get_accelerator
from internlm.checkpoint.universal_checkpoint.api import universal_load, universal_save
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.trainer import TrainState
Expand Down Expand Up @@ -60,7 +61,7 @@ class CheckpointLoadContent:
SCHEDULAER = "scheduler"


def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None, universal_ckpt=False):
"""Tries to load a checkpoint from the given folder.

Args:
Expand All @@ -83,7 +84,19 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
"""
load_content_str, load_ckpt_folder, load_content = process_load_info(load_info)

if load_content.need_load(CheckpointLoadContent.MODEL):
if universal_ckpt:
checkpoint_state = {}
if load_content.need_load(CheckpointLoadContent.MODEL):
checkpoint_state["model"] = ckpt_mm.model
if load_content.need_load(CheckpointLoadContent.OPIMIZER):
checkpoint_state["optimizer"] = ckpt_mm.optimizer
universal_load(
load_ckpt_folder, checkpoint_state, broadcast_checkpoint=gpc.config.ckpt.universal_ckpt.broadcast_load
)
if gpc.is_rank_for_log():
logger.warning("Finsh loading universal model checkpoint and optimizer checkpoint.")

if not universal_ckpt and load_content.need_load(CheckpointLoadContent.MODEL):
load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
load_content_str += f"{CheckpointLoadContent.MODEL}, "

Expand All @@ -93,12 +106,12 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
load_context(load_ckpt_folder, train_state)

# load optimizer states.
if load_content.need_load(CheckpointLoadContent.OPIMIZER):
if not universal_ckpt and load_content.need_load(CheckpointLoadContent.OPIMIZER):
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
else:
if gpc.is_rank_for_log():
logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!")

if not load_content.need_load(CheckpointLoadContent.OPIMIZER) and gpc.is_rank_for_log():
logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!")

# load lr scheduler states.
if load_content.need_load(CheckpointLoadContent.SCHEDULAER):
Expand All @@ -109,7 +122,7 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None):
if gpc.is_rank_for_log():
logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!")

if not load_content.need_load(CheckpointLoadContent.OPIMIZER):
if not universal_ckpt and not load_content.need_load(CheckpointLoadContent.OPIMIZER):
if ckpt_mm.lr_scheduler and train_state:
gpc.config.only_load_lr = True
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
Expand Down Expand Up @@ -440,6 +453,7 @@ def try_save_checkpoint(self, train_state, force=False):
train_state=train_state,
model_config=self.model_config,
model_config_file=self.model_config_file,
universal_ckpt=gpc.config.ckpt.universal_ckpt.enable,
)

if (
Expand Down Expand Up @@ -579,9 +593,15 @@ def try_resume_training(self, train_state: TrainState, current_time=""):
load_path = self.load_ckpt_info["path"]
load_content = self.load_ckpt_info["content"]
load_type = self.load_ckpt_info["ckpt_type"]
universal_ckpt = gpc.config.ckpt.universal_ckpt.enable
kwargs = {}

if universal_ckpt:
assert load_type == "internevo", "Only internevo ckpt support universal ckpt."
kwargs = {"universal_ckpt": universal_ckpt}

load_func = CheckpointLoadMethod.get_ckpt_load_type_func(load_type)
load_content_str = load_func(self, self.load_ckpt_info, train_state)
load_content_str = load_func(self, self.load_ckpt_info, train_state, **kwargs)

# If we only load model weight, we need rewrite zero optim's fp32 buffer.
if (
Expand Down Expand Up @@ -609,6 +629,7 @@ def save_checkpoint(
train_state: TrainState,
model_config: Dict = None,
model_config_file: str = None,
universal_ckpt=False,
):
"""
Save checkpoint to the given folder path.
Expand All @@ -621,13 +642,20 @@ def save_checkpoint(
if gpc.is_rank_for_log():
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")

timer("save-model").start()
save_model_checkpoint(folder=folder, model=model)
timer("save-model").stop()
if not universal_ckpt:
timer("save-model").start()
save_model_checkpoint(folder=folder, model=model)
timer("save-model").stop()

timer("save-optimizer").start()
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
timer("save-optimizer").stop()
timer("save-optimizer").start()
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
timer("save-optimizer").stop()
else:
universal_save(
path=folder,
checkpoint_state={"model": model, "optimizer": optimizer},
async_checkpoint=gpc.config.ckpt.universal_ckpt.aysnc_save,
)

if (
hasattr(train_state, "data_state_dict")
Expand Down
6 changes: 6 additions & 0 deletions internlm/checkpoint/universal_checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .api import universal_load, universal_save

__all__ = [
"universal_save",
"universal_load",
]
48 changes: 48 additions & 0 deletions internlm/checkpoint/universal_checkpoint/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# adopted from https://github.com/volcengine/veScale/blob/main/vescale/checkpoint

from .checkpointer import UniversalCheckpointer
from .common import CheckpointState


def universal_save(path: str, checkpoint_state: CheckpointState, async_checkpoint=False):
"""
Save a checkpoint to a given path
Args:
path: Defines the storage path for checkpoint.
checkpoint_state: A dictionary contains key-value pairs for model and optimizer.
- Model: Identified by 'model' key, value should be a model instance.
- Optimizer: Identified by 'optimizer' key, value should be an optimizer instance.
async_checkpoint: A boolean value indicating if saving checkpoint asynchronously,
i.e. after dumping tensors from GPU memory to Host memory,
the training program can continue training immediately.
Then universal_checkpoint will serialize tensors and dumping to
the persistent storage asynchronously.
Example:
>>> checkpoint_state = { "model": nn.Module, "optimizer": HybridZeroOptimizer }
>>> UniversalCheckpointer.save("/ckpt", checkpoint_state)
"""
UniversalCheckpointer.save(path, checkpoint_state, async_checkpoint=async_checkpoint)


def universal_load(path: str, checkpoint_state: CheckpointState, broadcast_checkpoint=False):
"""
Load a checkpoint from a given path
Args:
path: Defines the storage path for checkpoint.
checkpoint_state: A dictionary contains key-value pairs for model and optimizer.
- Model: Identified by 'model' key, value should be a model instance.
- Optimizer: Identified by 'optimizer' key, value should be an optimizer instance.
broadcast_checkpoint: A boolean value decides if load a model replica from one data parallel process group
then broadcast tensors to other data parallel process group using GPUs
to reduce the file system access
For example, when data parellel size = 2,
processes with data parallel rank = 0 load model from file system
then broadcast it to processes with data parallel rank = 1
Example:
>>> checkpoint_state = { "model": nn.Module, "optimizer": HybridZeroOptimizer }
>>> UniversalCheckpointer.load("/ckpt", checkpoint_state)
"""
UniversalCheckpointer.load(path, checkpoint_state, broadcast_checkpoint=broadcast_checkpoint)
Loading
Loading