-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcmp_test_cases.py
192 lines (147 loc) · 8.37 KB
/
cmp_test_cases.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import string
import shutil
import numpy as np
import os
import filecmp
import matplotlib.pyplot as plt
import unittest
import numpy as np
import re
import sys
from wetb.hawc2.Hawc2io import ReadHawc2
class CompareTestCases(unittest.TestCase):
def compare_lines(self, ref_file, test_file, skip_first=0):
with open(ref_file, encoding='utf-8') as ref:
ref_lines = ref.readlines()
with open(test_file, encoding='utf-8') as test:
test_lines = test.readlines()
self.assertEqual(len(ref_lines), len(test_lines), "\nNumber of lines differs in: '%s' and '%s'" % (ref_file, test_file))
for i, (ref_l, test_l) in enumerate(zip(ref_lines[skip_first:], test_lines[skip_first:])):
if ref_l.lower() != test_l.lower():
diff = "".join([[" ", "^"][a != b] for a, b in zip(ref_l, test_l)])
err_str = "%s%s%s\n\n" % (ref_l, test_l, diff)
raise AssertionError("Difference in line %d of %s\n%s" % (i, ref_file, err_str))
def compare_sel(self, ref_file, test_file):
self.compare_lines(ref_file, test_file, 8)
def compare_dat_contents(self, ref_file, test_file):
if filecmp.cmp(ref_file, test_file, shallow=False) is False:
self.compare_lines(ref_file, test_file)
def min_tol(self, ref_data, test_data):
def error(x, a, b):
atol, rtol = x
if rtol > 0 and atol > 0 and np.allclose(b, a, rtol, atol):
return rtol + atol
else:
return 10 ** 99
from scipy.optimize import fmin
atol, rtol = fmin(error, (1, 1), (ref_data, test_data), disp=False)
return atol, rtol
def compare_dat_plot(self, ref_file, test_file, show_plot=False, rtol=1.e-5, atol=1.e-8):
ref = ReadHawc2(os.path.splitext(ref_file)[0])
test = ReadHawc2(os.path.splitext(test_file)[0])
ref_data = ref()
test_data = test()
if not np.allclose(ref_data, test_data, rtol=rtol, atol=atol):
different_sensors = []
for i in range(ref.NrCh):
if not np.allclose(ref_data[:, i], test_data[:, i], rtol=rtol, atol=atol):
different_sensors.append(i)
path = os.path.join(os.path.dirname(test_file), "Compare", os.path.splitext(os.path.basename(ref_file))[0])
shutil.rmtree(path, ignore_errors=True)
try:
os.mkdir(path)
except:
try:
os.mkdir(os.path.join(os.path.dirname(test_file), "Compare"))
os.mkdir(path)
except:
pass
valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits)
sys.stderr.write("%d sensors are different in datafiles\n" % len(different_sensors))
sys.stderr.write("All close with\nAbsolute tolerance: %.6f\nRelative tolerance: %.6f\n\n" % self.min_tol(ref_data, test_data))
sys.stderr.write(" ".join("%-20s" % s for s in ["Mean abs error", "Mean rel error (%)", "Max abs error", "Max rel error(%)", "Sensor"]) + "\n")
abs_err = np.abs(ref_data - test_data)
mask = (np.abs(ref_data) != 0)
rel_err = abs_err[mask] / np.abs(ref_data[mask]) * 100
err_str = ["%.6f" % e for e in [np.mean(abs_err), np.mean(rel_err), np.max(abs_err), np.max(rel_err)]]
err_str = " ".join(["%-20s" % e for e in err_str])
sys.stderr.write("%s All data values\n" % (err_str))
for i in different_sensors:
abs_err = np.abs(ref_data[:, i] - test_data[:, i])
mask = (np.abs(ref_data[:, i]) != 0)
rel_err = abs_err[mask] / np.abs(ref_data[mask, i]) * 100
err_str = ["%.6f" % e for e in [np.mean(abs_err), np.mean(rel_err), np.max(abs_err), np.max(rel_err)]]
err_str = " ".join(["%-20s" % e for e in err_str])
sys.stderr.write("%s %d %s [%s] %s\n" % (err_str, (i + 1), ref.ChInfo[0][i], ref.ChInfo[1][i], ref.ChInfo[2][i]))
sys.stderr.flush()
plt.cla()
plt.plot(ref_data[:, i], 'g', lw=3, label="Ref: %s [%s] %s" % (ref.ChInfo[0][i], ref.ChInfo[1][i], ref.ChInfo[2][i]))
plt.plot(test_data[:, i], 'r', lw=1, label="test: %s [%s] %s" % (test.ChInfo[0][i], test.ChInfo[1][i], test.ChInfo[2][i]))
from matplotlib.font_manager import FontProperties
fontP = FontProperties()
fontP.set_size('small')
plt.legend(loc='best', prop=fontP)
plt.axes().set_title(os.path.basename(ref_file))
if show_plot:
plt.show()
else:
plot_file = os.path.join(path, ("%03d_" % (i + 1)) + "".join([c for c in ref.ChInfo[0][i] if c in valid_chars]) + ".png")
plt.savefig(plot_file)
#raise AssertionError("Difference in the the values of:\n%s" % "\n".join(["%d %s" % (i + 1, ref.ChInfo[0][i]) for i in different_sensors]))
def version_tag(self, filename):
re_version = re.compile(r".*_(\d*\.\d*)\.sel")
match = re.match(re_version, filename)
if match and len(match.groups()) == 1:
return match.group(1)
return ""
def common_path(self, path1, path2):
cp = []
for f1, f2 in zip(os.path.realpath(path1).split(os.path.sep), os.path.realpath(path2).split(os.path.sep)):
if f1 == f2:
cp.append(f1)
else:
break
return os.path.sep.join(cp)
def compare_file(self, ref_file, test_file, show_plot=False, rtol=1.e-5, atol=1.e-8):
try:
assert os.path.isfile(test_file), "File '%s' not found" % test_file
try:
self.compare_sel(ref_file, test_file)
self.compare_dat_plot(ref_file.replace(".sel", ".dat"), test_file.replace(".sel", ".dat"), show_plot=show_plot, rtol=rtol, atol=atol)
print ("ok\n\n\n")
except AssertionError as e:
sys.stderr.write(str(e) + "\n")
self.compare_dat_plot(ref_file.replace(".sel", ".dat"), test_file.replace(".sel", ".dat"), show_plot=show_plot, rtol=rtol, atol=atol)
print ("Data file ok\n\n\n")
except AssertionError as e:
sys.stderr.write (str(e) + "\n\n")
def compare_folder(self, ref_res_path, test_res_path, ref_version_tag, test_version_tag, show_plot=False, rtol=1.e-5, atol=1.e-8):
files = [f for f in os.listdir(ref_res_path) if f.endswith(".sel")]
common_path = self.common_path(ref_res_path, test_res_path)
for filename in files:
self.version_tag(filename)
ref_file = os.path.join(ref_res_path, filename)
ref_version_tag = self.version_tag(filename)
test_version_tag = self.version_tag([f for f in os.listdir(test_res_path) if f.endswith(".sel")][0])
print ("-"*50)
try:
prefix = filename[:filename.index(ref_version_tag)]
postfix = filename[filename.index(ref_version_tag) + len(ref_version_tag):]
test_filename = [f for f in os.listdir(test_res_path) if f.startswith(prefix) and f.endswith(postfix)][0]
test_file = os.path.join(test_res_path, test_filename)
print ("Comparing %s and %s\n" % tuple(f.replace(common_path, "") for f in (ref_file, test_file)))
except IndexError:
sys.stdout.flush()
sys.stderr.write ("\nNo matching test file found for %s\n\n" % ref_file.replace(common_path, ""))
sys.stderr.flush()
continue
self.compare_file(ref_file, test_file, show_plot, rtol, atol)
def runTest(self):
pass
if __name__ == "__main__":
ref_path = r'S:\AED\HAWC2\HAWC2_release_test_cases\version_11.4\output\res/'
test_path = r'S:\AED\HAWC2\HAWC2_release_test_cases\version_11.8w\output\res/'
# rtol: relative tolerance
# atol: absolute tolerance
# absolute(`a` - `b`) <= (`atol` + `rtol` * absolute(`b`))
CompareTestCases().compare_folder(ref_path, test_path, "11.4", "11.8w", show_plot=False, rtol=1.e-5, atol=1.e-5)