-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathtest_model.py
116 lines (99 loc) · 4.02 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import tensorflow as tf
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import os
import json
import math
import pickle
from utils import basic_hyperparams
from utils import load_data
from utils import load_global_inputs
from utils import get_valid_batch_feed_dict
def root_mean_squared_error(labels, preds):
total_size = np.size(labels)
return np.sqrt(np.sum(np.square(labels - preds)) / total_size)
def mean_absolute_error(labels, preds):
total_size = np.size(labels)
return np.sum(np.abs(labels - preds)) / total_size
if __name__ == '__main__':
# use specific gpu
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
# load hyperparameters
session = tf.Session(config=tf_config)
hps = basic_hyperparams()
hps_dict = json.load(open('./hparam_files/AirQualityGeoMan.json', 'r'))
hps.override_from_dict(hps_dict)
# model construction
tf.reset_default_graph()
from GeoMAN import GeoMAN
print(hps)
model = GeoMAN(hps)
# read data from test set
input_path = './data/'
test_data = load_data(
input_path, 'test', hps.n_steps_encoder, hps.n_steps_decoder)
global_inputs, global_attn_states = load_global_inputs(
input_path, hps.n_steps_encoder, hps.n_steps_decoder)
num_test = len(test_data[0])
print('test samples: {0}'.format(num_test))
# read scaler of the labels
f = open('./data/scalers/scaler-0.pkl', 'rb')
scaler = pickle.load(f)
f.close()
# path
if hps.ext_flag:
if hps.s_attn_flag == 0:
model_name = 'GeoMANng'
elif hps.s_attn_flag == 1:
model_name = 'GeoMANnl'
else:
model_name = 'GeoMAN'
else:
model_name = 'GeoMANne'
model_path = './logs/{}-{}-{}-{}-{}-{:.2f}-{:.3f}/'.format(model_name,
hps.n_steps_encoder,
hps.n_steps_decoder,
hps.n_stacked_layers,
hps.n_hidden_encoder,
hps.dropout_rate,
hps.lambda_l2_reg)
model_path += 'saved_models/final_model.ckpt'
# test params
n_split_test = 500 # times of splitting test set
test_rmses = []
test_maes = []
# restore model
print("Starting loading model...")
saver = tf.train.Saver()
with tf.Session() as sess:
model.init(sess)
# Restore model weights from previously saved model
saver.restore(sess, model_path)
print("Model successfully restored from file: %s" % model_path)
# test
test_loss = 0
test_indexes = np.int64(
np.linspace(0, num_test, n_split_test))
for k in range(n_split_test - 1):
feed_dict = get_valid_batch_feed_dict(
model, test_indexes, k, test_data, global_inputs, global_attn_states)
# re-scale predicted labels
batch_preds = sess.run(model.phs['preds'], feed_dict)
batch_preds = np.swapaxes(batch_preds, 0, 1)
batch_preds = np.reshape(batch_preds, [batch_preds.shape[0], -1])
batch_preds = scaler.inverse_transform(batch_preds)
# re-scale real labels
batch_labels = test_data[4]
batch_labels = batch_labels[test_indexes[k]:test_indexes[k + 1]]
batch_labels = scaler.inverse_transform(batch_labels)
test_rmses.append(root_mean_squared_error(
batch_labels, batch_preds))
test_maes.append(mean_absolute_error(batch_labels, batch_preds))
test_rmses = np.asarray(test_rmses)
test_maes = np.asarray(test_maes)
print('===============METRIC===============')
print('rmse = {:.6f}'.format(test_rmses.mean()))
print('mae = {:.6f}'.format(test_maes.mean()))