-
Notifications
You must be signed in to change notification settings - Fork 17
/
keras_learning_curve.py
139 lines (117 loc) · 5.37 KB
/
keras_learning_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from .plot_learning_curve import PlotLearningCurve
import tensorflow.keras as keras
class KerasLearningCurve(keras.callbacks.Callback):
"""Keras.callback interface to draw learning curve
This attempts to dynamically construct a learning curve plot
based on the keras model configuration. This depends on `metrics`
in `model.compile()`, `epochs` and `validation_data` in
`model.fit()`.
Example:
model.fit(x_train, y_train,
batch_size=128,
epochs=20,
validation_data=(x_test, y_test),
callbacks=[KerasLearningCurve()],
verbose=0)
Arguments:
draw_every: Only update the plot every `draw_interval` epoch. This
can be useful on a remote connection, where data transfer between
server and client might be slow.
**kwargs: forwarded to PlotLearningCurve, the defaults
are infered from the keras model configuration, or mappings
if mappings is defined.
"""
def __init__(self, draw_interval=1, **kwargs):
if not isinstance(draw_interval, int) or draw_interval <= 0:
raise ValueError('draw_interval must be a positive integer')
self._draw_interval = draw_interval
self._kwargs = kwargs
self._observed_metrics = set()
self._dynamic = True
self._plotter = None
if 'mappings' in self._kwargs:
self._dynamic = False
self._initialize_plotter()
def _infer_settings(self,
mappings=None,
line_config=None,
facet_config=None,
xaxis_config=None,
**kwargs):
# Dynamically set max_epoch, if not specified
if xaxis_config is None:
xaxis_config = dict()
if 'name' not in xaxis_config:
xaxis_config['name'] = 'Epoch'
if 'limit' not in xaxis_config:
xaxis_config['limit'] = [0, None]
if xaxis_config['limit'][1] is None:
xaxis_config['limit'][1] = self.params['epochs'] - 1
# Dynamically infer mappings
if mappings is None:
mappings = { key:dict() for key in self._observed_metrics }
for mapping_key, mapping_def in mappings.items():
infered_facet, infered_line = (mapping_key, 'train')
if mapping_key.startswith('val_'):
infered_facet, infered_line = (mapping_key[4:], 'validation')
if 'line' not in mapping_def:
mapping_def['line'] = infered_line
if 'facet' not in mapping_def:
mapping_def['facet'] = infered_facet
# Dynamically infer line_config
if line_config is None:
line_config = { mapping_def['line']:dict() for mapping_def in mappings.values() }
for line_key, line_def in line_config.items():
if line_key == 'train':
infered_name, infered_color = ('Train', '#F8766D')
elif line_key == 'validation':
infered_name, infered_color = ('Validation', '#00BFC4')
else:
infered_name, infered_color = (line_key, '#333333')
if 'name' not in line_def:
line_def['name'] = infered_name
if 'color' not in line_def:
line_def['color'] = infered_color
# Dynamically infer facet_config
if facet_config is None:
facet_config = { mapping_def['facet']:dict() for mapping_def in mappings.values() }
for facet_key, facet_def in facet_config.items():
if facet_key == 'loss':
infered_name, infered_limit, infered_scale = ('Loss', [None, None], 'log10')
elif facet_key in {'acc', 'accuracy', 'binary_accuracy', 'categorical_accuracy', 'sparse_categorical_accuracy'}:
infered_name, infered_limit, infered_scale = ('Accuracy', [0, 1], 'linear')
elif facet_key == 'lr':
infered_name, infered_limit, infered_scale = ('Learning Rate', [0, None], 'linear')
else:
infered_name, infered_limit, infered_scale = (facet_key, [None, None], 'linear')
if 'name' not in facet_def:
facet_def['name'] = infered_name
if 'limit' not in facet_def:
facet_def['limit'] = infered_limit
if 'scale' not in facet_def:
facet_def['scale'] = infered_scale
return {
'mappings': mappings,
'line_config': line_config,
'facet_config': facet_config,
'xaxis_config': xaxis_config,
**kwargs
}
def _initialize_plotter(self):
settings = self._infer_settings(**self._kwargs)
if self._plotter is None:
self._plotter = PlotLearningCurve(**settings)
else:
self._plotter.reconfigure(**settings)
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
if self._dynamic and len(logs.keys() - self._observed_metrics) > 0:
self._observed_metrics.update(logs.keys())
self._initialize_plotter()
self._plotter.append(epoch, logs)
# Update plot
if epoch % self._draw_interval == 0:
self._plotter.draw()
def on_train_end(self, logs=None):
if self._plotter is not None:
self._plotter.finalize()