From d0a19fb1f513ddbb53d6ba94bd87569b8a3ce5e7 Mon Sep 17 00:00:00 2001 From: sallyjunjun <72725839+sallyjunjun@users.noreply.github.com> Date: Thu, 26 Sep 2024 11:18:10 +0800 Subject: [PATCH] fix inject model and add multimodal dataloader (#341) --- internlm/core/parallel/shard.py | 2 +- internlm/data/build_dataloader.py | 1 + internlm/data/streaming/collaters.py | 26 +++++-- internlm/data/streaming/dataset.py | 112 +++++++++++++++++++++------ internlm/model/modules/linear.py | 6 ++ internlm/train/pipeline.py | 8 +- 6 files changed, 122 insertions(+), 33 deletions(-) diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py index 3c3f3fb3..fa79ddc9 100644 --- a/internlm/core/parallel/shard.py +++ b/internlm/core/parallel/shard.py @@ -159,7 +159,7 @@ def get_parallel_strategies_split_mode(linear_name: str) -> str: if linear_name in ("head", "output"): return "head" if linear_name in ("gate"): - return "head" # for MoE model + return "gate" # for MoE model elif linear_name in ("wqkv", "wq", "wk", "wv", "wkv", "w1", "w3", "w13"): return "column" elif linear_name in ("fc1", "fc2", "linear_1", "linear_2"): # for vit model diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index a317ae3e..64da9539 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -133,6 +133,7 @@ def get_streaming_train_loader_items(data_cfg): train_folder=data_cfg.train_folder, tokenizer_path=data_cfg.tokenizer_path, model_max_length=data_cfg.seq_len, + image_folder=data_cfg.get("image_folder", None), content_name=data_cfg.get("content_name", "text"), subset_name=data_cfg.get("subset_name", None), ) diff --git a/internlm/data/streaming/collaters.py b/internlm/data/streaming/collaters.py index 39f8fd70..9aa42de7 100644 --- a/internlm/data/streaming/collaters.py +++ b/internlm/data/streaming/collaters.py @@ -3,10 +3,12 @@ def streaming_packed_collate_fn(batch): input_ids_list = [] + images_list = [] cu_seqlens_list = [] indexes_list = [] type_ids_list = [] labels_list = [] + has_image = False for b in batch: input_ids_list.append(torch.LongTensor(b["input_ids"])) @@ -14,10 +16,22 @@ def streaming_packed_collate_fn(batch): indexes_list.append(torch.IntTensor(b["indexes"])) type_ids_list.append(torch.LongTensor(b["type_ids"])) labels_list.append(torch.LongTensor(b["labels"])) + if b.get("images", None) is not None: + has_image = True + images_list.append(torch.Tensor(b["images"])) - return { - "input_ids": torch.stack(input_ids_list), - "cu_seqlens": cu_seqlens_list, - "indexes": torch.stack(indexes_list), - "type_ids": torch.stack(type_ids_list), - }, torch.stack(labels_list) + if has_image: + return { + "input_ids": torch.stack(input_ids_list), + "images": torch.stack(images_list), + "cu_seqlens": cu_seqlens_list, + "indexes": torch.stack(indexes_list), + "type_ids": torch.stack(type_ids_list), + }, torch.stack(labels_list) + else: + return { + "input_ids": torch.stack(input_ids_list), + "cu_seqlens": cu_seqlens_list, + "indexes": torch.stack(indexes_list), + "type_ids": torch.stack(type_ids_list), + }, torch.stack(labels_list) diff --git a/internlm/data/streaming/dataset.py b/internlm/data/streaming/dataset.py index 564648c9..8b0755ed 100644 --- a/internlm/data/streaming/dataset.py +++ b/internlm/data/streaming/dataset.py @@ -1,9 +1,12 @@ import itertools +import os import sys import datasets import numpy as np +import torch from datasets.distributed import split_dataset_by_node +from PIL import Image from torch.utils.data import Dataset from internlm.core.context import ParallelMode @@ -21,6 +24,7 @@ def __init__( train_folder, tokenizer_path, model_max_length, + image_folder=None, content_name="text", subset_name=None, split="train", @@ -32,6 +36,7 @@ def __init__( ) self.content_name = content_name self.buffer_size = buffer_size + self.image_folder = image_folder self.senior_iterator = iter(self) self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) @@ -52,11 +57,37 @@ def __len__(self): return sys.maxsize def _tokenize(self, samples): - texts = [sample[self.content_name] for sample in samples] - tokenized_outputs = self.tokenizer(texts, truncation=True) - for i in range(len(samples)): - if len(tokenized_outputs["input_ids"][i]) > 0: - yield {key: tokenized_outputs[key][i] for key in tokenized_outputs} + if self.image_folder is None: + texts = [sample[self.content_name] for sample in samples] + tokenized_outputs = self.tokenizer(texts, truncation=True) + for i in range(len(samples)): + if len(tokenized_outputs["input_ids"][i]) > 0: + yield {key: tokenized_outputs[key][i] for key in tokenized_outputs} + else: + processed_images = [] + texts = [] + for sample in samples: + image_path = os.path.join(self.image_folder, sample["image"]) + image = Image.open(image_path).convert("RGB") + image = self.preprocess_image(image) + processed_images.append(image) + text = "\n".join([conv["value"] for conv in sample[self.content_name]]) + texts.append(text) + tokenized_outputs = self.tokenizer(texts, truncation=True) + for i in range(len(samples)): + if len(tokenized_outputs["input_ids"][i]) > 0: + tokenized_output = {key: tokenized_outputs[key][i] for key in tokenized_outputs} + tokenized_output["images"] = processed_images[i] + yield tokenized_output + + def preprocess_image(self, image): + image = image.resize((gpc.config.data.image_size, gpc.config.data.image_size)) + image = np.array(image) + image = np.moveaxis(image, -1, 0) + image = image.astype(np.float32) + image /= 255.0 + image = torch.from_numpy(image).unsqueeze(0) + return image def __getitem__(self, _): return next(self.senior_iterator) @@ -89,12 +120,15 @@ class StreamingDatasetPackSampleWithPad(Dataset): """ - def __init__(self, dataset, seq_len, micro_bsz, pad_token_id=0): + def __init__(self, dataset, seq_len, micro_bsz, pad_token_id=0, image_token_id=200000): self.dataset = dataset self.seq_len = seq_len self.micro_bsz = micro_bsz self.pad_token_id = pad_token_id self.senior_iterator = iter(self) + if gpc.config.data.get("is_multimodal", False): + self.image_token_id = image_token_id + self.image_token_size = int(gpc.config.data.image_size // gpc.config.data.patch_size) ** 2 def __iter__(self): input_ids = [] @@ -109,15 +143,30 @@ def __iter__(self): else cu_seqlens ) labels = labels + [-100] * (self.micro_bsz * self.seq_len - len(labels)) - yield { - "input_ids": input_ids, - "cu_seqlens": cu_seqlens, - "indexes": list( - itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])]) - ), - "labels": labels, - "type_ids": [0] * (self.micro_bsz * self.seq_len), - } + if "images" in sample: + image_token_id_list = [self.image_token_id] * self.image_token_size + input_ids = input_ids[: self.micro_bsz * self.seq_len - self.image_token_size] + input_ids = image_token_id_list + input_ids + yield { + "input_ids": input_ids, + "images": sample["images"], + "cu_seqlens": cu_seqlens, + "indexes": list( + itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])]) + ), + "labels": labels, + "type_ids": [0] * (self.micro_bsz * self.seq_len), + } + else: + yield { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens, + "indexes": list( + itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])]) + ), + "labels": labels, + "type_ids": [0] * (self.micro_bsz * self.seq_len), + } input_ids = sample["input_ids"] cu_seqlens = [0, len(sample["input_ids"])] labels = [w if w > 0 else -100 for w in sample["input_ids"]][1:] + [-100] @@ -134,15 +183,30 @@ def __iter__(self): else cu_seqlens ) labels = labels + [-100] * (self.micro_bsz * self.seq_len - len(labels)) - yield { - "input_ids": input_ids, - "cu_seqlens": cu_seqlens, - "indexes": list( - itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])]) - ), - "labels": labels, - "type_ids": [0] * (self.micro_bsz * self.seq_len), - } + if "images" in self.dataset[-1]: + image_token_id_list = [self.image_token_id] * self.image_token_size + input_ids = input_ids[: self.micro_bsz * self.seq_len - self.image_token_size] + input_ids = image_token_id_list + input_ids + yield { + "input_ids": input_ids, + "images": self.dataset[-1]["images"], + "cu_seqlens": cu_seqlens, + "indexes": list( + itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])]) + ), + "labels": labels, + "type_ids": [0] * (self.micro_bsz * self.seq_len), + } + else: + yield { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens, + "indexes": list( + itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])]) + ), + "labels": labels, + "type_ids": [0] * (self.micro_bsz * self.seq_len), + } def __len__(self): return sys.maxsize diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 2426ab8a..a3c684f6 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -639,6 +639,12 @@ def new_linear( dtype, is_expert, ) + elif split_mode == "gate": + return nn.Linear( + in_features, + out_features, + bias, + ) else: err_msg = ( f"Parallel strategies for linear is unsupported, which is named as {name}.\n" diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 5547a9fb..34c1479b 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -151,6 +151,10 @@ def _check_module(name, module): setattr(param, IS_TENSOR_ZERO_PARALLEL, True) # for moe linear module + if isinstance(module, nn.Linear) and not isinstance(module, ParallelLinearWithCommExt): + for param in module.parameters(): + setattr(param, IS_REPLICA_ZERO_PARALLEL, True) + if isinstance(module, Experts): for param in module.parameters(): if ( @@ -178,7 +182,7 @@ def _check_module(name, module): for _chunk in unwrap_naive_amp(model): # special case for pure dp or pure wdp mode - if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) or gpc.get_world_size( + if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) and gpc.get_world_size( ParallelMode.WEIGHT_DATA ) == gpc.get_world_size(ParallelMode.GLOBAL): _check_module_func = _check_module_pure_dp_wdp @@ -927,7 +931,7 @@ def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Opt # inject modules for _chunk in model: - if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) or gpc.get_world_size( + if gpc.get_world_size(ParallelMode.DATA) == gpc.get_world_size(ParallelMode.GLOBAL) and gpc.get_world_size( ParallelMode.WEIGHT_DATA ) == gpc.get_world_size(ParallelMode.GLOBAL): continue