Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delayed param #534

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions examples/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@

import torch
from pyroapi import distributions as dist
from pyroapi import infer, optim, pyro, pyro_backend
from pyroapi import infer, pyro, pyro_backend
from torch.distributions import constraints

import funsor
from funsor.montecarlo import MonteCarlo
from funsor.adam import Adam


def main(args):
Expand Down Expand Up @@ -43,20 +43,15 @@ def guide(data):

# Because the API in minipyro matches that of Pyro proper,
# training code works with generic Pyro implementations.
with pyro_backend(args.backend), MonteCarlo():
with pyro_backend(args.backend):
# Construct an SVI object so we can do variational inference on our
# model/guide pair.
Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
elbo = Elbo()
adam = optim.Adam({"lr": args.learning_rate})
adam = Adam(args.num_steps, lr=args.learning_rate, log_every=args.log_every)
svi = infer.SVI(model, guide, adam, elbo)

# Basic training loop
pyro.get_param_store().clear()
for step in range(args.num_steps):
loss = svi.step(data)
if args.verbose and step % 100 == 0:
print(f"step {step} loss = {loss}")
svi.run(data)

# Report the final values of the variational parameters
# in the guide after training.
Expand All @@ -75,7 +70,8 @@ def guide(data):
parser = argparse.ArgumentParser(description="Minipyro demo")
parser.add_argument("-b", "--backend", default="funsor")
parser.add_argument("-n", "--num-steps", default=1001, type=int)
parser.add_argument("-lr", "--learning-rate", default=0.02, type=float)
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
parser.add_argument("--log-every", type=int, default=20)
parser.add_argument("--jit", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
args = parser.parse_args()
Expand Down
8 changes: 2 additions & 6 deletions examples/talbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,11 @@ def main(args):
"rate": Tensor(torch.tensor(5.0, requires_grad=True)),
"nsteps": Tensor(torch.tensor(2.0, requires_grad=True)),
}
adam = Adam(args.num_steps, lr=args.learning_rate, log_every=args.log_every)

with Talbot(num_steps=args.talbot_num_steps):
# Fit the data
with Adam(
args.num_steps,
lr=args.learning_rate,
log_every=args.log_every,
params=init_params,
) as optim:
with adam.with_init(init_params) as optim:
loss.reduce(ops.min, {"rate", "nsteps"})
# Fitted curve.
fitted_curve = pred(rate=optim.param("rate"), nsteps=optim.param("nsteps"))
Expand Down
6 changes: 5 additions & 1 deletion funsor/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ def __init__(self, num_steps, **kwargs):
self.num_steps = num_steps
self.log_every = kwargs.pop("log_every", 0)
self.optim_params = kwargs # TODO make precise
self.params = kwargs.pop("params", {})
self.params = {}

def with_init(self, init_params):
self.params = init_params
return self

def param(self, name, domain=None):
if name not in self.params:
Expand Down
126 changes: 66 additions & 60 deletions funsor/minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

import torch
from pyro.distributions import validation_enabled
from pyro.optim.clipped_adam import ClippedAdam as _ClippedAdam

import funsor
from funsor.adam import Adam # noqa: F401
Copy link
Member Author

@ordabayevy ordabayevy Apr 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for compatibility with pyroapi



# Funsor repreresents distributions in a fundamentally different way from
Expand Down Expand Up @@ -227,13 +227,35 @@ def __enter__(self):
super(log_joint, self).__enter__()
self.log_factors = OrderedDict() # maps site name to log_prob factor
self.plates = set()
self.params = set()
return self

def process_message(self, msg):
if msg["type"] == "sample":
if msg["value"] is None:
# Create a delayed sample.
msg["value"] = funsor.Variable(msg["name"], msg["fn"].output)
elif msg["type"] == "param":
if msg["value"] is None:
# Create a delayed constrained parameter.
constraint = msg["args"][1]
batch_names = tuple(
frame.name for frame in msg["cond_indep_stack"].values()
)
batch_shape = tuple(
frame.size for frame in msg["cond_indep_stack"].values()
)
output = funsor.Reals[batch_shape + msg["output"].shape]
if constraint == torch.distributions.constraints.real:
msg["value"] = funsor.Variable(msg["name"], output)[batch_names]
else:
transform = torch.distributions.transform_to(constraint)
op = funsor.ops.WrappedTransformOp(transform)
# FIXME fix the message output
unconstrained = funsor.Variable(
msg["name"] + "_unconstrained", output
)[batch_names]
msg["value"] = op(unconstrained)

def postprocess_message(self, msg):
if msg["type"] == "sample":
Expand All @@ -243,6 +265,8 @@ def postprocess_message(self, msg):
log_prob = msg["fn"].log_prob(msg["value"])
self.log_factors[msg["name"]] = log_prob
self.plates.update(f.name for f in msg["cond_indep_stack"].values())
elif msg["type"] == "param":
self.params.update(msg["value"].inputs)


