diff --git a/pretrain.py b/pretrain.py index e5cd5cf..6b56dc0 100644 --- a/pretrain.py +++ b/pretrain.py @@ -37,6 +37,8 @@ def main(): help="Number of prediction labels.") parser.add_argument("--dropout", type=float, default=0.1, help="Dropout value.") parser.add_argument("--seed", type=int, default=7, help="Random seed.") + parser.add_argument("--seq_length", type=int, default=128, + help="Sequence length.") # Preprocess options. tokenizer_opts(parser) @@ -44,8 +46,6 @@ def main(): # Model options. model_opts(parser) - parser.add_argument("--vision_model_missing_prefix", type=str, required=False, default="embedding.vision_language.vision_", - help="Extra prefix when loading the vision pretrained model as the embedding of the whole model.") # Model parallelism options. mp_opts(parser) diff --git a/scripts/convert_model_add_prefix.py b/scripts/convert_model_add_prefix.py new file mode 100755 index 0000000..f75ca66 --- /dev/null +++ b/scripts/convert_model_add_prefix.py @@ -0,0 +1,28 @@ +import argparse +import collections +import torch + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("--input_model_path", type=str, default="models/input_model.bin", + help=".") + parser.add_argument("--output_model_path", type=str, default="models/output_model.bin", + help=".") + parser.add_argument("--prefix", type=str, default="", help="prefix to add") + + + args = parser.parse_args() + + input_model = torch.load(args.input_model_path, map_location="cpu") + + output_model = collections.OrderedDict() + + for k in input_model.keys(): + output_model[args.prefix + k] = input_model[k] + + torch.save(output_model, args.output_model_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/generate_lm_llava_deepspeed.py b/scripts/generate_lm_llava_deepspeed.py index c7d2e3f..157b9ac 100755 --- a/scripts/generate_lm_llava_deepspeed.py +++ b/scripts/generate_lm_llava_deepspeed.py @@ -96,7 +96,7 @@ def load_or_initialize_parameters(args, model): args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys)) if args.vision_model_in_VL_emb_path is not None: args.logger.info("loading model from {0}".format(args.vision_model_in_VL_emb_path)) - model = load_model(model, args.vision_model_in_VL_emb_path, missing_prefix="embedding.vision_language.vision_") + model = load_model(model, args.vision_model_in_VL_emb_path) else: # Initialize with normal distribution. for n, p in list(model.named_parameters()): @@ -112,8 +112,6 @@ def load_or_initialize_parameters(args, model): parser.add_argument("--top_k", type=int, default=70) parser.add_argument("--top_p", type=float, default=0) parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None, - help="Path of the vision pretrained model in the vision language embedding.") parser.add_argument("--instruction_template", type=str, choices=["sys0", "sys1", "sys2", "sys3", "sys4"], help="The instruction type for training large language-vision model.", default="sys0") @@ -146,7 +144,7 @@ def load_or_initialize_parameters(args, model): if args.pretrained_model_path: model = _load_state_dict_into_model(model, args.pretrained_model_path) if args.vision_model_in_VL_emb_path is not None: - model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path, missing_prefix="embedding.vision_language.vision_") + model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path) else: model = LLaVaGenerate(args) load_or_initialize_parameters(args, model) diff --git a/tencentpretrain/model_loader.py b/tencentpretrain/model_loader.py index 29db8ef..a8e0b17 100644 --- a/tencentpretrain/model_loader.py +++ b/tencentpretrain/model_loader.py @@ -1,32 +1,24 @@ import os import torch -import collections from tencentpretrain import mpu -def load_model(model, model_path, lora_pretrained_model_path=None, missing_prefix=""): +def load_model(model, model_path, lora_pretrained_model_path=None): """ Load model from saved weights. """ - state_dict = torch.load(model_path, map_location="cpu") - if missing_prefix != "": - state_dict_withprefix = collections.OrderedDict() - for k in state_dict.keys(): - state_dict_withprefix[missing_prefix + k] = state_dict[k] - del state_dict - state_dict = state_dict_withprefix if hasattr(model, "module"): - model.module.load_state_dict(state_dict, strict=False) + model.module.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) if lora_pretrained_model_path is not None: model.module.load_state_dict(torch.load(lora_pretrained_model_path, map_location="cpu"), strict=False) else: - model.load_state_dict(state_dict, strict=False) + model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False) if lora_pretrained_model_path is not None: model.load_state_dict(torch.load(lora_pretrained_model_path, map_location="cpu"), strict=False) return model -def _load_state_dict_into_model(model_to_load, model_path, start_prefix="", missing_prefix=""): +def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""): # Convert old format to new format if needed from a PyTorch state_dict # copy state_dict so _load_from_state_dict can modify it @@ -61,12 +53,6 @@ def load(module, state_dict, prefix=""): for name, child in module._modules.items(): if child is not None: load(child, state_dict, prefix + name + ".") - if missing_prefix != "": - state_dict_withprefix = collections.OrderedDict() - for k in state_dict.keys(): - state_dict_withprefix[missing_prefix + k] = state_dict[k] - del state_dict - state_dict = state_dict_withprefix load(model_to_load, state_dict, prefix=start_prefix) # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index 7301239..105ce75 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -96,8 +96,7 @@ def init_model(args): if args.vision_model_in_VL_emb_path is not None: args.logger.info("loading: {}".format(args.vision_model_in_VL_emb_path)) - model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_in_VL_emb_path, missing_prefix=args.vision_model_missing_prefix) - # model_for_training = load_model(model_for_training, args.vision_model_path) + model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_in_VL_emb_path) return model_for_training, model_for_dataloader diff --git a/tencentpretrain/utils/dataloader.py b/tencentpretrain/utils/dataloader.py index 03a2bcf..8f1a4f5 100755 --- a/tencentpretrain/utils/dataloader.py +++ b/tencentpretrain/utils/dataloader.py @@ -546,6 +546,7 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca self.patch_size = args.patch_size self.image_height = args.image_height self.image_width = args.image_width + self.args = args from torchvision import transforms from tencentpretrain.utils.misc import ZeroOneNormalize @@ -990,7 +991,9 @@ def __iter__(self): """ from torchvision.io import read_image from torchvision.io.image import ImageReadMode - seg_num = (self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1 + + seg_image_num = (self.image_height // self.patch_size) * (self.image_width // self.patch_size) + text_seq_length = self.args.seq_length - seg_image_num while True: while self._empty(): self._fill_buf() @@ -1013,14 +1016,17 @@ def __iter__(self): ins_seg_nums_src, ins_seg_nums_tgt = ins[1] ins_src_image, ins_image_pos = ins[2] - src_text.append(ins_src) - tgt.append(ins_tgt) + src_text.append(ins_src[:text_seq_length]) ins_seg_src = [1] * ins_seg_nums_src[0] + [0] * ins_seg_nums_src[1] - ins_seg_tgt = [] + seg_text.append(ins_seg_src[:text_seq_length]) + + ins_tgt_new = [self.vocab.get(PAD_TOKEN)] * seg_image_num + ins_tgt + tgt.append(ins_tgt_new[:self.args.seq_length]) + ins_seg_tgt = [0] * seg_image_num for i, num in enumerate(ins_seg_nums_tgt): ins_seg_tgt = ins_seg_tgt + [i % 2] * num - seg_text.append(ins_seg_src) - seg_tgt.append(ins_seg_tgt) + seg_tgt.append(ins_seg_tgt[:self.args.seq_length]) + try: image = read_image(ins_src_image, ImageReadMode.RGB) except: @@ -1028,7 +1034,7 @@ def __iter__(self): continue image = image.cuda(self.local_rank) src_image.append(self.transform(image)) - seg_image.append([1] * seg_num) + seg_image.append([1] * (seg_image_num + 1)) image_pos.append(ins_image_pos) if len(src_image) == 0: continue diff --git a/tencentpretrain/utils/dataset.py b/tencentpretrain/utils/dataset.py index 0566a69..c2f03f7 100755 --- a/tencentpretrain/utils/dataset.py +++ b/tencentpretrain/utils/dataset.py @@ -1089,8 +1089,6 @@ def worker(self, proc_id, start, end): class LlavaDataset(Dataset): def worker(self, proc_id, start, end): import json - num_image_tokens = self.args.vision_seq_length_in_VL # 576 - seq_text = self.seq_length - num_image_tokens PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0] role1, role2 = "USER", "ASSISTANT" im_start, im_end = "", "" @@ -1100,9 +1098,9 @@ def worker(self, proc_id, start, end): pos = start skip_item = 0 with open(self.corpus_path, mode="r", encoding="utf-8") as f: - datas = json.load(f) + data = json.load(f) while True: - item = datas[pos] + item = data[pos] pos += 1 try: path = item["image"] @@ -1128,11 +1126,11 @@ def worker(self, proc_id, start, end): prompt_after_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_after_image)) seg_before_image = [1] * len(prompt_before_image_id) seg_after_image = [1] * len(prompt_after_image_id) - if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text: + if len(prompt_before_image_id) + len(prompt_after_image_id) > self.seq_length: print("promt too long, jumped") continue prompt_answer_id = prompt_before_image_id + prompt_after_image_id - tgt_id = [PAD_ID] * (len(prompt_answer_id) + num_image_tokens - 1) + tgt_id = [PAD_ID] * (len(prompt_answer_id) - 1) tgt_seg_nums = [len(tgt_id)] elif i % 2 == 0: # human prompt = conv["value"] @@ -1159,10 +1157,10 @@ def worker(self, proc_id, start, end): pad_num = self.seq_length - sum(tgt_seg_nums) tgt_seg_nums = tgt_seg_nums + [pad_num] - if len(prompt_answer_id) > seq_text : - prompt_answer_id = prompt_answer_id[:seq_text] + if len(prompt_answer_id) > self.seq_length : + prompt_answer_id = prompt_answer_id[:self.seq_length] - pad_num = seq_text - len(prompt_answer_id) + pad_num = self.seq_length - len(prompt_answer_id) prompt_answer_seg_nums = [len(prompt_answer_id), pad_num] prompt_answer_id = prompt_answer_id + [PAD_ID] * pad_num