Skip to content

Commit

Permalink
Add llm auto configurator and apply per seq sft loss for qwen2/2.5 mo…
Browse files Browse the repository at this point in the history
…dels (#371)

Co-authored-by: 同润 <[email protected]>
  • Loading branch information
jerryli1981 and 同润 authored Oct 30, 2024
1 parent 8003e0c commit bf582d8
Show file tree
Hide file tree
Showing 13 changed files with 1,025 additions and 27 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ English | [简体中文](./README_zh-CN.md)
Pai-Megatron-Patch (https://github.com/alibaba/Pai-Megatron-Patch) is a deep learning training toolkit built for developers to train and predict LLMs & VLMs by using Megatron framework easily. With the continuous development of LLMs, the model structure and scale are rapidly evolving. Although these models can be conveniently manufactured using Transformers or DeepSpeed training framework, the training efficiency is comparably low. This phenomenon becomes even severer when the model scale exceeds 10 billion. The primary objective of Pai-Megatron-Patch is to effectively utilize the computational power of GPUs for LLM. This tool allows convenient training of commonly used LLM with all the accelerating techniques provided by Megatron-LM.

What's New:
- **Add llm auto configurator and apply per seq sft loss for qwen2/2.5 models.** [🔥🔥 2024.10.30]
- **Upgrade deepseek-v2-moe models to support MLA via transformer engine and pipeline ckpts conversion.** [🔥🔥 2024.09.26]
- **Support training Qwen2.5 models by using Megatron-Core.** [🔥🔥 2024.09.20]
- **Support Sequence Packing in SFT for Qwen2 and LLaMA 3.1 models.** [🔥🔥 2024.09.13]
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Pai-Megatron-Patch是各类开源大模型和Megatron训练加速引擎之间的
- [阿里云PAI获得FewCLUE基于大模型的小样本学习双料冠军](https://developer.aliyun.com/article/788081?spm=a2c6h.12873639.article-detail.17.11c5383cHpFZks&tlog=yuekan_8)

新功能:
- **添加大模型训练最优吞吐参数自动配置以及针对qwen2/2.5系列模型优化微调per seq sft loss.** [🔥🔥 2024.10.30]
- **升级Deepseek-V2-MoE系列模型支持TE版的MLA以及流水并行CKPT转换** [🔥🔥 2024.09.26]
- **支持用Megatron-Core框架训练Qwen2.5系列模型** [🔥🔥 2024.09.20]
- **支持Qwen2及LLaMA-3.1系列模型SFT的Sequence Packing技术.** [🔥🔥 2024.09.13]
Expand Down
52 changes: 35 additions & 17 deletions examples/qwen2/pretrain_qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,27 @@ def get_batch(data_iterator):
if args.train_mode == "pretrain":
raise ValueError('The LLama-SFT-Raw dataset should only be used for finetuning!')
# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank_original(data_iterator)
batch = get_batch_on_this_tp_rank_original(data_iterator, per_seq_average=True)
# slice batch along sequence dimension for context parallelism
num_seqs = batch.pop('num_seqs')
batch = get_batch_on_this_cp_rank(batch)

return tuple([*batch.values(), None])
return (
batch['tokens'],
batch['labels'],
batch['loss_mask'],
batch['attention_mask'],
batch['position_ids'],
num_seqs,
None
)
elif "-Idxmap" in args.dataset:
# get batches based on the TP rank you are on
if args.train_mode == "pretrain":
batch = get_batch_on_this_tp_rank(data_iterator)

else:
batch = get_batch_on_this_tp_rank_idxmap_sft(data_iterator)
batch = get_batch_on_this_tp_rank_idxmap_sft(data_iterator, per_seq_average=True)

packed_seq_params = None
if args.reset_position_ids:
Expand All @@ -129,14 +139,23 @@ def get_batch(data_iterator):
if packed_seq_params is not None and args.context_parallel_size > 1:
raise ValueError('Sequence Packing is not supported when CP>1 !')
# slice batch along sequence dimension for context parallelism
num_seqs = batch.pop('num_seqs', None)
batch = get_batch_on_this_cp_rank(batch)

return tuple([*batch.values(), packed_seq_params])
return (
batch['tokens'],
batch['labels'],
batch['loss_mask'],
batch['attention_mask'],
batch['position_ids'],
num_seqs,
packed_seq_params
)
else:
raise ValueError("please set correct --dataset ")


def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
def loss_func(loss_mask: torch.Tensor, num_seqs: torch.Tensor, output_tensor: torch.Tensor):
"""Loss function.
Args:
Expand All @@ -147,27 +166,26 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):

losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()

loss = torch.stack([torch.sum(losses.view(-1) * loss_mask), loss_mask.sum()])
if args.context_parallel_size > 1:
loss = torch.cat(
[torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)]
)
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
loss = loss[0] / loss[1]
else:
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

# Check individual rank losses are not NaN prior to DP all-reduce.
if args.check_for_nan_in_loss_and_grad:
global_rank = torch.distributed.get_rank()
assert not loss.isnan(), (
assert not loss.isnan().any(), (
f"Rank {global_rank}: found NaN in local forward loss calculation. "
f"Device: {torch.cuda.current_device()}, node: {os.uname()[1]}"
)

# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
averaged_loss = average_losses_across_data_parallel_group(loss)
averaged_loss = averaged_loss[0] / averaged_loss[1]

return loss * args.context_parallel_size, {"lm loss": averaged_loss[0]}
# NOTE: The grad will be scaled down by CP size later, should not remove this multilication factor
# LINK: https://github.com/NVIDIA/Megatron-LM/issues/906
# The issue is solved since 0926
return loss[0] * args.context_parallel_size, num_seqs.sum(), {"lm loss": averaged_loss}


def forward_step(data_iterator, model: GPTModel):
Expand All @@ -182,11 +200,11 @@ def forward_step(data_iterator, model: GPTModel):

# Get the batch.
timers("batch-generator", log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator)
tokens, labels, loss_mask, attention_mask, position_ids, num_seqs, packed_seq_params = get_batch(data_iterator)
timers("batch-generator").stop()
output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params)

return output_tensor, partial(loss_func, loss_mask)
return output_tensor, partial(loss_func, loss_mask, num_seqs)


def is_dataset_built_on_rank():
Expand Down
1 change: 1 addition & 0 deletions examples/qwen2/run_mcore_qwen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ megatron_options=" \
--rotary-base 1000000 \
--rotary-seq-len-interpolation-factor 1 \
--no-save-optim \
--calculate-per-token-loss \
"

run_cmd="torchrun $DISTRIBUTED_ARGS pretrain_qwen.py
Expand Down
1 change: 1 addition & 0 deletions examples/qwen2_5/run_mcore_qwen.sh
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ megatron_options=" \
--rotary-base 1000000 \
--rotary-seq-len-interpolation-factor 1 \
--no-save-optim \
--calculate-per-token-loss \
"

run_cmd="torchrun $DISTRIBUTED_ARGS ../qwen2/pretrain_qwen.py
Expand Down
118 changes: 109 additions & 9 deletions megatron_patch/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,79 @@
from megatron import get_args
except:
from megatron.training import get_args
try:
from megatron.utils import get_ltor_masks_and_position_ids
except:
from megatron.training.utils import get_ltor_masks_and_position_ids

from megatron_patch.tokenizer import get_tokenizer

def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
create_attention_mask: bool=True):
"""Build masks and position id for left to right model."""

# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()

# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
if create_attention_mask:
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
else:
attention_mask = None

def get_batch_on_this_tp_rank_original(data_iterator):
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0

# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()

if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):

# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()

# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask and attention_mask is not None:
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1

if attention_mask is not None:
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)

return attention_mask, loss_mask, position_ids

def get_batch_on_this_tp_rank_original(data_iterator, per_seq_average=False):
args = get_args()
tokenizer = get_tokenizer()
def _broadcast(item):
if item is None:
return
torch.distributed.broadcast(item, mpu.get_tensor_model_parallel_src_rank(),
group=mpu.get_tensor_model_parallel_group())

Expand All @@ -54,13 +115,20 @@ def _broadcast(item):
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)

num_seqs = None
if per_seq_average:
# NOTE: raw dataset does not support sequence packing
num_seqs = loss_mask.sum(dim=-1).long() # [mbs]
loss_mask = loss_mask / num_seqs.view(-1, 1)

batch = {
'tokens': tokens.cuda(non_blocking=True),
'labels': labels.cuda(non_blocking=True),
'loss_mask': loss_mask.cuda(non_blocking=True),
'attention_mask': attention_mask.cuda(non_blocking=True),
'position_ids': position_ids.cuda(non_blocking=True)
'position_ids': position_ids.cuda(non_blocking=True),
'num_seqs': num_seqs.cuda(non_blocking=True) if num_seqs is not None else None
}

if args.pipeline_model_parallel_size == 1:
Expand All @@ -69,6 +137,7 @@ def _broadcast(item):
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(batch['position_ids'])
_broadcast(num_seqs)

elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
Expand All @@ -92,13 +161,19 @@ def _broadcast(item):
device=torch.cuda.current_device())
position_ids = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64,
device=torch.cuda.current_device())

num_seqs = None
if per_seq_average:
num_seqs = torch.empty((args.micro_batch_size,), dtype=torch.int64,
device=torch.cuda.current_device())

if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(position_ids)
_broadcast(num_seqs)

elif mpu.is_pipeline_first_stage():
labels = None
Expand All @@ -121,12 +196,13 @@ def _broadcast(item):
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
'position_ids': position_ids,
'num_seqs': num_seqs
}

return batch

def get_batch_on_this_tp_rank_idxmap_sft(data_iterator):
def get_batch_on_this_tp_rank_idxmap_sft(data_iterator, per_seq_average=False):
args = get_args()
tokenizer = get_tokenizer()
def _broadcast(item):
Expand Down Expand Up @@ -158,20 +234,38 @@ def _broadcast(item):
False,
args.create_attention_mask_in_dataloader
)

num_seqs = None
if per_seq_average:
num_seqs = torch.zeros(position_ids.shape[0], device=torch.cuda.current_device(), dtype=torch.int64)
for b in range(position_ids.shape[0]):
p = position_ids[b]
start_indices = (p == 0).nonzero(as_tuple=True)[0]
seqlens = start_indices[1:] - start_indices[:-1]
num_seqs[b] = len(seqlens)
seqlens = seqlens.cpu().numpy().tolist() + [p.shape[0] - start_indices[-1].item()]
subseqs = torch.split(loss_mask[b], seqlens)
for start_idx, seqlen, subseq in zip(start_indices, seqlens, subseqs):
assert subseq.sum() > 0
loss_mask[b, start_idx: start_idx + seqlen] /= subseq.sum()


# dtype: long, long, float, bool, long
batch = {
'tokens': tokens.cuda(non_blocking=True),
'labels': labels.cuda(non_blocking=True),
'loss_mask': loss_mask.cuda(non_blocking=True),
'attention_mask': attention_mask.cuda(non_blocking=True) if attention_mask is not None else None,
'position_ids': position_ids.cuda(non_blocking=True)
'position_ids': position_ids.cuda(non_blocking=True),
'num_seqs': num_seqs.cuda(non_blocking=True) if num_seqs is not None else None
}

if args.pipeline_model_parallel_size == 1:
_broadcast(batch['tokens'])
_broadcast(batch['labels'])
_broadcast(batch['loss_mask'])
_broadcast(batch['attention_mask'])
_broadcast(num_seqs)

elif mpu.is_pipeline_first_stage():
_broadcast(batch['tokens'])
Expand Down Expand Up @@ -200,11 +294,17 @@ def _broadcast(item):
position_ids = torch.empty((args.micro_batch_size, args.seq_length), dtype=torch.int64,
device=torch.cuda.current_device())

num_seqs = None
if per_seq_average:
num_seqs = torch.empty((args.micro_batch_size,), dtype=torch.int64,
device=torch.cuda.current_device())

if args.pipeline_model_parallel_size == 1:
_broadcast(tokens)
_broadcast(labels)
_broadcast(loss_mask)
_broadcast(attention_mask)
_broadcast(num_seqs)

elif mpu.is_pipeline_first_stage():
labels = None
Expand Down
Loading

0 comments on commit bf582d8

Please sign in to comment.