forked from Tiiiger/bert_score
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathget_rescale_baseline.py
97 lines (79 loc) · 3.04 KB
/
get_rescale_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import argparse
import gzip
import os
from random import shuffle
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sacrebleu
import torch
from tqdm.auto import tqdm
import bert_score
def get_data(lang="en"):
if lang == "en":
file_path = "data/news.2017.en.shuffled.deduped"
elif lang == "zh":
file_path = "data/paracrawl/crawl_chinese.txt"
else:
file_path = f"data/paracrawl/rand_{lang}.txt"
with open(file_path, "r") as f:
lines = []
for i, line in enumerate(f):
if i == 1_000_000:
break
line = line.strip()
if len(line.split(" ")) < 32 and len(line.split(" ")) > 0:
lines.append(line)
samples = np.random.choice(
range(len(lines)), size=(2, len(lines) // 2), replace=False
)
hyp = [lines[i] for i in samples[0]]
cand = [lines[i] for i in samples[1]]
return hyp, cand
def chunk(l, n):
# looping till length l
for i in range(0, len(l), n):
yield l[i : i + n]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument(
"--lang", type=str, required=True, help="language to compute baseline with"
)
parser.add_argument("-m", "--model", nargs="+", help="models to tune")
parser.add_argument("-b", "--batch_size", type=int, default=64)
args = parser.parse_args()
hyp, cand = get_data(lang=args.lang)
for model_type in args.model:
baseline_file_path = f"rescale_baseline/{args.lang}/{model_type}.tsv"
if os.path.isfile(baseline_file_path):
print(f"{model_type} baseline exists for {args.lang}")
continue
else:
print(f"computing baseline for {model_type} on {args.lang}")
scorer = bert_score.BERTScorer(model_type=model_type, all_layers=True)
with torch.no_grad():
score_means = None
count = 0
for batches in tqdm(
chunk(list(zip(hyp, cand)), 1000), total=len(hyp) / 1000
):
batch_hyp, batch_cand = zip(*batches)
scores = scorer.score(
batch_hyp, batch_cand, batch_size=args.batch_size
)
scores = torch.stack(scores, dim=0)
if score_means is None:
score_means = scores.mean(dim=-1)
else:
score_means = score_means * count / (
count + len(batches)
) + scores.mean(dim=-1) * len(batches) / (count + len(batches))
count += len(batches)
pd_baselines = pd.DataFrame(
score_means.numpy().transpose(), columns=["P", "R", "F"]
)
pd_baselines.index.name = "LAYER"
os.makedirs(os.path.dirname(baseline_file_path), exist_ok=True)
pd_baselines.to_csv(baseline_file_path)
del scorer