forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexport_onnx_tests_generator.py
162 lines (141 loc) · 5.26 KB
/
export_onnx_tests_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import io
import os
import shutil
import traceback
import onnx
import onnx_test_common
import torch
from onnx import numpy_helper
from test_nn import new_module_tests
from torch.autograd import Variable
from torch.testing._internal.common_nn import module_tests
# Take a test case (a dict) as input, return the test name.
def get_test_name(testcase):
if "fullname" in testcase:
return "test_" + testcase["fullname"]
test_name = "test_" + testcase["constructor"].__name__
if "desc" in testcase:
test_name += "_" + testcase["desc"]
return test_name
# Take a test case (a dict) as input, return the input for the module.
def gen_input(testcase):
if "input_size" in testcase:
if (
testcase["input_size"] == ()
and "desc" in testcase
and testcase["desc"][-6:] == "scalar"
):
testcase["input_size"] = (1,)
return Variable(torch.randn(*testcase["input_size"]))
elif "input_fn" in testcase:
input = testcase["input_fn"]()
if isinstance(input, Variable):
return input
return Variable(testcase["input_fn"]())
def gen_module(testcase):
if "constructor_args" in testcase:
args = testcase["constructor_args"]
module = testcase["constructor"](*args)
module.train(False)
return module
module = testcase["constructor"]()
module.train(False)
return module
def print_stats(FunctionalModule_nums, nn_module):
print(f"{FunctionalModule_nums} functional modules detected.")
supported = []
unsupported = []
not_fully_supported = []
for key, value in nn_module.items():
if value == 1:
supported.append(key)
elif value == 2:
unsupported.append(key)
elif value == 3:
not_fully_supported.append(key)
def fun(info, l):
print(info)
for v in l:
print(v)
# Fully Supported Ops: All related test cases of these ops have been exported
# Semi-Supported Ops: Part of related test cases of these ops have been exported
# Unsupported Ops: None of related test cases of these ops have been exported
for info, l in [
[f"{len(supported)} Fully Supported Operators:", supported],
[
f"{len(not_fully_supported)} Semi-Supported Operators:",
not_fully_supported,
],
[f"{len(unsupported)} Unsupported Operators:", unsupported],
]:
fun(info, l)
def convert_tests(testcases, sets=1):
print(f"Collect {len(testcases)} test cases from PyTorch.")
failed = 0
FunctionalModule_nums = 0
nn_module = {}
for t in testcases:
test_name = get_test_name(t)
module = gen_module(t)
module_name = str(module).split("(")[0]
if module_name == "FunctionalModule":
FunctionalModule_nums += 1
else:
if module_name not in nn_module:
nn_module[module_name] = 0
try:
input = gen_input(t)
f = io.BytesIO()
torch.onnx._export(
module,
input,
f,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
onnx_model = onnx.load_from_string(f.getvalue())
onnx.checker.check_model(onnx_model)
onnx.helper.strip_doc_string(onnx_model)
output_dir = os.path.join(onnx_test_common.pytorch_converted_dir, test_name)
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
os.makedirs(output_dir)
with open(os.path.join(output_dir, "model.onnx"), "wb") as file:
file.write(onnx_model.SerializeToString())
for i in range(sets):
output = module(input)
data_dir = os.path.join(output_dir, f"test_data_set_{i}")
os.makedirs(data_dir)
for index, var in enumerate([input]):
tensor = numpy_helper.from_array(var.data.numpy())
with open(
os.path.join(data_dir, f"input_{index}.pb"), "wb"
) as file:
file.write(tensor.SerializeToString())
for index, var in enumerate([output]):
tensor = numpy_helper.from_array(var.data.numpy())
with open(
os.path.join(data_dir, f"output_{index}.pb"), "wb"
) as file:
file.write(tensor.SerializeToString())
input = gen_input(t)
if module_name != "FunctionalModule":
nn_module[module_name] |= 1
except: # noqa: E722,B001
traceback.print_exc()
if module_name != "FunctionalModule":
nn_module[module_name] |= 2
failed += 1
print(
"Collect {} test cases from PyTorch repo, failed to export {} cases.".format(
len(testcases), failed
)
)
print(
"PyTorch converted cases are stored in {}.".format(
onnx_test_common.pytorch_converted_dir
)
)
print_stats(FunctionalModule_nums, nn_module)
if __name__ == "__main__":
testcases = module_tests + new_module_tests
convert_tests(testcases)