-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
metric.py
50 lines (44 loc) · 1.69 KB
/
metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
import mxnet as mx
class AccMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(AccMetric, self).__init__('acc',
axis=self.axis,
output_names=None,
label_names=None)
self.losses = []
self.count = 0
def update(self, labels, preds):
self.count += 1
label = labels[0]
pred_label = preds[1]
#print('ACC', label.shape, pred_label.shape)
if pred_label.shape != label.shape:
pred_label = mx.ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32').flatten()
label = label.asnumpy()
if label.ndim == 2:
label = label[:, 0]
label = label.astype('int32').flatten()
assert label.shape == pred_label.shape
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
class LossValueMetric(mx.metric.EvalMetric):
def __init__(self):
self.axis = 1
super(LossValueMetric, self).__init__('lossvalue',
axis=self.axis,
output_names=None,
label_names=None)
self.losses = []
def update(self, labels, preds):
#label = labels[0].asnumpy()
pred = preds[-1].asnumpy()
#print('in loss', pred.shape)
#print(pred)
loss = pred[0]
self.sum_metric += loss
self.num_inst += 1.0
#gt_label = preds[-2].asnumpy()
#print(gt_label)