-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Delayed param #534
base: master
Are you sure you want to change the base?
Delayed param #534
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,9 +21,9 @@ | |
|
||
import torch | ||
from pyro.distributions import validation_enabled | ||
from pyro.optim.clipped_adam import ClippedAdam as _ClippedAdam | ||
|
||
import funsor | ||
from funsor.adam import Adam # noqa: F401 | ||
|
||
|
||
# Funsor repreresents distributions in a fundamentally different way from | ||
|
@@ -227,13 +227,35 @@ def __enter__(self): | |
super(log_joint, self).__enter__() | ||
self.log_factors = OrderedDict() # maps site name to log_prob factor | ||
self.plates = set() | ||
self.params = set() | ||
return self | ||
|
||
def process_message(self, msg): | ||
if msg["type"] == "sample": | ||
if msg["value"] is None: | ||
# Create a delayed sample. | ||
msg["value"] = funsor.Variable(msg["name"], msg["fn"].output) | ||
elif msg["type"] == "param": | ||
if msg["value"] is None: | ||
# Create a delayed constrained parameter. | ||
constraint = msg["args"][1] | ||
batch_names = tuple( | ||
frame.name for frame in msg["cond_indep_stack"].values() | ||
) | ||
batch_shape = tuple( | ||
frame.size for frame in msg["cond_indep_stack"].values() | ||
) | ||
output = funsor.Reals[batch_shape + msg["output"].shape] | ||
if constraint == torch.distributions.constraints.real: | ||
msg["value"] = funsor.Variable(msg["name"], output)[batch_names] | ||
else: | ||
transform = torch.distributions.transform_to(constraint) | ||
op = funsor.ops.WrappedTransformOp(transform) | ||
# FIXME fix the message output | ||
unconstrained = funsor.Variable( | ||
msg["name"] + "_unconstrained", output | ||
)[batch_names] | ||
msg["value"] = op(unconstrained) | ||
|
||
def postprocess_message(self, msg): | ||
if msg["type"] == "sample": | ||
|
@@ -243,6 +265,8 @@ def postprocess_message(self, msg): | |
log_prob = msg["fn"].log_prob(msg["value"]) | ||
self.log_factors[msg["name"]] = log_prob | ||
self.plates.update(f.name for f in msg["cond_indep_stack"].values()) | ||
elif msg["type"] == "param": | ||
self.params.update(msg["value"].inputs) | ||
|
||
|
||
# apply_stack is called by pyro.sample and pyro.param. | ||
|
@@ -306,11 +330,13 @@ def param( | |
event_dim=None, | ||
): | ||
cond_indep_stack = {} | ||
output = None | ||
if init_value is not None: | ||
if event_dim is None: | ||
event_dim = init_value.dim() | ||
output = funsor.Reals[init_value.shape[init_value.dim() - event_dim :]] | ||
# infer output | ||
value = init_value | ||
if value is None: | ||
value, _ = PARAM_STORE[name] | ||
if event_dim is None: | ||
event_dim = value.dim() | ||
output = funsor.Reals[value.shape[value.dim() - event_dim :]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. infer |
||
|
||
def fn(init_value, constraint): | ||
if name in PARAM_STORE: | ||
|
@@ -362,40 +388,6 @@ def plate(name, size, dim): | |
return PlateMessenger(fn=None, name=name, size=size, dim=dim) | ||
|
||
|
||
# This is a thin wrapper around the `torch.optim.Optimizer` class that | ||
# dynamically generates optimizers for dynamically generated parameters. | ||
# See http://docs.pyro.ai/en/0.3.1/optimization.html | ||
class PyroOptim(object): | ||
def __init__(self, optim_args): | ||
self.optim_args = optim_args | ||
# Each parameter will get its own optimizer, which we keep track | ||
# of using this dictionary keyed on parameters. | ||
self.optim_objs = {} | ||
|
||
def __call__(self, params): | ||
for param in params: | ||
# If we've seen this parameter before, use the previously | ||
# constructed optimizer. | ||
if param in self.optim_objs: | ||
optim = self.optim_objs[param] | ||
# If we've never seen this parameter before, construct | ||
# an Adam optimizer and keep track of it. | ||
else: | ||
optim = self.TorchOptimizer([param], **self.optim_args) | ||
self.optim_objs[param] = optim | ||
# Take a gradient step for the parameter param. | ||
optim.step() | ||
|
||
|
||
# We wrap some commonly used PyTorch optimizers. | ||
class Adam(PyroOptim): | ||
TorchOptimizer = torch.optim.Adam | ||
|
||
|
||
class ClippedAdam(PyroOptim): | ||
TorchOptimizer = _ClippedAdam | ||
|
||
|
||
# This is a unified interface for stochastic variational inference in Pyro. | ||
# The actual construction of the loss is taken care of by `loss`. | ||
# See http://docs.pyro.ai/en/0.3.1/inference_algos.html | ||
|
@@ -408,24 +400,33 @@ def __init__(self, model, guide, optim, loss): | |
|
||
# This method handles running the model and guide, constructing the loss | ||
# function, and taking a gradient step. | ||
def step(self, *args, **kwargs): | ||
def run(self, *args, **kwargs): | ||
# This wraps both the call to `model` and `guide` in a `trace` so that | ||
# we can record all the parameters that are encountered. Note that | ||
# further tracing occurs inside of `loss`. | ||
with trace() as param_capture: | ||
# We use block here to allow tracing to record parameters only. | ||
with block(hide_fn=lambda msg: msg["type"] != "param"): | ||
loss = self.loss(self.model, self.guide, *args, **kwargs) | ||
# Differentiate the loss. | ||
funsor.to_data(loss).backward() | ||
# Grab all the parameters from the trace. | ||
params = [site["value"].data.unconstrained() for site in param_capture.values()] | ||
# Take a step w.r.t. each parameter in params. | ||
self.optim(params) | ||
# Zero out the gradients so that they don't accumulate. | ||
for p in params: | ||
p.grad = torch.zeros_like(p.grad) | ||
return loss.item() | ||
with funsor.terms.lazy: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
with trace() as param_capture: | ||
# We use block here to allow tracing to record parameters only. | ||
with block(hide_fn=lambda msg: msg["type"] != "param"): | ||
loss = self.loss(self.model, self.guide, *args, **kwargs) | ||
init_params = { | ||
name: funsor.to_funsor(site["fn"](*site["args"]).data.unconstrained()) | ||
for name, site in param_capture.items() | ||
} | ||
with self.optim.with_init(init_params): | ||
with funsor.montecarlo.MonteCarlo(): | ||
result = loss.reduce(funsor.ops.min) | ||
for name, value in self.optim.params.items(): | ||
name = name.replace("_unconstrained", "") | ||
value = value.data | ||
old_value, constraint = PARAM_STORE[name] | ||
if old_value is not value: | ||
old_value.data.copy_(value) | ||
return result | ||
|
||
def step(self, *args, **kwargs): | ||
self.optim.num_steps = 1 | ||
return self.run(*args, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for compatibility with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, let's think about alternative workarounds... One issue here is that the Adam optimizer statistics would not be persisted across svi steps. One option is simply to change pyroapi's SVI interface to look for either |
||
|
||
|
||
# TODO(eb8680) Replace this with funsor.Expectation. | ||
|
@@ -455,12 +456,14 @@ def elbo(model, guide, *args, **kwargs): | |
with log_joint() as model_log_joint: | ||
model(*args, **kwargs) | ||
|
||
params = model_log_joint.params | guide_log_joint.params | ||
# contract out auxiliary variables in the guide | ||
guide_log_probs = list(guide_log_joint.log_factors.values()) | ||
guide_aux_vars = ( | ||
frozenset().union(*(f.inputs for f in guide_log_probs)) | ||
- frozenset(guide_log_joint.plates) | ||
- frozenset(model_log_joint.log_factors) | ||
- frozenset(params) | ||
) | ||
if guide_aux_vars: | ||
guide_log_probs = funsor.sum_product.partial_sum_product( | ||
|
@@ -477,6 +480,7 @@ def elbo(model, guide, *args, **kwargs): | |
frozenset().union(*(f.inputs for f in model_log_probs)) | ||
- frozenset(model_log_joint.plates) | ||
- frozenset(guide_log_joint.log_factors) | ||
- frozenset(params) | ||
) | ||
if model_aux_vars: | ||
model_log_probs = funsor.sum_product.partial_sum_product( | ||
|
@@ -517,7 +521,7 @@ def elbo(model, guide, *args, **kwargs): | |
) | ||
|
||
loss = -elbo | ||
assert not loss.inputs | ||
assert set(loss.inputs).issubset(params) | ||
return loss | ||
|
||
|
||
|
@@ -533,8 +537,7 @@ def __call__(self, model, guide, *args, **kwargs): | |
# This is a wrapper for compatibility with full Pyro. | ||
class Trace_ELBO(ELBO): | ||
def __call__(self, model, guide, *args, **kwargs): | ||
with funsor.montecarlo.MonteCarlo(): | ||
return elbo(model, guide, *args, **kwargs) | ||
return elbo(model, guide, *args, **kwargs) | ||
|
||
|
||
class TraceMeanField_ELBO(ELBO): | ||
|
@@ -571,7 +574,8 @@ def __call__(self, *args): | |
|
||
# Augment args with reads from the global param store. | ||
unconstrained_params = tuple( | ||
param(name).data.unconstrained() for name in self._param_trace | ||
site["fn"](*site["args"]).data.unconstrained() | ||
for site in self._param_trace.values() | ||
) | ||
params_and_args = unconstrained_params + args | ||
|
||
|
@@ -587,7 +591,9 @@ def compiled(*params_and_args): | |
constrained_param = param(name) # assume param has been initialized | ||
assert constrained_param.data.unconstrained() is unconstrained_param | ||
self._param_trace[name]["value"] = constrained_param | ||
result = replay(self.fn, guide_trace=self._param_trace)(*args) | ||
with funsor.terms.eager: | ||
with funsor.montecarlo.MonteCarlo(): | ||
result = replay(self.fn, guide_trace=self._param_trace)(*args) | ||
assert not result.inputs | ||
assert result.output == funsor.Real | ||
return funsor.to_data(result) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for compatibility with
pyroapi