-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathexport.py
126 lines (109 loc) · 5.19 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import os
import torch
cur_path = os.path.abspath(os.path.dirname(__file__))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export model")
# torch
parser.add_argument("--img_size", type=str, default="1,3,224,224", help="推理尺寸")
parser.add_argument("--weights", help="模型权重", required=True)
# torchscript
parser.add_argument("--torch2script", action="store_true", help="(可选)转为torchscript")
parser.add_argument("--script_gpu", action="store_true", help="(可选)导出GPU模型,默认CPU模型")
# onnx
parser.add_argument("--torch2onnx", action="store_true", help="(可选)转为onnx")
parser.add_argument("--dynamic", action="store_true", help="(可选)batch轴设为动态")
# tensorrt
parser.add_argument("--onnx2trt", action="store_true", help="(可选)转为tensorrt")
parser.add_argument("--trt_fp16", action="store_true", help="(可选)保存为fp16模型")
# openvino
parser.add_argument("--onnx2openvino", action="store_true", help="(可选)转为openvino")
# mnn
parser.add_argument("--onnx2mnn", action="store_true", help="(可选)转为mnn")
parser.add_argument("--mnn_fp16", action="store_true", help="(可选)保存为fp16模型")
cfg = parser.parse_args()
# ==========================torch===============================
cfg.img_size = [int(line) for line in cfg.img_size.split(",")]
imgs = torch.ones(tuple(cfg.img_size))
model = torch.load(cfg.weights, map_location="cpu") # 直接加载model,而非model.state_dict
model.eval()
output_torch = model(imgs).detach().numpy()
print("output shape is ", output_torch.shape)
# ==========================导出TorchScript===============================
if cfg.torch2script:
from Models.Backend.torchscript import ScriptBackend
if cfg.script_gpu and torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
script_weights = cfg.weights.split(".")[0] + ".torchscript"
ScriptBackend.convert(
model=model.to(device),
imgs=imgs.to(device),
weights=script_weights,
)
output_script = ScriptBackend.infer(
weights=script_weights, imgs=imgs.to(device)
)
output_script = output_script.detach().cpu().numpy()
# ==========================导出ONNX===============================
if cfg.torch2onnx:
from Models.Backend.onnx import OnnxBackend
onnx_weights = cfg.weights.split(".")[0] + ".onnx"
# torch转onnx
OnnxBackend.convert(
model=model,
imgs=imgs,
weights=onnx_weights,
dynamic=cfg.dynamic
)
output_onnx = OnnxBackend.infer(weights=onnx_weights, imgs=imgs.numpy())
# ==========================导出TensorRT===============================
if cfg.onnx2trt:
assert cfg.dynamic == False, "Warn: only supported fixed shapes"
assert os.path.exists(onnx_weights), "Warn: %s no exist" % onnx_weights
from Models.Backend.tensorrt import TensorrtBackend
trt_weights = onnx_weights.split(".")[0] + ".trt"
# onnx转tensorrt
TensorrtBackend.convert(
onnx_weights=onnx_weights,
trt_weights=trt_weights,
fp16=cfg.trt_fp16,
)
output_trt = TensorrtBackend.infer(
weights=trt_weights, imgs=imgs.numpy(), output_shape=output_onnx.shape
)
# ==========================导出OpenVINO===============================
if cfg.onnx2openvino:
assert cfg.dynamic == False, "Warn: only supported fixed shapes"
assert os.path.exists(onnx_weights), "Warn: %s no exist" % onnx_weights
openvino_weights = onnx_weights.split(".")[0] + "_openvino"
from Models.Backend.openvino import OpennVINOBackend
# onnx转openvino
OpennVINOBackend.convert(onnx_weights, openvino_weights)
output_openvino = OpennVINOBackend.infer(
weights=openvino_weights, imgs=imgs.numpy()
)
# ==========================导出MNN===============================
if cfg.onnx2mnn:
assert os.path.exists(onnx_weights), "Warn: %s no exist" % onnx_weights
mnn_weights = onnx_weights.split(".")[0] + ".mnn"
from Models.Backend.mnn import MNNBackbend
MNNBackbend.convert(onnx_weights, mnn_weights, fp16=cfg.mnn_fp16)
output_mnn = MNNBackbend.infer(
mnn_weights, imgs.numpy(), output_shape=output_onnx.shape
)
# ==========================验证结果===============================
print("\n", "*" * 28)
if cfg.torch2script:
print("output_torch - output_script = ", (output_torch - output_script).max())
if cfg.torch2onnx:
print("output_torch - output_onnx = ", (output_torch - output_onnx).max())
if cfg.onnx2trt:
print("output_torch - output_trt = ", (output_torch - output_trt).max())
if cfg.onnx2openvino:
print(
"output_torch - output_openvino = ", (output_torch - output_openvino).max()
)
if cfg.onnx2mnn:
print("output_torch - output_mnn = ", (output_torch - output_mnn).max())