forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_verification.py
115 lines (100 loc) · 3.72 KB
/
test_verification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# Owner(s): ["module: onnx"]
import numpy as np
import torch
from torch.onnx import _experimental, verification
from torch.testing._internal import common_utils
class TestVerification(common_utils.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
def test_check_export_model_diff_returns_diff_when_constant_mismatch(self):
class UnexportableModel(torch.nn.Module):
def forward(self, x, y):
# tensor.data() will be exported as a constant,
# leading to wrong model output under different inputs.
return x + y.data
test_input_groups = [
((torch.randn(2, 3), torch.randn(2, 3)), {}),
((torch.randn(2, 3), torch.randn(2, 3)), {}),
]
results = verification.check_export_model_diff(
UnexportableModel(), test_input_groups
)
self.assertRegex(
results,
r"Graph diff:(.|\n)*"
r"First diverging operator:(.|\n)*"
r"prim::Constant(.|\n)*"
r"Former source location:(.|\n)*"
r"Latter source location:",
)
def test_check_export_model_diff_returns_diff_when_dynamic_controlflow_mismatch(
self,
):
class UnexportableModel(torch.nn.Module):
def forward(self, x, y):
for i in range(x.size(0)):
y = x[i] + y
return y
test_input_groups = [
((torch.randn(2, 3), torch.randn(2, 3)), {}),
((torch.randn(4, 3), torch.randn(2, 3)), {}),
]
export_options = _experimental.ExportOptions(
input_names=["x", "y"], dynamic_axes={"x": [0]}
)
results = verification.check_export_model_diff(
UnexportableModel(), test_input_groups, export_options
)
self.assertRegex(
results,
r"Graph diff:(.|\n)*"
r"First diverging operator:(.|\n)*"
r"prim::Constant(.|\n)*"
r"Latter source location:(.|\n)*",
)
def test_check_export_model_diff_returns_empty_when_correct_export(self):
class SupportedModel(torch.nn.Module):
def forward(self, x, y):
return x + y
test_input_groups = [
((torch.randn(2, 3), torch.randn(2, 3)), {}),
((torch.randn(2, 3), torch.randn(2, 3)), {}),
]
results = verification.check_export_model_diff(
SupportedModel(), test_input_groups
)
self.assertEqual(results, "")
def test_compare_ort_pytorch_outputs_no_raise_with_acceptable_error_percentage(
self,
):
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
verification._compare_ort_pytorch_outputs(
ort_outs,
pytorch_outs,
rtol=1e-5,
atol=1e-6,
check_shape=True,
check_dtype=False,
ignore_none=True,
acceptable_error_percentage=0.3,
)
def test_compare_ort_pytorch_outputs_raise_without_acceptable_error_percentage(
self,
):
ort_outs = [np.array([[1.0, 2.0], [3.0, 4.0]])]
pytorch_outs = [torch.tensor([[1.0, 2.0], [3.0, 1.0]])]
with self.assertRaises(AssertionError):
verification._compare_ort_pytorch_outputs(
ort_outs,
pytorch_outs,
rtol=1e-5,
atol=1e-6,
check_shape=True,
check_dtype=False,
ignore_none=True,
acceptable_error_percentage=None,
)