Skip to content

Commit

Permalink
fix dataset and dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
janinezhao committed Jan 4, 2024
1 parent b255e5b commit 845fef8
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 42 deletions.
4 changes: 2 additions & 2 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def main():
help="Number of prediction labels.")
parser.add_argument("--dropout", type=float, default=0.1, help="Dropout value.")
parser.add_argument("--seed", type=int, default=7, help="Random seed.")
parser.add_argument("--seq_length", type=int, default=128,
help="Sequence length.")

# Preprocess options.
tokenizer_opts(parser)
tgt_tokenizer_opts(parser)

# Model options.
model_opts(parser)
parser.add_argument("--vision_model_missing_prefix", type=str, required=False, default="embedding.vision_language.vision_",
help="Extra prefix when loading the vision pretrained model as the embedding of the whole model.")

# Model parallelism options.
mp_opts(parser)
Expand Down
28 changes: 28 additions & 0 deletions scripts/convert_model_add_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse
import collections
import torch


def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--input_model_path", type=str, default="models/input_model.bin",
help=".")
parser.add_argument("--output_model_path", type=str, default="models/output_model.bin",
help=".")
parser.add_argument("--prefix", type=str, default="", help="prefix to add")


args = parser.parse_args()

input_model = torch.load(args.input_model_path, map_location="cpu")

output_model = collections.OrderedDict()

for k in input_model.keys():
output_model[args.prefix + k] = input_model[k]

torch.save(output_model, args.output_model_path)


if __name__ == "__main__":
main()
6 changes: 2 additions & 4 deletions scripts/generate_lm_llava_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def load_or_initialize_parameters(args, model):
args.logger.info("unexpected_keys: {0}".format(keys_info.unexpected_keys))
if args.vision_model_in_VL_emb_path is not None:
args.logger.info("loading model from {0}".format(args.vision_model_in_VL_emb_path))
model = load_model(model, args.vision_model_in_VL_emb_path, missing_prefix="embedding.vision_language.vision_")
model = load_model(model, args.vision_model_in_VL_emb_path)
else:
# Initialize with normal distribution.
for n, p in list(model.named_parameters()):
Expand All @@ -112,8 +112,6 @@ def load_or_initialize_parameters(args, model):
parser.add_argument("--top_k", type=int, default=70)
parser.add_argument("--top_p", type=float, default=0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--vision_model_in_VL_emb_path", type=str, default=None,
help="Path of the vision pretrained model in the vision language embedding.")
parser.add_argument("--instruction_template", type=str, choices=["sys0", "sys1", "sys2", "sys3", "sys4"],
help="The instruction type for training large language-vision model.", default="sys0")

Expand Down Expand Up @@ -146,7 +144,7 @@ def load_or_initialize_parameters(args, model):
if args.pretrained_model_path:
model = _load_state_dict_into_model(model, args.pretrained_model_path)
if args.vision_model_in_VL_emb_path is not None:
model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path, missing_prefix="embedding.vision_language.vision_")
model = _load_state_dict_into_model(model, args.vision_model_in_VL_emb_path)
else:
model = LLaVaGenerate(args)
load_or_initialize_parameters(args, model)
Expand Down
22 changes: 4 additions & 18 deletions tencentpretrain/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,24 @@
import os
import torch
import collections
from tencentpretrain import mpu


def load_model(model, model_path, lora_pretrained_model_path=None, missing_prefix=""):
def load_model(model, model_path, lora_pretrained_model_path=None):
"""
Load model from saved weights.
"""
state_dict = torch.load(model_path, map_location="cpu")
if missing_prefix != "":
state_dict_withprefix = collections.OrderedDict()
for k in state_dict.keys():
state_dict_withprefix[missing_prefix + k] = state_dict[k]
del state_dict
state_dict = state_dict_withprefix
if hasattr(model, "module"):
model.module.load_state_dict(state_dict, strict=False)
model.module.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)
if lora_pretrained_model_path is not None:
model.module.load_state_dict(torch.load(lora_pretrained_model_path, map_location="cpu"), strict=False)
else:
model.load_state_dict(state_dict, strict=False)
model.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)
if lora_pretrained_model_path is not None:
model.load_state_dict(torch.load(lora_pretrained_model_path, map_location="cpu"), strict=False)
return model


def _load_state_dict_into_model(model_to_load, model_path, start_prefix="", missing_prefix=""):
def _load_state_dict_into_model(model_to_load, model_path, start_prefix=""):
# Convert old format to new format if needed from a PyTorch state_dict

# copy state_dict so _load_from_state_dict can modify it
Expand Down Expand Up @@ -61,12 +53,6 @@ def load(module, state_dict, prefix=""):
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
if missing_prefix != "":
state_dict_withprefix = collections.OrderedDict()
for k in state_dict.keys():
state_dict_withprefix[missing_prefix + k] = state_dict[k]
del state_dict
state_dict = state_dict_withprefix

load(model_to_load, state_dict, prefix=start_prefix)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
Expand Down
3 changes: 1 addition & 2 deletions tencentpretrain/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def init_model(args):

