Skip to content

Commit

Permalink
Preserve the identities of valued/observed variables
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 6, 2023
1 parent dd8c4a9 commit 95a4ceb
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 66 deletions.
16 changes: 5 additions & 11 deletions aeppl/joint_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,16 @@ def conditional_logprob(
# maps to the logprob graphs and value variables before returning them.
rv_values = {**original_rv_values, **realized}

fgraph, _, memo = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)

if extra_rewrites is not None:
extra_rewrites.add_requirements(fgraph, rv_values, memo)
extra_rewrites.apply(fgraph)
fgraph, new_rv_values = construct_ir_fgraph(
rv_values, ir_rewriter=ir_rewriter, extra_rewrites=extra_rewrites
)

# We assign log-densities on a per-node basis, and not per-output/variable.
realized_vars = set()
new_to_old_rvs = {}
nodes_to_vals: Dict["Apply", List[Tuple["Variable", "Variable"]]] = {}

for bnd_var, (old_mvar, old_val) in zip(fgraph.outputs, rv_values.items()):
for bnd_var, (old_mvar, val) in zip(fgraph.outputs, new_rv_values.items()):
mnode = bnd_var.owner
assert mnode and isinstance(mnode.op, ValuedVariable)

Expand All @@ -165,11 +163,7 @@ def conditional_logprob(
if old_mvar in realized:
realized_vars.add(rv_var)

# Do this just in case a value variable was changed. (Some transforms
# do this.)
new_val = memo[old_val]

nodes_to_vals.setdefault(rv_node, []).append((val_var, new_val))
nodes_to_vals.setdefault(rv_node, []).append((val_var, val))

new_to_old_rvs[rv_var] = old_mvar

Expand Down
55 changes: 36 additions & 19 deletions aeppl/rewriting.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union

import aesara.tensor as at
from aesara.compile.mode import optdb
from aesara.graph.basic import Apply, Variable
from aesara.graph.features import Feature
from aesara.graph.fg import FunctionGraph
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from aesara.graph.rewriting.basic import (
GraphRewriter,
NodeRewriter,
in2out,
node_rewriter,
)
from aesara.graph.rewriting.db import EquilibriumDB, RewriteDatabaseQuery, SequenceDB
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.extra_ops import BroadcastTo
Expand Down Expand Up @@ -180,9 +185,10 @@ def incsubtensor_rv_replace(fgraph, node):


def construct_ir_fgraph(
rv_values: Dict[Variable, Variable],
rvs_to_values: Dict[Variable, Variable],
ir_rewriter: Optional[GraphRewriter] = None,
) -> Tuple[FunctionGraph, Dict[Variable, Variable], Dict[Variable, Variable]]:
extra_rewrites: Optional[Union[GraphRewriter, NodeRewriter]] = None,
) -> Tuple[FunctionGraph, Dict[Variable, Variable]]:
r"""Construct a `FunctionGraph` in measurable IR form for the keys in `rv_values`.
A custom IR rewriter can be specified. By default,
Expand Down Expand Up @@ -215,9 +221,8 @@ def construct_ir_fgraph(
Returns
-------
A `FunctionGraph` of the measurable IR, a copy of `rv_values` containing
the new, cloned versions of the original variables in `rv_values`, and
a ``dict`` mapping all the original variables to their cloned values in
the `FunctionGraph`.
the new, cloned versions of the original variables in `rv_values`.
"""

