Skip to content

Commit

Permalink
Add README and more examples
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-lew committed Dec 13, 2022
1 parent 1492269 commit 6844796
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 9 deletions.
148 changes: 146 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,147 @@
# adev
# ADEV

Haskell prototype to accompany the paper "ADEV: Sound Automatic Differentiation of Expected Values of Probabilistic Programs"
This repository contains the Haskell prototype that accompanies the paper "[ADEV: Sound Automatic Differentiation of Expected Values of Probabilistic Programs](https://popl23.sigplan.org/details/POPL-2023-popl-research-papers/5/ADEV-Sound-Automatic-Differentiation-of-Expected-Values-of-Probabilistic-Programs)".

## Overview

![Overview of ADEV](figures/adev-diagram.png)
ADEV is a method of automatically differentiating loss functions defined as *expected values* of probabilistic processes. ADEV users define a _probabilistic program_ $t$, which, given a parameter of type $\mathbb{R}$ (or a subtype), outputs a value of type $\widetilde{\mathbb{R}}$,
which represents probabilistic estimators of losses. We translate $t$ to a new probabilistic program $s$,
whose expected return value is the derivative of $s$’s expected return value. Running $s$ yields provably unbiased
estimates $x_i$ of the loss's derivative, which can be used in the inner loop of stochastic optimization algorithms like ADAM or stochastic gradient descent.

ADEV goes beyond standard AD by explicitly supporting probabilistic primitives, like `flip`, for flipping a coin. If these probabilistic constructs are ignored, standard AD may produce incorrect results, as this figure from our paper illustrates:
![Optimizing an example loss function using ADEV](figures/example.png)
In this example, standard AD
fails to account for the parameter $\theta$'s effect on the *probability* of entering each branch. ADEV, by contrast, correctly accounts
for the probabilistic effects, generating similar code to what a practitioner might hand-derive. Correct
gradients are often crucial for downstream applications, e.g. optimization via stochastic gradient descent.

ADEV compositionally supports various gradient estimation strategies from the literature, including:
- Reparameterization trick (Kingma & Welling 2014)
- Score function estimator (Ranganath et al. 2014)
- Baselines as control variates (Mnih and Gregor 2014)
- Multi-sample estimators that Storchastic supports (e.g. leave-one-out baselines) (van Krieken et al. 2021)
- Variance reduction via dependency tracking (Schulman et al. 2015)
- Special estimators for differentiable particle filtering (Ścibior et al. 2021)
- Implicit reparameterization (Figurnov et al. 2018)
- Measure-valued derivatives (Heidergott and Vázquez-Abad 2000)
- Reparameterized rejection sampling (Nasseth et al. 2017)


## Haskell Example

ADEV extends forward-mode automatic differentiation to support *probabilistic programs*. Consider the following example:

```haskell
import Numeric.ADEV.Class (ADEV(..))
import Numeric.ADEV.Interp ()
import Numeric.ADEV.Diff (diff)
import Control.Monad (replicateM)
import Control.Monad.Bayes.Sampler.Strict (sampleIO)

-- Define a loss function l as the expected
-- value of a probabilistic process.
l theta = expect $ do
b <- flip_reinforce theta
if b then
return 0
else
return (-theta / 2)

-- Take its derivative.
l' = diff l

-- Helper function for computing averages
mean xs = sum xs / (realToFrac $ length xs)

-- Estimating the loss and its derivative
-- by averaging many samples
estimate_loss = fmap mean (replicateM 1000 (l 0.4))
estimate_deriv = fmap mean (replicateM 1000 (l' 0.4))

main = do
loss <- sampleIO estimate_loss
deriv <- sampleIO estimate_deriv
print (loss, deriv)
```

**Defining a loss.** The function `l` is defined to be the *expected value* of a probabilistic process, using `expect`. The process in question involves flipping a coin, whose probability of heads is `theta`, and returning either `0` or `-theta / 2`, depending on the coin flip's result.

