From ff153098aba5475527e54785b95feff9a8fe3eef Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Tue, 2 Nov 2021 22:07:55 -0400 Subject: [PATCH 1/7] Importance funsor --- funsor/constant.py | 5 +++-- funsor/delta.py | 18 +++++++++--------- funsor/distribution.py | 15 ++++++++++++--- funsor/integrate.py | 2 +- funsor/montecarlo.py | 17 ++++++++++++++++- funsor/terms.py | 37 +++++++++++++++++++++++++++++++++++++ 6 files changed, 78 insertions(+), 16 deletions(-) diff --git a/funsor/constant.py b/funsor/constant.py index 5267240d..f38aa299 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -10,6 +10,7 @@ Binary, Funsor, FunsorMeta, + Importance, Number, Reduce, Unary, @@ -165,7 +166,7 @@ def eager_binary_constant_constant(op, lhs, rhs): return op(lhs.arg, rhs.arg) -@eager.register(Binary, ops.BinaryOp, Constant, (Number, Tensor)) +@eager.register(Binary, ops.BinaryOp, Constant, (Importance, Number, Tensor)) def eager_binary_constant_tensor(op, lhs, rhs): const_inputs = OrderedDict( (k, v) for k, v in lhs.const_inputs.items() if k not in rhs.inputs @@ -175,7 +176,7 @@ def eager_binary_constant_tensor(op, lhs, rhs): return op(lhs.arg, rhs) -@eager.register(Binary, ops.BinaryOp, (Number, Tensor), Constant) +@eager.register(Binary, ops.BinaryOp, (Importance, Number, Tensor), Constant) def eager_binary_tensor_constant(op, lhs, rhs): const_inputs = OrderedDict( (k, v) for k, v in rhs.const_inputs.items() if k not in lhs.inputs diff --git a/funsor/delta.py b/funsor/delta.py index fec86b58..8eb75d03 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -145,15 +145,15 @@ def eager_subs(self, subs): new_terms.append((value.name, (point, log_density))) continue - if not any( - d.dtype == "real" - for side in (value, point) - for d in side.inputs.values() - ): - dtype = get_default_dtype() - is_equal = ops.astype((value == point).all(), dtype) - log_densities.append(is_equal.log() + log_density) - continue + # if not any( + # d.dtype == "real" + # for side in (value, point) + # for d in side.inputs.values() + # ): + dtype = get_default_dtype() + is_equal = ops.astype((value == point).all(), dtype) + log_densities.append(is_equal.log() + log_density) + continue # Try to invert the substitution. soln = solve(value, point) diff --git a/funsor/distribution.py b/funsor/distribution.py index b5bef194..d26a89f2 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -235,12 +235,21 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): if not raw_dist.has_rsample: # scaling of dice_factor by num samples should already be handled by Funsor.sample raw_log_prob = raw_dist.log_prob(raw_value) - dice_factor = to_funsor( - raw_log_prob - ops.detach(raw_log_prob), + log_prob = to_funsor( + raw_log_prob, output=self.output, dim_to_name=dim_to_name, ) - result = funsor.delta.Delta(value_name, funsor_value, dice_factor) + model_sample = funsor.delta.Delta(value_name, funsor_value, log_prob) + guide_sample = funsor.delta.Delta( + value_name, funsor_value, ops.detach(log_prob) + ) + result = funsor.terms.Importance( + ops.logaddexp, + model_sample, + guide_sample, + frozenset({Variable(value_name, self.inputs[value_name])}), + ) else: result = funsor.delta.Delta(value_name, funsor_value) return result diff --git a/funsor/integrate.py b/funsor/integrate.py index ab6f98ad..c2b111fb 100644 --- a/funsor/integrate.py +++ b/funsor/integrate.py @@ -183,7 +183,7 @@ def eager_integrate(delta, integrand, reduced_vars): if name in reduced_names ) new_integrand = Subs(integrand, subs) - new_log_measure = Subs(delta, subs) + new_log_measure = delta.reduce(ops.logaddexp, reduced_names) result = Integrate(new_log_measure, new_integrand, reduced_vars - delta_fresh) return result diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index 06d66961..dc08ca13 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -9,7 +9,7 @@ from funsor.integrate import Integrate from funsor.interpretations import StatefulInterpretation from funsor.tensor import Tensor -from funsor.terms import Approximate, Funsor, Number +from funsor.terms import Approximate, Funsor, Importance, Number from funsor.util import get_backend from . import ops @@ -60,6 +60,21 @@ def monte_carlo_approximate(state, op, model, guide, approx_vars): return result +@MonteCarlo.register(Importance, ops.LogaddexpOp, Funsor, Funsor, frozenset) +def monte_carlo_importance(state, op, model, guide, approx_vars): + sample_options = {} + if state.rng_key is not None and get_backend() == "jax": + import jax + + sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key) + + sample = guide.sample(approx_vars, state.sample_inputs, **sample_options) + + result = sample + model - guide + + return result + + @functools.singledispatch def extract_samples(discrete_density): """ diff --git a/funsor/terms.py b/funsor/terms.py index 2c39bab0..c09f8d22 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1333,6 +1333,43 @@ def eager_approximate(op, model, guide, approx_vars): return model # exact +class Importance(Funsor): + """ + Interpretation-specific approximation wrt a set of variables. + + The default eager interpretation should be exact. + The user-facing interface is the :meth:`Funsor.approximate` method. + + :param op: An associative operator. + :type op: ~funsor.ops.AssociativeOp + :param Funsor model: An exact funsor depending on ``reduced_vars``. + :param Funsor guide: A proposal funsor guiding optional approximation. + :param frozenset approx_vars: A set of variables over which to approximate. + """ + + def __init__(self, op, model, guide, approx_vars): + assert isinstance(op, AssociativeOp) + assert isinstance(model, Funsor) + assert isinstance(guide, Funsor) + assert model.output is guide.output + assert isinstance(approx_vars, frozenset), approx_vars + inputs = model.inputs.copy() + inputs.update(guide.inputs) + output = model.output + super().__init__(inputs, output) + self.op = op + self.model = model + self.guide = guide + self.approx_vars = approx_vars + + def eager_reduce(self, op, reduced_vars): + assert reduced_vars.issubset(self.inputs) + if not reduced_vars: + return self + + return self.model.reduce(op, reduced_vars) + + class NumberMeta(FunsorMeta): """ Wrapper to fill in default ``dtype``. From d32fb832640e4d55bf8e2d0373bdfcf6522b4614 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 3 Nov 2021 01:19:04 -0400 Subject: [PATCH 2/7] replace Importance with Approximate(ops.sample, ...) --- funsor/__init__.py | 2 ++ funsor/constant.py | 6 +++--- funsor/distribution.py | 13 +++---------- funsor/montecarlo.py | 18 +----------------- funsor/terms.py | 21 +++++++++++++-------- 5 files changed, 22 insertions(+), 38 deletions(-) diff --git a/funsor/__init__.py b/funsor/__init__.py index dbed289a..dfbbf86a 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -10,6 +10,7 @@ from funsor.sum_product import MarkovProduct from funsor.tensor import Tensor, function from funsor.terms import ( + Approximate, Cat, Funsor, Independent, @@ -55,6 +56,7 @@ __all__ = [ "__version__", + "Approximate", "Array", "Bint", "Cat", diff --git a/funsor/constant.py b/funsor/constant.py index f38aa299..9007ca26 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -7,10 +7,10 @@ import funsor.ops as ops from funsor.tensor import Tensor from funsor.terms import ( + Approximate, Binary, Funsor, FunsorMeta, - Importance, Number, Reduce, Unary, @@ -166,7 +166,7 @@ def eager_binary_constant_constant(op, lhs, rhs): return op(lhs.arg, rhs.arg) -@eager.register(Binary, ops.BinaryOp, Constant, (Importance, Number, Tensor)) +@eager.register(Binary, ops.BinaryOp, Constant, (Approximate, Number, Tensor)) def eager_binary_constant_tensor(op, lhs, rhs): const_inputs = OrderedDict( (k, v) for k, v in lhs.const_inputs.items() if k not in rhs.inputs @@ -176,7 +176,7 @@ def eager_binary_constant_tensor(op, lhs, rhs): return op(lhs.arg, rhs) -@eager.register(Binary, ops.BinaryOp, (Importance, Number, Tensor), Constant) +@eager.register(Binary, ops.BinaryOp, (Approximate, Number, Tensor), Constant) def eager_binary_tensor_constant(op, lhs, rhs): const_inputs = OrderedDict( (k, v) for k, v in rhs.const_inputs.items() if k not in lhs.inputs diff --git a/funsor/distribution.py b/funsor/distribution.py index d26a89f2..0b5152c8 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -240,16 +240,9 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): output=self.output, dim_to_name=dim_to_name, ) - model_sample = funsor.delta.Delta(value_name, funsor_value, log_prob) - guide_sample = funsor.delta.Delta( - value_name, funsor_value, ops.detach(log_prob) - ) - result = funsor.terms.Importance( - ops.logaddexp, - model_sample, - guide_sample, - frozenset({Variable(value_name, self.inputs[value_name])}), - ) + model = funsor.delta.Delta(value_name, funsor_value, log_prob) + guide = funsor.delta.Delta(value_name, funsor_value, ops.detach(log_prob)) + result = model.approximate(ops.sample, guide, value_name) else: result = funsor.delta.Delta(value_name, funsor_value) return result diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index dc08ca13..30e4ef52 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -45,6 +45,7 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): @MonteCarlo.register(Approximate, ops.LogaddexpOp, Funsor, Funsor, frozenset) +@MonteCarlo.register(Approximate, ops.SampleOp, Funsor, Funsor, frozenset) def monte_carlo_approximate(state, op, model, guide, approx_vars): sample_options = {} if state.rng_key is not None and get_backend() == "jax": @@ -53,23 +54,6 @@ def monte_carlo_approximate(state, op, model, guide, approx_vars): sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key) sample = guide.sample(approx_vars, state.sample_inputs, **sample_options) - if sample is guide: - return model # cannot progress - result = sample + model - guide - - return result - - -@MonteCarlo.register(Importance, ops.LogaddexpOp, Funsor, Funsor, frozenset) -def monte_carlo_importance(state, op, model, guide, approx_vars): - sample_options = {} - if state.rng_key is not None and get_backend() == "jax": - import jax - - sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key) - - sample = guide.sample(approx_vars, state.sample_inputs, **sample_options) - result = sample + model - guide return result diff --git a/funsor/terms.py b/funsor/terms.py index c09f8d22..fd23e930 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1313,19 +1313,18 @@ def __init__(self, op, model, guide, approx_vars): inputs = model.inputs.copy() inputs.update(guide.inputs) output = model.output - fresh = frozenset(v.name for v in approx_vars) - bound = {v.name: v.output for v in approx_vars} - super().__init__(inputs, output, fresh, bound) + super().__init__(inputs, output) self.op = op self.model = model self.guide = guide self.approx_vars = approx_vars - def _alpha_convert(self, alpha_subs): - alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} - op, model, guide, approx_vars = super()._alpha_convert(alpha_subs) - approx_vars = frozenset(alpha_subs.get(var.name, var) for var in approx_vars) - return op, model, guide, approx_vars + def eager_reduce(self, op, reduced_vars): + assert reduced_vars.issubset(self.inputs) + if not reduced_vars: + return self + + return self.model.reduce(op, reduced_vars) @eager.register(Approximate, AssociativeOp, Funsor, Funsor, frozenset) @@ -1333,6 +1332,12 @@ def eager_approximate(op, model, guide, approx_vars): return model # exact +@eager.register(Approximate, ops.SampleOp, Funsor, Funsor, frozenset) +def eager_approximate(op, model, guide, approx_vars): + expr = reflect.interpret(Approximate, op, model, guide, approx_vars) + return expr + + class Importance(Funsor): """ Interpretation-specific approximation wrt a set of variables. From 1eece101a95680b4b1f3c9f1a3c7aca94e11c6b2 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 3 Nov 2021 01:23:25 -0400 Subject: [PATCH 3/7] remove Importance --- funsor/montecarlo.py | 2 +- funsor/terms.py | 37 ------------------------------------- 2 files changed, 1 insertion(+), 38 deletions(-) diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index 30e4ef52..ac5a0446 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -9,7 +9,7 @@ from funsor.integrate import Integrate from funsor.interpretations import StatefulInterpretation from funsor.tensor import Tensor -from funsor.terms import Approximate, Funsor, Importance, Number +from funsor.terms import Approximate, Funsor, Number from funsor.util import get_backend from . import ops diff --git a/funsor/terms.py b/funsor/terms.py index fd23e930..e9f7044d 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1338,43 +1338,6 @@ def eager_approximate(op, model, guide, approx_vars): return expr -class Importance(Funsor): - """ - Interpretation-specific approximation wrt a set of variables. - - The default eager interpretation should be exact. - The user-facing interface is the :meth:`Funsor.approximate` method. - - :param op: An associative operator. - :type op: ~funsor.ops.AssociativeOp - :param Funsor model: An exact funsor depending on ``reduced_vars``. - :param Funsor guide: A proposal funsor guiding optional approximation. - :param frozenset approx_vars: A set of variables over which to approximate. - """ - - def __init__(self, op, model, guide, approx_vars): - assert isinstance(op, AssociativeOp) - assert isinstance(model, Funsor) - assert isinstance(guide, Funsor) - assert model.output is guide.output - assert isinstance(approx_vars, frozenset), approx_vars - inputs = model.inputs.copy() - inputs.update(guide.inputs) - output = model.output - super().__init__(inputs, output) - self.op = op - self.model = model - self.guide = guide - self.approx_vars = approx_vars - - def eager_reduce(self, op, reduced_vars): - assert reduced_vars.issubset(self.inputs) - if not reduced_vars: - return self - - return self.model.reduce(op, reduced_vars) - - class NumberMeta(FunsorMeta): """ Wrapper to fill in default ``dtype``. From 1c04b05e7f8fdbd87f2149644468882a017206de Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 3 Nov 2021 02:22:42 -0400 Subject: [PATCH 4/7] fixes --- funsor/delta.py | 15 ++++++--------- test/test_distribution.py | 12 ++++++++---- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/funsor/delta.py b/funsor/delta.py index 8eb75d03..82c469de 100644 --- a/funsor/delta.py +++ b/funsor/delta.py @@ -145,15 +145,12 @@ def eager_subs(self, subs): new_terms.append((value.name, (point, log_density))) continue - # if not any( - # d.dtype == "real" - # for side in (value, point) - # for d in side.inputs.values() - # ): - dtype = get_default_dtype() - is_equal = ops.astype((value == point).all(), dtype) - log_densities.append(is_equal.log() + log_density) - continue + var_diff = value.input_vars ^ point.input_vars + if not any(d.output.dtype == "real" for d in var_diff): + dtype = get_default_dtype() + is_equal = ops.astype((value == point).all(), dtype) + log_densities.append(is_equal.log() + log_density) + continue # Try to invert the substitution. soln = solve(value, point) diff --git a/test/test_distribution.py b/test/test_distribution.py index 36144149..abd27d50 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -754,9 +754,10 @@ def _get_stat_diff( funsor_dist = funsor_dist_class(*params) rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) - sample_value = funsor_dist.sample( - frozenset(["value"]), sample_inputs, rng_key=rng_key - ) + with funsor.montecarlo.MonteCarlo(): + sample_value = funsor_dist.sample( + frozenset(["value"]), sample_inputs, rng_key=rng_key + ) expected_inputs = OrderedDict( tuple(sample_inputs.items()) + tuple(inputs.items()) @@ -1426,7 +1427,10 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape): name_to_dim = {batch_dim: -1 - i for i, batch_dim in enumerate(batch_dims)} rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) - data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0][1][0] + with funsor.montecarlo.MonteCarlo(): + data = ( + actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0].terms[0][1][0] + ) actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( From 271fda9f637048f6a5a3c8dc8d210580c2c2c260 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 3 Nov 2021 22:54:20 -0400 Subject: [PATCH 5/7] revert Approximate changes; Importance funsor --- funsor/__init__.py | 2 -- funsor/constant.py | 6 ++--- funsor/distribution.py | 2 +- funsor/importance.py | 57 +++++++++++++++++++++++++++++++++++++++ funsor/montecarlo.py | 3 ++- funsor/terms.py | 21 ++++++--------- test/test_distribution.py | 12 +++------ 7 files changed, 75 insertions(+), 28 deletions(-) create mode 100644 funsor/importance.py diff --git a/funsor/__init__.py b/funsor/__init__.py index dfbbf86a..dbed289a 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -10,7 +10,6 @@ from funsor.sum_product import MarkovProduct from funsor.tensor import Tensor, function from funsor.terms import ( - Approximate, Cat, Funsor, Independent, @@ -56,7 +55,6 @@ __all__ = [ "__version__", - "Approximate", "Array", "Bint", "Cat", diff --git a/funsor/constant.py b/funsor/constant.py index 9007ca26..d31d9a06 100644 --- a/funsor/constant.py +++ b/funsor/constant.py @@ -5,9 +5,9 @@ from functools import reduce import funsor.ops as ops +from funsor.importance import Importance from funsor.tensor import Tensor from funsor.terms import ( - Approximate, Binary, Funsor, FunsorMeta, @@ -166,7 +166,7 @@ def eager_binary_constant_constant(op, lhs, rhs): return op(lhs.arg, rhs.arg) -@eager.register(Binary, ops.BinaryOp, Constant, (Approximate, Number, Tensor)) +@eager.register(Binary, ops.BinaryOp, Constant, (Importance, Number, Tensor)) def eager_binary_constant_tensor(op, lhs, rhs): const_inputs = OrderedDict( (k, v) for k, v in lhs.const_inputs.items() if k not in rhs.inputs @@ -176,7 +176,7 @@ def eager_binary_constant_tensor(op, lhs, rhs): return op(lhs.arg, rhs) -@eager.register(Binary, ops.BinaryOp, (Approximate, Number, Tensor), Constant) +@eager.register(Binary, ops.BinaryOp, (Importance, Number, Tensor), Constant) def eager_binary_tensor_constant(op, lhs, rhs): const_inputs = OrderedDict( (k, v) for k, v in rhs.const_inputs.items() if k not in lhs.inputs diff --git a/funsor/distribution.py b/funsor/distribution.py index 0b5152c8..fce5bb82 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -242,7 +242,7 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): ) model = funsor.delta.Delta(value_name, funsor_value, log_prob) guide = funsor.delta.Delta(value_name, funsor_value, ops.detach(log_prob)) - result = model.approximate(ops.sample, guide, value_name) + result = funsor.importance.Importance(model, guide) else: result = funsor.delta.Delta(value_name, funsor_value) return result diff --git a/funsor/importance.py b/funsor/importance.py new file mode 100644 index 00000000..e38055cf --- /dev/null +++ b/funsor/importance.py @@ -0,0 +1,57 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + + +from funsor.delta import Delta +from funsor.interpretations import DispatchedInterpretation +from funsor.terms import Funsor, eager, reflect + + +class Importance(Funsor): + """ + Importance sampling for approximating integrals wrt a set of variables. + + When the proposal distribution (guide) is Delta then the eager + interpretation is ``Delta + log_importance_weight``. + The user-facing interface is the :meth:`Funsor.approximate` method. + + :param Funsor model: An exact funsor depending on ``sampled_vars``. + :param Funsor guide: A proposal distribution. + :param frozenset approx_vars: A set of variables over which to approximate. + """ + + def __init__(self, model, guide, sampled_vars): + assert isinstance(model, Funsor) + assert isinstance(guide, Funsor) + assert isinstance(sampled_vars, frozenset), sampled_vars + inputs = model.inputs.copy() + inputs.update(guide.inputs) + output = model.output + super().__init__(inputs, output) + self.model = model + self.guide = guide + self.sampled_vars = sampled_vars + + def eager_reduce(self, op, reduced_vars): + assert reduced_vars.issubset(self.inputs) + if not reduced_vars: + return self + + return self.model.reduce(op, reduced_vars) + + +@eager.register(Importance, Funsor, Delta) +def eager_importance(model, guide): + # Delta + log_importance_weight + return guide + model - guide + + +lazy_importance = DispatchedInterpretation("lazy_importance") +""" +Lazy interpretation of Importance with a Delta guide. +""" + + +@lazy_importance.register(Importance, Funsor, Delta) +def _lazy_importance(model, guide): + return reflect.interpret(Importance, model, guide) diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index ac5a0446..06d66961 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -45,7 +45,6 @@ def monte_carlo_integrate(state, log_measure, integrand, reduced_vars): @MonteCarlo.register(Approximate, ops.LogaddexpOp, Funsor, Funsor, frozenset) -@MonteCarlo.register(Approximate, ops.SampleOp, Funsor, Funsor, frozenset) def monte_carlo_approximate(state, op, model, guide, approx_vars): sample_options = {} if state.rng_key is not None and get_backend() == "jax": @@ -54,6 +53,8 @@ def monte_carlo_approximate(state, op, model, guide, approx_vars): sample_options["rng_key"], state.rng_key = jax.random.split(state.rng_key) sample = guide.sample(approx_vars, state.sample_inputs, **sample_options) + if sample is guide: + return model # cannot progress result = sample + model - guide return result diff --git a/funsor/terms.py b/funsor/terms.py index e9f7044d..2c39bab0 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -1313,18 +1313,19 @@ def __init__(self, op, model, guide, approx_vars): inputs = model.inputs.copy() inputs.update(guide.inputs) output = model.output - super().__init__(inputs, output) + fresh = frozenset(v.name for v in approx_vars) + bound = {v.name: v.output for v in approx_vars} + super().__init__(inputs, output, fresh, bound) self.op = op self.model = model self.guide = guide self.approx_vars = approx_vars - def eager_reduce(self, op, reduced_vars): - assert reduced_vars.issubset(self.inputs) - if not reduced_vars: - return self - - return self.model.reduce(op, reduced_vars) + def _alpha_convert(self, alpha_subs): + alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()} + op, model, guide, approx_vars = super()._alpha_convert(alpha_subs) + approx_vars = frozenset(alpha_subs.get(var.name, var) for var in approx_vars) + return op, model, guide, approx_vars @eager.register(Approximate, AssociativeOp, Funsor, Funsor, frozenset) @@ -1332,12 +1333,6 @@ def eager_approximate(op, model, guide, approx_vars): return model # exact -@eager.register(Approximate, ops.SampleOp, Funsor, Funsor, frozenset) -def eager_approximate(op, model, guide, approx_vars): - expr = reflect.interpret(Approximate, op, model, guide, approx_vars) - return expr - - class NumberMeta(FunsorMeta): """ Wrapper to fill in default ``dtype``. diff --git a/test/test_distribution.py b/test/test_distribution.py index abd27d50..d5ba96c9 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -754,10 +754,9 @@ def _get_stat_diff( funsor_dist = funsor_dist_class(*params) rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) - with funsor.montecarlo.MonteCarlo(): - sample_value = funsor_dist.sample( - frozenset(["value"]), sample_inputs, rng_key=rng_key - ) + sample_value = funsor_dist.sample( + frozenset(["value"]), sample_inputs, rng_key=rng_key + ) expected_inputs = OrderedDict( tuple(sample_inputs.items()) + tuple(inputs.items()) @@ -1427,10 +1426,7 @@ def test_categorical_event_dim_conversion(batch_shape, event_shape): name_to_dim = {batch_dim: -1 - i for i, batch_dim in enumerate(batch_dims)} rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) - with funsor.montecarlo.MonteCarlo(): - data = ( - actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0].terms[0][1][0] - ) + data = actual.sample(frozenset(["value"]), rng_key=rng_key).terms[0].terms[0][1][0] actual_log_prob = funsor.to_data(actual(value=data), name_to_dim=name_to_dim) expected_log_prob = funsor.to_data(actual, name_to_dim=name_to_dim).log_prob( From cb255a4ba09fdf0c0a7e0c2b8f16d647a93b2504 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 3 Nov 2021 23:22:12 -0400 Subject: [PATCH 6/7] fixes --- funsor/distribution.py | 3 ++- funsor/importance.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index fce5bb82..ed79b1df 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -242,7 +242,8 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): ) model = funsor.delta.Delta(value_name, funsor_value, log_prob) guide = funsor.delta.Delta(value_name, funsor_value, ops.detach(log_prob)) - result = funsor.importance.Importance(model, guide) + sampled_var = frozenset({Variable(value_name, self.inputs[value_name])}) + result = funsor.importance.Importance(model, guide, sampled_var) else: result = funsor.delta.Delta(value_name, funsor_value) return result diff --git a/funsor/importance.py b/funsor/importance.py index e38055cf..4eb04de8 100644 --- a/funsor/importance.py +++ b/funsor/importance.py @@ -15,9 +15,9 @@ class Importance(Funsor): interpretation is ``Delta + log_importance_weight``. The user-facing interface is the :meth:`Funsor.approximate` method. - :param Funsor model: An exact funsor depending on ``sampled_vars``. + :param Funsor model: A funsor depending on ``sampled_vars``. :param Funsor guide: A proposal distribution. - :param frozenset approx_vars: A set of variables over which to approximate. + :param frozenset sampled_vars: A set of input variables to sample. """ def __init__(self, model, guide, sampled_vars): @@ -40,18 +40,18 @@ def eager_reduce(self, op, reduced_vars): return self.model.reduce(op, reduced_vars) -@eager.register(Importance, Funsor, Delta) -def eager_importance(model, guide): +@eager.register(Importance, Funsor, Delta, frozenset) +def eager_importance(model, guide, sampled_vars): # Delta + log_importance_weight return guide + model - guide lazy_importance = DispatchedInterpretation("lazy_importance") """ -Lazy interpretation of Importance with a Delta guide. +Lazy interpretation of the Importance with a Delta guide. """ -@lazy_importance.register(Importance, Funsor, Delta) -def _lazy_importance(model, guide): - return reflect.interpret(Importance, model, guide) +@lazy_importance.register(Importance, Funsor, Delta, frozenset) +def _lazy_importance(model, guide, sampled_vars): + return reflect.interpret(Importance, model, guide, sampled_vars) From 7dd3a886acecfde1b4560cecd7db37d86b566c9f Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Wed, 3 Nov 2021 23:46:12 -0400 Subject: [PATCH 7/7] docs for importance modele --- docs/source/funsors.rst | 8 ++++++++ funsor/__init__.py | 2 ++ funsor/distribution.py | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/source/funsors.rst b/docs/source/funsors.rst index 2334ed50..9b26d1b2 100644 --- a/docs/source/funsors.rst +++ b/docs/source/funsors.rst @@ -64,3 +64,11 @@ Constant :undoc-members: :show-inheritance: :member-order: bysource + +Importance +---------- +.. automodule:: funsor.importance + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource diff --git a/funsor/__init__.py b/funsor/__init__.py index dbed289a..8724c7ed 100644 --- a/funsor/__init__.py +++ b/funsor/__init__.py @@ -4,6 +4,7 @@ from funsor.constant import Constant from funsor.domains import Array, Bint, Domain, Real, Reals, bint, find_domain, reals from funsor.factory import make_funsor +from funsor.importance import Importance from funsor.integrate import Integrate from funsor.interpreter import interpretation, reinterpret from funsor.op_factory import make_op @@ -61,6 +62,7 @@ "Constant", "Domain", "Funsor", + "Importance", "Independent", "Integrate", "Lambda", diff --git a/funsor/distribution.py b/funsor/distribution.py index ed79b1df..5c4332ee 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -243,7 +243,7 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): model = funsor.delta.Delta(value_name, funsor_value, log_prob) guide = funsor.delta.Delta(value_name, funsor_value, ops.detach(log_prob)) sampled_var = frozenset({Variable(value_name, self.inputs[value_name])}) - result = funsor.importance.Importance(model, guide, sampled_var) + result = funsor.Importance(model, guide, sampled_var) else: result = funsor.delta.Delta(value_name, funsor_value) return result