-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpg-evaluate-model
198 lines (180 loc) · 7.3 KB
/
pg-evaluate-model
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#!/usr/bin/env python
import logging
import sys
import time
from collections import Counter
from pathlib import Path
import click
import os.path as op
from proc_gen import Problem, TASK_TO_PROBLEMS
from proc_gen.evaluate import get_scores, scores_to_latex
from proc_gen.utils import get_ckpt_dir, replace_in_path
logger = logging.getLogger("evaluate")
logger.addHandler(logging.StreamHandler(sys.stdout))
logger.propagate = False
logger.setLevel(logging.INFO)
def _get_data(log_path_or_paths, score_reference=False):
# Copyright (c) Facebook, Inc. and its affiliates.
# The code in this function is licensed under the MIT license.
if isinstance(log_path_or_paths, str):
log_path_or_paths = [log_path_or_paths]
ids, src, ref, ref_toks, hypo, hypo_toks = None, None, None, None, {}, {}
names = Counter()
for k, log_path in enumerate(log_path_or_paths):
assert op.isfile(log_path)
cur_src, cur_ref, cur_ref_toks, cur_hypo, cur_hypo_toks = {}, {}, {}, {}, {}
with open(log_path) as f:
for l in f:
line = l.strip()
if line.startswith("D-"):
# if line.startswith('H-'):
if not score_reference:
_id, _, sent = line.split("\t", 2)
cur_hypo[_id[2:]] = sent
elif line.startswith("T-"):
_id, sent = line.split("\t", 1)
cur_ref[_id[2:]] = sent
if score_reference:
cur_hypo[_id[2:]] = sent
elif line.startswith("S-"):
_id, sent = line.split("\t", 1)
cur_src[_id[2:]] = sent
elif line.startswith("TT-"):
_id, sent = line.split("\t", 1)
cur_ref_toks[_id[3:]] = [int(i) for i in sent.split()]
elif line.startswith("H-"):
_id, _, sent = line.split("\t", 2)
cur_hypo_toks[_id[2:]] = [int(i) for i in sent.split()]
cur_ids = sorted(cur_src.keys())
assert (
set(cur_ids) == set(cur_ref.keys()) == set(cur_hypo.keys())
) # == set(cur_ref_toks.keys())
cur_src = [cur_src[i] for i in cur_ids]
cur_ref = [cur_ref[i] for i in cur_ids]
# cur_ref_toks = [cur_ref_toks[i] for i in cur_ids]
if k == 0:
ids, src, ref, ref_toks = cur_ids, cur_src, cur_ref, cur_ref_toks
else:
assert set(ids) == set(cur_ids) and set(src) == set(cur_src)
assert set(ref) == set(cur_ref)
name = op.splitext(op.basename(log_path))[0]
names.update([name])
if names[name] > 1:
name += f".{names[name]}"
hypo[name] = [cur_hypo[i] for i in cur_ids]
hypo_toks[name] = [cur_hypo_toks[i] for i in cur_ids]
return {"0": src}, {"0": ref}, {"0": ref_toks}, hypo, hypo_toks
@click.command()
@click.option(
"--data_dir",
default="/data/procgen/v1/processed",
help="Base dir for saving the processed train/val/test files.",
)
@click.option(
"--dataset",
type=click.Choice(["Recipe1M", "dummy"]),
help="Type of the dataset provided through --input_file.",
)
@click.option(
"--problem", type=click.Choice(Problem.__members__.keys()),
)
@click.option(
"--model_type",
type=click.Choice(["onmt", "huggingface", "fairseq"]),
help="Which modeling library used.",
)
@click.option(
"--model_arch",
type=click.Choice(["lstm", "conv", "transformer", "bart", "gpt2"]),
help="Which model architecture to evaluate.",
)
@click.option("--version", type=int, default=0, help=".")
def evaluate(data_dir, dataset, problem, model_type, model_arch, version):
if problem == "Requirements_TO_TargetProduct":
problem = Problem.Requirements_TO_TargetProduct
elif problem == "TargetProduct_TO_Requirements":
problem = Problem.TargetProduct_TO_Requirements
elif problem == "Requirements_TO_TargetProductAndTasks":
problem = Problem.Requirements_TO_TargetProductAndTasks
elif problem == "TargetProductAndRequirements_TO_Tasks":
problem = Problem.TargetProductAndRequirements_TO_Tasks
elif problem == "TargetProductAndRequirementsAndTasks":
problem = Problem.TargetProductAndRequirementsAndTasks
elif problem == "RequirementsAndTargetProductAndTasks":
problem = Problem.RequirementsAndTargetProductAndTasks
data_dir = Path(data_dir) / problem.name / dataset / model_type
if model_type == "fairseq":
def score_predictions():
results_dir: Path = replace_in_path(data_dir, "data", "results")
log_path = (
results_dir
/ model_arch
/ f"{model_arch}-on-test-0"
/ "generate-test.txt"
)
logger.info(f"Loading data from {log_path}")
sources = []
references = []
if problem in TASK_TO_PROBLEMS["language_modeling"]:
src_lang = problem.name
else:
src_lang, _ = problem.name.replace("_", "").split("TO")
with open(data_dir / f"test.{src_lang}", "r") as f:
lines = list(map(lambda l: l.rstrip(), f.readlines()))
for rec in lines:
splits = rec.split(" <rts> ")
sources.append(splits[0])
references.append(splits[1])
sources = {"0": sources}
references = {"0": references}
with open(str(log_path), "r") as f:
model_to_hypotheses = {
model_arch: list(
map(lambda l: l.rstrip().split(" <rts> ")[1], f.readlines())
)
}
# Task- or product-level
task_or_product_metrics = [
"token_acc",
"gleu",
"chrf",
"wer",
"bleu",
"rouge_1",
"meteor",
"bert_score",
]
# Task set-level
if problem in (
Problem.TargetProductAndRequirements_TO_Tasks,
Problem.Requirements_TO_TargetProductAndTasks,
Problem.TargetProductAndRequirementsAndTasks,
Problem.RequirementsAndTargetProductAndTasks,
):
task_set_metrics = [
"kendall_task_ranking", # task order
"req_cov", # requirement coverage
"essential_req_cov", # essential requirement coverage
]
else:
task_set_metrics = []
# Note: corpus score = mean(sentence scores)
logger.info(f"Computing scores...")
corpus_scores, group_scores = get_scores(
sources,
references,
model_to_hypotheses,
metrics=task_or_product_metrics + task_set_metrics,
verbose=True,
problem=problem.name,
)
logger.info(
f"LaTeX table with corpus scores: {scores_to_latex(corpus_scores)}"
)
else:
raise NotImplementedError(f"TODO: Implement results for {model_type}")
start = time.time()
score_predictions()
logger.info(f"Time elapsed: {time.time() - start}")
if __name__ == "__main__":
evaluate()