Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(moe): support moe zero1 setting #350

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/1.8B_MoE16_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
weight=dict(size=1, overlap=True),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True),
expert_zero1=dict(size=-1),
)

cudnn_deterministic = False
Expand Down
1 change: 1 addition & 0 deletions configs/7B_MoE4_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@
weight=dict(size=1, overlap=True),
expert=dict(size=-1, no_tp=False),
expert_weight=dict(size=1, overlap=True),
expert_zero1=dict(size=-1),
)

cudnn_deterministic = False
Expand Down
106 changes: 85 additions & 21 deletions internlm/checkpoint/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_using_isp
from internlm.utils.parallel import is_using_isp, is_using_moe
from internlm.utils.storage_manager import get_fns, llm_load, llm_save

from .utils import (
Expand Down Expand Up @@ -310,47 +310,93 @@ def load_optimizer_checkpoint(folder, optim):

fns = get_fns(folder)
max_tp, max_wp, max_pp, max_zero = 0, 0, 0, 0
is_moe_optim = False
max_ep, max_ewp, max_moe_zero = 0, 0, 0
for fn in fns:
if fn.startswith("optimizer_") and not fn.endswith(".md5"):
if is_using_isp():
_, wp, pp, zero = os.path.splitext(fn)[0].split("_")
max_zero = max(max_zero, int(zero[2:]))
max_wp = max(max_wp, int(wp[2:]))
if fn.startswith("optimizer_ep"):
is_moe_optim = True
_, ep, ewp, pp, moe_zero = os.path.splitext(fn)[0].split("_")
else:
_, wp, pp, zero = os.path.splitext(fn)[0].split("_")
if is_moe_optim:
max_ep = max(max_ep, int(ep[2:]))
max_ewp = max(max_ewp, int(ewp[3:]))
max_moe_zero = max(max_moe_zero, int(moe_zero[2:]))
else:
max_zero = max(max_zero, int(zero[2:]))
max_wp = max(max_wp, int(wp[2:]))
max_pp = max(max_pp, int(pp[2:]))
else:
_, tp, pp, zero = os.path.splitext(fn)[0].split("_")
max_zero = max(max_zero, int(zero[2:]))
if fn.startswith("optimizer_ep"):
is_moe_optim = True
_, ep, tp, pp, moe_zero = os.path.splitext(fn)[0].split("_")
else:
_, tp, pp, zero = os.path.splitext(fn)[0].split("_")
max_tp = max(max_tp, int(tp[2:]))
max_pp = max(max_pp, int(pp[2:]))
if is_moe_optim:
max_ep = max(max_ep, int(ep[2:]))
max_moe_zero = max(max_moe_zero, int(moe_zero[2:]))
else:
max_zero = max(max_zero, int(zero[2:]))

zero_size = gpc.get_world_size(ParallelMode.ZERO1)
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
moe_zero_size = gpc.get_world_size(ParallelMode.EXPERT_ZERO1)
if is_using_isp():
ewp_size = gpc.get_world_size(ParallelMode.EXPERT_WEIGHT)
ewp_rank = gpc.get_local_rank(ParallelMode.EXPERT_WEIGHT)

assert zero_size == max_zero + 1, (
f"The optimizer states are save for {max_zero+1} zero parallel, "
f"while current has {zero_size} zero broadcast range."
)
if is_moe_optim:
assert moe_zero_size == max_moe_zero + 1, (
f"The optimizer states are save for {max_moe_zero+1} expert zero parallelism, "
f"while current has {moe_zero_size} expert zero broadcast range."
)
assert (
ep_size == max_ep + 1
), f"The optimizer states are save for {max_ep+1} parallelism, while current has {ep_size} weight parallelism"
if is_using_isp():
assert ewp_size == max_ewp + 1, (
f"The optimizer states are save for {max_ewp+1} expert weight parallelism, "
f"while current has {ewp_size} expert weight parallelism"
)
else:
assert zero_size == max_zero + 1, (
f"The optimizer states are save for {max_zero+1} zero parallel, "
f"while current has {zero_size} zero broadcast range."
)
assert (
wp_size == max_wp + 1
), f"The optimizer states are save for {max_wp+1} parallelism, while current has {wp_size} weight parallelism"
assert (
pp_size == max_pp + 1
), f"The optimizer states are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
if not is_using_isp():
assert (
tp_size == max_tp + 1
), f"The optimizer states are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
assert (
wp_size == max_wp + 1
), f"The optimizer states are save for {max_wp+1} parallelism, while current has {wp_size} weight parallelism"

zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
ep_rank = gpc.get_local_rank(ParallelMode.EXPERT)
moe_zero_rank = gpc.get_local_rank(ParallelMode.EXPERT_ZERO1)
if is_using_isp():
fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
if is_using_moe() and moe_zero_size * ep_size * ewp_size > zero_size * wp_size:
fp = f"optimizer_ep{ep_rank}_ewp{ewp_rank}_pp{pp_rank}_zo{moe_zero_rank}.pt"
else:
fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
else:
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
if is_using_moe() and moe_zero_size * ep_size > zero_size:
fp = f"optimizer_ep{ep_rank}_tp{tp_rank}_pp{pp_rank}_zo{moe_zero_rank}.pt"
else:
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"

states = llm_load(os.path.join(folder, fp), map_location=get_current_device())

Expand Down Expand Up @@ -400,17 +446,35 @@ def save_optimizer_checkpoint(optim, state_path):
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
dp_size = gpc.get_world_size(ParallelMode.DATA)
ep_size = gpc.get_world_size(ParallelMode.EXPERT)
ep_rank = gpc.get_local_rank(ParallelMode.EXPERT)
moe_data_size = gpc.get_world_size(ParallelMode.EXPERT_DATA)
moe_zero_size = gpc.get_world_size(ParallelMode.EXPERT_ZERO1)
moe_zero_rank = gpc.get_local_rank(ParallelMode.EXPERT_ZERO1)
if is_using_isp():
ewp_rank = gpc.get_local_rank(ParallelMode.EXPERT_WEIGHT)
ewp_size = gpc.get_world_size(ParallelMode.EXPERT_WEIGHT)

states = optim.state_dict()
if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)):
if is_using_isp():
fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
if (gpc.get_global_rank() % (tp_size * dp_size)) < zero_size * wp_size:
llm_save(os.path.join(state_path, fp), states)
if is_using_moe() and moe_zero_size * ep_size * ewp_size > zero_size * wp_size:
fp = f"optimizer_ep{ep_rank}_ewp{ewp_rank}_pp{pp_rank}_zo{moe_zero_rank}.pt"
if (gpc.get_global_rank() % (ewp_size * ep_size * moe_data_size)) < moe_zero_size * ewp_size * ep_size:
llm_save(os.path.join(state_path, fp), states)
else:
fp = f"optimizer_wp{wp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
if (gpc.get_global_rank() % (tp_size * dp_size)) < zero_size * wp_size:
llm_save(os.path.join(state_path, fp), states)
else:
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
if (gpc.get_global_rank() % (tp_size * dp_size)) < zero_size * tp_size:
llm_save(os.path.join(state_path, fp), states)
if is_using_moe() and moe_zero_size * ep_size > zero_size:
fp = f"optimizer_ep{ep_rank}_tp{tp_rank}_pp{pp_rank}_zo{moe_zero_rank}.pt"
if (gpc.get_global_rank() % (tp_size * ep_size * moe_data_size)) < moe_zero_size * tp_size * ep_size:
llm_save(os.path.join(state_path, fp), states)
else:
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
if (gpc.get_global_rank() % (tp_size * dp_size)) < zero_size * tp_size:
llm_save(os.path.join(state_path, fp), states)
if "zero_devide_optim_plan" in states:
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
fp_meta = os.path.join(state_path, optim.rank_unique_id)
Expand Down
28 changes: 25 additions & 3 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,10 @@ def __init__(self):
self.zero1_parallel_size = -1
self.nettest_parallel_size = 1
self.expert_parallel_size = -1
self.expert_tensor_parallel_size = 1
self.expert_weight_parallel_size = -1
self.expert_data_parallel_size = -1
self.expert_zero1_parallel_size = -1
self.num_processes_on_current_node = -1
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
Expand Down Expand Up @@ -509,6 +513,8 @@ def init_parallel_groups(self):
parallel_config._add_item("weight", dict(size=1, overlap=False))
if "expert" not in parallel_config:
parallel_config._add_item("expert", dict(size=-1, no_tp=False))
if "expert_zero1" not in parallel_config:
parallel_config._add_item("expert_zero1", dict(size=-1))
if "expert_weight" not in parallel_config:
parallel_config._add_item("expert_weight", dict(size=1, overlap=False))
# set default value for sequence_2D
Expand All @@ -530,6 +536,7 @@ def init_parallel_groups(self):
self._set_parallel_size_from_config(parallel_config, "pipeline", "pipeline_parallel_size")
self._set_parallel_size_from_config(parallel_config, "zero1", "zero1_parallel_size")
self._set_parallel_size_from_config(parallel_config, "expert", "expert_parallel_size")
self._set_parallel_size_from_config(parallel_config, "expert_zero1", "expert_zero1_parallel_size")
self._set_parallel_size_from_config(parallel_config, "expert_weight", "expert_weight_parallel_size")