# apply_stack is called by pyro.sample and pyro.param.
Expand Down Expand Up @@ -306,11 +330,13 @@ def param(
event_dim=None,
):
cond_indep_stack = {}
output = None
if init_value is not None:
if event_dim is None:
event_dim = init_value.dim()
output = funsor.Reals[init_value.shape[init_value.dim() - event_dim :]]
# infer output
value = init_value
if value is None:
value, _ = PARAM_STORE[name]
if event_dim is None:
event_dim = value.dim()
output = funsor.Reals[value.shape[value.dim() - event_dim :]]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infer output when pyro.param was already defined elsewhere


def fn(init_value, constraint):
if name in PARAM_STORE:
Expand Down Expand Up @@ -362,40 +388,6 @@ def plate(name, size, dim):
return PlateMessenger(fn=None, name=name, size=size, dim=dim)


# This is a thin wrapper around the `torch.optim.Optimizer` class that
# dynamically generates optimizers for dynamically generated parameters.
# See http://docs.pyro.ai/en/0.3.1/optimization.html
class PyroOptim(object):
def __init__(self, optim_args):
self.optim_args = optim_args
# Each parameter will get its own optimizer, which we keep track
# of using this dictionary keyed on parameters.
self.optim_objs = {}

def __call__(self, params):
for param in params:
# If we've seen this parameter before, use the previously
# constructed optimizer.
if param in self.optim_objs:
optim = self.optim_objs[param]
# If we've never seen this parameter before, construct
# an Adam optimizer and keep track of it.
else:
optim = self.TorchOptimizer([param], **self.optim_args)
self.optim_objs[param] = optim
# Take a gradient step for the parameter param.
optim.step()


# We wrap some commonly used PyTorch optimizers.
class Adam(PyroOptim):
TorchOptimizer = torch.optim.Adam


class ClippedAdam(PyroOptim):
TorchOptimizer = _ClippedAdam


# This is a unified interface for stochastic variational inference in Pyro.
# The actual construction of the loss is taken care of by `loss`.
# See http://docs.pyro.ai/en/0.3.1/inference_algos.html
Expand All @@ -408,24 +400,33 @@ def __init__(self, model, guide, optim, loss):

# This method handles running the model and guide, constructing the loss
# function, and taking a gradient step.
def step(self, *args, **kwargs):
def run(self, *args, **kwargs):
# This wraps both the call to `model` and `guide` in a `trace` so that
# we can record all the parameters that are encountered. Note that
# further tracing occurs inside of `loss`.
with trace() as param_capture:
# We use block here to allow tracing to record parameters only.
with block(hide_fn=lambda msg: msg["type"] != "param"):
loss = self.loss(self.model, self.guide, *args, **kwargs)
# Differentiate the loss.
funsor.to_data(loss).backward()
# Grab all the parameters from the trace.
params = [site["value"].data.unconstrained() for site in param_capture.values()]
# Take a step w.r.t. each parameter in params.
self.optim(params)
# Zero out the gradients so that they don't accumulate.
for p in params:
p.grad = torch.zeros_like(p.grad)
return loss.item()
with funsor.terms.lazy:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lazy interpretation is needed here to make sure that funsor.Integrate is not eagerly expanded in Expectation

with trace() as param_capture:
# We use block here to allow tracing to record parameters only.
with block(hide_fn=lambda msg: msg["type"] != "param"):
loss = self.loss(self.model, self.guide, *args, **kwargs)
init_params = {
name: funsor.to_funsor(site["fn"](*site["args"]).data.unconstrained())
for name, site in param_capture.items()
}
with self.optim.with_init(init_params):
with funsor.montecarlo.MonteCarlo():
result = loss.reduce(funsor.ops.min)
for name, value in self.optim.params.items():
name = name.replace("_unconstrained", "")
value = value.data
old_value, constraint = PARAM_STORE[name]
if old_value is not value:
old_value.data.copy_(value)
return result

def step(self, *args, **kwargs):
self.optim.num_steps = 1
return self.run(*args, **kwargs)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for compatibility with SVI interface

Copy link
Member

@fritzo fritzo Apr 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, let's think about alternative workarounds... One issue here is that the Adam optimizer statistics would not be persisted across svi steps.

