Skip to content

Commit

Permalink
Merge pull request #29 from huggingface/phuc/feature_topology_agnosti…
Browse files Browse the repository at this point in the history
…c_optim_loading

Some sanity fix for "PR [Feature] Topology-agnostic optimizer states loading"
  • Loading branch information
xrsrke authored Jan 19, 2024
2 parents 44bebf0 + 348fc91 commit a6318ae
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 18 deletions.
39 changes: 22 additions & 17 deletions src/nanotron/serialize/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,26 @@ def load_optimizer(
with open(ckp_optimizer_config_path, "r") as file:
ckp_optimizer_config = json.load(file)

if ckp_optimizer_config["parallelism"]["tp_size"] != parallel_context.tp_pg.size():
ckp_pp_size = ckp_optimizer_config["parallelism"]["pp_size"]
ckp_tp_size = ckp_optimizer_config["parallelism"]["tp_size"]
ckp_dp_size = ckp_optimizer_config["parallelism"]["dp_size"]

if int(ckp_tp_size) != int(parallel_context.tp_pg.size()):
assert (
param_shard_metadata is not None
), "You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size"
), f"You have to pass how the original parameters are sharded in order to resume in a different tensor parallel size, ckp_tp_size: {ckp_tp_size}, current tp_size: {parallel_context.tp_pg.size()}"
assert (
model is not None
), "You have to pass the model in order to adjust the optimizer states according to how the current parameters are sharded"

def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -> TensorMetadata:
return param_shard_metadata[param_name.replace("module.", "")][(str(pp_rank), str(tp_rank))]

ckp_pp_size = ckp_optimizer_config["parallelism"]["pp_size"]
ckp_tp_size = ckp_optimizer_config["parallelism"]["tp_size"]
ckp_optim_type = ckp_optimizer_config["type"]

if ckp_optim_type == ZeroDistributedOptimizer.__name__:
# NOTE: if the checkpoint is from a Zero-1 optimizer, then we need to merge the shards
# across data parallel dimension, before merging the shards across tensor parallel dimension
ckp_dp_size = ckp_optimizer_config["parallelism"]["dp_size"]
shard_paths = list(
root_folder.glob(
f"{ObjectType.OPTIMIZER.value}_pp-*-of-{ckp_pp_size}_dp-*-of-{ckp_dp_size}_tp-*-of-{ckp_tp_size}.pt"
Expand All @@ -177,10 +178,10 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -

model_state_dict = model.state_dict()
new_optim_state_dict = optimizer.state_dict()
OPTIMIZER_STATE_NAMES = ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"]
# NOTE: because we can only resume training with the same optimizer type
# (0, 0) = (pp_rank, tp_rank)
# NOTE: also we don't merge "step" because it's just a scalar
OPTIMIZER_STATE_NAMES = ckp_sharded_optim_states[(0, 0)]["state"][0].keys() - ["step"]

for param_name, _ in tqdm(
sorted(model_state_dict.items(), key=lambda x: x[0]),
Expand Down Expand Up @@ -258,17 +259,21 @@ def get_checkpoint_state_metadata(param_name: str, pp_rank: int, tp_rank: int) -
)

if isinstance(optimizer, ZeroDistributedOptimizer):
# NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension
current_dp_rank = dist.get_rank(parallel_context.dp_pg)
for param_index in state_dict["state"]:
param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0]
for state_name in OPTIMIZER_STATE_NAMES:
sliced_tensor = get_sliced_tensor(
param=state_dict["state"][param_index][state_name],
start_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][0],
end_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][1],
)
state_dict["state"][param_index][state_name] = sliced_tensor
# NOTE: only reshard after merging tp shards
# or we get a new dp_Size
if int(ckp_tp_size) != parallel_context.tp_pg.size() or int(ckp_dp_size) != parallel_context.dp_pg.size():
# NOTE: if the optimizer is ZeRO-1, now we shard the optimizer states across data parallel dimension
current_dp_rank = dist.get_rank(parallel_context.dp_pg)
OPTIMIZER_STATE_NAMES = state_dict["state"][0].keys() - ["step"]
for param_index in state_dict["state"]:
param_name = [name for idx, name in state_dict["names"].items() if idx == param_index][0]
for state_name in OPTIMIZER_STATE_NAMES:
sliced_tensor = get_sliced_tensor(
param=state_dict["state"][param_index][state_name],
start_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][0],
end_offset=optimizer.param_name_to_dp_rank_offsets[param_name][current_dp_rank][1],
)
state_dict["state"][param_index][state_name] = sliced_tensor

optimizer.load_state_dict(state_dict)

Expand Down
3 changes: 2 additions & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from nanotron.logging import LoggerWriter, LogItem, human_format, log_rank, set_logger_verbosity_format
from nanotron.models import NanotronModel, build_model
from nanotron.models.base import check_model_has_grad
from nanotron.models.llama import LlamaForTraining, RotaryEmbedding
from nanotron.models.starcoder2 import Starcoder2ForTraining
from nanotron.optim.clip_grads import clip_grad_norm
Expand Down Expand Up @@ -650,7 +651,7 @@ def _init_model(
# Model make it DDP
if make_ddp is True:
# Check that the model has at least one grad. Necessary for DDP
# check_model_has_grad(model=model, parallel_context=parallel_context)
check_model_has_grad(model=model, parallel_context=parallel_context)
# TODO @thomasw21: DDP doesn't support broadcasting complex buffers (and we don't really need that broadcasting anyway)
model = DistributedDataParallel(
model,
Expand Down

0 comments on commit a6318ae

Please sign in to comment.