-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathparse_logs.py
68 lines (55 loc) · 1.65 KB
/
parse_logs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def parse_txt(path):
lines = open(path,'r').read().split('\n')
acc = lines[0].split(' ')[-1]
acc = float(acc)
return acc
def parse_logs(which):
print(which)
res = []
# labels = [10,50,100,250,500,750,1000]
# seeds = [10,42,1337,2019]
labels = [10,250,1000]
seeds = [42,2019]
df = pd.DataFrame(index=labels,columns=seeds)
df.index.names = ['labels']
for lab in labels:
for seed in seeds:
try:
acc = parse_txt('metrics/%s_lab_cifar10_%d_seed%d.txt'%(which,lab,seed))
except:
acc = -1
df[seed][lab] = acc
# res = np.array(res).reshape(-1,4)
# print(res)
df['mean'] = df.mean(axis=1)
df['std'] = df.std(axis=1)
print(df)
df.to_csv('metrics/%s_results.csv'%(which))
def plot_graphs():
which = 'ssl'
df = pd.read_csv('metrics/%s_results.csv'%(which),header=0, index_col='labels')
mean_ssl = df['mean'].values
std_ssl = df['std'].values
which = 'sup'
df = pd.read_csv('metrics/%s_results.csv'%(which),header=0, index_col='labels')
mean_sup = df['mean'].values
std_sup = df['std'].values
x = df.index.values
plt.plot(x,mean_ssl,label='semi_supervised_gan',color='red')
plt.plot(x,mean_sup,label='supervised',color='blue')
# plt.errorbar(x, mean_ssl, yerr=std_ssl, fmt='-', label='semi_supervised_gan',color='red')
# plt.errorbar(x, mean_sup, yerr=std_sup, fmt='-', label='supervised',color='blue')
plt.legend()
plt.grid()
plt.xlabel('number of labelled samples')
plt.ylabel('accuracy')
plt.savefig('graphs/cifar_ssl_sup_compare.png')
if __name__ == '__main__':
parse_logs('sup')
parse_logs('ssl')
plot_graphs()