-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain_c.py
93 lines (72 loc) · 2.94 KB
/
main_c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -*- coding: utf-8 -*-
import keras as K
import keras.layers as L
import numpy as np
import os
import time
import h5py
import argparse
from data_util import *
from models_c import *
from ops import *
from keras.callbacks import ModelCheckpoint
from keras.callbacks import EarlyStopping
from keras.callbacks import TensorBoard
# save weights
_weights_m = "logs/weights/Hou_merge_weights.h5"
_weights = "logs/weights/Hou_weights_"+str(2*r+1)+".h5"
_TFBooard = 'logs/events/'
parser = argparse.ArgumentParser()
parser.add_argument('--modelname', type=str,
default='logs/weights/models.h5', help='final model save name')
parser.add_argument('--epochs',type=int,
default=30,help='number of epochs')
args = parser.parse_args()
if not os.path.exists('logs/weights/'):
os.makedirs('logs/weights/')
if not os.path.exists(_TFBooard):
# shutil.rmtree(_TFBooard)
os.mkdir(_TFBooard)
def train_merge(model):
# # create train data
creat_train(validation=False)
creat_train(validation=True)
# Xl_train = np.load('../file/train_Xl.npy')
Xm_train = np.load('./file/train_Xm.npy')
Y_train = K.utils.np_utils.to_categorical(np.load('./file/train_Y.npy'))
# Xl_val = np.load('../file/val_Xl.npy')
Xm_val = np.load('./file/val_Xm.npy')
Y_val = K.utils.np_utils.to_categorical(np.load('./file/val_Y.npy'))
model_ckt = ModelCheckpoint(filepath=_weights_m, verbose=1, save_best_only=True)
# if you need tensorboard while training phase just change train fit like
# TFBoard = TensorBoard(
# log_dir=_TFBooard, write_graph=True, write_images=False)
# model.fit([Xm_train, Xm_train[:, r, r, :, np.newaxis]], Y_train, batch_size=BATCH_SIZE, class_weight=cls_weights,
# epochs=args.epochs, callbacks=[model_ckt, TFBoard], validation_data=([Xh_val, Xh_val[:, r, r, :, np.newaxis]], Y_val))
model.fit([Xm_train, Xm_train[:, r, r, :, np.newaxis]], Y_train, batch_size=BATCH_SIZE, epochs=args.epochs,
callbacks=[model_ckt], validation_data=([Xm_val, Xm_val[:, r, r, :,np.newaxis]], Y_val))
scores = model.evaluate(
[Xm_val, Xm_val[:, r, r, :, np.newaxis]], Y_val, batch_size=100)
print('Test score:', scores[0])
print('Test accuracy:', scores[1])
model.save(args.modelname)
def test(network,mode=None):
if network == 'merge':
model = merge_branch()
model.load_weights(_weights_m)
Xm = make_cTest()
pred = model.predict([Xm,Xm[:,r,r ,:,np.newaxis]])
np.save('pred.npy',pred)
acc,kappa = cvt_map(pred,show=False)
print('acc: {:.2f}% Kappa: {:.4f}'.format(acc,kappa))
def main():
model = merge_branch()
imgname = 'merge_model.png'
visual_model(model, imgname)
train_merge(model)
start = time.time()
test('merge')
print('elapsed time:{:.2f}s'.format(time.time() - start))
#test phase
if __name__ == '__main__':
main()