diff --git a/src/Turing.jl b/src/Turing.jl index 6318e2bd5..6a6f25b56 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -107,6 +107,7 @@ export @model, # modelling AutoForwardDiff, # ADTypes AutoReverseDiff, AutoZygote, + AutoEnzyme, AutoMooncake, setprogress!, # debugging Flat, diff --git a/src/essential/Essential.jl b/src/essential/Essential.jl index c04c7e862..8e3cf4bab 100644 --- a/src/essential/Essential.jl +++ b/src/essential/Essential.jl @@ -11,7 +11,8 @@ using Bijectors: PDMatDistribution using AdvancedVI using StatsFuns: logsumexp, softmax @reexport using DynamicPPL -using ADTypes: ADTypes, AutoForwardDiff, AutoReverseDiff, AutoZygote, AutoMooncake +using ADTypes: + ADTypes, AutoForwardDiff, AutoEnzyme, AutoReverseDiff, AutoZygote, AutoMooncake using AdvancedPS: AdvancedPS @@ -19,6 +20,7 @@ include("container.jl") export @model, @varname, + AutoEnzyme, AutoForwardDiff, AutoZygote, AutoReverseDiff, diff --git a/test/Project.toml b/test/Project.toml index 1620e6f4b..19343bbf3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -12,6 +12,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" @@ -52,6 +53,7 @@ Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.32.2" +Enzyme = "0.13" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10" HypothesisTests = "0.11" diff --git a/test/mcmc/Inference.jl b/test/mcmc/Inference.jl index da29e7708..cd60b50d1 100644 --- a/test/mcmc/Inference.jl +++ b/test/mcmc/Inference.jl @@ -7,6 +7,7 @@ using Distributions: Bernoulli, Beta, InverseGamma, Normal using Distributions: sample import DynamicPPL using DynamicPPL: Sampler, getlogp +import Enzyme import ForwardDiff using LinearAlgebra: I import MCMCChains @@ -449,53 +450,56 @@ using Turing alg = HMC(0.01, 5; adtype=adbackend) res = sample(StableRNG(seed), vdemo2(randn(D, 100)), alg, 10) - # Vector assumptions - N = 10 - alg = HMC(0.2, 4; adtype=adbackend) + # TODO(mhauru) Type unstable getfield of tuple not supported in Enzyme yet + if !(adbackend isa AutoEnzyme) + # Vector assumptions + N = 10 + alg = HMC(0.2, 4; adtype=adbackend) - @model function vdemo3() - x = Vector{Real}(undef, N) - for i in 1:N - x[i] ~ Normal(0, sqrt(4)) + @model function vdemo3() + x = Vector{Real}(undef, N) + for i in 1:N + x[i] ~ Normal(0, sqrt(4)) + end end - end - # TODO(mhauru) What is the point of the below @elapsed stuff? It prints out some - # timings. Do we actually ever look at them? - t_loop = @elapsed res = sample(StableRNG(seed), vdemo3(), alg, 1000) + # TODO(mhauru) What is the point of the below @elapsed stuff? It prints out some + # timings. Do we actually ever look at them? + t_loop = @elapsed res = sample(StableRNG(seed), vdemo3(), alg, 1000) - # Test for vectorize UnivariateDistribution - @model function vdemo4() - x = Vector{Real}(undef, N) - @. x ~ Normal(0, 2) - end + # Test for vectorize UnivariateDistribution + @model function vdemo4() + x = Vector{Real}(undef, N) + @. x ~ Normal(0, 2) + end - t_vec = @elapsed res = sample(StableRNG(seed), vdemo4(), alg, 1000) + t_vec = @elapsed res = sample(StableRNG(seed), vdemo4(), alg, 1000) - @model vdemo5() = x ~ MvNormal(zeros(N), 4 * I) + @model vdemo5() = x ~ MvNormal(zeros(N), 4 * I) - t_mv = @elapsed res = sample(StableRNG(seed), vdemo5(), alg, 1000) + t_mv = @elapsed res = sample(StableRNG(seed), vdemo5(), alg, 1000) - println("Time for") - println(" Loop : ", t_loop) - println(" Vec : ", t_vec) - println(" Mv : ", t_mv) + println("Time for") + println(" Loop : ", t_loop) + println(" Vec : ", t_vec) + println(" Mv : ", t_mv) - # Transformed test - @model function vdemo6() - x = Vector{Real}(undef, N) - @. x ~ InverseGamma(2, 3) - end + # Transformed test + @model function vdemo6() + x = Vector{Real}(undef, N) + @. x ~ InverseGamma(2, 3) + end - sample(StableRNG(seed), vdemo6(), alg, 10) + sample(StableRNG(seed), vdemo6(), alg, 10) - N = 3 - @model function vdemo7() - x = Array{Real}(undef, N, N) - @. x ~ [InverseGamma(2, 3) for i in 1:N] - end + N = 3 + @model function vdemo7() + x = Array{Real}(undef, N, N) + @. x ~ [InverseGamma(2, 3) for i in 1:N] + end - sample(StableRNG(seed), vdemo7(), alg, 10) + sample(StableRNG(seed), vdemo7(), alg, 10) + end end @testset "vectorization .~" begin @@ -519,51 +523,54 @@ using Turing alg = HMC(0.01, 5; adtype=adbackend) res = sample(StableRNG(seed), vdemo2(randn(D, 100)), alg, 10) - # Vector assumptions - N = 10 - alg = HMC(0.2, 4; adtype=adbackend) + # TODO(mhauru) Type unstable getfield of tuple not supported in Enzyme yet + if !(adbackend isa AutoEnzyme) + # Vector assumptions + N = 10 + alg = HMC(0.2, 4; adtype=adbackend) - @model function vdemo3() - x = Vector{Real}(undef, N) - for i in 1:N - x[i] ~ Normal(0, sqrt(4)) + @model function vdemo3() + x = Vector{Real}(undef, N) + for i in 1:N + x[i] ~ Normal(0, sqrt(4)) + end end - end - # TODO(mhauru) Same question as above about @elapsed. - t_loop = @elapsed res = sample(StableRNG(seed), vdemo3(), alg, 1_000) + # TODO(mhauru) Same question as above about @elapsed. + t_loop = @elapsed res = sample(StableRNG(seed), vdemo3(), alg, 1_000) - # Test for vectorize UnivariateDistribution - @model function vdemo4() - x = Vector{Real}(undef, N) - return x .~ Normal(0, 2) - end + # Test for vectorize UnivariateDistribution + @model function vdemo4() + x = Vector{Real}(undef, N) + return x .~ Normal(0, 2) + end - t_vec = @elapsed res = sample(StableRNG(seed), vdemo4(), alg, 1_000) + t_vec = @elapsed res = sample(StableRNG(seed), vdemo4(), alg, 1_000) - @model vdemo5() = x ~ MvNormal(zeros(N), 4 * I) + @model vdemo5() = x ~ MvNormal(zeros(N), 4 * I) - t_mv = @elapsed res = sample(StableRNG(seed), vdemo5(), alg, 1_000) + t_mv = @elapsed res = sample(StableRNG(seed), vdemo5(), alg, 1_000) - println("Time for") - println(" Loop : ", t_loop) - println(" Vec : ", t_vec) - println(" Mv : ", t_mv) + println("Time for") + println(" Loop : ", t_loop) + println(" Vec : ", t_vec) + println(" Mv : ", t_mv) - # Transformed test - @model function vdemo6() - x = Vector{Real}(undef, N) - return x .~ InverseGamma(2, 3) - end + # Transformed test + @model function vdemo6() + x = Vector{Real}(undef, N) + return x .~ InverseGamma(2, 3) + end - sample(StableRNG(seed), vdemo6(), alg, 10) + sample(StableRNG(seed), vdemo6(), alg, 10) - @model function vdemo7() - x = Array{Real}(undef, N, N) - return x .~ [InverseGamma(2, 3) for i in 1:N] - end + @model function vdemo7() + x = Array{Real}(undef, N, N) + return x .~ [InverseGamma(2, 3) for i in 1:N] + end - sample(StableRNG(seed), vdemo7(), alg, 10) + sample(StableRNG(seed), vdemo7(), alg, 10) + end end @testset "Type parameters" begin diff --git a/test/mcmc/abstractmcmc.jl b/test/mcmc/abstractmcmc.jl index 6486a8628..b526919e1 100644 --- a/test/mcmc/abstractmcmc.jl +++ b/test/mcmc/abstractmcmc.jl @@ -5,6 +5,7 @@ using AdvancedMH: AdvancedMH using Distributions: sample using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL +import Enzyme using ForwardDiff: ForwardDiff using LinearAlgebra: I using LogDensityProblems: LogDensityProblems diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 1d7208b43..ebeece03a 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -12,6 +12,7 @@ import Combinatorics using Distributions: InverseGamma, Normal using Distributions: sample using DynamicPPL: DynamicPPL +using Enzyme: Enzyme using ForwardDiff: ForwardDiff using Random: Random using ReverseDiff: ReverseDiff diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index d45846f3d..31d5f1102 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -7,6 +7,7 @@ import ..ADUtils using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample import DynamicPPL using DynamicPPL: Sampler +import Enzyme import ForwardDiff using HypothesisTests: ApproximateTwoSampleKSTest, pvalue import ReverseDiff @@ -15,7 +16,7 @@ import Random using StableRNGs: StableRNG using StatsFuns: logistic import Mooncake -using Test: @test, @test_logs, @testset, @test_throws +using Test: @test, @test_broken, @test_logs, @testset, @test_throws using Turing @testset "Testing hmc.jl with $adbackend" for adbackend in ADUtils.adbackends @@ -84,7 +85,8 @@ using Turing r = reshape(Array(chain), n_samples, 2, 2) r_mean = dropdims(mean(r; dims=1); dims=1) - @test isapprox(r_mean, mean(dist); atol=0.2) + # TODO(mhauru) The below remains broken for Enzyme. Need to investigate why. + @test isapprox(r_mean, mean(dist); atol=0.2) broken = (adbackend isa AutoEnzyme) end @testset "multivariate support" begin @@ -112,18 +114,14 @@ using Turing alpha = 0.16 # regularizatin term var_prior = sqrt(1.0 / alpha) # variance of the Gaussian prior - @model function bnn(ts) - b1 ~ MvNormal( - [0.0; 0.0; 0.0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior] - ) - w11 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior]) - w12 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior]) - w13 ~ MvNormal([0.0; 0.0], [var_prior 0.0; 0.0 var_prior]) + @model function bnn(ts, var_prior) + b1 ~ MvNormal(zeros(3), var_prior * I) + w11 ~ MvNormal(zeros(2), var_prior * I) + w12 ~ MvNormal(zeros(2), var_prior * I) + w13 ~ MvNormal(zeros(2), var_prior * I) bo ~ Normal(0, var_prior) - wo ~ MvNormal( - [0.0; 0; 0], [var_prior 0.0 0.0; 0.0 var_prior 0.0; 0.0 0.0 var_prior] - ) + wo ~ MvNormal(zeros(3), var_prior * I) for i in rand(1:N, 10) y = nn(xs[i], b1, w11, w12, w13, bo, wo) ts[i] ~ Bernoulli(y) @@ -132,7 +130,9 @@ using Turing end # Sampling - chain = sample(StableRNG(seed), bnn(ts), HMC(0.1, 5; adtype=adbackend), 10) + chain = sample( + StableRNG(seed), bnn(ts, var_prior), HMC(0.1, 5; adtype=adbackend), 10 + ) end @testset "hmcda inference" begin @@ -345,12 +345,16 @@ using Turing end @testset "Check ADType" begin - alg = HMC(0.1, 10; adtype=adbackend) - m = DynamicPPL.contextualize( - gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context) - ) - # These will error if the adbackend being used is not the one set. - sample(StableRNG(seed), m, alg, 10) + # These tests don't make sense for Enzyme, since it does not use a particular element + # type. + if !(adbackend isa AutoEnzyme) + alg = HMC(0.1, 10; adtype=adbackend) + m = DynamicPPL.contextualize( + gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context) + ) + # These will error if the adbackend being used is not the one set. + sample(StableRNG(seed), m, alg, 10) + end end end diff --git a/test/mcmc/sghmc.jl b/test/mcmc/sghmc.jl index c1d07d2ce..22327b981 100644 --- a/test/mcmc/sghmc.jl +++ b/test/mcmc/sghmc.jl @@ -4,6 +4,7 @@ using ..Models: gdemo_default using ..NumericalTests: check_gdemo import ..ADUtils using Distributions: sample +import Enzyme import ForwardDiff using LinearAlgebra: dot import ReverseDiff diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index cc9ab8c87..d93781837 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -5,6 +5,7 @@ using ..ADUtils: ADUtils using Distributions using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL +using Enzyme: Enzyme using ForwardDiff: ForwardDiff using LinearAlgebra: Diagonal, I using Mooncake: Mooncake diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 2c01dc524..231c9ec88 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -1,11 +1,11 @@ module ADUtils +using Enzyme: Enzyme using ForwardDiff: ForwardDiff using Pkg: Pkg using Random: Random using ReverseDiff: ReverseDiff using Mooncake: Mooncake -using Test: Test using Turing: Turing using Turing: DynamicPPL using Zygote: Zygote @@ -239,6 +239,10 @@ adbackends = [ Turing.AutoForwardDiff(; chunksize=0), Turing.AutoReverseDiff(; compile=false), Turing.AutoMooncake(; config=nothing), + # TODO(mhauru) Do we want to run both? For now yes, while building up Enzyme + # integration, but in the long term maybe not? + Turing.AutoEnzyme(; mode=Enzyme.Forward), + Turing.AutoEnzyme(; mode=Enzyme.Reverse), ] end diff --git a/test/test_utils/test_utils.jl b/test/test_utils/test_utils.jl index bf9f2b9b8..b29ae6226 100644 --- a/test/test_utils/test_utils.jl +++ b/test/test_utils/test_utils.jl @@ -1,7 +1,7 @@ """Module for testing the test utils themselves.""" module TestUtilsTests -using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError +using ..ADUtils: ADTypeCheckContext, AbstractWrongADBackendError, adbackends using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff using Test: @test, @testset, @test_throws @@ -13,12 +13,11 @@ using Zygote: Zygote @testset "ADTypeCheckContext" begin Turing.@model test_model() = x ~ Turing.Normal(0, 1) tm = test_model() - adtypes = ( - Turing.AutoForwardDiff(), - Turing.AutoReverseDiff(), - Turing.AutoZygote(), - # TODO: Mooncake - # Turing.AutoMooncake(config=nothing), + # These tests don't make sense for Enzyme, since it doesn't have its own element type. + # TODO(mhauru): Make these tests work for more Mooncake. + adtypes = filter( + adtype -> !(adtype isa Turing.AutoMooncake || adtype isa Turing.AutoEnzyme), + adbackends, ) for actual_adtype in adtypes sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)