Skip to content

Commit

Permalink
fix inject model and add multimodal dataloader (#341)
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun authored Sep 26, 2024
1 parent f2f8e5b commit d0a19fb
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 33 deletions.
2 changes: 1 addition & 1 deletion internlm/core/parallel/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str:
if linear_name in ("head", "output"):
return "head"
if linear_name in ("gate"):
return "head" # for MoE model
return "gate" # for MoE model
elif linear_name in ("wqkv", "wq", "wk", "wv", "wkv", "w1", "w3", "w13"):
return "column"
elif linear_name in ("fc1", "fc2", "linear_1", "linear_2"): # for vit model
Expand Down
1 change: 1 addition & 0 deletions internlm/data/build_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def get_streaming_train_loader_items(data_cfg):
train_folder=data_cfg.train_folder,
tokenizer_path=data_cfg.tokenizer_path,
model_max_length=data_cfg.seq_len,
image_folder=data_cfg.get("image_folder", None),
content_name=data_cfg.get("content_name", "text"),
subset_name=data_cfg.get("subset_name", None),
)
Expand Down
26 changes: 20 additions & 6 deletions internlm/data/streaming/collaters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,35 @@

def streaming_packed_collate_fn(batch):
input_ids_list = []
images_list = []
cu_seqlens_list = []
indexes_list = []
type_ids_list = []
labels_list = []
has_image = False

for b in batch:
input_ids_list.append(torch.LongTensor(b["input_ids"]))
cu_seqlens_list.append(torch.IntTensor(b["cu_seqlens"]))
indexes_list.append(torch.IntTensor(b["indexes"]))
type_ids_list.append(torch.LongTensor(b["type_ids"]))
labels_list.append(torch.LongTensor(b["labels"]))
if b.get("images", None) is not None:
has_image = True
images_list.append(torch.Tensor(b["images"]))

return {
"input_ids": torch.stack(input_ids_list),
"cu_seqlens": cu_seqlens_list,
"indexes": torch.stack(indexes_list),
"type_ids": torch.stack(type_ids_list),
}, torch.stack(labels_list)
if has_image:
return {
"input_ids": torch.stack(input_ids_list),
"images": torch.stack(images_list),
"cu_seqlens": cu_seqlens_list,
"indexes": torch.stack(indexes_list),
"type_ids": torch.stack(type_ids_list),
}, torch.stack(labels_list)
else:
return {
"input_ids": torch.stack(input_ids_list),
"cu_seqlens": cu_seqlens_list,
"indexes": torch.stack(indexes_list),
"type_ids": torch.stack(type_ids_list),
}, torch.stack(labels_list)
112 changes: 88 additions & 24 deletions internlm/data/streaming/dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import itertools
import os
import sys

import datasets
import numpy as np
import torch
from datasets.distributed import split_dataset_by_node
from PIL import Image
from torch.utils.data import Dataset