**Differentiating.** ADEV's `diff` operator converts such a loss into a new function `l'` representing its derivative, with respect to the input parameter `theta`.

**Running the estimators.** Operationally, neither `l` nor `l'` compute exact expectations (or derivatives of expecations): instead, they represent _unbiased estimators_ of the desired values, which can be run using `sampleIO`.
On one run, the above code printed `(-0.122, -0.10)`, which are very close to the correct values of $-0.12$ and $-0.1$.

**Composing `expect` with other operators.** Note that ADEV also provides primitives for manipulating expected values, e.g. `exp_` for taking their exponents. For example, the code `fmap mean (replicateM 1000 (exp_ (l 0.4)))` yielded `0.881` on a sample run, close to the true value of $e^{-0.12} = 0.886$. This is the exponent of the expected value, not the expected value of the exponent, which is slightly different ($0.6 \times e^{-0.2} + 0.4 \times e^0 = 0.891$).

**Optimization.** We can use ADEV's estimated derivatives to implement a stochastic optimization algorithm:

```haskell
sgd loss eta x0 steps =
if steps == 0 then
return [x0]
else do
v <- diff loss x0
let x1 = x0 - eta * v
xs <- sgd loss eta x1 (steps - 1)
return (x0:xs)
```

Running `sampleIO $ sgd l 0.2 0.2 100` finds the value of $\theta$ that minimizes $l$, namely $\theta = 0.5$.

## Haskell Encoding of ADEV Programs

In the ADEV paper, the program `l` above would have type $\mathbb{R} \to \widetilde{\mathbb{R}}$.
In Haskell, its type is `ADEV p m r => r -> m r`. Why?

In general, expressions in the ADEV source language are represented by Haskell expressions with polymorphic type `ADEV p m r => ...`, where the `...` is a Haskell type that uses the three type variables `p`, `m`, and `r` as follows:

* `r` represents real numbers, $\mathbb{R}$ in the ADEV paper. (The type of positive reals reals, $\mathbb{R}_{>0}$, is represented as `Log r`.)
* `m r` represents estimated real numbers, $\widetilde{\mathbb{R}}$ in the ADEV paper.
* `p m a` represents probabilistic programs returning `a`, $P\,a$ in the ADEV paper.

Below, we show how `l`'s type relates to the types of its sub-expressions:
```haskell
-- `flip_reinforce` takes a real parameter and
-- probabilistically outputs a Boolean.
flip_reinforce :: ADEV p m r => r -> p m Bool

-- Using `do`, we can build a larger computation that
-- uses the result of a flip to compute a real.
-- Its type reflects that it still takes a real parameter
-- as input, but now probabilistically outputs a real.
prog :: ADEV p m r => r -> p m r
prog theta = do
b <- flip_reinforce theta
if b then
return 0
else
return (-theta/2)

-- The `expect` operation turns a probabilistic computation
-- over reals (type P R) into an estimator of its expected
-- value (type R~).
expect :: ADEV p m r => p m r -> m r

-- By composing expect and prog, we get l from above.
l :: ADEV p m r => r -> m r
l = expect . prog
```

## Implementation

To understand ADEV's implementation, it is useful to first skim the ADEV paper, which explains how ADEV modularly extends standard forward-mode AD with support for probabilistic primitives. The Haskell code is a relatively direct encoding of the ideas described in the paper. Briefly:

* All the primitives in the ADEV language, including those introduced by the extensions from Appendix B, are encoded as methods of the `ADEV` typeclass, in the [Numeric.ADEV.Class](src/Numeric/ADEV/Class.hs) module. This is like a 'specification' that a specific interpreter of the ADEV language can satisfy. It leaves open what concrete Haskell types will be used to represent the ADEV types of real numbers $\mathbb{R}$, estimated reals $\widetilde{\mathbb{R}}$, and monadic probabilistic programs $P\,\tau$ — it uses the type variables `r`, `m r`, and `p m tau` for this purpose.

