-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathvgg16_lstm_train.py
34 lines (23 loc) · 1.26 KB
/
vgg16_lstm_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
from keras import backend as K
import sys
import os
def main():
K.set_image_dim_ordering('tf')
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from keras_video_classifier.library.utility.plot_utils import plot_and_save_history
from keras_video_classifier.library.recurrent_networks import VGG16LSTMVideoClassifier
from keras_video_classifier.library.utility.ucf.UCF101_loader import load_ucf
data_set_name = 'UCF-101'
input_dir_path = os.path.join(os.path.dirname(__file__), 'very_large_data')
output_dir_path = os.path.join(os.path.dirname(__file__), 'models', data_set_name)
report_dir_path = os.path.join(os.path.dirname(__file__), 'reports', data_set_name)
np.random.seed(42)
# this line downloads the video files of UCF-101 dataset if they are not available in the very_large_data folder
load_ucf(input_dir_path)
classifier = VGG16LSTMVideoClassifier()
history = classifier.fit(data_dir_path=input_dir_path, model_dir_path=output_dir_path, data_set_name=data_set_name)
plot_and_save_history(history, VGG16LSTMVideoClassifier.model_name,
report_dir_path + '/' + VGG16LSTMVideoClassifier.model_name + '-history.png')
if __name__ == '__main__':
main()