from internlm.core.context import ParallelMode
Expand All @@ -21,6 +24,7 @@ def __init__(
train_folder,
tokenizer_path,
model_max_length,
image_folder=None,
content_name="text",
subset_name=None,
split="train",
Expand All @@ -32,6 +36,7 @@ def __init__(
)
self.content_name = content_name
self.buffer_size = buffer_size
self.image_folder = image_folder
self.senior_iterator = iter(self)

self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
Expand All @@ -52,11 +57,37 @@ def __len__(self):
return sys.maxsize

def _tokenize(self, samples):
texts = [sample[self.content_name] for sample in samples]
tokenized_outputs = self.tokenizer(texts, truncation=True)
for i in range(len(samples)):
if len(tokenized_outputs["input_ids"][i]) > 0:
yield {key: tokenized_outputs[key][i] for key in tokenized_outputs}
if self.image_folder is None:
texts = [sample[self.content_name] for sample in samples]
tokenized_outputs = self.tokenizer(texts, truncation=True)
for i in range(len(samples)):
if len(tokenized_outputs["input_ids"][i]) > 0:
yield {key: tokenized_outputs[key][i] for key in tokenized_outputs}
else:
processed_images = []
texts = []
for sample in samples:
image_path = os.path.join(self.image_folder, sample["image"])
image = Image.open(image_path).convert("RGB")
image = self.preprocess_image(image)
processed_images.append(image)
text = "\n".join([conv["value"] for conv in sample[self.content_name]])
texts.append(text)
tokenized_outputs = self.tokenizer(texts, truncation=True)
for i in range(len(samples)):
if len(tokenized_outputs["input_ids"][i]) > 0:
tokenized_output = {key: tokenized_outputs[key][i] for key in tokenized_outputs}
tokenized_output["images"] = processed_images[i]
yield tokenized_output

def preprocess_image(self, image):
image = image.resize((gpc.config.data.image_size, gpc.config.data.image_size))
image = np.array(image)
image = np.moveaxis(image, -1, 0)
image = image.astype(np.float32)
image /= 255.0
image = torch.from_numpy(image).unsqueeze(0)
return image

def __getitem__(self, _):
return next(self.senior_iterator)
Expand Down Expand Up @@ -89,12 +120,15 @@ class StreamingDatasetPackSampleWithPad(Dataset):
"""

def __init__(self, dataset, seq_len, micro_bsz, pad_token_id=0):
def __init__(self, dataset, seq_len, micro_bsz, pad_token_id=0, image_token_id=200000):
self.dataset = dataset
self.seq_len = seq_len
self.micro_bsz = micro_bsz
self.pad_token_id = pad_token_id
self.senior_iterator = iter(self)
if gpc.config.data.get("is_multimodal", False):
self.image_token_id = image_token_id
self.image_token_size = int(gpc.config.data.image_size // gpc.config.data.patch_size) ** 2

def __iter__(self):
input_ids = []
Expand All @@ -109,15 +143,30 @@ def __iter__(self):
else cu_seqlens
)
labels = labels + [-100] * (self.micro_bsz * self.seq_len - len(labels))
yield {
"input_ids": input_ids,
"cu_seqlens": cu_seqlens,
"indexes": list(
itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])
),
"labels": labels,
"type_ids": [0] * (self.micro_bsz * self.seq_len),
}
if "images" in sample:
image_token_id_list = [self.image_token_id] * self.image_token_size
input_ids = input_ids[: self.micro_bsz * self.seq_len - self.image_token_size]
input_ids = image_token_id_list + input_ids
yield {
"input_ids": input_ids,
"images": sample["images"],
"cu_seqlens": cu_seqlens,
"indexes": list(
itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])
),
"labels": labels,
"type_ids": [0] * (self.micro_bsz * self.seq_len),
}
else:
yield {
"input_ids": input_ids,
"cu_seqlens": cu_seqlens,
"indexes": list(
itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])
),
"labels": labels,
"type_ids": [0] * (self.micro_bsz * self.seq_len),
}
input_ids = sample["input_ids"]
cu_seqlens = [0, len(sample["input_ids"])]
labels = [w if w > 0 else -100 for w in sample["input_ids"]][1:] + [-100]
Expand All @@ -134,15 +183,30 @@ def __iter__(self):
else cu_seqlens
)
labels = labels + [-100] * (self.micro_bsz * self.seq_len - len(labels))
yield {
"input_ids": input_ids,
"cu_seqlens": cu_seqlens,
"indexes": list(
itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])
),
"labels": labels,
"type_ids": [0] * (self.micro_bsz * self.seq_len),
}
if "images" in self.dataset[-1]:
image_token_id_list = [self.image_token_id] * self.image_token_size
input_ids = input_ids[: self.micro_bsz * self.seq_len - self.image_token_size]
input_ids = image_token_id_list + input_ids
yield {
"input_ids": input_ids,
"images": self.dataset[-1]["images"],
"cu_seqlens": cu_seqlens,
"indexes": list(
itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])
),
"labels": labels,
"type_ids": [0] * (self.micro_bsz * self.seq_len),
}
else:
yield {
"input_ids": input_ids,
"cu_seqlens": cu_seqlens,
"indexes": list(
itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])])
),
"labels": labels,
"type_ids": [0] * (self.micro_bsz * self.seq_len),
}

def __len__(self):
return sys.maxsize
Expand Down
6 changes: 6 additions & 0 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,12 @@ def new_linear(
dtype,
is_expert,
)
elif split_mode == "gate":
return nn.Linear(
in_features,
out_features,
bias,
)
else:
err_msg = (
f"Parallel strategies for linear is unsupported, which is named as {name}.\n"
Expand Down
8 changes: 6 additions & 2 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ def _check_module(name, module):
setattr(param, IS_TENSOR_ZERO_PARALLEL, True)

# for moe linear module
if isinstance(module, nn.Linear) and not isinstance(module, ParallelLinearWithCommExt):
for param in module.parameters():
setattr(param, IS_REPLICA_ZERO_PARALLEL, True)

if isinstance(module, Experts):
for param in module.parameters():
if (
Expand Down Expand Up @@ -178,7 +182,7 @@ def _check_module(name, module):

for _chunk in unwrap_naive_amp(model):
# special case for pure dp or pure wdp mode
if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) or gpc.get_world_size(
if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) and gpc.get_world_size(
ParallelMode.WEIGHT_DATA
) == gpc.get_world_size(ParallelMode.GLOBAL):
_check_module_func = _check_module_pure_dp_wdp
Expand Down Expand Up @@ -927,7 +931,7 @@ def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Opt

# inject modules
for _chunk in model:
if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) or gpc.get_world_size(
if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) and gpc.get_world_size(
ParallelMode.WEIGHT_DATA
) == gpc.get_world_size(ParallelMode.GLOBAL):
continue
Expand Down

0 comments on commit d0a19fb

Please sign in to comment.