-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathload_imagenet_vqa_dataset_followup.py
143 lines (124 loc) · 5.33 KB
/
load_imagenet_vqa_dataset_followup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Example script on how to load this dataset without depending on the entire framework.
For followup question we need existing output to ask a followup question about.
The model output for this dataset can be downloaded as described in the readme.
Also note that correctly answered questions will not be asked again, so the dataset becomes smaller.
"""
from collections import Counter
from copy import deepcopy
from pprint import pprint
from packg.iotools import load_yaml
from ovqa.paths import get_data_dir
from torch.utils.data import DataLoader
from ovqa.datasets.classifier_vqa_dataset import ClassifierVQADataset
from ovqa.datasets.imagenet_hierarchy import load_hierarchy
from ovqa.followup import Followup
from ovqa.processors import BlipImageEvalProcessor
from ovqa.result_loader import read_single_result
def text_processor_noop(x):
return x
def main():
# ----- load dataset as before
data_dir = get_data_dir()
imagenet_dir = data_dir / "imagenet1k"
vis_root = imagenet_dir
ann_paths = [
"ovqa/annotations/imagenet1k/generated/val.json",
"ovqa/annotations/imagenet1k/generated/classes_data.json",
]
vis_processor = None # None will give a pillow image back
# select which question the model will be asked
question_type = "what-seen-image" # "what-is-in-image", "whats-this"
# whether to use cropped images for imagenet or not
cropped_images_dir = "square" # "" or "square"
# see ovqa/configs/datasets/imagenet1k.yaml
config = {
"question_type": question_type,
"class_name_key": "clip_bench_label",
"cropped_images_dir": cropped_images_dir,
}
dataset = ClassifierVQADataset(
vis_processor=vis_processor,
text_processor=text_processor_noop,
vis_root=vis_root,
ann_paths=ann_paths,
config=config,
)
print(f"Original dataset length: {len(dataset)}")
# load followup info
followup_cfg = load_yaml("ovqa/configs/followup/followup_imagenet.yaml")["run"]["followup_cfg"]
pprint(followup_cfg)
default_followup_object = "object"
hier = load_hierarchy("imagenet1k")
synonym_dict = None
targets = {v["key"]: v["class_idx"] for v in dataset.annotation}
follower = Followup(followup_cfg, hier, dataset.classnames, synonym_dict, targets)
# load previous model output
followup_prev_dir = "output/imagenet1k-square~val/blip1~ftvqa~default~none~what-seen-image"
result_obj = read_single_result(followup_prev_dir)
assert result_obj is not None, f"Failed to read output from: {followup_prev_dir}"
preds = result_obj.load_output()
if next(iter(targets.keys())) not in preds:
# fix prediction keys from '0' to ''val_00000001' etc
new_preds = {}
for i, v in enumerate(dataset.annotation):
key = v["key"]
pred = preds[str(i)]
new_preds[key] = pred
preds = new_preds
# run followup pipeline
to_followup = follower.evaluate_pipeline(preds)
# to_followup now looks like
# {'val_00000003': {'status': 'followup', 'object': 'dog'},} ...
# where status is "correct", "failed" or "followup" and in case of followup "object" is set.
counter_followup = Counter(v["status"] for v in to_followup.values())
print(str(dict(counter_followup)))
# update dataset and config based on the followup questions to ask
new_anns = []
for ann in dataset.annotation:
ann_followup = to_followup[ann["key"]]
if ann_followup["status"] in "correct":
continue
# define the followup question
if ann_followup["status"] == "followup":
ask_object = ann_followup[default_followup_object]
elif ann_followup["status"] == "failed":
ask_object = default_followup_object
else:
raise ValueError(f"Unknown status: {ann_followup['status']}")
new_ann = deepcopy(ann)
# note this is used in ClassifierVQADataset.get_item
new_ann["question_followup"] = ask_object
new_anns.append(new_ann)
dataset.annotation = new_anns
print(f"Updated dataset, new length: {len(dataset.annotation)}")
# ----- look at the final dataset
# note that to get the final followup question, the text_input from the dataset must be
# formatted with the correct prompt. the correct prompt depends on the model (see model configs)
followup_prompt = "What type of {} is this?"
datapoint = dataset[0]
pprint(datapoint)
followup_question = followup_prompt.format(datapoint["text_input"])
print(f"Actual text_input: {followup_question}")
print()
# in order to use a dataloader, we need to transform the images to tensors, so we can stack them
dataset.vis_processor = BlipImageEvalProcessor(
image_size=224, mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25)
)
dataloader = DataLoader(
dataset,
shuffle=False,
num_workers=0,
batch_size=2,
collate_fn=dataset.collater,
)
for i, batch in enumerate(dataloader):
image_tensor = batch.pop("image")
print("image:", image_tensor.shape, image_tensor.dtype, image_tensor.device)
pprint(batch)
followup_questions = [followup_prompt.format(t) for t in batch["text_input"]]
print(f"Followup questions: {followup_questions}")
print()
break
if __name__ == "__main__":
main()