-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathonnx_inference.py
87 lines (69 loc) · 2.69 KB
/
onnx_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import argparse
import os
import cv2
import time
import numpy as np
import sys
sys.path.append('../../')
import onnxruntime
from utils.misc import PreProcessor, PostProcessor
from utils.vis_tools import visualize
def make_parser():
parser = argparse.ArgumentParser("onnxruntime inference sample")
parser.add_argument("-m", "--model", type=str, default="../../weights/onnx/11/yolov1.onnx",
help="Input your onnx model.")
parser.add_argument("-i", "--image_path", type=str, default='../test_image.jpg',
help="Path to your input image.")
parser.add_argument("-o", "--output_dir", type=str, default='../../det_results/onnx/',
help="Path to your output directory.")
parser.add_argument("-s", "--score_thr", type=float, default=0.35,
help="Score threshould to filter the result.")
parser.add_argument("-size", "--img_size", type=int, default=640,
help="Specify an input shape for inference.")
return parser
if __name__ == '__main__':
args = make_parser().parse_args()
# class color for better visualization
np.random.seed(0)
class_colors = [(np.random.randint(255),
np.random.randint(255),
np.random.randint(255)) for _ in range(80)]
# preprocessor
prepocess = PreProcessor(img_size=args.img_size)
# postprocessor
postprocess = PostProcessor(num_classes=80, conf_thresh=args.score_thr, nms_thresh=0.5)
# read an image
input_shape = tuple([args.img_size, args.img_size])
origin_img = cv2.imread(args.image_path)
# preprocess
x, ratio = prepocess(origin_img)
t0 = time.time()
# inference
session = onnxruntime.InferenceSession(args.model)
ort_inputs = {session.get_inputs()[0].name: x[None, :, :, :]}
output = session.run(None, ort_inputs)
print("inference time: {:.1f} ms".format((time.time() - t0)*1000))
t0 = time.time()
# post process
bboxes, scores, labels = postprocess(output[0])
bboxes /= ratio
print("post-process time: {:.1f} ms".format((time.time() - t0)*1000))
# visualize detection
origin_img = visualize(
img=origin_img,
bboxes=bboxes,
scores=scores,
labels=labels,
vis_thresh=args.score_thr,
class_colors=class_colors
)
# show
cv2.imshow('onnx detection', origin_img)
cv2.waitKey(0)
# save results
os.makedirs(args.output_dir, exist_ok=True)
output_path = os.path.join(args.output_dir, os.path.basename(args.image_path))
cv2.imwrite(output_path, origin_img)