Skip to content

Commit

Permalink
feat(model): fix dict has no attri mode error
Browse files Browse the repository at this point in the history
  • Loading branch information
huangting4201 committed Jan 23, 2024
1 parent 32df5ad commit d388ddc
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 23 deletions.
4 changes: 3 additions & 1 deletion internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,9 @@ def __init__(
super().__init__()

checkpoint_layer_num = int(num_layers * checkpoint)
self.tp_mode = gpc.config.parallel.tensor.mode
self.tp_mode = "mtp"
if isinstance(gpc.config.parallel.tensor, dict):
self.tp_mode = gpc.config.parallel.tensor.get("mode", "mtp")

if is_reward:
head_cls = RewardModelLinear
Expand Down
5 changes: 4 additions & 1 deletion internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ def __init__(
clip_grad_norm = zero_cfg.clip_grad_norm
self._overlap_sync_grad = zero_cfg.overlap_sync_grad
self._overlap_sync_param = zero_cfg.overlap_sync_param
self.use_isp = gpc.config.parallel.tensor.mode == "isp"
self.use_isp = (
isinstance(gpc.config.parallel["tensor"], dict)
and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp"
)

super().__init__(optim=optimizer)

Expand Down
4 changes: 3 additions & 1 deletion internlm/solver/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ def compute_norm(
Total norm of the parameters, need total_norm**(1/norm) before using.
"""

weight_parallel_mode = ParallelMode.WEIGHT if gpc.config.parallel.tensor.mode == "isp" else ParallelMode.TENSOR
weight_parallel_mode = (
ParallelMode.WEIGHT if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.TENSOR
)
enable_cuda_kernels = gradients[0].device.type == "cuda"
# Norm parameters.
norm_type = float(norm_type)
Expand Down
16 changes: 10 additions & 6 deletions internlm/train/training_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]):


def set_parallel_attr_for_param_groups(model: Union[nn.Module, nn.ModuleList]):
tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp")

def _check_module(module):
# layer_norm
if isinstance(module, (RMSNorm, nn.LayerNorm)):
Expand All @@ -103,17 +105,17 @@ def _check_module(module):
# embedding and head
if isinstance(module, (Embedding1D, ParallelGPT2Embeddings, BaseScaleColumnParallelLinear)):
for param in module.parameters():
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.config.parallel.tensor.mode == "isp":
if gpc.is_initialized(ParallelMode.TENSOR) and tp_mode == "isp":
setattr(param, IS_TENSOR_DATA_PARALLEL, True)
elif gpc.is_initialized(ParallelMode.TENSOR) and gpc.config.parallel.tensor.mode != "isp":
elif gpc.is_initialized(ParallelMode.TENSOR) and tp_mode != "isp":
setattr(param, IS_TENSOR_ZERO_PARALLEL, True)

# for linear module
if isinstance(module, (ColumnParallelLinear, RowParallelLinear)):
for param in module.parameters():
if gpc.is_initialized(ParallelMode.TENSOR) and gpc.config.parallel.tensor.mode != "isp":
if gpc.is_initialized(ParallelMode.TENSOR) and tp_mode != "isp":
setattr(param, IS_TENSOR_ZERO_PARALLEL, True)
elif gpc.is_initialized(ParallelMode.WEIGHT) and gpc.config.parallel.tensor.mode == "isp":
elif gpc.is_initialized(ParallelMode.WEIGHT) and tp_mode == "isp":
setattr(param, IS_WEIGHT_ZERO_PARALLEL, True)

if not isinstance(model, nn.ModuleList):
Expand Down Expand Up @@ -187,13 +189,15 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f

# Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random
# state in the same dp group are all the same.
random_mode = ParallelMode.WEIGHT_DATA if gpc.config.parallel.tensor["mode"] == "isp" else ParallelMode.DATA
random_mode = (
ParallelMode.WEIGHT_DATA if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.DATA
)
set_mode(random_mode)

# if fsdp enabled, wrap the model
model = wrap_FSDP_model(model)

if gpc.config.parallel.tensor.mode != "isp":
if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp":
isp_communicator = None
else:
isp_communicator = ISPCommunicator(
Expand Down
2 changes: 1 addition & 1 deletion internlm/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def split_params_into_different_groups_for_optimizer_with_new_partition_strategy
pgroup["optimizer_mode"] = ParallelMode.ZERO1

# param groups may contain empty groups, such as embed_head
if gpc.config.parallel.tensor.mode == "isp":
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
param_groups.extend(new_groups.values())
else:
assert len(new_groups["embed_head"]["params"]) <= 0
Expand Down
16 changes: 8 additions & 8 deletions internlm/utils/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def save_model_checkpoint(folder, model):
# even if pp is not considered, it will definitely not be written on the same machine.

# for tensor parallel mode with isp
if gpc.config.parallel.tensor.mode == "isp":
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
if wdp_rank == 0 or dp_rank == 0:
fn = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt"
fp = os.path.join(folder, fn)
Expand Down Expand Up @@ -564,7 +564,7 @@ def load_model_checkpoint(folder, model):
for fn in fns:
if fn.startswith("model_t") and not fn.endswith(".md5"):
segements = os.path.splitext(fn)[0].split("_")
if gpc.config.parallel.tensor.mode == "isp":
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
max_pp = max(max_pp, int(segements[-1][2:]))
max_wp = max(max_wp, int(segements[-2][2:]))
max_tp = max(max_tp, int(segements[-3][2:]))
Expand All @@ -590,7 +590,7 @@ def load_model_checkpoint(folder, model):
dp_size == max_zo + 1
), f"The weights are save for {max_zo+1} FSDP shards , while current has {dp_size} FSDP shards"

if gpc.config.parallel.tensor.mode == "isp":
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
should_load_name = f"model_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}.pt"
elif gpc.config.parallel.zero1.fsdp:
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
Expand Down Expand Up @@ -702,7 +702,7 @@ def save_optimizer_checkpoint(optim, state_path):

states = optim.state_dict()
if isinstance(optim, HybridZeroOptimizer):
if gpc.config.parallel.tensor.mode == "isp":
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
llm_save(os.path.join(state_path, fp), states)
else:
Expand Down Expand Up @@ -752,7 +752,7 @@ def load_optimizer_checkpoint(folder, optim):
max_tp, max_wp, max_pp, max_zero, max_dp = 0, 0, 0, 0, 0
for fn in fns:
if fn.startswith("optimizer_") and not fn.endswith(".md5"):
if gpc.config.parallel.tensor.mode == "isp":
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
_, tp, wp, pp, dp = os.path.splitext(fn)[0].split("_")
max_dp = max(max_dp, int(dp[2:]))
max_tp = max(max_tp, int(tp[2:]))
Expand All @@ -770,12 +770,12 @@ def load_optimizer_checkpoint(folder, optim):
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
dp_size = gpc.get_world_size(ParallelMode.DATA)

if gpc.config.parallel.tensor.mode == "isp":
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
assert dp_size == max_dp + 1, (
f"The optimizer states are save for {max_dp+1} data parallelism, "
f"while current has {dp_size} data parallelism"
)
if gpc.config.parallel.tensor.mode != "isp":
if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp":
assert zero_size == max_zero + 1, (
f"The optimizer states are save for {max_zero+1} zero parallel, "
f"while current has {zero_size} zero broadcast range."
Expand All @@ -795,7 +795,7 @@ def load_optimizer_checkpoint(folder, optim):
wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
if gpc.config.parallel.tensor.mode == "isp":
if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp":
fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt"
else:
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
Expand Down
14 changes: 9 additions & 5 deletions internlm/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def is_replica_zero_parallel_parameter(p):
def is_tensor_data_parallel_parameter(p):
return (
gpc.is_initialized(ParallelMode.TENSOR)
and gpc.config.parallel.tensor.mode == "isp"
and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp"
and hasattr(p, IS_TENSOR_DATA_PARALLEL)
and getattr(p, IS_TENSOR_DATA_PARALLEL)
)
Expand All @@ -35,7 +35,7 @@ def is_tensor_data_parallel_parameter(p):
def is_tensor_zero_parallel_parameter(p):
return (
gpc.is_initialized(ParallelMode.TENSOR)
and gpc.config.parallel.tensor.mode != "isp"
and gpc.config.parallel["tensor"].get("mode", "mtp") != "isp"
and hasattr(p, IS_TENSOR_ZERO_PARALLEL)
and getattr(p, IS_TENSOR_ZERO_PARALLEL)
)
Expand All @@ -44,7 +44,7 @@ def is_tensor_zero_parallel_parameter(p):
def is_weight_zero_parallel_parameter(p):
return (
gpc.is_initialized(ParallelMode.WEIGHT)
and gpc.config.parallel.tensor.mode == "isp"
and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp"
and hasattr(p, IS_WEIGHT_ZERO_PARALLEL)
and getattr(p, IS_WEIGHT_ZERO_PARALLEL)
)
Expand All @@ -58,7 +58,9 @@ def sync_model_param(model):
"""

sync_moe_param = gpc.is_using_parallel_mode(ParallelMode.EXPERT_DATA)
sync_parallel_mode = ParallelMode.WEIGHT_DATA if gpc.config.parallel.tensor["mode"] == "isp" else ParallelMode.DATA
sync_parallel_mode = (
ParallelMode.WEIGHT_DATA if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.DATA
)
for param in model.parameters():
if sync_moe_param and getattr(param, "is_expert", False):
ranks = gpc.get_ranks_in_group(ParallelMode.EXPERT_DATA)
Expand All @@ -79,7 +81,9 @@ def sync_model_replica_param_group(model):
model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency.
"""

parallel_mode = ParallelMode.WEIGHT if gpc.config.parallel.tensor["mode"] == "isp" else ParallelMode.TENSOR
parallel_mode = (
ParallelMode.WEIGHT if gpc.config.parallel["tensor"].get("mode", "mtp") == "isp" else ParallelMode.TENSOR
)
if gpc.is_using_parallel_mode(parallel_mode):
for param in model.parameters():
if is_replica_zero_parallel_parameter(param):
Expand Down

0 comments on commit d388ddc

Please sign in to comment.