One option is simply to change pyroapi's SVI interface to look for either .run() or if missing fall back to .step(). Also I think it's more important to create a simple didactic example than to fastidiously conform to the pyroapi interface (since that interface hasn't seen much use).



# TODO(eb8680) Replace this with funsor.Expectation.
Expand Down Expand Up @@ -455,12 +456,14 @@ def elbo(model, guide, *args, **kwargs):
with log_joint() as model_log_joint:
model(*args, **kwargs)

params = model_log_joint.params | guide_log_joint.params
# contract out auxiliary variables in the guide
guide_log_probs = list(guide_log_joint.log_factors.values())
guide_aux_vars = (
frozenset().union(*(f.inputs for f in guide_log_probs))
- frozenset(guide_log_joint.plates)
- frozenset(model_log_joint.log_factors)
- frozenset(params)
)
if guide_aux_vars:
guide_log_probs = funsor.sum_product.partial_sum_product(
Expand All @@ -477,6 +480,7 @@ def elbo(model, guide, *args, **kwargs):
frozenset().union(*(f.inputs for f in model_log_probs))
- frozenset(model_log_joint.plates)
- frozenset(guide_log_joint.log_factors)
- frozenset(params)
)
if model_aux_vars:
model_log_probs = funsor.sum_product.partial_sum_product(
Expand Down Expand Up @@ -517,7 +521,7 @@ def elbo(model, guide, *args, **kwargs):
)

loss = -elbo
assert not loss.inputs
assert set(loss.inputs).issubset(params)
return loss


Expand All @@ -533,8 +537,7 @@ def __call__(self, model, guide, *args, **kwargs):
# This is a wrapper for compatibility with full Pyro.
class Trace_ELBO(ELBO):
def __call__(self, model, guide, *args, **kwargs):
with funsor.montecarlo.MonteCarlo():
return elbo(model, guide, *args, **kwargs)
return elbo(model, guide, *args, **kwargs)


class TraceMeanField_ELBO(ELBO):
Expand Down Expand Up @@ -571,7 +574,8 @@ def __call__(self, *args):

# Augment args with reads from the global param store.
unconstrained_params = tuple(
param(name).data.unconstrained() for name in self._param_trace
site["fn"](*site["args"]).data.unconstrained()
for site in self._param_trace.values()
)
params_and_args = unconstrained_params + args

Expand All @@ -587,7 +591,9 @@ def compiled(*params_and_args):
constrained_param = param(name) # assume param has been initialized
assert constrained_param.data.unconstrained() is unconstrained_param
self._param_trace[name]["value"] = constrained_param
result = replay(self.fn, guide_trace=self._param_trace)(*args)
with funsor.terms.eager:
with funsor.montecarlo.MonteCarlo():
result = replay(self.fn, guide_trace=self._param_trace)(*args)
assert not result.inputs
assert result.output == funsor.Real
return funsor.to_data(result)
Expand Down
40 changes: 29 additions & 11 deletions test/test_minipyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def model(data=None):
("pyro", "ClippedAdam"),
("minipyro", "Adam"),
("funsor", "Adam"),
("funsor", "ClippedAdam"),
],
)
def test_optimizer(backend, optim_name, jit):
Expand Down Expand Up @@ -384,8 +383,11 @@ def guide():
vectorize_particles=True,
strict_enumeration_warning=True,
)
elbo = elbo.differentiable_loss if backend == "pyro" else elbo
actual_loss = funsor.to_data(elbo(model, guide))
if backend == "pyro":
elbo = elbo.differentiable_loss
actual_loss = funsor.to_data(elbo(model, guide))
else:
actual_loss = funsor.to_data(elbo(model, guide)(q=q))
actual_loss.backward()
actual_grad = funsor.to_data(pyro.param("q")).grad

Expand Down Expand Up @@ -466,12 +468,20 @@ def hand_model():
def guide():
pass

params = {name: pyro.param(name) for name in pyro.get_param_store().keys()}
elbo = infer.TraceEnum_ELBO(max_plate_nesting=1)
elbo = elbo.differentiable_loss if backend == "pyro" else elbo
auto_loss = elbo(auto_model, guide)
if backend == "pyro":
elbo = elbo.differentiable_loss
auto_loss = elbo(auto_model, guide)
else:
auto_loss = elbo(auto_model, guide)(**params)

elbo = infer.TraceEnum_ELBO(max_plate_nesting=0)
elbo = elbo.differentiable_loss if backend == "pyro" else elbo
hand_loss = elbo(hand_model, guide)
if backend == "pyro":
elbo = elbo.differentiable_loss
hand_loss = elbo(hand_model, guide)
else:
hand_loss = elbo(hand_model, guide)(**params)
_check_loss_and_grads(hand_loss, auto_loss)


Expand Down Expand Up @@ -577,12 +587,20 @@ def hand_guide(data):
pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a]))

data = torch.tensor([0, 0])
params = {name: pyro.param(name) for name in pyro.get_param_store().keys()}
elbo = infer.TraceEnum_ELBO(max_plate_nesting=1)
elbo = elbo.differentiable_loss if backend == "pyro" else elbo
auto_loss = elbo(auto_model, auto_guide, data)
if backend == "pyro":
elbo = elbo.differentiable_loss
auto_loss = elbo(auto_model, auto_guide, data)
else:
auto_loss = elbo(auto_model, auto_guide, data)(**params)

elbo = infer.TraceEnum_ELBO(max_plate_nesting=0)
elbo = elbo.differentiable_loss if backend == "pyro" else elbo
hand_loss = elbo(hand_model, hand_guide, data)
if backend == "pyro":
elbo = elbo.differentiable_loss
hand_loss = elbo(hand_model, hand_guide, data)
else:
hand_loss = elbo(hand_model, hand_guide, data)(**params)
_check_loss_and_grads(hand_loss, auto_loss)


Expand Down