-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhog_svm.py
175 lines (150 loc) · 6.33 KB
/
hog_svm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# -*- coding=utf-8 -*-
import glob
import platform
import time
from PIL import Image
from skimage.feature import hog
import numpy as np
import os
import joblib
from sklearn.svm import LinearSVC
import shutil
import sys
def get_image_list(filePath, nameList):
print('read image from ',filePath)
img_list = []
for name in nameList:
temp = Image.open(os.path.join(filePath,name))
img_list.append(temp.copy())
temp.close()
return img_list
# 提取特徵並保存
def get_feat(image_list, name_list,image_height,image_width, label_list, savePath):
i = 0
for image in image_list:
try:
# 如果是灰色圖片,把3改1
image = np.reshape(image, (image_height, image_width, 3))
except:
print('发送了异常,图片大小size不满足要求:',name_list[i])
continue
gray = rgb2gray(image) / 255.0
# 这句话根据你的尺寸改改
fd = hog(gray, orientations=12,block_norm='L1', pixels_per_cell=[8, 8], cells_per_block=[2, 2], visualize=False,
transform_sqrt=True)
fd = np.concatenate((fd, [label_list[i]]))
fd_name = name_list[i] + '.feat'
fd_path = os.path.join(savePath, fd_name)
joblib.dump(fd, fd_path)
i += 1
print("Test features are extracted and saved.")
# 變成灰色圖片
def rgb2gray(im):
gray = im[:, :, 0] * 0.2989 + im[:, :, 1] * 0.5870 + im[:, :, 2] * 0.1140
return gray
# 獲得圖片名稱與對應的圖片
def get_name_label(file_path):
print("read label from ",file_path)
name_list = []
label_list = []
with open(file_path) as f:
for line in f.readlines():
print(line)
#一般是name label 三部分,所以至少长度为3 所以可以通过这个忽略空白行
if len(line)>=3:
name_list.append(line.split(' ')[0])
label_list.append(line.split(' ')[1].replace('\n','').replace('\r',''))
if not str(label_list[-1]).isdigit():
print("label必须为数字,得到的是:",label_list[-1],"程序终止,请检查文件")
exit(1)
return name_list, label_list
# 萃取特徵
def extra_feat(train_label_path,test_label_path,train_image_path,test_image_path,train_feat_path,test_feat_path,image_height,image_width):
train_name, train_label = get_name_label(train_label_path)
test_name, test_label = get_name_label(test_label_path)
train_image = get_image_list(train_image_path, train_name)
test_image = get_image_list(test_image_path, test_name)
get_feat(train_image, train_name,image_height,image_width, train_label, train_feat_path)
get_feat(test_image, test_name,image_height,image_width, test_label, test_feat_path)
# 建立特徵資料夾
def mkdir(train_feat_path,test_feat_path):
if not os.path.exists(train_feat_path):
os.mkdir(train_feat_path)
if not os.path.exists(test_feat_path):
os.mkdir(test_feat_path)
# 訓練和測試
def train_and_test(label_map,train_feat_path,test_feat_path,model_path):
t0 = time.time()
features = []
labels = []
correct_number = 0
total = 0
for feat_path in glob.glob(os.path.join(train_feat_path, '*.feat')):
data = joblib.load(feat_path)
features.append(data[:-1])
labels.append(data[-1])
print("Training a Linear LinearSVM Classifier.")
clf = LinearSVC()
clf.fit(features, labels)
# 保存模型
if not os.path.exists(model_path):
os.makedirs(model_path)
joblib.dump(clf, model_path + 'model')
# 加载模型 可以註解上面程式碼直接进行加载模型,不進行訓練
# clf = joblib.load(model_path+'model')
print("训练之后的模型存放在model文件夹中")
# exit()
result_list = []
for feat_path in glob.glob(os.path.join(test_feat_path, '*.feat')):
total += 1
if platform.system() == 'Windows':
symbol = '\\'
else:
symbol = '/'
image_name = feat_path.split(symbol)[1].split('.feat')[0]
data_test = joblib.load(feat_path)
data_test_feat = data_test[:-1].reshape((1, -1)).astype(np.float64)
result = clf.predict(data_test_feat)
result_list.append(image_name + ' ' + label_map[int(result[0])] + '\n')
if int(result[0]) == int(data_test[-1]):
correct_number += 1
write_to_txt(result_list)
rate = float(correct_number) / total
t1 = time.time()
print('accuracy : %f' % rate)
print('spend time : %f' % (t1 - t0))
def write_to_txt(list):
with open('result.txt', 'w') as f:
f.write('圖片'+' '+'預測結果'+'\n')
f.writelines(list)
def main():
# all labels
label_map = {1:'elaina', 2:'mashiro', 3:'sakura_kyoku'}
# train dataset path
train_image_path = 'train_anime_img'
# test dataset path
test_image_path = 'test_anime_img'
train_label_path = os.path.join(train_image_path,'train.txt')
test_label_path = os.path.join(test_image_path,'train.txt')
image_height = 128
image_width = 100
train_feat_path = 'train/'
test_feat_path = 'test/'
model_path = 'model/'
mkdir(train_feat_path,test_feat_path) # 資料夾不存在則自動建立
# need_input = input('是否手动输入各个信息?y/n\n')
# if need_input == 'y':
# train_image_path = input('请输入训练图片文件夹的位置,如 /home/icelee/image\n')
# test_image_path = input('请输入测试图片文件夹的位置,如 /home/icelee/image\n')
# train_label_path = input('请输入训练集合标签的位置,如 /home/icelee/train.txt\n')
# test_label_path = input('请输入测试集合标签的位置,如 /home/icelee/test.txt\n')
# size = int(input('请输入您图片的大小:如64x64,则输入64\n'))
need_extra_feat = input('是否重新萃取特真?y/n\n')
if need_extra_feat == 'y':
shutil.rmtree(train_feat_path)
shutil.rmtree(test_feat_path)
mkdir(train_feat_path,test_feat_path)
extra_feat(train_label_path,test_label_path,train_image_path,test_image_path,train_feat_path,test_feat_path,image_height,image_width) #萃取特徵並保存到資料夾
train_and_test(label_map,train_feat_path,test_feat_path,model_path) # 訓練並預測
if __name__ == '__main__':
main()