-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcnn_inference.py
executable file
·95 lines (75 loc) · 3.3 KB
/
cnn_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import onnx
import numpy as np
import onnxruntime as rt
import cv2
import argparse
class CNNInference():
def __init__(self,model_path,img_size,classes_path) -> None:
self.model_path = model_path
self.img_size = img_size
self.classes_path = classes_path
model = onnx.load(model_path)
self.session = rt.InferenceSession(model.SerializeToString())
def get_image(self, path, show=False):
'''
Read the image and disply
path : input image path
show : display the image
'''
img = cv2.imread(path)
if show:
cv2.imshow("Frame",img)
cv2.waitKey(0)
return img
def preprocess(self, img, use_transform=False):
'''
Image Pre-processing steps for inference
img : image
use_transform : use trasfromation step
'''
img = img / 255.
img = cv2.resize(img, (self.img_size, self.img_size))
if use_transform:
img = (img - [0.485, 0.456, 0.406]) / [0.229, 0.224, 0.225]
else:
h, w = img.shape[0], img.shape[1]
y0 = (h - 224) // 2
x0 = (w - 224) // 2
img = img[y0 : y0+224, x0 : x0+224, :]
img = np.transpose(img, axes=[2,0, 1])
img = img.astype(np.float32)
img = np.expand_dims(img, axis=0)
return img
def get_classes(self):
'''
Read the class file and return class names as an array
'''
with open(self.classes_path) as f:
class_names = f.readlines()
class_names = [c.strip() for c in class_names]
return class_names
def predict(self, path,
show_img=False,
use_transform=False):
img = self.get_image(path, show=show_img)
img = self.preprocess(img, use_transform=use_transform)
inputs = {self.session.get_inputs()[0].name: img}
preds = self.session.run(None, inputs)[0]
preds = np.squeeze(preds)
a = np.argsort(preds)[::-1]
labels = self.get_classes()
print('Predicted Class : %s' %(labels[a[0]]))
if __name__ == '__main__':
'''
python cnn_inference.py --model_path=models/cats_vs_dogs/cats_vs_dogs_resnet18_exp_1.onnx --class_path=models/cats_vs_dogs/classes.txt --img_path=test1.jpg --image_size=224 --use_transform=True
'''
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='models/cats_vs_dogs/cats_vs_dogs_resnet18_exp_1.onnx', help='ONNX model path')
parser.add_argument('--class_path', type=str, default='models/cats_vs_dogs/classes.txt', help='Class file path which contain class names')
parser.add_argument('--img_path', type=str, default='test1.jpg', help='Input Image path')
parser.add_argument('--image_size', type=int, default=224, help='Input Image size (Used for the training)')
parser.add_argument('--show_image', type=bool, default=True, help='Display the image')
parser.add_argument('--use_transform', type=bool, default=True, help='Use image transforms in pre-processing step')
args = parser.parse_args()
cnn_infer = CNNInference(args.model_path,args.image_size,args.class_path)
cnn_infer.predict(args.img_path,show_img=args.show_image,use_transform=args.use_transform)