diff --git a/examples/minipyro.py b/examples/minipyro.py index 4b3ca061..55b03ea9 100644 --- a/examples/minipyro.py +++ b/examples/minipyro.py @@ -11,11 +11,11 @@ import torch from pyroapi import distributions as dist -from pyroapi import infer, optim, pyro, pyro_backend +from pyroapi import infer, pyro, pyro_backend from torch.distributions import constraints import funsor -from funsor.montecarlo import MonteCarlo +from funsor.adam import Adam def main(args): @@ -43,20 +43,15 @@ def guide(data): # Because the API in minipyro matches that of Pyro proper, # training code works with generic Pyro implementations. - with pyro_backend(args.backend), MonteCarlo(): + with pyro_backend(args.backend): # Construct an SVI object so we can do variational inference on our # model/guide pair. Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO elbo = Elbo() - adam = optim.Adam({"lr": args.learning_rate}) + adam = Adam(args.num_steps, lr=args.learning_rate, log_every=args.log_every) svi = infer.SVI(model, guide, adam, elbo) - - # Basic training loop pyro.get_param_store().clear() - for step in range(args.num_steps): - loss = svi.step(data) - if args.verbose and step % 100 == 0: - print(f"step {step} loss = {loss}") + svi.run(data) # Report the final values of the variational parameters # in the guide after training. @@ -75,7 +70,8 @@ def guide(data): parser = argparse.ArgumentParser(description="Minipyro demo") parser.add_argument("-b", "--backend", default="funsor") parser.add_argument("-n", "--num-steps", default=1001, type=int) - parser.add_argument("-lr", "--learning-rate", default=0.02, type=float) + parser.add_argument("-lr", "--learning-rate", default=0.01, type=float) + parser.add_argument("--log-every", type=int, default=20) parser.add_argument("--jit", action="store_true") parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() diff --git a/examples/talbot.py b/examples/talbot.py index b35c9609..485a5534 100644 --- a/examples/talbot.py +++ b/examples/talbot.py @@ -132,15 +132,11 @@ def main(args): "rate": Tensor(torch.tensor(5.0, requires_grad=True)), "nsteps": Tensor(torch.tensor(2.0, requires_grad=True)), } + adam = Adam(args.num_steps, lr=args.learning_rate, log_every=args.log_every) with Talbot(num_steps=args.talbot_num_steps): # Fit the data - with Adam( - args.num_steps, - lr=args.learning_rate, - log_every=args.log_every, - params=init_params, - ) as optim: + with adam.with_init(init_params) as optim: loss.reduce(ops.min, {"rate", "nsteps"}) # Fitted curve. fitted_curve = pred(rate=optim.param("rate"), nsteps=optim.param("nsteps")) diff --git a/funsor/adam.py b/funsor/adam.py index 2e361c0e..8cbf7869 100644 --- a/funsor/adam.py +++ b/funsor/adam.py @@ -26,7 +26,11 @@ def __init__(self, num_steps, **kwargs): self.num_steps = num_steps self.log_every = kwargs.pop("log_every", 0) self.optim_params = kwargs # TODO make precise - self.params = kwargs.pop("params", {}) + self.params = {} + + def with_init(self, init_params): + self.params = init_params + return self def param(self, name, domain=None): if name not in self.params: diff --git a/funsor/minipyro.py b/funsor/minipyro.py index b6fc1812..ad220802 100644 --- a/funsor/minipyro.py +++ b/funsor/minipyro.py @@ -21,9 +21,9 @@ import torch from pyro.distributions import validation_enabled -from pyro.optim.clipped_adam import ClippedAdam as _ClippedAdam import funsor +from funsor.adam import Adam # noqa: F401 # Funsor repreresents distributions in a fundamentally different way from @@ -227,6 +227,7 @@ def __enter__(self): super(log_joint, self).__enter__() self.log_factors = OrderedDict() # maps site name to log_prob factor self.plates = set() + self.params = set() return self def process_message(self, msg): @@ -234,6 +235,27 @@ def process_message(self, msg): if msg["value"] is None: # Create a delayed sample. msg["value"] = funsor.Variable(msg["name"], msg["fn"].output) + elif msg["type"] == "param": + if msg["value"] is None: + # Create a delayed constrained parameter. + constraint = msg["args"][1] + batch_names = tuple( + frame.name for frame in msg["cond_indep_stack"].values() + ) + batch_shape = tuple( + frame.size for frame in msg["cond_indep_stack"].values() + ) + output = funsor.Reals[batch_shape + msg["output"].shape] + if constraint == torch.distributions.constraints.real: + msg["value"] = funsor.Variable(msg["name"], output)[batch_names] + else: + transform = torch.distributions.transform_to(constraint) + op = funsor.ops.WrappedTransformOp(transform) + # FIXME fix the message output + unconstrained = funsor.Variable( + msg["name"] + "_unconstrained", output + )[batch_names] + msg["value"] = op(unconstrained) def postprocess_message(self, msg): if msg["type"] == "sample": @@ -243,6 +265,8 @@ def postprocess_message(self, msg): log_prob = msg["fn"].log_prob(msg["value"]) self.log_factors[msg["name"]] = log_prob self.plates.update(f.name for f in msg["cond_indep_stack"].values()) + elif msg["type"] == "param": + self.params.update(msg["value"].inputs) # apply_stack is called by pyro.sample and pyro.param. @@ -306,11 +330,13 @@ def param( event_dim=None, ): cond_indep_stack = {} - output = None - if init_value is not None: - if event_dim is None: - event_dim = init_value.dim() - output = funsor.Reals[init_value.shape[init_value.dim() - event_dim :]] + # infer output + value = init_value + if value is None: + value, _ = PARAM_STORE[name] + if event_dim is None: + event_dim = value.dim() + output = funsor.Reals[value.shape[value.dim() - event_dim :]] def fn(init_value, constraint): if name in PARAM_STORE: @@ -362,40 +388,6 @@ def plate(name, size, dim): return PlateMessenger(fn=None, name=name, size=size, dim=dim) -# This is a thin wrapper around the `torch.optim.Optimizer` class that -# dynamically generates optimizers for dynamically generated parameters. -# See http://docs.pyro.ai/en/0.3.1/optimization.html -class PyroOptim(object): - def __init__(self, optim_args): - self.optim_args = optim_args - # Each parameter will get its own optimizer, which we keep track - # of using this dictionary keyed on parameters. - self.optim_objs = {} - - def __call__(self, params): - for param in params: - # If we've seen this parameter before, use the previously - # constructed optimizer. - if param in self.optim_objs: - optim = self.optim_objs[param] - # If we've never seen this parameter before, construct - # an Adam optimizer and keep track of it. - else: - optim = self.TorchOptimizer([param], **self.optim_args) - self.optim_objs[param] = optim - # Take a gradient step for the parameter param. - optim.step() - - -# We wrap some commonly used PyTorch optimizers. -class Adam(PyroOptim): - TorchOptimizer = torch.optim.Adam - - -class ClippedAdam(PyroOptim): - TorchOptimizer = _ClippedAdam - - # This is a unified interface for stochastic variational inference in Pyro. # The actual construction of the loss is taken care of by `loss`. # See http://docs.pyro.ai/en/0.3.1/inference_algos.html @@ -408,24 +400,33 @@ def __init__(self, model, guide, optim, loss): # This method handles running the model and guide, constructing the loss # function, and taking a gradient step. - def step(self, *args, **kwargs): + def run(self, *args, **kwargs): # This wraps both the call to `model` and `guide` in a `trace` so that # we can record all the parameters that are encountered. Note that # further tracing occurs inside of `loss`. - with trace() as param_capture: - # We use block here to allow tracing to record parameters only. - with block(hide_fn=lambda msg: msg["type"] != "param"): - loss = self.loss(self.model, self.guide, *args, **kwargs) - # Differentiate the loss. - funsor.to_data(loss).backward() - # Grab all the parameters from the trace. - params = [site["value"].data.unconstrained() for site in param_capture.values()] - # Take a step w.r.t. each parameter in params. - self.optim(params) - # Zero out the gradients so that they don't accumulate. - for p in params: - p.grad = torch.zeros_like(p.grad) - return loss.item() + with funsor.terms.lazy: + with trace() as param_capture: + # We use block here to allow tracing to record parameters only. + with block(hide_fn=lambda msg: msg["type"] != "param"): + loss = self.loss(self.model, self.guide, *args, **kwargs) + init_params = { + name: funsor.to_funsor(site["fn"](*site["args"]).data.unconstrained()) + for name, site in param_capture.items() + } + with self.optim.with_init(init_params): + with funsor.montecarlo.MonteCarlo(): + result = loss.reduce(funsor.ops.min) + for name, value in self.optim.params.items(): + name = name.replace("_unconstrained", "") + value = value.data + old_value, constraint = PARAM_STORE[name] + if old_value is not value: + old_value.data.copy_(value) + return result + + def step(self, *args, **kwargs): + self.optim.num_steps = 1 + return self.run(*args, **kwargs) # TODO(eb8680) Replace this with funsor.Expectation. @@ -455,12 +456,14 @@ def elbo(model, guide, *args, **kwargs): with log_joint() as model_log_joint: model(*args, **kwargs) + params = model_log_joint.params | guide_log_joint.params # contract out auxiliary variables in the guide guide_log_probs = list(guide_log_joint.log_factors.values()) guide_aux_vars = ( frozenset().union(*(f.inputs for f in guide_log_probs)) - frozenset(guide_log_joint.plates) - frozenset(model_log_joint.log_factors) + - frozenset(params) ) if guide_aux_vars: guide_log_probs = funsor.sum_product.partial_sum_product( @@ -477,6 +480,7 @@ def elbo(model, guide, *args, **kwargs): frozenset().union(*(f.inputs for f in model_log_probs)) - frozenset(model_log_joint.plates) - frozenset(guide_log_joint.log_factors) + - frozenset(params) ) if model_aux_vars: model_log_probs = funsor.sum_product.partial_sum_product( @@ -517,7 +521,7 @@ def elbo(model, guide, *args, **kwargs): ) loss = -elbo - assert not loss.inputs + assert set(loss.inputs).issubset(params) return loss @@ -533,8 +537,7 @@ def __call__(self, model, guide, *args, **kwargs): # This is a wrapper for compatibility with full Pyro. class Trace_ELBO(ELBO): def __call__(self, model, guide, *args, **kwargs): - with funsor.montecarlo.MonteCarlo(): - return elbo(model, guide, *args, **kwargs) + return elbo(model, guide, *args, **kwargs) class TraceMeanField_ELBO(ELBO): @@ -571,7 +574,8 @@ def __call__(self, *args): # Augment args with reads from the global param store. unconstrained_params = tuple( - param(name).data.unconstrained() for name in self._param_trace + site["fn"](*site["args"]).data.unconstrained() + for site in self._param_trace.values() ) params_and_args = unconstrained_params + args @@ -587,7 +591,9 @@ def compiled(*params_and_args): constrained_param = param(name) # assume param has been initialized assert constrained_param.data.unconstrained() is unconstrained_param self._param_trace[name]["value"] = constrained_param - result = replay(self.fn, guide_trace=self._param_trace)(*args) + with funsor.terms.eager: + with funsor.montecarlo.MonteCarlo(): + result = replay(self.fn, guide_trace=self._param_trace)(*args) assert not result.inputs assert result.output == funsor.Real return funsor.to_data(result) diff --git a/test/test_minipyro.py b/test/test_minipyro.py index c5f83c6a..57846d24 100644 --- a/test/test_minipyro.py +++ b/test/test_minipyro.py @@ -162,7 +162,6 @@ def model(data=None): ("pyro", "ClippedAdam"), ("minipyro", "Adam"), ("funsor", "Adam"), - ("funsor", "ClippedAdam"), ], ) def test_optimizer(backend, optim_name, jit): @@ -384,8 +383,11 @@ def guide(): vectorize_particles=True, strict_enumeration_warning=True, ) - elbo = elbo.differentiable_loss if backend == "pyro" else elbo - actual_loss = funsor.to_data(elbo(model, guide)) + if backend == "pyro": + elbo = elbo.differentiable_loss + actual_loss = funsor.to_data(elbo(model, guide)) + else: + actual_loss = funsor.to_data(elbo(model, guide)(q=q)) actual_loss.backward() actual_grad = funsor.to_data(pyro.param("q")).grad @@ -466,12 +468,20 @@ def hand_model(): def guide(): pass + params = {name: pyro.param(name) for name in pyro.get_param_store().keys()} elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) - elbo = elbo.differentiable_loss if backend == "pyro" else elbo - auto_loss = elbo(auto_model, guide) + if backend == "pyro": + elbo = elbo.differentiable_loss + auto_loss = elbo(auto_model, guide) + else: + auto_loss = elbo(auto_model, guide)(**params) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) - elbo = elbo.differentiable_loss if backend == "pyro" else elbo - hand_loss = elbo(hand_model, guide) + if backend == "pyro": + elbo = elbo.differentiable_loss + hand_loss = elbo(hand_model, guide) + else: + hand_loss = elbo(hand_model, guide)(**params) _check_loss_and_grads(hand_loss, auto_loss) @@ -577,12 +587,20 @@ def hand_guide(data): pyro.sample("c_{}".format(i), dist.Categorical(probs_c[a])) data = torch.tensor([0, 0]) + params = {name: pyro.param(name) for name in pyro.get_param_store().keys()} elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) - elbo = elbo.differentiable_loss if backend == "pyro" else elbo - auto_loss = elbo(auto_model, auto_guide, data) + if backend == "pyro": + elbo = elbo.differentiable_loss + auto_loss = elbo(auto_model, auto_guide, data) + else: + auto_loss = elbo(auto_model, auto_guide, data)(**params) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) - elbo = elbo.differentiable_loss if backend == "pyro" else elbo - hand_loss = elbo(hand_model, hand_guide, data) + if backend == "pyro": + elbo = elbo.differentiable_loss + hand_loss = elbo(hand_model, hand_guide, data) + else: + hand_loss = elbo(hand_model, hand_guide, data)(**params) _check_loss_and_grads(hand_loss, auto_loss)