-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autodiff implementation (experimental) #494
base: master
Are you sure you want to change the base?
Conversation
inputs = OrderedDict([(var.name, var.output) for var in expanded_vars]) | ||
inputs.update(arg.inputs) | ||
output = arg.output | ||
fresh = frozenset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be
fresh = frozenset(v.name for v in expanded_vars)
from funsor.terms import Binary, Funsor, Lambda, Number, Tuple, Variable, lazy | ||
from funsor.testing import assert_close, random_tensor | ||
|
||
funsor.set_backend("torch") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test files should read but not write the global backend. Instead you can decorate each test with
@pytest.mark.skipif(get_backend() != "torch", reason="backend-specific")
and then run tests with
FUNSOR_BACKEND=torch pytest test/test_autodiff.py
@@ -994,6 +1001,46 @@ def die_binary(op, lhs, rhs): | |||
raise NotImplementedError(f"Missing pattern for {repr(expr)}") | |||
|
|||
|
|||
class Expand(Funsor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I'd like to better understand the need for this.
We've been trying to preserve the extensionality property in Funsor, which states that: if under every grounding substitution subs
a pair of funsors f,g
satisfy f(**subs) == g(**subs)
, then it should be permissible for an optimizer to replace funsor f
with funsor g
in any expression. IIUC this Expand
funsor would break extensionality because f.expand(...)
behaves as f
under every grounding substitution.
This is an implementation of autodiff. The goal is to address issues in computing expectations in
TraceEnum_ELBO
andTraceMarkovEnum_ELBO
(#493). As of now it seems to fixnan
gradients undereager
interpretation inTraceEnum_ELBO
.The algorithm implements equivalents of
linearize()
,transpose()
functions, and is tape-free (#446).JVP(primal, tangent)
and then pattern matched to propagate tangents, e.g.:Out tangent is a linear function of in tangents.
JVP
is used for(add,mul)
semiring andLJVP
is used for(logaddexp,add)
semiring..reduce(sum_op, "i")
and.expand("i")
(broadcasting does this automatically).