-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfinal_eval..py
112 lines (75 loc) · 2.82 KB
/
final_eval..py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from datasets import Dataset
from sklearn.metrics import f1_score, mean_squared_error
from torch.utils.data import DataLoader
from training_final_trainval import CustomBert
def compute_metrics(labels, y_preds):
pcl_threshold = 1.5
pred_cl = y_preds > pcl_threshold
true_cl = labels > pcl_threshold
mse = mean_squared_error(labels, y_preds)
acc = torch.mean((pred_cl == true_cl).float())
f1p = f1_score(true_cl, pred_cl, pos_label=True)
results = {"mse": mse, "acc": acc, "f1p": f1p}
return results
def load_model():
model_path = "results/model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CustomBert()
model.load_state_dict(torch.load(model_path))
model.eval()
model.to(device)
return model, device
def load_data(path):
dev = Dataset.load_from_disk(path)
dev.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
dev_loader = DataLoader(dev, batch_size=16)
return dev_loader
def predict(model, device, loader):
y_preds = []
y_trues = []
with torch.no_grad():
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
y_pred = model(input_ids, attention_mask).cpu()
labels = batch["labels"]
y_preds.append(y_pred)
y_trues.append(labels)
return torch.cat(y_preds), torch.cat(y_trues)
def load_test(path):
test = Dataset.load_from_disk(path)
test.set_format(type="torch", columns=["input_ids", "attention_mask"])
test_loader = DataLoader(test, batch_size=16)
return test_loader
def predict_test(model, device, loader):
y_preds = []
with torch.no_grad():
for batch in loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
y_pred = model(input_ids, attention_mask).cpu()
y_preds.append(y_pred)
return torch.cat(y_preds)
def main():
model, device = load_model()
dev_loader = load_data("data/dev")
global binary_classifier
binary_classifier = False
y_preds, y_trues = predict(model, device, dev_loader)
metrics = compute_metrics(y_trues, y_preds)
print(metrics)
pred_labels = (y_preds > 1.5).int()
print(len(pred_labels))
print(f1_score((y_trues > 1.5).int(), pred_labels, pos_label=True))
with open("dev.txt", "w") as f:
for pred in pred_labels:
f.write(f"{pred.item()}\n")
test_loader = load_test("data/test")
y_preds = predict_test(model, device, test_loader)
pred_labels = (y_preds > 1.5).int()
with open("test.txt", "w") as f:
for pred in pred_labels:
f.write(f"{pred.item()}\n")
if __name__ == "__main__":
main()