-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.c
110 lines (97 loc) · 2.3 KB
/
main.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// -----------------------------------------------------------------------------
//
//
// -----------------------------------------------------------------------------
#include <stdio.h>
#include "infercat.h"
#include "mnistReader.h"
#include "mnist_model.h"
// ...
const int32_t data_count = 10000;
const char* labels = "./dataset/t10k-labels-idx1-ubyte";
const char* images = "./dataset/t10k-images-idx3-ubyte";
// ...
int main(int argc, char const *argv[])
{
// ...
int32_t error_count = 0;
int32_t total_count = 0;
// ...
if(mnistReader_init(labels, images) != 0)
{
printf("\n");
printf("Dataset open problem ...\n");
return -1;
}
// ...
for(int32_t i=0;i<data_count;i++)
{
// Get another label / image pair from the MNIST database
uint8_t img_label;
uint8_t img_raw[784];
mnistReader_getByIndex(i, img_raw, &img_label);
// Normalise raw image to [0, 1]
float img[784];
for(int32_t j=0;j<784;j++)
{
img[j] = (float)(img_raw[j]) / 256.0;
}
// Inference happens here!
float* output;
int32_t output_size;
infercat_iterate(
img,
(InfercatLayer**)(mnist_model),
mnist_model_LAYERCOUNT,
&output, &output_size
);
// ...
printf("\n");
printf("IMAGE INDEX\n");
printf("===========\n");
printf("%d\n", i);
// ...
printf("\n");
printf("Label\n");
printf("-----\n");
printf("%d\n", img_label);
// ...
printf("\n");
printf("Network Output\n");
printf("--------------\n");
float maxValue = -1e19;
int32_t output_maxIndex = -1;
for(int32_t j=0;j<output_size;j++)
{
if(output[j] > maxValue)
{
output_maxIndex = j;
maxValue = output[j];
}
printf("[%d]: %f\n", j, output[j]);
}
// ...
total_count += 1;
if(output_maxIndex != img_label)
{
error_count += 1;
}
// ...
float error_rate = (float)error_count / (float)total_count;
// ...
printf("\n");
printf("Statistics\n");
printf("----------\n");
printf("Accuracy: %%%.3f (%d/%d)\n",
(100.0 * (1.0 - error_rate)),
total_count - error_count,
total_count
);
printf("Error rate: %%%.3f (%d/%d)\n",
(100.0 * error_rate), error_count, total_count
);
}
// ...
printf("\n");
return 0;
}