-
Notifications
You must be signed in to change notification settings - Fork 2
/
vgg16_bidirectional_lstm_hi_dim_predict.py
47 lines (33 loc) · 1.8 KB
/
vgg16_bidirectional_lstm_hi_dim_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import sys
import os
def main():
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from keras_video_classifier.library.recurrent_networks import VGG16BidirectionalLSTMVideoClassifier
from keras_video_classifier.library.utility.ucf.UCF101_loader import load_ucf, scan_ucf_with_labels
vgg16_include_top = False
data_dir_path = os.path.join(os.path.dirname(__file__), 'very_large_data')
model_dir_path = os.path.join(os.path.dirname(__file__), 'models/UCF-101')
config_file_path = VGG16BidirectionalLSTMVideoClassifier.get_config_file_path(model_dir_path,
vgg16_include_top=vgg16_include_top)
weight_file_path = VGG16BidirectionalLSTMVideoClassifier.get_weight_file_path(model_dir_path,
vgg16_include_top=vgg16_include_top)
np.random.seed(42)
load_ucf(data_dir_path)
predictor = VGG16BidirectionalLSTMVideoClassifier()
predictor.load_model(config_file_path, weight_file_path)
videos = scan_ucf_with_labels(data_dir_path, [label for (label, label_index) in predictor.labels.items()])
video_file_path_list = np.array([file_path for file_path in videos.keys()])
np.random.shuffle(video_file_path_list)
correct_count = 0
count = 0
for video_file_path in video_file_path_list:
label = videos[video_file_path]
predicted_label = predictor.predict(video_file_path)
print('predicted: ' + predicted_label + ' actual: ' + label)
correct_count = correct_count + 1 if label == predicted_label else correct_count
count += 1
accuracy = correct_count / count
print('accuracy: ', accuracy)
if __name__ == '__main__':
main()