-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtest.py
182 lines (135 loc) · 6.53 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import argparse
import sys
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from Dataset import *
from torch_geometric.loader import DataLoader
from model.GATE18 import *
class RMSELoss(torch.nn.Module):
def __init__(self):
super(RMSELoss, self).__init__()
self.mse = torch.nn.MSELoss()
def forward(self, output, targets):
return torch.sqrt(self.mse(output, targets))
def load_model_state(model, state_dict_path):
model.load_state_dict(torch.load(state_dict_path))
model.eval() # Set the model to evaluation mode
return model
# Evaluation Function
#-------------------------------------------------------------------------------------------------------------------------------
def evaluate(models, loader, criterion, device):
# Initialize variables to accumulate the evaluation results
total_loss = 0.0
y_true = []
y_pred = []
id = []
# Disable gradient calculation during evaluation
with torch.no_grad():
for graphbatch in loader:
graphbatch.to(device)
targets = graphbatch.y
# Forward pass EMSEMBLE MODEL
outputs = []
for model in models:
outputs.append(model(graphbatch).view(-1))
output = torch.mean(torch.stack(outputs), dim=0)
loss = criterion(output, targets)
# Accumulate loss and collect the true and predicted values for later use
total_loss += loss.item()
y_true.extend(targets.tolist())
y_pred.extend(output.tolist())
id.extend(graphbatch.id)
# Calculate evaluation metrics
eval_loss = total_loss / len(loader)
# Pearson Correlation Coefficient
corr_matrix = np.corrcoef(y_true, y_pred)
r = corr_matrix[0, 1]
# Link the predictions to the corresponding ids in a dictionary
id_to_pred = dict(zip(id, zip(y_true, y_pred)))
# R2 Score
r2_score = 1 - np.sum((np.array(y_true) - np.array(y_pred)) ** 2) / np.sum((np.array(y_true) - np.mean(np.array(y_true))) ** 2)
# RMSE in pK unit
min=0
max=16
true_labels_unscaled = torch.tensor(y_true) * (max - min) + min
predictions_unscaled = torch.tensor(y_pred) * (max - min) + min
rmse = criterion(predictions_unscaled, true_labels_unscaled)
return eval_loss, r, rmse, r2_score, true_labels_unscaled, predictions_unscaled, id_to_pred
#-------------------------------------------------------------------------------------------------------------------------------
# Plotting Functions
#-------------------------------------------------------------------------------------------------------------------------
def plot_error_histogram(ax, errors, title):
n, bins, patches = ax.hist(errors, bins=50, color='blue', edgecolor='black')
# Add text on top of each column
for count, patch in zip(n, patches):
ax.text(patch.get_x() + patch.get_width() / 2, patch.get_height(), f'{int(count)}',
ha='center', va='bottom')
ax.set_title(title)
ax.set_xlabel('Absolute Error (pK)')
ax.set_ylabel('Frequency')
def plot_predictions(y_true, y_pred, title, metrics='', filepath=None, axislim=14):
plt.scatter(y_true, y_pred, alpha=0.5, c='blue')
# Displaying the metrics in the top left corner of the plotting area
plt.text(0.05, 0.95, metrics, fontsize=14, transform=plt.gca().transAxes,
verticalalignment='top', horizontalalignment='left')
plt.plot([0, axislim], [0, axislim], color='red', linestyle='--')
plt.xlabel('True pK Values', fontsize=12)
plt.ylabel('Predicted pK Values', fontsize=12)
plt.ylim(0, axislim)
plt.yticks(fontsize=12)
plt.xlim(0, axislim)
plt.xticks(fontsize=12)
plt.title(title, fontsize=14, fontweight='bold')
plt.savefig(filepath, dpi=300)
#-------------------------------------------------------------------------------------------------------------------------
def parse_args():
parser = argparse.ArgumentParser(description="Testing Parameters and Input Dataset Control")
# REQUIRED Arguments
parser.add_argument("--stdicts", type=str, required=True, help="String of comma-separated paths to stdicts that should be tested as an ensemble")
parser.add_argument("--dataset_path", required=True, help="The path to the test dataset pt file")
# OPTIONAL Arguments
parser.add_argument("--model_arch", default="GATE18d", help="The name of the model architecture")
parser.add_argument("--save_path", default=None, help="The path where the results should be exported to")
return parser.parse_args()
args = parse_args()
# Paths
dataset_path = args.dataset_path
stdicts = args.stdicts.split(',')
save_path = args.save_path
if save_path == None: save_path = os.path.dirname(dataset_path)
# Load the datasets
test_dataset = torch.load(dataset_path)
test_loader = DataLoader(dataset = test_dataset, batch_size=128, shuffle=True, num_workers=4, persistent_workers=True)
# Emsemble Model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_arch = args.model_arch
conv_dropout_prob = 0
dropout_prob = 0
criterion = RMSELoss()
node_feat_dim = test_dataset[0].x.shape[1]
edge_feat_dim = test_dataset[0].edge_attr.shape[1]
model_class = getattr(sys.modules[__name__], model_arch)
models = [model_class(
dropout_prob=dropout_prob,
in_channels=node_feat_dim,
edge_dim=edge_feat_dim,
conv_dropout_prob=conv_dropout_prob).float().to(device)
for _ in range(len(stdicts))]
## MODEL NAME ##
model_paths = list(stdicts)
#for m in model_paths: print(m)
models = [load_model_state(model, path) for model, path in zip(models, model_paths)]
# Run inference
loss, r, rmse, r2_score, y_true, y_pred, id_to_pred = evaluate(models, test_loader, criterion, device)
# Plotting
#-------------------------------------------------------------------------------------------------------------------------
test_dataset_name = os.path.basename(dataset_path).split('.')[0]
# Save the predictions to a json file
with open(os.path.join(save_path, f'{test_dataset_name}_predictions.json'), 'w', encoding='utf-8') as json_file:
json.dump(id_to_pred, json_file, ensure_ascii=False, indent=4)
# Save Predictions Scatterplot
filepath = os.path.join(save_path, f'{test_dataset_name}_predictions.png')
plot_predictions(y_true, y_pred, test_dataset_name, metrics=f"R = {r:.3f}\nRMSE = {rmse:.3f}", filepath=filepath, axislim=14)
#-------------------------------------------------------------------------------------------------------------------------