Skip to content

Commit

Permalink
fix(unpack_data): pad -100 on labels (#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunpengsdu authored Mar 29, 2024
1 parent 2a6b1ce commit 87e8a9e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
2 changes: 1 addition & 1 deletion internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _load_accum_batch(self, data: Any, label: Any):

if self.data_process_func:
_data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
_label = self.data_process_func(_label, _data["cu_seqlens"])
_label = self.data_process_func(_label, _data["cu_seqlens"], padding_v=-100)
_data.pop("cu_seqlens")
_data.pop("indexes")

Expand Down
8 changes: 6 additions & 2 deletions internlm/core/scheduler/pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def load_micro_batch(self):
micro_batch_data["input_ids"] = self.data_process_func(
micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
)
micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"])
micro_batch_label = self.data_process_func(
micro_batch_label, micro_batch_data["cu_seqlens"], padding_v=-100
)

micro_batch_data.pop("cu_seqlens")
micro_batch_data.pop("indexes")
Expand Down Expand Up @@ -822,7 +824,9 @@ def load_micro_batch(self, model_chunk_id):
micro_batch_data["input_ids"] = self.data_process_func(
micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
)
micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"])
micro_batch_label = self.data_process_func(
micro_batch_label, micro_batch_data["cu_seqlens"], padding_v=-100
)

micro_batch_data.pop("cu_seqlens")
micro_batch_data.pop("indexes")
Expand Down
16 changes: 9 additions & 7 deletions internlm/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_dataset_type_id(dataset_type_ids_map, path):
return match_idxes[0]


def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False):
def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False, padding_v: int = 0):
"""
input_ids: if input_ids is not type_ids, the shape is (1, packed_length)
else the shape is (micro_num, packed_length)
Expand All @@ -36,16 +36,18 @@ def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False):
"""
bsz = input_ids.shape[0]

num_sequence = gpc.config.data["micro_bsz"]
num_seq = gpc.config.data["micro_bsz"]
seq_len_ = gpc.config.data.seq_len
dtype_ = input_ids.dtype

outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
outputs = torch.empty(bsz, num_seq, seq_len_, device=input_ids.device, dtype=dtype_).fill_(padding_v)

for i in range(bsz):
output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
output = torch.empty(num_seq, seq_len_, device=input_ids.device, dtype=dtype_).fill_(padding_v)
cu_seqlens_slice = cu_seqlens[i]
for j in range(num_sequence):
seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
output[j, 0:seq_length] = input_ids[i, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
for j in range(num_seq):
length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
output[j, 0:length] = input_ids[i, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
outputs[i] = output

# if the input_ids is not type_ids, we need squeeze the first dimension if it is 1.
Expand Down

0 comments on commit 87e8a9e

Please sign in to comment.