forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlazy_dyndep_test.py
133 lines (100 loc) · 3.82 KB
/
lazy_dyndep_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
from hypothesis import given, settings
import hypothesis.strategies as st
from multiprocessing import Process
import numpy as np
import tempfile
import shutil
import caffe2.python.hypothesis_test_util as hu
import unittest
op_engine = 'GLOO'
class TemporaryDirectory:
def __enter__(self):
self.tmpdir = tempfile.mkdtemp()
return self.tmpdir
def __exit__(self, type, value, traceback):
shutil.rmtree(self.tmpdir)
def allcompare_process(filestore_dir, process_id, data, num_procs):
from caffe2.python import core, data_parallel_model, workspace, lazy_dyndep
from caffe2.python.model_helper import ModelHelper
from caffe2.proto import caffe2_pb2
lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
workspace.RunOperatorOnce(
core.CreateOperator(
"FileStoreHandlerCreate", [], ["store_handler"], path=filestore_dir
)
)
rendezvous = dict(
kv_handler="store_handler",
shard_id=process_id,
num_shards=num_procs,
engine=op_engine,
exit_nets=None
)
model = ModelHelper()
model._rendezvous = rendezvous
workspace.FeedBlob("test_data", data)
data_parallel_model._RunComparison(
model, "test_data", core.DeviceOption(caffe2_pb2.CPU, 0)
)
class TestLazyDynDepAllCompare(hu.HypothesisTestCase):
@given(
d=st.integers(1, 5), n=st.integers(2, 11), num_procs=st.integers(1, 8)
)
@settings(deadline=None)
def test_allcompare(self, d, n, num_procs):
dims = []
for _ in range(d):
dims.append(np.random.randint(1, high=n))
test_data = np.random.ranf(size=tuple(dims)).astype(np.float32)
with TemporaryDirectory() as tempdir:
processes = []
for idx in range(num_procs):
process = Process(
target=allcompare_process,
args=(tempdir, idx, test_data, num_procs)
)
processes.append(process)
process.start()
while len(processes) > 0:
process = processes.pop()
process.join()
class TestLazyDynDepError(unittest.TestCase):
def test_errorhandler(self):
from caffe2.python import core, lazy_dyndep
import tempfile
with tempfile.NamedTemporaryFile() as f:
lazy_dyndep.RegisterOpsLibrary(f.name)
def handler(e):
raise ValueError("test")
lazy_dyndep.SetErrorHandler(handler)
with self.assertRaises(ValueError, msg="test"):
core.RefreshRegisteredOperators()
def test_importaftererror(self):
from caffe2.python import core, lazy_dyndep
import tempfile
with tempfile.NamedTemporaryFile() as f:
lazy_dyndep.RegisterOpsLibrary(f.name)
def handler(e):
raise ValueError("test")
lazy_dyndep.SetErrorHandler(handler)
with self.assertRaises(ValueError):
core.RefreshRegisteredOperators()
def handlernoop(e):
raise
lazy_dyndep.SetErrorHandler(handlernoop)
lazy_dyndep.RegisterOpsLibrary("@/caffe2/caffe2/distributed:file_store_handler_ops")
core.RefreshRegisteredOperators()
def test_workspacecreatenet(self):
from caffe2.python import workspace, lazy_dyndep
import tempfile
with tempfile.NamedTemporaryFile() as f:
lazy_dyndep.RegisterOpsLibrary(f.name)
called = False
def handler(e):
raise ValueError("test")
lazy_dyndep.SetErrorHandler(handler)
with self.assertRaises(ValueError, msg="test"):
workspace.CreateNet("fake")
if __name__ == "__main__":
unittest.main()