-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsvi_horovod.py
141 lines (120 loc) · 5.46 KB
/
svi_horovod.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import torch
import torch.multiprocessing as mp
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.nn import PyroModule
from pyro.optim import Adam, HorovodOptimizer
# We define a model as usual, with no reference to Horovod.
# This model is data parallel and supports subsampling.
class Model(PyroModule):
def __init__(self, size):
super().__init__()
self.size = size
def forward(self, covariates, data=None):
coeff = pyro.sample("coeff", dist.Normal(0, 1))
bias = pyro.sample("bias", dist.Normal(0, 1))
scale = pyro.sample("scale", dist.LogNormal(0, 1))
# Since we'll use a distributed dataloader during training, we need to
# manually pass minibatches of (covariates,data) that are smaller than
# the full self.size. In particular we cannot rely on pyro.plate to
# automatically subsample, since that would lead to all workers drawing
# identical subsamples.
with pyro.plate("data", self.size, len(covariates)):
loc = bias + coeff * covariates
return pyro.sample("obs", dist.Normal(loc, scale),
obs=data)
# The following is a standard training loop. To emphasize the Horovod-specific
# parts, we've guarded them by `if args.horovod:`.
def main(args):
# Create a model, synthetic data, and a guide.
pyro.set_rng_seed(args.seed)
model = Model(args.size)
covariates = torch.randn(args.size)
data = model(covariates)
guide = AutoNormal(model)
if args.horovod:
# Initialize Horovod and set PyTorch globals.
import horovod.torch as hvd
hvd.init()
torch.set_num_threads(1)
if args.cuda:
torch.cuda.set_device(hvd.local_rank())
if args.cuda:
torch.set_default_tensor_type("torch.cuda.FloatTensor")
device = torch.tensor(0).device
if args.horovod:
# Initialize parameters and broadcast to all workers.
guide(covariates[:1], data[:1]) # Initializes model and guide.
hvd.broadcast_parameters(guide.state_dict(), root_rank=0)
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
# Create an ELBO loss and a Pyro optimizer.
elbo = Trace_ELBO()
optim = Adam({"lr": args.learning_rate})
if args.horovod:
# Wrap the basic optimizer in a distributed optimizer.
optim = HorovodOptimizer(optim)
# Create a dataloader.
dataset = torch.utils.data.TensorDataset(covariates, data)
import socket
print("Running on", socket.gethostname())
if args.horovod:
# Horovod requires a distributed sampler.
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, hvd.size(), hvd.rank())
else:
sampler = torch.utils.data.RandomSampler(dataset)
config = {"batch_size": args.batch_size, "sampler": sampler}
if args.cuda:
config["num_workers"] = 1
config["pin_memory"] = True
# Try to use forkserver to spawn workers instead of fork.
if (hasattr(mp, "_supports_context") and mp._supports_context and
"forkserver" in mp.get_all_start_methods()):
config["multiprocessing_context"] = "forkserver"
dataloader = torch.utils.data.DataLoader(dataset, **config)
# Run stochastic variational inference.
svi = SVI(model, guide, optim, elbo)
for epoch in range(args.num_epochs):
if args.horovod:
# Set rng seeds on distributed samplers. This is required.
sampler.set_epoch(epoch)
for step, (covariates_batch, data_batch) in enumerate(dataloader):
loss = svi.step(covariates_batch.to(device), data_batch.to(device))
if args.horovod:
# Optionally average loss metric across workers.
# You can do this with arbitrary torch.Tensors.
loss = torch.tensor(loss)
loss = hvd.allreduce(loss, "loss")
loss = loss.item()
# Print only on the rank=0 worker.
if step % 100 == 0 and hvd.rank() == 0:
print("epoch {} step {} loss = {:0.4g}".format(epoch, step, loss))
else:
if step % 100 == 0:
print("epoch {} step {} loss = {:0.4g}".format(epoch, step, loss))
if args.horovod:
# After we're done with the distributed parts of the program,
# we can shutdown all but the rank=0 worker.
hvd.shutdown()
if hvd.rank() != 0:
return
if args.outfile:
print("saving to {}".format(args.outfile))
torch.save({"model": model, "guide": guide}, args.outfile)
if __name__ == "__main__":
assert pyro.__version__.startswith('1.6.0')
parser = argparse.ArgumentParser(description="Distributed training via Horovod")
parser.add_argument("-o", "--outfile")
parser.add_argument("-s", "--size", default=1000, type=int)
parser.add_argument("-b", "--batch-size", default=16, type=int)
parser.add_argument("-n", "--num-epochs", default=10, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
parser.add_argument("--cuda", action="store_true")
parser.add_argument("--horovod", action="store_true", default=True)
parser.add_argument("--no-horovod", action="store_false", dest="horovod")
parser.add_argument("--seed", default=20200723, type=int)
args = parser.parse_args()
main(args)