-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathnaturalbench_retrieval.py
146 lines (123 loc) · 5.49 KB
/
naturalbench_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import argparse
import os
import t2v_metrics
import json
from torch.utils.data import Dataset
def get_retrieval_scores(scores_i2t):
ids = list(range(scores_i2t.shape[0]))
retrieval_scores = []
for id, score_i2t in zip(ids, scores_i2t):
retrieval_scores.append({
"id" : id,
"c0_i0": score_i2t[0][0],
"c0_i1": score_i2t[1][0],
"c1_i0": score_i2t[0][1],
"c1_i1": score_i2t[1][1]}
)
return retrieval_scores
def get_retrieval_acc(scores):
text_correct_count = 0
image_correct_count = 0
group_correct_count = 0
def text_correct(result):
return result["c0_i0"] > result["c1_i0"] and result["c1_i1"] > result["c0_i1"]
def image_correct(result):
return result["c0_i0"] > result["c0_i1"] and result["c1_i1"] > result["c1_i0"]
def group_correct(result):
return image_correct(result) and text_correct(result)
for result in scores:
text_correct_count += 1 if text_correct(result) else 0
image_correct_count += 1 if image_correct(result) else 0
group_correct_count += 1 if group_correct(result) else 0
denominator = len(scores)
result = {
'text': text_correct_count/denominator,
'image': image_correct_count/denominator,
'group': group_correct_count/denominator,
}
return result
class NaturalBench_Retrieval(Dataset):
def __init__(self,
root_dir='./datasets',
download=True,
image_preprocess=None,
return_image_paths=True):
self.root_dir = root_dir
self.dataset_dir = os.path.join(root_dir, "NaturalBench-Retrieval")
self.image_dir = os.path.join(self.dataset_dir, "images")
self.metadata_path = os.path.join(self.dataset_dir, 'metadata.json')
self.download_links = "https://huggingface.co/datasets/BaiqiL/NaturalBench/resolve/main/NaturalBench-Retrieval.zip"
if not os.path.exists(self.dataset_dir):
if download:
import subprocess
model_file_name = "NaturalBench-Retrieval.zip"
image_zip_file = os.path.join(self.root_dir, model_file_name)
if not os.path.exists(image_zip_file):
subprocess.call(
["wget", self.download_links, "-O", model_file_name], cwd=self.root_dir
)
subprocess.call(["unzip", "-q", model_file_name], cwd=self.root_dir)
with open(self.metadata_path, 'r', encoding='utf-8') as file:
self.metadata = json.load(file)
self.return_image_paths = return_image_paths
if return_image_paths:
assert image_preprocess is None
self.preprocess = None
self.preprocess = image_preprocess
def __len__(self):
return len(self.metadata)
def __getitem__(self, idx):
assert self.metadata[idx]['index'] == idx
image_0_path = os.path.join(self.image_dir, self.metadata[idx]['image_0'])
image_1_path = os.path.join(self.image_dir, self.metadata[idx]['image_1'])
if self.return_image_paths:
image_0 = image_0_path
image_1 = image_1_path
else:
image_0 = self.preprocess(self.image_loader(image_0_path))
image_1 = self.preprocess(self.image_loader(image_1_path))
caption_0 = self.metadata[idx]['caption_0']
caption_1 = self.metadata[idx]['caption_1']
item = {
"images": [image_0, image_1],
"texts": [caption_0, caption_1]
}
return item
def evaluate_scores(self, scores):
retrieval_scores = get_retrieval_scores(scores)
acc = get_retrieval_acc(retrieval_scores)
print("NaturalBench-Retrieval performance (overall)")
print(f"{'Dataset': <70} {'Text': <10} {'Image': <10} {'Group': <10}")
print(f"{'NaturalBench-Retrieval': <70} {acc['text']: <10.2%} {acc['image']: <10.2%} {acc['group']: <10.2%}")
results = {}
results['all'] = acc
return results
def config():
parser = argparse.ArgumentParser()
parser.add_argument("--root_dir", default="./datasets", type=str,
help='Root directory for saving datasets.')
parser.add_argument("--cache_dir", default=t2v_metrics.constants.HF_CACHE_DIR, type=str)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--batch_size", default=16, type=int)
parser.add_argument("--model", default="openai:ViT-L-14", type=str) #VQAScore:"clip-flant5-xxl"
parser.add_argument("--question", default=None, type=str)
parser.add_argument("--answer", default=None, type=str)
return parser.parse_args()
def main():
args = config()
if not os.path.exists(args.root_dir):
os.makedirs(args.root_dir)
score_func = t2v_metrics.get_score_model(model=args.model, device=args.device, cache_dir=args.cache_dir)
kwargs = {}
if args.question is not None:
print(f"Using question template: {args.question}")
kwargs['question_template'] = args.question
if args.answer is not None:
print(f"Using answer template: {args.answer}")
kwargs['answer_template'] = args.answer
print(f"Performance of {args.model}.")
dataset = NaturalBench_Retrieval(root_dir=args.root_dir)
scores = score_func.batch_forward(dataset, batch_size=args.batch_size, **kwargs).cpu()
dataset.evaluate_scores(scores)
if __name__ == "__main__":
main()