# the user should not set the data parallel size manually
Expand Down Expand Up @@ -576,6 +583,20 @@ def init_parallel_groups(self):
// self.expert_tensor_parallel_size
// self.expert_parallel_size,
)

if self.expert_zero1_parallel_size == -1:
self.expert_zero1_parallel_size = self.expert_data_parallel_size
self.expert_zero1_parallel_size = max(1, self.expert_zero1_parallel_size)
assert self.expert_zero1_parallel_size <= self.expert_data_parallel_size, (
f"expert_zero1_parallel_size:{self.expert_zero1_parallel_size} should be less than "
f"expert_data_parallel_size:{self.expert_data_parallel_size}"
)
assert self.expert_data_parallel_size % self.expert_zero1_parallel_size == 0, (
f"expert_data_parallel_size:{self.expert_data_parallel_size} % expert_zero1_parallel_size: "
f"{self.expert_zero1_parallel_size} != 0"
)
assert self.expert_zero1_parallel_size >= 1

if (
isinstance(parallel_config["tensor"], dict)
and parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name
Expand Down Expand Up @@ -636,6 +657,7 @@ def init_parallel_groups(self):
self.expert_tensor_parallel_size,
self.expert_weight_parallel_size,
self.expert_data_parallel_size,
self.expert_zero1_parallel_size,
parallel_config.sequence_2D,
]

Expand All @@ -661,10 +683,10 @@ def init_parallel_groups(self):
if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
if self.config.model.get("num_experts", 1) > 1:
if isinstance(parallel_config["tensor"], dict) and parallel_config["tensor"]["mode"] == "isp":
initializers.append(pgroup_initializer.Initializer_Expert_Weight_Data(*initializer_args))
if parallel_config["tensor"]["mode"] == TensorParallelMode.isp.name:
initializers.append(pgroup_initializer.Initializer_Expert_Weight_Data_Zero(*initializer_args))
else:
initializers.append(pgroup_initializer.Initializer_Expert_Data(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Expert_Data_Zero(*initializer_args))
if parallel_config.sequence_2D.get("enable", False) is True:
initializers.append(pgroup_initializer.Initializer_2D_SEQUENCE_PARALLEL(*initializer_args))

Expand Down
Loading
Loading