# We're going to create a `FunctionGraph` that effectively represents the
Expand All @@ -233,16 +238,20 @@ def construct_ir_fgraph(
# so that they're distinct nodes in the graph. This allows us to replace
# all instances of the original random variables with their value
# variables, while leaving the output clones untouched.
rv_value_clones = {}
rv_clone_to_value_clone = {}
rv_to_value_clone = {}
value_clone_to_value = {}
measured_outputs = {}
memo = {}
for rv, val in rv_values.items():
memo: Dict[Variable, Variable] = {}
for rv, val in rvs_to_values.items():
rv_node_clone = rv.owner.clone()
rv_clone = rv_node_clone.outputs[rv.owner.outputs.index(rv)]
rv_value_clones[rv_clone] = val
measured_outputs[rv] = valued_variable(rv_clone, val)
# Prevent value variables from being cloned
memo[val] = val
val_clone = val.clone()
val_clone.name = "val_clone"
rv_clone_to_value_clone[rv_clone] = val_clone
rv_to_value_clone[rv] = val_clone
value_clone_to_value[val_clone] = val
measured_outputs[rv] = valued_variable(rv_clone, val_clone)

# We add `ShapeFeature` because it will get rid of references to the old
# `RandomVariable`s that have been lifted; otherwise, it will be difficult
Expand All @@ -257,9 +266,6 @@ def construct_ir_fgraph(
copy_inputs=False,
)

# Update `rv_values` so that it uses the new cloned variables
rv_value_clones = {memo[k]: v for k, v in rv_value_clones.items()}

# Replace valued non-output variables with their values
fgraph.replace_all(
[(memo[rv], val) for rv, val in measured_outputs.items() if rv in memo],
Expand All @@ -272,11 +278,22 @@ def construct_ir_fgraph(

ir_rewriter.rewrite(fgraph)

if extra_rewrites is not None:
# Expect `value_clone_to_value` to be updated in-place
extra_rewrites.add_requirements(fgraph, rv_to_value_clone, value_clone_to_value)
extra_rewrites.apply(fgraph)

# Undo un-valued measurable IR rewrites
new_to_old = tuple((v, k) for k, v in fgraph.measurable_conversions.items())
fgraph.replace_all(new_to_old, reason="undo-unvalued-measurables")
# and add the original value variables back in
new_to_old += tuple(value_clone_to_value.items())
fgraph.replace_all(
new_to_old, reason="undo-unvalued-measurables", import_missing=True
)

new_rvs_to_values = dict(zip(rvs_to_values.keys(), value_clone_to_value.values()))

return fgraph, rv_value_clones, memo
return fgraph, new_rvs_to_values


@register_useless
Expand Down
36 changes: 17 additions & 19 deletions aeppl/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,10 @@ def transform_values(fgraph: FunctionGraph, node: Apply):
4. Replace the old `ValuedVariable` with a new one containing a
`TransformedVariable` value.
Step 3. is currently accomplished by updating the `memo` dictionary
associated with the `FunctionGraph`. Our main entry-point,
Step 3. is currently accomplished by updating the `rvs_to_values`
dictionary associated with the `FunctionGraph`. Our main entry-point,
`conditional_logprob`, checks this dictionary for value variable changes.
TODO: This approach is less than ideal, because it puts awkward demands on
users/callers of this rewrite to check with `memo`; let's see if we can do
something better.
The new value variable mentioned in Step 2. may be of a different `Type`
(e.g. extra/fewer dimensions) than the original value variable; this is why
we must replace the corresponding original value variables before we
Expand Down Expand Up @@ -235,8 +231,8 @@ def transform_values(fgraph: FunctionGraph, node: Apply):

# This effectively lets the caller know that a value variable has been
# replaced (i.e. they should filter all their old value variables through
# the memo/replacements map).
fgraph.memo[value_var] = trans_value_var
# the replacements map).
fgraph.value_clone_to_value[value_var] = trans_value_var

trans_var = trans_node.outputs[rv_var_out_idx]
new_var = valued_variable(trans_var, untrans_value_var)
Expand All @@ -252,7 +248,7 @@ class TransformValuesMapping(Feature):
"""

def __init__(self, values_to_transforms, memo):
def __init__(self, values_to_transforms, value_clone_to_value):
"""
Parameters
==========
Expand All @@ -261,20 +257,19 @@ def __init__(self, values_to_transforms, memo):
value variable can be assigned one of `RVTransform`,
`DEFAULT_TRANSFORM`, or ``None``. Random variables with no
transform specified remain unchanged.
memo
Mapping from variables to their clones. This is updated
in-place whenever a value variable is transformed.
value_clone_to_value
Mapping between random variable value clones and their original
value variables.
"""
self.values_to_transforms = values_to_transforms
self.memo = memo
self.value_clone_to_value = value_clone_to_value

def on_attach(self, fgraph):
if hasattr(fgraph, "values_to_transforms"):
raise AlreadyThere()

fgraph.values_to_transforms = self.values_to_transforms
fgraph.memo = self.memo
fgraph.value_clone_to_value = self.value_clone_to_value


class TransformValuesRewrite(GraphRewriter):
Expand Down Expand Up @@ -322,6 +317,7 @@ def __init__(
measurable variable can be assigned an `RVTransform` instance,
`DEFAULT_TRANSFORM`, or ``None``. Measurable variables with no
transform specified remain unchanged.
rvs_to_values
"""

Expand All @@ -330,14 +326,16 @@ def __init__(
def add_requirements(
self,
fgraph,
rv_to_values: Dict[TensorVariable, TensorVariable],
memo: Dict[TensorVariable, TensorVariable],
rvs_to_values: Dict[TensorVariable, TensorVariable],
value_clone_to_value: Dict[TensorVariable, TensorVariable],
):
values_to_transforms = {
rv_to_values[rv]: transform
rvs_to_values[rv]: transform
for rv, transform in self.rvs_to_transforms.items()
}
values_transforms_feature = TransformValuesMapping(values_to_transforms, memo)
values_transforms_feature = TransformValuesMapping(
values_to_transforms, value_clone_to_value
)
fgraph.attach_feature(values_transforms_feature)

def apply(self, fgraph: FunctionGraph):
Expand Down
17 changes: 6 additions & 11 deletions tests/test_composite_logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,20 @@ def test_unvalued_ir_reversion():
"""Make sure that un-valued IR rewrites are reverted."""
srng = at.random.RandomStream(0)

x_rv = srng.normal()
x_rv = srng.normal(name="X")
y_rv = at.clip(x_rv, 0, 1)
z_rv = srng.normal(y_rv, 1, name="z")
y_rv.name = "Y"
z_rv = srng.normal(y_rv, 1, name="Z")
z_vv = z_rv.clone()
z_vv.name = "z"

# Only the `z_rv` is "valued", so `y_rv` doesn't need to be converted into
# measurable IR.
rv_values = {z_rv: z_vv}

z_fgraph, _, memo = construct_ir_fgraph(rv_values)
z_fgraph, new_rvs_to_values = construct_ir_fgraph(rv_values)

assert memo[y_rv] in z_fgraph.measurable_conversions

measurable_y_rv = z_fgraph.measurable_conversions[memo[y_rv]]
assert isinstance(measurable_y_rv.owner.op, MeasurableClip)

# `construct_ir_fgraph` should've reverted the un-valued measurable IR
# change
assert measurable_y_rv not in z_fgraph
assert not any(isinstance(node.op, MeasurableClip) for node in z_fgraph.apply_nodes)


def test_shifted_cumsum():
Expand Down
6 changes: 3 additions & 3 deletions tests/test_convolutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_add_independent_normals(mu_x, mu_y, sigma_x, sigma_y, x_shape, y_shape,
Z_rv.name = "Z"
z_vv = Z_rv.clone()

fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv})
fgraph, *_ = construct_ir_fgraph({Z_rv: z_vv})

(valued_var_out_node) = fgraph.outputs[0].owner
# The convolution should be applied, and not the transform
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_normal_add_input_valued():
Z_rv.name = "Z"
z_vv = Z_rv.clone()

fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv, X_rv: x_vv})
fgraph, *_ = construct_ir_fgraph({Z_rv: z_vv, X_rv: x_vv})

valued_var_out_node = fgraph.outputs[0].owner
# We should not expect the convolution to be applied; instead, the
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_normal_add_three_inputs():
Z_rv.name = "Z"
z_vv = Z_rv.clone()

fgraph, _, _ = construct_ir_fgraph({Z_rv: z_vv})
fgraph, *_ = construct_ir_fgraph({Z_rv: z_vv})

valued_var_out_node = fgraph.outputs[0].owner
# The convolution should be applied, and not the transform
Expand Down
6 changes: 3 additions & 3 deletions tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def test_switch_mixture():
z_vv = Z1_rv.clone()
z_vv.name = "z1"

fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
fgraph, *_ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})

out_rv = fgraph.outputs[0].owner.inputs[0]
assert isinstance(out_rv.owner.op, MixtureRV)
Expand All @@ -696,7 +696,7 @@ def test_switch_mixture():

Z1_rv.name = "Z1"

fgraph, _, _ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})
fgraph, *_ = construct_ir_fgraph({Z1_rv: z_vv, I_rv: i_vv})

out_rv = fgraph.outputs[0].owner.inputs[0]
assert out_rv.name == "Z1-mixture"
Expand All @@ -705,7 +705,7 @@ def test_switch_mixture():

Z2_rv = at.stack((X_rv, Y_rv))[I_rv]

fgraph2, _, _ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})
fgraph2, *_ = construct_ir_fgraph({Z2_rv: z_vv, I_rv: i_vv})

assert equal_computations(fgraph.outputs, fgraph2.outputs)

Expand Down

0 comments on commit 95a4ceb

Please sign in to comment.