if args.vision_model_in_VL_emb_path is not None:
args.logger.info("loading: {}".format(args.vision_model_in_VL_emb_path))
model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_in_VL_emb_path, missing_prefix=args.vision_model_missing_prefix)
# model_for_training = load_model(model_for_training, args.vision_model_path)
model_for_training = _load_state_dict_into_model(model_for_training, args.vision_model_in_VL_emb_path)
return model_for_training, model_for_dataloader


Expand Down
20 changes: 13 additions & 7 deletions tencentpretrain/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def __init__(self, args, dataset_path, batch_size, global_rank, world_size, loca
self.patch_size = args.patch_size
self.image_height = args.image_height
self.image_width = args.image_width
self.args = args

from torchvision import transforms
from tencentpretrain.utils.misc import ZeroOneNormalize
Expand Down Expand Up @@ -990,7 +991,9 @@ def __iter__(self):
"""
from torchvision.io import read_image
from torchvision.io.image import ImageReadMode
seg_num = (self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1

seg_image_num = (self.image_height // self.patch_size) * (self.image_width // self.patch_size)
text_seq_length = self.args.seq_length - seg_image_num
while True:
while self._empty():
self._fill_buf()
Expand All @@ -1013,22 +1016,25 @@ def __iter__(self):
ins_seg_nums_src, ins_seg_nums_tgt = ins[1]
ins_src_image, ins_image_pos = ins[2]

src_text.append(ins_src)
tgt.append(ins_tgt)
src_text.append(ins_src[:text_seq_length])
ins_seg_src = [1] * ins_seg_nums_src[0] + [0] * ins_seg_nums_src[1]
ins_seg_tgt = []
seg_text.append(ins_seg_src[:text_seq_length])

ins_tgt_new = [self.vocab.get(PAD_TOKEN)] * seg_image_num + ins_tgt
tgt.append(ins_tgt_new[:self.args.seq_length])
ins_seg_tgt = [0] * seg_image_num
for i, num in enumerate(ins_seg_nums_tgt):
ins_seg_tgt = ins_seg_tgt + [i % 2] * num
seg_text.append(ins_seg_src)
seg_tgt.append(ins_seg_tgt)
seg_tgt.append(ins_seg_tgt[:self.args.seq_length])

try:
image = read_image(ins_src_image, ImageReadMode.RGB)
except:
print("Something is wrong when reading {}, just skipped!".format(ins_src_image))
continue
image = image.cuda(self.local_rank)
src_image.append(self.transform(image))
seg_image.append([1] * seg_num)
seg_image.append([1] * (seg_image_num + 1))
image_pos.append(ins_image_pos)
if len(src_image) == 0:
continue
Expand Down
16 changes: 7 additions & 9 deletions tencentpretrain/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,8 +1089,6 @@ def worker(self, proc_id, start, end):
class LlavaDataset(Dataset):
def worker(self, proc_id, start, end):
import json
num_image_tokens = self.args.vision_seq_length_in_VL # 576
seq_text = self.seq_length - num_image_tokens
PAD_ID = self.tokenizer.convert_tokens_to_ids([PAD_TOKEN])[0]
role1, role2 = "USER", "ASSISTANT"
im_start, im_end = "<Image>", "</Image>"
Expand All @@ -1100,9 +1098,9 @@ def worker(self, proc_id, start, end):
pos = start
skip_item = 0
with open(self.corpus_path, mode="r", encoding="utf-8") as f:
datas = json.load(f)
data = json.load(f)
while True:
item = datas[pos]
item = data[pos]
pos += 1
try:
path = item["image"]
Expand All @@ -1128,11 +1126,11 @@ def worker(self, proc_id, start, end):
prompt_after_image_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(prompt_after_image))
seg_before_image = [1] * len(prompt_before_image_id)
seg_after_image = [1] * len(prompt_after_image_id)
if len(prompt_before_image_id) + len(prompt_after_image_id) > seq_text:
if len(prompt_before_image_id) + len(prompt_after_image_id) > self.seq_length:
print("promt too long, jumped")
continue
prompt_answer_id = prompt_before_image_id + prompt_after_image_id
tgt_id = [PAD_ID] * (len(prompt_answer_id) + num_image_tokens - 1)
tgt_id = [PAD_ID] * (len(prompt_answer_id) - 1)
tgt_seg_nums = [len(tgt_id)]
elif i % 2 == 0: # human
prompt = conv["value"]
Expand All @@ -1159,10 +1157,10 @@ def worker(self, proc_id, start, end):
pad_num = self.seq_length - sum(tgt_seg_nums)
tgt_seg_nums = tgt_seg_nums + [pad_num]

if len(prompt_answer_id) > seq_text :
prompt_answer_id = prompt_answer_id[:seq_text]
if len(prompt_answer_id) > self.seq_length :
prompt_answer_id = prompt_answer_id[:self.seq_length]

pad_num = seq_text - len(prompt_answer_id)
pad_num = self.seq_length - len(prompt_answer_id)
prompt_answer_seg_nums = [len(prompt_answer_id), pad_num]
prompt_answer_id = prompt_answer_id + [PAD_ID] * pad_num

Expand Down

0 comments on commit 845fef8

Please sign in to comment.