diff --git a/funsor/autodiff.py b/funsor/autodiff.py new file mode 100644 index 000000000..f11ac2fd9 --- /dev/null +++ b/funsor/autodiff.py @@ -0,0 +1,282 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections import defaultdict +from functools import reduce, singledispatch + +import funsor.ops as ops +from funsor import Tensor +from funsor.adjoint import _alpha_unmangle +from funsor.cnf import Contraction +from funsor.domains import Array, Bint, Real, Reals +from funsor.interpretations import autodiff, trace +from funsor.interpreter import interpretation +from funsor.ops import AssociativeOp, LogOp +from funsor.terms import ( + Binary, + Funsor, + Lambda, + Number, + Reduce, + Tuple, + Unary, + Variable, + eager, + lazy, +) + + +class JVP(Tuple): + """ + Tuple:(Primal, Tanget) + Semiring: (Add, Mul) + """ + + sum_op = ops.add + prod_op = ops.mul + div_op = ops.safediv + zero = Number(0) + one = Number(1) + + @property + def primal(self): + return self[0] + + @property + def tangent(self): + return self[1] + + +class LJVP(Tuple): + """ + Tuple: (LogPrimal, LogTanget) + Semiring: (Logaddexp, Add) + """ + + sum_op = ops.logaddexp + prod_op = ops.add + div_op = ops.safesub + zero = Number(-math.inf) + one = Number(0) + + @property + def primal(self): + return self[0] + + @property + def tangent(self): + return self[1] + + +@trace.register(Binary, AssociativeOp, Funsor, Funsor) +def trace_binary_associativeop(op, lhs, rhs): + with lazy: + result = Binary(op, lhs, rhs) + return result + + +@trace.register(Reduce, AssociativeOp, Funsor, frozenset) +def trace_binary_associativeop(op, arg, reduced_args): + with lazy: + result = Reduce(op, arg, reduced_args) + return result + + +def to_jvp(primal): + input_vars = tuple(Variable(key, value) for key, value in primal.inputs.items()) + output = reduce(lambda x, y: Lambda(y, x), reversed(input_vars), primal).output + tangent_placeholder = Variable(str(id(primal)), output)[tuple(primal.inputs)] + return JVP(primal, tangent_placeholder) + + +def to_ljvp(primal): + input_vars = tuple(Variable(key, value) for key, value in primal.inputs.items()) + output = reduce(lambda x, y: Lambda(y, x), reversed(input_vars), primal).output + tangent_placeholder = Variable(str(id(primal)), output)[tuple(primal.inputs)] + return LJVP(primal, tangent_placeholder) + + +def grad(expr, targets, out_tangent=None): + out_tangent = expr.one if out_tangent is None else out_tangent + in_tangents = set(target.tangent for target in targets) + transposes = transpose( + expr.tangent, + out_tangent, + in_tangents, + defaultdict(lambda: expr.zero), + type(expr), + ) + result = {} + for target in targets: + result[target] = transposes[target.tangent] + return result + + +@singledispatch +def transpose(expr, out_tangent, in_tangents, result, semiring): + if expr in in_tangents: + result[expr] = semiring.sum_op(result[expr], out_tangent) + return result + + +@transpose.register(Binary) +def transpose_binary(expr, out_tangent, in_tangents, result, semiring): + + op, lhs, rhs = expr.op, expr.lhs, expr.rhs + sum_op, prod_op = semiring.sum_op, semiring.prod_op + + if expr in in_tangents: + result[expr] = sum_op(result[expr], out_tangent) + out_tangent = result[expr] + + if op is sum_op: + lhs_adj = out_tangent.reduce(sum_op, out_tangent.input_vars - lhs.input_vars) + rhs_adj = out_tangent.reduce(sum_op, out_tangent.input_vars - rhs.input_vars) + elif op is prod_op: + lhs_adj = prod_op(rhs, out_tangent).reduce( + sum_op, out_tangent.input_vars - lhs.input_vars + ) + rhs_adj = prod_op(lhs, out_tangent).reduce( + sum_op, out_tangent.input_vars - rhs.input_vars + ) + else: + return result # is it always correct? + result = transpose(lhs, lhs_adj, in_tangents, result, semiring) + result = transpose(rhs, rhs_adj, in_tangents, result, semiring) + return result + + +@transpose.register(Reduce) +def transpose_reduce(expr, out_tangent, in_tangents, result, semiring): + # fix this in contraction as well + op, arg, reduced_vars = _alpha_unmangle(expr) + sum_op, prod_op = semiring.sum_op, semiring.prod_op + + if expr in in_tangents: + result[expr] = sum_op(result[expr], out_tangent) + out_tangent = result[expr] + + if op is sum_op: + arg_adj = out_tangent.expand(tuple(reduced_vars)) + result = transpose(arg, arg_adj, in_tangents, result, semiring) + return result + elif op is prod_op: + # this is unnecessary + return result + else: + raise ValueError + + +@transpose.register(Contraction) +def transpose_contraction(expr, out_tangent, in_tangents, result): + breakpoint() + if expr in in_tangents: + result[expr] += out_tangent + out_tangent = result[expr] + + if expr.red_op is ops.nullop: + for term in expr.terms: + if expr.bin_op is ops.add: + term_adj = out_tangent.reduce( + ops.add, out_tangent.input_vars - term.input_vars + ) + elif expr.bin_op is ops.mul: + expr_div_term = reduce( + ops.mul, tuple(t for t in expr.terms if t is not term) + ) + term_adj = (out_tangent * expr_div_term).reduce( + ops.add, out_tangent.input_vars - term.input_vars + ) + else: + raise ValueError + result = transpose(term, term_adj, in_tangents, result) + elif expr.bin_op is ops.nullop: + for term in expr.terms: # only one term + if expr.red_op is ops.add: + term_adj = out_tangent.expand(tuple(expr.reduced_vars)) + elif expr.red_op is ops.mul: + term_adj = ops.safediv(ops.mul(out_tangent, expr), term) + else: + raise ValueError + result = transpose(term, term_adj, in_tangents, result) + else: + raise ValueError + return result + + +@eager.register(Binary, AssociativeOp, JVP, JVP) +@eager.register(Binary, AssociativeOp, LJVP, LJVP) +@autodiff.register(Binary, AssociativeOp, JVP, JVP) +@autodiff.register(Binary, AssociativeOp, LJVP, LJVP) +def jvp_binary(op, lhs, rhs): + sum_op, prod_op = lhs.sum_op, lhs.prod_op + primal = Binary(op, lhs.primal, rhs.primal) + if op is sum_op: + tangent = sum_op(lhs.tangent, rhs.tangent) + elif op is prod_op: + tangent = sum_op( + prod_op(rhs.primal, lhs.tangent), prod_op(lhs.primal, rhs.tangent) + ) + else: + raise NotImplementedError + return type(lhs)(primal, tangent) + + +@eager.register(Binary, AssociativeOp, JVP, (Number, Tensor)) +@eager.register(Binary, AssociativeOp, LJVP, (Number, Tensor)) +@autodiff.register(Binary, AssociativeOp, JVP, (Number, Tensor)) +@autodiff.register(Binary, AssociativeOp, LJVP, (Number, Tensor)) +def jvp_binary_jvp_funsor(op, lhs, rhs): + sum_op, prod_op = lhs.sum_op, lhs.prod_op + primal = Binary(op, lhs.primal, rhs) + if op is sum_op: + tangent = sum_op(lhs.tangent, rhs) + elif op is prod_op: + tangent = prod_op(lhs.tangent, rhs) + else: + raise NotImplementedError + return type(lhs)(primal, tangent) + + +@eager.register(Binary, AssociativeOp, (Number, Tensor), JVP) +@eager.register(Binary, AssociativeOp, (Number, Tensor), LJVP) +@autodiff.register(Binary, AssociativeOp, (Number, Tensor), JVP) +@autodiff.register(Binary, AssociativeOp, (Number, Tensor), LJVP) +def jvp_binary_jvp_funsor(op, lhs, rhs): + sum_op, prod_op = rhs.sum_op, rhs.prod_op + primal = Binary(op, lhs, rhs.primal) + if op is sum_op: + tangent = sum_op(lhs, rhs.tangent) + elif op is prod_op: + tangent = prod_op(lhs, rhs.tangent) + else: + raise NotImplementedError + return type(rhs)(primal, tangent) + + +@eager.register(Reduce, AssociativeOp, JVP, frozenset) +@eager.register(Reduce, AssociativeOp, LJVP, frozenset) +@autodiff.register(Reduce, AssociativeOp, JVP, frozenset) +@autodiff.register(Reduce, AssociativeOp, LJVP, frozenset) +def jvp_reduce(op, arg, reduced_vars): + sum_op, prod_op, div_op = arg.sum_op, arg.prod_op, arg.div_op + primal = Reduce(op, arg.primal, reduced_vars) + if op is sum_op: + tangent = Reduce(sum_op, arg.tangent, reduced_vars) + elif op is prod_op: + tangent = Reduce( + sum_op, div_op(prod_op(arg.tangent, primal), arg.primal), reduced_vars + ) + else: + raise NotImplementedError + return type(arg)(primal, tangent) + + +# @lazy.register(Unary, LogOp, JVP) +# @eager.register(Unary, LogOp, JVP) +# def jvp_log(op, arg): +# arg_primal, arg_tangent = arg +# primal = Unary(op, arg_primal) +# tangent = Binary(ops.truediv, arg_tangent, arg_primal) +# return JVP(primal, tangent) diff --git a/funsor/domains.py b/funsor/domains.py index 617e965c5..c447f1f0a 100644 --- a/funsor/domains.py +++ b/funsor/domains.py @@ -261,6 +261,7 @@ def _find_domain_getitem(op, lhs_domain, rhs_domain): return Array[dtype, shape] elif isinstance(lhs_domain, ProductDomain): # XXX should this return a Union? + return Real raise NotImplementedError( "Cannot statically infer domain from: " f"{lhs_domain}[{rhs_domain}]" ) @@ -325,7 +326,10 @@ def _find_domain_associative_generic(op, *domains): return Array[domains[0].dtype, ()] lhs, rhs = domains - if lhs.dtype == "real" or rhs.dtype == "real": + # FIXME + if lhs is rhs: + return lhs + elif lhs.dtype == "real" or rhs.dtype == "real": dtype = "real" elif op in (ops.add, ops.mul, ops.pow, ops.max, ops.min): dtype = op(lhs.dtype - 1, rhs.dtype - 1) + 1 diff --git a/funsor/interpretations.py b/funsor/interpretations.py index 48f042de1..2e38d6b9b 100644 --- a/funsor/interpretations.py +++ b/funsor/interpretations.py @@ -329,6 +329,20 @@ def reflect(cls, *args): Eager exact naive interpretation wherever possible. """ +trace_base = DispatchedInterpretation("trace") +trace = PrioritizedInterpretation(trace_base, eager_base, normalize_base, reflect) +""" +Constructs a trace (expression) in terms of primitive operations. +""" + +autodiff_base = DispatchedInterpretation("autodiff") +autodiff = PrioritizedInterpretation( + autodiff_base, trace_base, eager_base, normalize_base, reflect +) +""" +Constructs a trace (expression) in terms of primitive operations. +""" + die = DispatchedInterpretation("die") eager_or_die = PrioritizedInterpretation(eager_base, die, reflect) diff --git a/funsor/tensor.py b/funsor/tensor.py index 63f5ffb1e..77cfe207e 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -18,9 +18,10 @@ from . import ops from .delta import Delta from .domains import Array, ArrayType, Bint, Product, Real, Reals, find_domain -from .ops import GetitemOp, MatmulOp, Op, ReshapeOp +from .ops import AssociativeOp, GetitemOp, MatmulOp, Op, ReshapeOp from .terms import ( Binary, + Expand, Finitary, Funsor, FunsorMeta, @@ -682,6 +683,28 @@ def eager_scatter_tensor(op, subs, source, reduced_vars): return Tensor(data, destin_inputs, output.dtype) +@eager.register(Expand, Number, tuple) +def eager_tensor_expand(arg, expanded_vars): + shape = tuple(var.output.size for var in expanded_vars) + inputs = OrderedDict([(var.name, var.output) for var in expanded_vars]) + data = ops.new_full( + funsor.tensor.get_default_prototype(), + shape, + arg.data + ) + return Tensor(data, inputs, arg.dtype) + + +@eager.register(Expand, Tensor, tuple) +def eager_tensor_expand(arg, expanded_vars): + expanded_shape = tuple(var.output.size for var in expanded_vars) + old_shape = (-1,) * (len(arg.inputs) + len(arg.output.shape)) + new_shape = expanded_shape + old_shape + inputs = OrderedDict([(var.name, var.output) for var in expanded_vars]) + inputs.update(arg.inputs) + return Tensor(ops.expand(arg.data, new_shape), inputs, arg.dtype) + + @eager.register(Binary, Op, Tensor, Number) def eager_binary_tensor_number(op, lhs, rhs): dtype = find_domain(op, lhs.output, rhs.output).dtype diff --git a/funsor/terms.py b/funsor/terms.py index 9423e0266..0768b764e 100644 --- a/funsor/terms.py +++ b/funsor/terms.py @@ -337,6 +337,13 @@ def item(self): def requires_grad(self): return False + def expand(self, expanded_vars): + # Eagerly convert reduced_vars to appropriate things. + assert isinstance(expanded_vars, tuple) + if not expanded_vars: + return self + return Expand(self, expanded_vars) + def reduce(self, op, reduced_vars=None): """ Reduce along all or a subset of inputs. @@ -994,6 +1001,46 @@ def die_binary(op, lhs, rhs): raise NotImplementedError(f"Missing pattern for {repr(expr)}") +class Expand(Funsor): + """ + Lazy expand operation over multiple variables. + + The user-facing interface is the :meth:`Funsor.expand` method. + + :param op: An associative operator. + :type op: ~funsor.ops.AssociativeOp + :param funsor arg: An argument to be reduced. + :param frozenset reduced_vars: A set of variables over which to reduce. + """ + + def __init__(self, arg, expanded_vars): + assert isinstance(arg, Funsor) + assert isinstance(expanded_vars, tuple) + assert all(isinstance(v, Variable) for v in expanded_vars) + inputs = OrderedDict([(var.name, var.output) for var in expanded_vars]) + inputs.update(arg.inputs) + output = arg.output + fresh = frozenset() + bound = {} + super().__init__(inputs, output, fresh, bound) + self.arg = arg + self.expanded_vars = expanded_vars + + def __repr__(self): + assert self.expanded_vars + rvars = [repr(v) for v in self.expanded_vars] + return "{}.expand({{{}}})".format( + repr(self.arg), ", ".join(rvars) + ) + + def __str__(self): + assert self.expanded_vars + rvars = [repr(v) for v in self.expanded_vars] + return "{}.expand({{{}}})".format( + repr(self.arg), ", ".join(rvars) + ) + + class Reduce(Funsor): """ Lazy reduction over multiple variables. diff --git a/test/test_autodiff.py b/test/test_autodiff.py new file mode 100644 index 000000000..1f6291d85 --- /dev/null +++ b/test/test_autodiff.py @@ -0,0 +1,245 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math +from collections import OrderedDict +from functools import reduce + +import pytest +import torch + +import funsor +import funsor.ops as ops +from funsor.autodiff import JVP, grad, to_jvp, to_ljvp +from funsor.domains import Bint, Real, Reals +from funsor.factory import Bound, Fresh, Has, make_funsor +from funsor.interpretations import autodiff, trace +from funsor.interpreter import interpretation +from funsor.optimizer import apply_optimizer +from funsor.sum_product import MarkovProduct +from funsor.tensor import Tensor +from funsor.terms import Binary, Funsor, Lambda, Number, Tuple, Variable, lazy +from funsor.testing import assert_close, random_tensor + +funsor.set_backend("torch") + + +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_mul_x_y(sum_op, prod_op, tojvp): + with autodiff: + # Mul + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + + z = prod_op(x, y) + result = grad(z, (x, y), out_adj) + + expected_x = prod_op(out_adj, y.primal).reduce(sum_op, "k") + expected_y = prod_op(out_adj, x.primal) + + actual_x = apply_optimizer(result[x]) + actual_y = apply_optimizer(result[y]) + + assert_close(actual_x, expected_x, rtol=1e-5) + assert_close(actual_y, expected_y, rtol=1e-5) + + +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_mul_x_x(sum_op, prod_op, tojvp): + with autodiff: + # Mul + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + out_adj = random_tensor(OrderedDict(j=Bint[4])) + + z = prod_op(x, x) + result = grad(z, (x,), out_adj) + + two = 2 if tojvp is to_jvp else math.log(2) + expected_x = reduce(prod_op, (two, out_adj, x.primal)) + actual_x = apply_optimizer(result[x]) + assert_close(actual_x, expected_x) + + +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_add_x_x(sum_op, prod_op, tojvp): + with autodiff: + # Add + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + out_adj = random_tensor(OrderedDict(j=Bint[4])) + + z = sum_op(x, x) + result = grad(z, (x,), out_adj) + + two = 2 if tojvp is to_jvp else math.log(2) + expected_x = prod_op(two, out_adj) + actual_x = apply_optimizer(result[x]) + assert_close(actual_x, expected_x) + + +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_add_x_y(sum_op, prod_op, tojvp): + with autodiff: + # Add + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + + z = sum_op(x, y) + result = grad(z, (x, y), out_adj) + + expected_x = out_adj.reduce(sum_op, "k") + expected_y = out_adj + + actual_x = apply_optimizer(result[x]) + actual_y = apply_optimizer(result[y]) + + assert_close(actual_x, expected_x) + assert_close(actual_y, expected_y) + + +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_mul_add_x_x_y(sum_op, prod_op, tojvp): + with autodiff: + # Add Mul + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + + z = sum_op(prod_op(x, x), y) + result = grad(z, (x, y), out_adj) + + two = 2 if tojvp is to_jvp else math.log(2) + expected_x = reduce(prod_op, (two, x.primal, out_adj.reduce(sum_op, "k"))) + expected_y = out_adj + + actual_x = apply_optimizer(result[x]) + actual_y = apply_optimizer(result[y]) + + assert_close(actual_x, expected_x) + assert_close(actual_y, expected_y) + + +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_mul_add_xx_yy(sum_op, prod_op, tojvp): + with autodiff: + # Add Mul + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) + + z = reduce(sum_op, (prod_op(x, x), y, y)) + result = grad(z, (x, y), out_adj) + + two = 2 if tojvp is to_jvp else math.log(2) + expected_x = reduce(prod_op, (two, x.primal, out_adj.reduce(sum_op, "k"))) + expected_y = prod_op(two, out_adj) + + actual_x = apply_optimizer(result[x]) + actual_y = apply_optimizer(result[y]) + + assert_close(actual_x, expected_x) + assert_close(actual_y, expected_y) + + +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_reduce_add_x(sum_op, prod_op, tojvp): + with autodiff: + # Reduce + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4])) + + z = y.reduce(sum_op, "k") + result = grad(z, (y,), out_adj) + + expected_y = out_adj.expand((Variable("k", Bint[5]),)) + actual_y = apply_optimizer(result[y]) + assert_close(actual_y, expected_y, rtol=1e-5) + + +@pytest.mark.parametrize( + "sum_op,prod_op,div_op,tojvp", + [ + (ops.add, ops.mul, ops.safediv, to_jvp), + (ops.logaddexp, ops.add, ops.safesub, to_ljvp), + ], +) +def test_reduce_mul_x(sum_op, prod_op, div_op, tojvp): + with autodiff: + # Reduce + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(j=Bint[4])) + + z = y.reduce(prod_op, "k") + result = grad(z, (y,), out_adj) + + actual_y = apply_optimizer(result[y]) + expected_y = div_op(prod_op(out_adj, z.primal), y.primal) + assert_close(actual_y, apply_optimizer(expected_y), rtol=1e-5) + + +@pytest.mark.parametrize( + "sum_op,prod_op,tojvp", + [(ops.add, ops.mul, to_jvp), (ops.logaddexp, ops.add, to_ljvp)], +) +def test_mul_reduce_x_y(sum_op, prod_op, tojvp): + with autodiff: + # Reduce + x = tojvp(random_tensor(OrderedDict(j=Bint[4]))) + y = tojvp(random_tensor(OrderedDict(j=Bint[4], k=Bint[5]))) + out_adj = random_tensor(OrderedDict(k=Bint[5])) + + z = prod_op(x, y).reduce(sum_op, "j") + result = grad(z, (x, y), out_adj) + + expected_x = prod_op(y.primal, out_adj).reduce(sum_op, "k") + expected_y = prod_op(x.primal, out_adj.expand((Variable("j", Bint[4]),))) + + actual_x = apply_optimizer(result[x]) + actual_y = apply_optimizer(result[y]) + + assert_close(actual_x, expected_x, rtol=1e-5) + assert_close(actual_y, expected_y, rtol=1e-5) + + +# def test_trace(): +# @make_funsor +# def Matmul( +# x: Has[{"i"}], +# y: Has[{"i"}], +# i: Bound +# ) -> Fresh[lambda x: x]: +# return (x * y).reduce(ops.add, i) +# +# x = random_tensor(OrderedDict(j=Bint[4])) +# y = random_tensor(OrderedDict(j=Bint[4], k=Bint[5])) +# eager_z = Matmul(x, y, "j") +# with lazy: +# lazy_z = Matmul(x, y, "j") +# +# with trace: +# trace_z = Matmul(x, y, "j") +# +# assert_close(eager_z, apply_optimizer(lazy_z)) +# assert_close(eager_z, apply_optimizer(trace_z))