-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathview_results.py
153 lines (123 loc) · 5.39 KB
/
view_results.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import sys
from haven import haven_jupyter as hj
from haven import haven_results as hr
import pandas as pd
import os
import numpy as np
import pylab as plt
from haven import haven_utils as hu
from src import utils as ut
import argparse, exp_configs
from src import pretty_plot
def filter_exp_list(exp_list):
# Ignore the following combinations
exp_list_new = []
for exp_dict in exp_list:
if not ut.is_valid_exp(exp_dict):
continue
exp_list_new += [exp_dict]
return exp_list_new
def get_one_plot(exp_list, savedir_base, plot_names=None):
traceList = []
p_rules=np.unique([e['partition'] for e in exp_list])
s_rules =np.unique([e['selection'] for e in exp_list])
u_rules = np.unique([e['update'] for e in exp_list])
score_list_list = hr.get_score_lists(exp_list, savedir_base)
assert(len(exp_list)== len(score_list_list))
for exp_dict, score_list in zip(exp_list, score_list_list):
# single figure
score_df = pd.DataFrame(score_list)
if "converged" in score_df.columns:
ind = np.where(np.isnan(np.array(score_df["converged"])))[0][-1] + 1
converged = {"Y":score_df["converged"][ind],
"X":ind}
else:
converged = None
legend = ut.legendFunc(exp_dict['partition'], exp_dict['selection'], exp_dict['update'],
p_rules, s_rules, u_rules, plot_names=plot_names)
trace = {"Y":np.array(score_df["loss"]),
"X":np.array(score_df["iteration"]),
"legend":legend,
"converged":converged}
traceList += [trace]
return traceList
def get_dataset_plots(exp_list, plot_names, savedir_base):
figureList = []
loss_name = exp_list[0]['dataset']['loss']
dataset = exp_list[0]['dataset']['name'].upper()
xlabel = 'Iterations'
exp_list_list = hr.group_exp_list(exp_list, groupby_list=['block_size'])
exp_list_list.sort(key=lambda x:x[0]['block_size'])
# across blocks
for exp_list_bs in exp_list_list:
block_size = exp_list_bs[0]['block_size']
if block_size == -1:
xlabel = "Iterations"
else:
xlabel = "Iterations with |b|=%d" % block_size
trace_list = get_one_plot(exp_list_bs, savedir_base, plot_names=plot_names)
figureList += [{'traceList':trace_list,
"xlabel":xlabel,
"ylabel":("$f(x) - f^*$ for %s on Dataset %s" %
(loss_name, dataset)),
"yscale":"log"}]
return figureList
def plot_exp_list(exp_list, savedir_base, exp_name, outdir='figures', plot_names=None):
exp_list = filter_exp_list(exp_list)
exp_list_list = hr.group_exp_list(exp_list, groupby_list=['dataset'])
plotList = []
# across dataset
for exp_list_dataset in exp_list_list:
plotList += [get_dataset_plots(exp_list_dataset, plot_names, savedir_base)]
nrows = len(plotList)
ncols = len(plotList[0])
# Main plot
pp_main = pretty_plot.PrettyPlot(title=exp_name,
axFontSize=14,
axTickSize=11,
legend_size=8,
figsize=(5*ncols,4*nrows),
legend_type="line",
yscale="linear",
subplots=(nrows, ncols),
linewidth=1,
box_linewidth=1,
markersize=8,
y_axis_size=10,
x_axis_size=10,
shareRowLabel=True)
for rowi, row in enumerate(plotList):
for fi, figure in enumerate(row):
for trace in figure["traceList"]:
pp_main.add_yxList(y_vals=trace["Y"],
x_vals=trace["X"],
label=trace["legend"],
converged=trace["converged"])
pp_main.plot(ylabel=figure["ylabel"],
xlabel=figure["xlabel"],
yscale=figure["yscale"])
# SAVE THE WHOLE PLOT
if outdir is not None:
pp_main.fig.tight_layout(rect=[0, 0.03, 1, 0.95])
pp_main.fig.suptitle("")
fig_name = os.path.join(outdir, "%s.pdf" % (exp_name))
dirname = os.path.dirname(fig_name)
if dirname != '':
os.makedirs(dirname, exist_ok=True)
pp_main.fig.savefig(fig_name, dpi = 600)
return pp_main.fig
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--exp_group_list', nargs="+")
parser.add_argument('-sb', '--savedir_base', required=True)
args = parser.parse_args()
# Plot experiments
# ===================
exp_list = []
for exp_group_name in args.exp_group_list:
exp_list = exp_configs.EXP_GROUPS[exp_group_name]
fig = plot_exp_list(exp_list, args.savedir_base, outdir='docs',
exp_name=exp_group_name,
plot_names=exp_configs.PLOT_NAMES[exp_group_name])
plt.close()
print(exp_group_name, 'saved.')