* The [Numeric.ADEV.Interp](src/Numeric/ADEV/Interp.hs) module provides one instance of the `ADEV` typeclass, implementing the standard semantics of an ADEV term. The type variables `p`, `m`, and `r` are instantiated so that the type of reals `r` is interpreted as `Double`, the type `m r` of *estimated* reals is interpreted as the type `m Double` for some `MonadDistribution` `m` (where `MonadDistribution` is the [monad-bayes](https://github.com/tweag/monad-bayes) typeclass for probabilistic programs), and the type of probabilistic programs `p m a` is interpreted as `WriterT Sum m a` (the `Sum` maintains an accumulated loss, and is described in Appendix B.2 of the ADEV paper).

* The [Numeric.ADEV.Diff](src/Numeric/ADEV/Diff.hs) module provides built-in derivatives for each primitive. These are organized into a second instance of the `ADEV` typeclass, where now the type `r` of reals is interpreted as `ForwardDouble`, representing forward-mode dual numbers $\mathcal{D}\{\mathbb{R}\}$ from the paper; the type `m r` of estimated reals is interpreted as the type `m ForwardDouble` for some `MonadDistribution` `m`, which implements the type $\widetilde{\mathbb{R}}_\mathcal{D}$ of estimated dual numbers from the paper; and the type `p m tau` of probabilistic programs is interpreted as `ContT ForwardDouble m tau`, i.e., the type of *higher-order functions* that transform an input `loss_to_go : tau -> m ForwardDouble` into an estimated dual-number loss of type `m ForwardDouble` (this implements the type $P_\mathcal{D}\,\tau$ from the paper).


## Installing ADEV
1. Install `stack` (https://docs.haskellstack.org/en/stable/install_and_upgrade/).
2. Clone this repository.
3. Run the examples using `stack run ExampleName`, where `ExampleName.hs` is the name of a file from the `examples` directory.
4. Or: Run `stack ghci` to enter a REPL.
21 changes: 18 additions & 3 deletions adev.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ library
exposed-modules:
Numeric.ADEV.Class
Numeric.ADEV.Diff
Numeric.ADEV.Distributions
Numeric.ADEV.Interp
other-modules:
Paths_adev
Expand All @@ -43,10 +44,24 @@ library
, vector >=0.12.3.1 && <0.12.4
default-language: Haskell2010

executable adev-exe
executable Figure2
main-is: Figure2.hs
other-modules:
Paths_adev
hs-source-dirs:
examples
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
build-depends:
ad ==4.5.2.*
, adev
, base >=4.7 && <5
, log-domain >=0.12 && <0.14
, monad-bayes >=1.1.0
, mtl ==2.2.2.*
, transformers ==0.5.6.2.*
, vector >=0.12.3.1 && <0.12.4
default-language: Haskell2010

executable ParticleFilterExample
main-is: ParticleFilterExample.hs
hs-source-dirs:
examples
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
Expand Down
47 changes: 47 additions & 0 deletions examples/ParticleFilterExample.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
module Main (main) where

import Numeric.ADEV.Class
import Numeric.ADEV.Diff (diff)
import Numeric.ADEV.Interp ()
import Numeric.ADEV.Distributions (normalD)
import Numeric.AD.Mode.Forward.Double (ForwardDouble)
import Control.Monad.Bayes.Class (MonadDistribution)
import Control.Monad.Bayes.Sampler.Strict (sampleIO)
import Numeric.Log (Log(..))

-- smc :: ([a] -> Log r) -> D m r a -> (a -> D m r a) -> ([a] -> m r) -> Int -> Int -> m r

dens :: (RealFrac r, Floating r) => D m r Double -> Double -> Log r
dens (D _ f) x = f x

normalDensity :: (RealFrac r, Floating r) => r -> r -> Double -> Log r
normalDensity mu sig x = Exp $ -log(sig) - log(2*pi) / 2 - ((realToFrac x)-mu)^2/(2*sig^2)

ys = [undefined, 1,2,3,4,5]

l :: (MonadDistribution m, RealFloat r, Floating r, ADEV p m r) => r -> m r
l theta = smc p q0 q f 2 1000
where
p xs = let xys = zip (map realToFrac xs) (reverse (take (length xs) ys)) in
pxys xys
pxys [] = undefined
pxys [(x, y)] = normalDensity 0 (exp theta) (realToFrac x)
pxys ((x,y):((xprev,yprev):xys)) = normalDensity (realToFrac xprev) (exp theta) (realToFrac x) * normalDensity (realToFrac x) 1 y * pxys ((xprev,yprev):xys)
q0 = normalD 0 (exp theta)
q x = normalD (realToFrac x) (exp theta)
f xs = return 1

sga :: MonadDistribution m => (ForwardDouble -> m ForwardDouble) -> Double -> Double -> Int -> m [Double]
sga loss eta x0 steps =
if steps == 0 then
return [x0]
else do
v <- diff loss x0
let x1 = x0 + eta * v
xs <- sga loss eta x1 (steps - 1)
return (x0:xs)

main :: IO ()
main = do
vs <- sampleIO $ sga l 10.0 0.0 500
print (vs)
Binary file added figures/adev-diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 12 additions & 1 deletion package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,20 @@ library:
source-dirs: src

executables:
adev-exe:
Figure2:
main: Figure2.hs
source-dirs: examples
other-modules: []
ghc-options:
- -threaded
- -rtsopts
- -with-rtsopts=-N
dependencies:
- adev
ParticleFilterExample:
main: ParticleFilterExample.hs
source-dirs: examples
other-modules: []
ghc-options:
- -threaded
- -rtsopts
Expand Down
4 changes: 2 additions & 2 deletions src/Numeric/ADEV/Diff.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import Control.Monad.Bayes.Class (
uniformD,
logCategorical,
poisson,
bernoulli,
bernoulli,
normal)
import Control.Monad.Cont (ContT(..))
import Numeric.AD.Internal.Forward.Double (
Expand Down Expand Up @@ -172,7 +172,7 @@ instance MonadDistribution m => ADEV (ContT ForwardDouble) m ForwardDouble where
let (D qs qd) = dq (head v)
let qqd = Exp . primal . ln . qd
v' <- qs
return (v':v, (pp (v':v) / pp v) / qqd v')
return (v':v, w * (pp (v':v) / pp v) / qqd v')
step particles = do
particles <- resample particles
mapM propagate particles
Expand Down
16 changes: 16 additions & 0 deletions src/Numeric/ADEV/Distributions.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
module Numeric.ADEV.Distributions (normalD, geometricD) where

import Numeric.ADEV.Class

import Control.Monad.Bayes.Class (
MonadDistribution,
geometric,
normal)

import Numeric.Log (Log(..))

normalD :: (MonadDistribution m, RealFrac r, Floating r) => r -> r -> D m r Double
normalD mu sig = D (normal (realToFrac mu) (realToFrac sig)) (\x -> Exp $ -log(sig) - log(2*pi) / 2 - ((realToFrac x)-mu)^2/(2*sig^2))

geometricD :: (MonadDistribution m, RealFrac r, Floating r) => r -> D m r Int
geometricD p = D (geometric (realToFrac p)) (\x -> Exp $ log(p) + (fromIntegral $ x-1) * log(1-p))
2 changes: 1 addition & 1 deletion src/Numeric/ADEV/Interp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ instance MonadDistribution m => ADEV (WriterT (Sum Double)) m Double where
let (v, w) = particle
let (D qs qd) = q (head v)
v' <- qs
return (v':v, (p (v':v) / p v) / qd v')
return (v':v, w * (p (v':v) / p v) / qd v')
step particles = do
particles <- resample particles
mapM propagate particles
Expand Down

0 comments on commit 6844796

Please sign in to comment.