-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Delayed param #534
base: master
Are you sure you want to change the base?
Delayed param #534
Conversation
|
||
import funsor | ||
from funsor.adam import Adam # noqa: F401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for compatibility with pyroapi
value, _ = PARAM_STORE[name] | ||
if event_dim is None: | ||
event_dim = value.dim() | ||
output = funsor.Reals[value.shape[value.dim() - event_dim :]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer output
when pyro.param
was already defined elsewhere
|
||
def step(self, *args, **kwargs): | ||
self.optim.num_steps = 1 | ||
return self.run(*args, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for compatibility with SVI
interface
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, let's think about alternative workarounds... One issue here is that the Adam optimizer statistics would not be persisted across svi steps.
One option is simply to change pyroapi's SVI interface to look for either .run()
or if missing fall back to .step()
. Also I think it's more important to create a simple didactic example than to fastidiously conform to the pyroapi interface (since that interface hasn't seen much use).
for p in params: | ||
p.grad = torch.zeros_like(p.grad) | ||
return loss.item() | ||
with funsor.terms.lazy: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lazy
interpretation is needed here to make sure that funsor.Integrate
is not eagerly expanded in Expectation
|
I think you're right, but let's discuss. That's a little different from Pyro where jit is baked into ELBO subclasses. |
Addresses #533
Group coded with @fritzo @eb8680 @fehiepsi