diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 998ed64..e9c3537 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -55,10 +55,10 @@ jobs: - uses: actions/checkout@v4 - uses: quarto-dev/quarto-actions/setup@v2 with: - version: 1.6.30 + version: 1.5.57 - uses: julia-actions/setup-julia@latest with: - version: '1' + version: '1.10' - uses: julia-actions/cache@v2 - name: Cache Quarto id: cache-quarto diff --git a/docs/src/index.md b/docs/src/index.md index 83cbadc..0d896d9 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -17,15 +17,15 @@ The package adds two new optimisers that are compatible with the [Optimisers.jl] 1. Stochastic Gradient Langevin Dynamics (SGLD) (Welling and Teh 2011) — [`SGLD`](@ref). 2. Improper SGLD (see, for example, Grathwohl et al. (2020)) — [`ImproperSGLD`](@ref). -SGLD is an efficient gradient-based Markov Chain Monte Carlo (MCMC) method that can be used in the context of EBM to draw samples from the model posterior (Murphy 2023). Formally, we can draw from $p_{\theta}(\mathbf{x})$ as follows +SGLD is an efficient gradient-based Markov Chain Monte Carlo (MCMC) method that can be used in the context of EBM to draw samples from the model posterior (Murphy 2023). Formally, we can draw from $p_{\theta}(x)$ as follows ``` math \begin{aligned} - \mathbf{x}_{j+1} &\leftarrow \mathbf{x}_j - \frac{\epsilon_j^2}{2} \nabla_x \mathcal{E}_{\theta}(\mathbf{x}_j) + \epsilon_j \mathbf{r}_j, && j=1,...,J + x_{j+1} &\leftarrow x_j - \frac{\epsilon_j^2}{2} \nabla_x \mathcal{E}_{\theta}(x_j) + \epsilon_j r_j, && j=1,...,J \end{aligned} ``` -where $\mathbf{r}_j \sim \mathcal{N}(\mathbf{0},\mathbf{I})$ is a stochastic term and the step-size $\epsilon_j$ is typically polynomially decayed (Welling and Teh 2011). To allow for faster sampling, it is common practice to choose the step-size $\epsilon_j$ and the standard deviation of $\mathbf{r}_j$ separately. While $\mathbf{x}_J$ is only guaranteed to distribute as $p_{\theta}(\mathbf{x})$ if $\epsilon \rightarrow 0$ and $J \rightarrow \infty$, the bias introduced for a small finite $\epsilon$ is negligible in practice (Murphy 2023). We denote this form of sampling as Improper SGLD. +where $r_j \sim \mathcal{N}(0,I)$ is a stochastic term and the step-size $\epsilon_j$ is typically polynomially decayed (Welling and Teh 2011). To allow for faster sampling, it is common practice to choose the step-size $\epsilon_j$ and the standard deviation of $r_j$ separately. While $x_J$ is only guaranteed to distribute as $p_{\theta}(x)$ if $\epsilon \rightarrow 0$ and $J \rightarrow \infty$, the bias introduced for a small finite $\epsilon$ is negligible in practice (Murphy 2023). We denote this form of sampling as Improper SGLD. ### Example: Bayesian Inferecne with SGLD @@ -154,7 +154,7 @@ plot(p1, p2, size=(800, 400)) │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(3 => 1, σ) # 4 parameters │ summary(x) = "3×7000 adjoint(::Matrix{Float64}) with eltype Float64" - └ @ Flux ~/.julia/packages/Flux/HBF2N/src/layers/stateless.jl:60 + └ @ Flux ~/.julia/packages/Flux/htpCe/src/layers/stateless.jl:59 Final parameters are Float32[-2.3311744 1.1305944 -1.5102222 -4.0762844] Test accuracy is 0.9666666666666667 Final parameters are Float32[-0.6106307 2.760134 -0.031244753 -5.8856964] @@ -166,9 +166,9 @@ plot(p1, p2, size=(800, 400)) In the context of EBM, the optimisers can be used to sample from a model posterior. To this end, the package provides the following samples: -1. [`UnconditionalSampler`](@ref) — samples from the unconditional distribution $p_{\theta}(\mathbf{x})$ as in Grathwohl et al. (2020). -2. [`ConditionalSampler`](@ref) — samples from the conditional distribution $p_{\theta}(\mathbf{x}|y)$ as in Grathwohl et al. (2020). -3. [`JointSampler`](@ref) — samples from the joint distribution $p_{\theta}(\mathbf{x},y)$ as in Kelly, Zemel, and Grathwohl (2021). +1. [`UnconditionalSampler`](@ref) — samples from the unconditional distribution $p_{\theta}(x)$ as in Grathwohl et al. (2020). +2. [`ConditionalSampler`](@ref) — samples from the conditional distribution $p_{\theta}(x|y)$ as in Grathwohl et al. (2020). +3. [`JointSampler`](@ref) — samples from the joint distribution $p_{\theta}(x,y)$ as in Kelly, Zemel, and Grathwohl (2021). ### Example: Joint Energy-Based Model @@ -221,7 +221,7 @@ end ``` [ Info: Epoch 1 - Accuracy: 0.9995 + Accuracy: 0.99 [ Info: Epoch 2 Accuracy: 0.9995 [ Info: Epoch 3 @@ -270,7 +270,7 @@ plot(plt) │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 2) # 6 parameters │ summary(x) = "2-element Vector{Float64}" - └ @ Flux ~/.julia/packages/Flux/HBF2N/src/layers/stateless.jl:60 + └ @ Flux ~/.julia/packages/Flux/htpCe/src/layers/stateless.jl:59 ![](index_files/figure-commonmark/cell-8-output-2.svg) diff --git a/docs/src/index_files/figure-commonmark/cell-5-output-2.svg b/docs/src/index_files/figure-commonmark/cell-5-output-2.svg index 3535cd6..1301bef 100644 --- a/docs/src/index_files/figure-commonmark/cell-5-output-2.svg +++ b/docs/src/index_files/figure-commonmark/cell-5-output-2.svg @@ -1,106 +1,106 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + diff --git a/docs/src/index_files/figure-commonmark/cell-8-output-2.svg b/docs/src/index_files/figure-commonmark/cell-8-output-2.svg index 5841ffe..50e91c3 100644 --- a/docs/src/index_files/figure-commonmark/cell-8-output-2.svg +++ b/docs/src/index_files/figure-commonmark/cell-8-output-2.svg @@ -1,2308 +1,2312 @@ - + - + - + - + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +