-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutil.py
48 lines (39 loc) · 1.35 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np
def calc_bleu2(hypotheis, refers):
bleu = 0
max_bleu = 0
smoothie = SmoothingFunction()
for h, r in zip(hypotheis, refers):
rr = [x for x in r ]
hh = [x for x in h ] # remove UNK
try:
hh = hh[: hh.index(3)] # truncated to EOS
except:
# print("no eos")
hh = hh
cur_bleu = sentence_bleu([rr], hh, weights=(0, 1, 0, 0),
smoothing_function=smoothie.method1) # BLEU2
bleu += cur_bleu
great = None
great_target = None
if cur_bleu > max_bleu:
max_bleu = cur_bleu
great = hh
great_target = rr
return bleu / len(hypotheis), max_bleu, great, great_target
def translate(idx, idx2word):
word = []
for w in idx:
if w != 0: # pad
if idx2word[w] == '<EOS>':
break
else:
word.append(idx2word[w])
return " ".join(word)
def translate_pairs(topic_list, target_list, generated_list, word2idx):
idx2word = {v: k for k, v in word2idx.items()}
ret = []
for t, r, g in zip(topic_list, target_list, generated_list):
ret.append((translate(t, idx2word), translate(r, idx2word), translate(g, idx2word)))
return ret