-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathframetracking-evaluate
executable file
·105 lines (79 loc) · 3.07 KB
/
frametracking-evaluate
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import json
import click
from collections import defaultdict
import tabulate
import pickle
from frames.utils import get_users_for_fold, cmp_turns, make_table
try:
from neobunch import neobunchify
except:
neobunchify = None
@click.group()
def cli():
pass
@cli.command()
@click.argument('original_json', type=click.File('r'))
def foldinfo(original_json):
"""
for every fold, summarize how may dialogues and turns it contains
"""
def print_fold_dlgs(dlgs, fold):
users = get_users_for_fold(fold)
dlgs = [d for d in dlgs if d['user_id'] in users]
print("Fold: %d, d=%d, t=%d [%s]" % (fold, len(dlgs),
sum(len(d['turns'])
for d in dlgs), ", ".join(users)))
gt_dlgs = json.load(original_json)
for f in range(1, 11):
print_fold_dlgs(gt_dlgs, f)
@cli.command()
@click.option("--save-as", "-o", type=click.Path(), help="save a json file with wrong instances tagged as 'wrong'")
@click.argument('original_json', type=click.File('r'))
@click.argument('results_json', type=click.File('r'))
@click.argument('fold', type=int)
def eval(original_json, results_json, fold, save_as):
"""
original_json:
the original dataset file, possibly augmented with turn tags
results_json:
the predictions (for all dialogues in a fold)
fold:
the fold the predictions in results_json are for.
One of 1,2,3,4,5,6,7,8,9,10
"""
gt_dlgs = json.load(original_json)
pred_dlgs = pred_dlgs_lst = json.load(results_json)
cnt, cnt_noanno, cnt_comb = [defaultdict(float) for i in range(3)]
users = get_users_for_fold(fold)
# results can be given either as a dict mapping from dialogue id to
# dialogue, or as a list of dialogues (which we transform here into the
# dict)
if not isinstance(pred_dlgs, dict):
pred_dlgs = {d['id']: d
for d in pred_dlgs_lst if d['user_id'] in users}
gt_dlgs = [d for d in gt_dlgs if d['user_id'] in users]
print("predicted dialogues:", len(pred_dlgs))
print("groundtruth dialogues:", len(gt_dlgs))
for gt_dlg in gt_dlgs:
if gt_dlg['id'] not in pred_dlgs:
raise RuntimeError(
"Dialogue '%s' not found in predictions" % gt_dlg['id'])
pred_dlg = pred_dlgs[gt_dlg['id']]
if len(gt_dlg['turns']) != len(pred_dlg['turns']):
raise RuntimeError(
"Dialogue '%s' number of turns differs from ground truth.")
for gt_turn, pred_turn in zip(gt_dlg['turns'], pred_dlg['turns']):
if gt_turn['author'] != 'user':
continue
cmp_turns(cnt, cnt_noanno, cnt_comb, gt_turn, pred_turn)
tables = make_table(cnt, cnt_noanno, cnt_comb)
for table in tables:
print(table)
if save_as:
with open(save_as, "w") as f:
json.dump(pred_dlgs_lst, f)
if __name__ == '__main__':
cli()