-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrun_experiments.py
108 lines (80 loc) · 4.62 KB
/
run_experiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
from tqdm import tqdm, trange
import json
import argparse
import itertools
from SimilarityVLM import SimilarityVLM
from dataset import DatasetHandler
from FewShotTestHandler import FewShotTestHandler
from classifier import WeightedTextFewShotClassifier # TODO: Support multiple variations of classifiers
def get_vlm(vlm_name: str) -> SimilarityVLM:
if vlm_name == "VT-TWINS":
from VTTWINS.wrapper import VTTWINS_SimilarityVLM
return VTTWINS_SimilarityVLM(reset_cache=False)
if vlm_name == "CLIP":
from CLIP.CLIPVLM import ClipVLM
return ClipVLM(reset_cache=False)
if vlm_name == "UniVL":
from UNIVL.wrapper import UniVL_SimilarityVLM
return UniVL_SimilarityVLM(reset_cache=False)
if vlm_name == "MILES":
from MILES.wrapper import MILES_SimilarityVLM
return MILES_SimilarityVLM(reset_cache=False)
if vlm_name == "VideoCLIP":
from video_clip.video_clip import VideoClipVLM
return VideoClipVLM(reset_cache=False)
raise ValueError(f"Unrecognized VLM name: {vlm_name}")
def get_param_iterator(param_json_file: str) -> list:
PARAM_KEYS = ["dataset_name", "dataset_split", "n_way", "n_support", "n_query", "n_episodes", "text_weight"]
with open(args.parameters, "r") as fp:
params = json.load(fp)
# Check json validity
missing_keys = [key for key in PARAM_KEYS if key not in params.keys()]
if len(missing_keys):
raise ValueError(f"Param json file missing required keys: {missing_keys}")
# Convert all given param values to lists (even if they only have a single value)
for key in params.keys():
if type(params[key]) is not list:
params[key] = [params[key]]
# Create an experiment param iterator which samples all valid combinations of params
experiment_param_iter = [
{PARAM_KEYS[i]: param_values_tuple[i] for i in range(len(PARAM_KEYS))}
for param_values_tuple in itertools.product(*[params[key] for key in PARAM_KEYS])
]
# Filter out parameter instances which have n_support = 0, but not text_weight = 1 (text_weight does nothing in the zero-shot case)
experiment_param_iter = list(filter(lambda exp_params: not (exp_params["n_support"] == 0 and exp_params["text_weight"] != 1), experiment_param_iter))
return experiment_param_iter
if __name__ == "__main__":
argparser = argparse.ArgumentParser(description="Runs repeated few-shot tests on the given VLM with all combinations of the parameters specified in the given json file.")
argparser.add_argument("vlm_name", type=str, help="VLM name to use for experiments. Assumes the script is run in the corresponding conda environment.")
argparser.add_argument("parameters", type=str, help="Path to a json file specifying parameter values (singular or lists) to run the experiments on.")
args = argparser.parse_args()
print(f"\n\n\n----- {args.vlm_name} -----")
vlm = get_vlm(args.vlm_name)
param_iter = get_param_iterator(args.parameters)
# During testing, save most recently loaded dataset for reuse
# Assumes dataset params are in the outermost loop of the product / first in the PARAM_KEYS list
query_dataset = None
train_dataset = None
# Save which datasets have been explicitly cached (repeatedly running fill_cache is fast but still wasted time)
datasets_in_cache = set()
test_handler = FewShotTestHandler()
pbar = tqdm(param_iter)
for exp_params in pbar:
pbar.set_postfix(exp_params)
# Load dataset
if query_dataset is None or not (query_dataset.name == exp_params["dataset_name"] and query_dataset.split == exp_params["dataset_split"]):
query_dataset = DatasetHandler(exp_params["dataset_name"], exp_params["dataset_split"])
train_dataset = DatasetHandler(exp_params["dataset_name"], "train")
# Fill vlm cache
if query_dataset.id() not in datasets_in_cache:
query_dataset.fill_cache(vlm)
datasets_in_cache.add(query_dataset.id())
if train_dataset.id() not in datasets_in_cache:
train_dataset.fill_cache(vlm)
datasets_in_cache.add(train_dataset.id())
# Construct classifier around vlm
classifier = WeightedTextFewShotClassifier(vlm, text_weight=exp_params["text_weight"])
# Run experiment
test_handler.run_few_shot_test(classifier, query_dataset, train_dataset, exp_params["n_way"], exp_params["n_support"],
exp_params["n_query"], exp_params["n_episodes"])