forked from castorini/castor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperimental_settings.py
160 lines (132 loc) · 5.61 KB
/
experimental_settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# The following 5 conditions vary
# idf_source, stopwords_and_stemming, punctuation, words-with-hyphens
import argparse
import itertools
import shlex
import subprocess
class Setting(object):
def __init__(self, label, value_flag_map):
self.label = label
self.choice_flags = value_flag_map
def get_settings(self):
return self.choice_flags.keys()
def get_choice(self, setting):
return self.choice_flags[setting]
def get_options(self):
options = []
for key in self.choice_flags.keys():
options.append("{}:{}".format(self.label, key))
return options
class Experiments(object):
def __init__(self, qa_dataset, word_embeddings_file):
self.settings = {}
self.combinations = []
self.qa_data = qa_dataset
self.w2v_file = word_embeddings_file
#self.cmd_root = "python main.py --dataset_folder {} --word_vectors_file {} --run-name-prefix run --epochs 1 --num_conv_filters 5".format(self.qa_data, self.w2v_file)
self.cmd_root = "python main.py --dataset_folder {} --word_vectors_file {} --run-name-prefix run --paper-ext-feats".format(self.qa_data, self.w2v_file)
self.eval_cmd_root = "../../Anserini/eval/trec_eval.9.0/trec_eval -m map -m recip_rank -m bpref"
self.rbp_cmd_root = "rbp_eval"
def add_setting(self, setting):
self.settings[setting.label] = setting
self._setup_combinations()
def _setup_combinations(self):
all_settings = []
for setting in self.settings.values():
all_settings.append(setting.get_options())
self.combinations = []
for c in itertools.product(*all_settings):
self.combinations.append(c)
def _run_cmd(self, cmd):
pargs = shlex.split(cmd)
p = subprocess.Popen(pargs, stdout=subprocess.PIPE, stderr=subprocess.PIPE, \
bufsize=1, universal_newlines=True)
pout, perr = p.communicate()
return pout, perr
def _run_eval(self):
for split in ['raw-dev', 'raw-test']:
cmd = '{} {}/{}.qrel run.{}.smrun'.format(self.eval_cmd_root,
self.qa_data, split, split)
out, err = self._run_cmd(cmd)
print(split, '------' )
# trec_eval scores
metrics = []
scores = []
for line in str(out).split('\n'):
if not line.strip().split(): continue
fields = line.strip().split()
metrics.append(fields[0])
scores.append(fields[-1])
# rbp_eval scores
cmd = '{} {}/{}.qrel run.{}.smrun'.format(self.rbp_cmd_root,
self.qa_data, split, split)
out, err = self._run_cmd(cmd)
for line in str(out).split('\n'):
if not line.startswith('p= 0.50'): continue
metrics.append('rbp_p0.5')
scores.append(' '.join(line.strip().split()[-2:]))
print('\t'.join(metrics))
print('\t'.join(scores))
def run(self, indices):
"""
runs a particular combination of settings
"""
for ci in indices:
combo = self.combinations[ci]
print(combo)
cmd_args = []
# set model name
model_name = 'sm_cnn.'
for setting_choice in combo:
setting, choice = setting_choice.split(':')
model_name += '{}-{}.'.format(setting, choice)
cmd_args.append(self.settings[setting].choice_flags[choice])
model_name += 'model'
cmd = '{} {} {}'.format(self.cmd_root, ' '.join(cmd_args), model_name)
print(cmd)
out, err = self._run_cmd(cmd)
with open(model_name + '.log', 'w') as lf:
print('---------- OUT ------------', file=lf)
print(out, file=lf)
print('---------- ERR ------------', file=lf)
print(err, file=lf)
self._run_eval()
def run_all(self):
"""
runs all experiments
"""
pass
def list_settings(self):
"""
lists all settings
"""
for c in enumerate(self.combinations):
print(c)
print("--run X Y Z to run combinations number X Y Z")
if __name__ == "__main__":
ap = argparse.ArgumentParser(description="Lists exerimental settings and runs experiments")
ap.add_argument("--list", help="lists all available experimental settings combinations",
action="store_true")
ap.add_argument("--run", help='runs experimenal setting combination NUMBER(s)',
nargs="+", type=int)
ap.add_argument("indexPath", help="required for some combination of experiments")
ap.add_argument('qa_data', help="path to the QA dataset",
choices=['../../data/TrecQA', '../../data/WikiQA'])
ap.add_argument('word_embeddings_file', help="the word embeddings file")
args = ap.parse_args()
experiments = Experiments(args.qa_data, args.word_embeddings_file)
experiments.add_setting(Setting('idf_source', {
'qa-data':'',
'corpus-index': '--index-for-corpusIDF {}'.format(args.indexPath)
}))
experiments.add_setting(Setting('punctuation', {
'keep': '',
'remove': '--stop-punct'
}))
experiments.add_setting(Setting('dash_words', {
'keep': '',
'split': '--dash-split'
}))
experiments.list_settings()
if args.run:
experiments.run(args.run)