diff --git a/Project.toml b/Project.toml index 2a5cc3d668..23422b5343 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.75" +version = "0.25.76" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/eachvariate.jl b/src/eachvariate.jl index 701be99faa..36a9ae9e97 100644 --- a/src/eachvariate.jl +++ b/src/eachvariate.jl @@ -11,6 +11,17 @@ function EachVariate{V}(x::AbstractArray{<:Real,M}) where {V,M} return EachVariate{V,typeof(x),typeof(ax),T,M-V}(x, ax) end +function ChainRulesCore.rrule(::Type{EachVariate{V}}, x::AbstractArray{<:Real}) where {V} + y = EachVariate{V}(x) + size_x = size(x) + function EachVariate_pullback(Δ) + # TODO: Should we also handle `Tangent{<:EachVariate}`? + Δ_out = reshape(mapreduce(vec, vcat, ChainRulesCore.unthunk(Δ)), size_x) + return (ChainRulesCore.NoTangent(), Δ_out) + end + return y, EachVariate_pullback +end + Base.IteratorSize(::Type{EachVariate{V,P,A,T,N}}) where {V,P,A,T,N} = Base.HasShape{N}() Base.axes(x::EachVariate) = x.axes diff --git a/test/eachvariate.jl b/test/eachvariate.jl new file mode 100644 index 0000000000..f41a5207d2 --- /dev/null +++ b/test/eachvariate.jl @@ -0,0 +1,17 @@ +using ChainRulesTestUtils +using ChainRulesTestUtils: FiniteDifferences + +# Without this, `to_vec` will also include the `axes` field of `EachVariate`. +function FiniteDifferences.to_vec(xs::Distributions.EachVariate{V}) where {V} + vals, vals_from_vec = FiniteDifferences.to_vec(xs.parent) + return vals, x -> Distributions.EachVariate{V}(vals_from_vec(x)) +end + +@testset "eachvariate.jl" begin + @testset "ChainRules" begin + xs = randn(2, 3, 4, 5) + test_rrule(Distributions.EachVariate{1}, xs) + test_rrule(Distributions.EachVariate{2}, xs) + test_rrule(Distributions.EachVariate{3}, xs) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 715247bf85..614ec7fb1c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -86,6 +86,7 @@ const tests = [ "univariate/discrete/discreteuniform", "univariate/continuous/tdist", "multivariate/product", + "eachvariate", ### missing files compared to /src: # "common",