Skip to content

Commit

Permalink
feat(*): re-impl embedding/head of isp version (#261)
Browse files Browse the repository at this point in the history
Co-authored-by: huangting4201 <[email protected]>
  • Loading branch information
mwiacx and huangting4201 authored Jul 17, 2024
1 parent 2c6df5c commit 7cd091c
Show file tree
Hide file tree
Showing 22 changed files with 504 additions and 140 deletions.
64 changes: 29 additions & 35 deletions internlm/checkpoint/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def load_model_checkpoint(folder, model):
If tensor parallel mode is isp, the saved weight is named:
- folder
- model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt
- model_wp{wp_rank}_pp{pp_rank}.pt
If fsdp is activated, the saved weight is named:
- folder
Expand All @@ -122,19 +122,19 @@ def load_model_checkpoint(folder, model):
fns = get_fns(folder)

# avoid ckpt misuse between FSDP and no-FSDP
test_fn = list([f for f in fns if f.startswith("model_t") and not f.endswith(".md5")]).pop()
_start_with = "model_w" if is_using_isp() else "model_t"
test_fn = list([f for f in fns if f.startswith(_start_with) and not f.endswith(".md5")]).pop()
assert ("_dp" in test_fn and gpc.config.parallel.zero1.fsdp) or (
"_dp" not in test_fn and not gpc.config.parallel.zero1.fsdp
), "FSDP model wants to load no-FSDP ckpts or reverse"

max_pp, max_wp, max_tp, max_zo = 0, 0, 0, 0
for fn in fns:
if fn.startswith("model_t") and not fn.endswith(".md5"):
if fn.startswith(_start_with) and not fn.endswith(".md5"):
segements = os.path.splitext(fn)[0].split("_")
if is_using_isp():
max_pp = max(max_pp, int(segements[-1][2:]))
max_wp = max(max_wp, int(segements[-2][2:]))
max_tp = max(max_tp, int(segements[-3][2:]))
elif gpc.config.parallel.zero1.fsdp:
max_zo = max(max_zo, int(segements[-1][2:]))
max_pp = max(max_pp, int(segements[-2][2:]))
Expand All @@ -149,16 +149,17 @@ def load_model_checkpoint(folder, model):
assert (
wp_size == max_wp + 1
), f"The weights are save for {max_wp+1} parallelism, while current has {wp_size} weight parallelism"
assert (
tp_size == max_tp + 1
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
if not is_using_isp():
assert (
tp_size == max_tp + 1
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
if gpc.config.parallel.zero1.fsdp:
assert (
dp_size == max_zo + 1
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"

if is_using_isp():
should_load_name = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt"
should_load_name = f"model_wp{wp_rank}_pp{pp_rank}.pt"
elif gpc.config.parallel.zero1.fsdp:
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
else:
Expand Down Expand Up @@ -205,7 +206,7 @@ def save_model_checkpoint(folder, model):
If tensor parallel mode is isp, the saved weight is named:
- folder
- model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt
- model_wp{wp_rank}_pp{pp_rank}.pt
If fsdp is activated, the saved weight is named:
- folder
Expand Down Expand Up @@ -243,11 +244,11 @@ def save_model_checkpoint(folder, model):

# for tensor parallel mode with isp
if is_using_isp():
if wdp_rank == 0 or dp_rank == 0:
fn = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt"
if wdp_rank == 0:
fn = f"model_wp{wp_rank}_pp{pp_rank}.pt"
fp = os.path.join(folder, fn)
llm_save(fp, saved_obj=states)
topo_fn = f"topo_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.json"
topo_fn = f"topo_wp{wp_rank}_pp{pp_rank}.json"
topo_fp = os.path.join(folder, topo_fn)
llm_save(topo_fp, saved_obj=topo)
else:
Expand Down Expand Up @@ -292,13 +293,12 @@ def load_optimizer_checkpoint(folder, optim):
"""

fns = get_fns(folder)
max_tp, max_wp, max_pp, max_zero, max_dp = 0, 0, 0, 0, 0
max_tp, max_wp, max_pp, max_zero = 0, 0, 0, 0
for fn in fns:
if fn.startswith("optimizer_") and not fn.endswith(".md5"):
if is_using_isp():
_, tp, wp, pp, dp = os.path.splitext(fn)[0].split("_")
max_dp = max(max_dp, int(dp[2:]))
max_tp = max(max_tp, int(tp[2:]))
_, wp, pp, zero = os.path.splitext(fn)[0].split("_")
max_zero = max(max_zero, int(zero[2:]))
max_wp = max(max_wp, int(wp[2:]))
max_pp = max(max_pp, int(pp[2:]))
else:
Expand All @@ -311,24 +311,18 @@ def load_optimizer_checkpoint(folder, optim):
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
dp_size = gpc.get_world_size(ParallelMode.DATA)

if is_using_isp():
assert dp_size == max_dp + 1, (
f"The optimizer states are save for {max_dp+1} data parallelism, "
f"while current has {dp_size} data parallelism"
)
if not is_using_isp():
assert zero_size == max_zero + 1, (
f"The optimizer states are save for {max_zero+1} zero parallel, "
f"while current has {zero_size} zero broadcast range."
)
assert zero_size == max_zero + 1, (
f"The optimizer states are save for {max_zero+1} zero parallel, "
f"while current has {zero_size} zero broadcast range."
)
assert (
pp_size == max_pp + 1
), f"The optimizer states are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
assert (
tp_size == max_tp + 1
), f"The optimizer states are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
if not is_using_isp():
assert (
tp_size == max_tp + 1
), f"The optimizer states are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
assert (
wp_size == max_wp + 1
), f"The optimizer states are save for {max_wp+1} parallelism, while current has {wp_size} weight parallelism"
Expand All @@ -337,9 +331,8 @@ def load_optimizer_checkpoint(folder, optim):
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
if is_using_isp():
fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
else:
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"

Expand Down Expand Up @@ -387,16 +380,17 @@ def save_optimizer_checkpoint(optim, state_path):
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
dp_size = gpc.get_world_size(ParallelMode.DATA)

states = optim.state_dict()
if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)):
if is_using_isp():
fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
llm_save(os.path.join(state_path, fp), states)
fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
if (gpc.get_global_rank() % (tp_size * dp_size)) < zero_size * wp_size:
llm_save(os.path.join(state_path, fp), states)
else:
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
if (gpc.get_global_rank() % (tp_size * dp_size)) < zero_size * tp_size:
Expand Down
2 changes: 0 additions & 2 deletions internlm/core/context/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .parallel_context import (
IS_REPLICA_ZERO_PARALLEL,
IS_TENSOR_DATA_PARALLEL,
IS_TENSOR_EXPERT_DATA_PARALLEL,
IS_TENSOR_ZERO_PARALLEL,
IS_WEIGHT_ZERO_PARALLEL,
Expand Down Expand Up @@ -32,7 +31,6 @@
__all__ = [
"Config",
"IS_TENSOR_ZERO_PARALLEL",
"IS_TENSOR_DATA_PARALLEL",
"IS_REPLICA_ZERO_PARALLEL",
"IS_WEIGHT_ZERO_PARALLEL",
"IS_TENSOR_EXPERT_DATA_PARALLEL",
Expand Down
8 changes: 5 additions & 3 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
from .process_group_initializer import ParallelMode
from .random import add_seed, get_seeds, set_mode

# for layernorm
IS_REPLICA_ZERO_PARALLEL = "is_replica_zero_parallel"
# for isp, with optimizer split in dp group
IS_TENSOR_DATA_PARALLEL = "is_tensor_data_parallel"
# for mtp/msp/fsp, with optimizer split in zero1 group
# for mtp/msp/fsp with tensor parallel, and optimizer split in zero1 group
IS_TENSOR_ZERO_PARALLEL = "is_tensor_zero_parallel"
# for isp with weight parallel, and optimizer split in zero1 group
IS_WEIGHT_ZERO_PARALLEL = "is_weight_zero_parallel"
# for moe
IS_TENSOR_EXPERT_DATA_PARALLEL = "is_tensor_expert_data_parallel"

logger = get_logger(__file__)
Expand Down Expand Up @@ -564,6 +565,7 @@ def init_parallel_groups(self):
initializers.append(pgroup_initializer.Initializer_Weight_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_ISP_Data(*initializer_args))
if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp":
initializers.append(pgroup_initializer.Initializer_Zero1_ISP(*initializer_args))
else:
Expand Down
63 changes: 63 additions & 0 deletions internlm/core/context/process_group_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class ParallelMode(Enum):
# sequence parallel
SEQUENCE = "sequence"

# real data parallel for isp
ISP_DATA = "isp_data"

# grouped query attention
GQA = "gqa"

Expand Down Expand Up @@ -854,6 +857,66 @@ def init_dist_group(self, use_cpu: bool = False):
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode


class Initializer_ISP_Data(ProcessGroupInitializer):
"""A ProcessGroupInitializer for real data parallel group in isp.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
weight_parallel_size (int): Size of model weight parallel.
weight_data_parallel_size (int): Size of data parallel for common weight.
sequence_parallel_size (int): Size of data sequence parallel.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
zero1_parallel_size (int): Size of zero1 parallel.
nettest_parallel_size (int): Size of net testing parallel.
expert_parallel_size (int): Size of expert parallel.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.isp_data_parallel_size = self.tensor_parallel_size * self.data_parallel_size
self.num_isp_data_parallel_group = self.world_size // self.isp_data_parallel_size

assert self.world_size % self.isp_data_parallel_size == 0

def init_dist_group(self, use_cpu: bool = False):
"""Initialize real data parallel groups for isp, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A real data parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.ISP_DATA

for i in range(self.num_isp_data_parallel_group):
ranks = [i * self.isp_data_parallel_size + j for j in range(self.isp_data_parallel_size)]
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None

if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks

return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode


class Initializer_GQA(ProcessGroupInitializer):
"""A ProcessGroupInitializer for allreduce kv gradients with common attention head.
Expand Down
Loading

0 comments on commit 7cd091c

Please sign in to comment.