Skip to content

Commit

Permalink
Chain rule for EachVariate constructor (#1627)
Browse files Browse the repository at this point in the history
* added chain rule for eachvariate

* version bump

* added comment to explain why to_vec overload is needed

* simplified impl of rrule

* Update src/eachvariate.jl

Co-authored-by: David Widmann <[email protected]>

* Update runtests.jl

* Update src/eachvariate.jl

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
torfjelde and devmotion authored Oct 11, 2022
1 parent 2d0bce0 commit a31ebc4
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Distributions"
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
authors = ["JuliaStats"]
version = "0.25.75"
version = "0.25.76"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
11 changes: 11 additions & 0 deletions src/eachvariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ function EachVariate{V}(x::AbstractArray{<:Real,M}) where {V,M}
return EachVariate{V,typeof(x),typeof(ax),T,M-V}(x, ax)
end

function ChainRulesCore.rrule(::Type{EachVariate{V}}, x::AbstractArray{<:Real}) where {V}
y = EachVariate{V}(x)
size_x = size(x)
function EachVariate_pullback(Δ)
# TODO: Should we also handle `Tangent{<:EachVariate}`?
Δ_out = reshape(mapreduce(vec, vcat, ChainRulesCore.unthunk(Δ)), size_x)
return (ChainRulesCore.NoTangent(), Δ_out)
end
return y, EachVariate_pullback
end

Base.IteratorSize(::Type{EachVariate{V,P,A,T,N}}) where {V,P,A,T,N} = Base.HasShape{N}()

Base.axes(x::EachVariate) = x.axes
Expand Down
17 changes: 17 additions & 0 deletions test/eachvariate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using ChainRulesTestUtils
using ChainRulesTestUtils: FiniteDifferences

# Without this, `to_vec` will also include the `axes` field of `EachVariate`.
function FiniteDifferences.to_vec(xs::Distributions.EachVariate{V}) where {V}
vals, vals_from_vec = FiniteDifferences.to_vec(xs.parent)
return vals, x -> Distributions.EachVariate{V}(vals_from_vec(x))
end

@testset "eachvariate.jl" begin
@testset "ChainRules" begin
xs = randn(2, 3, 4, 5)
test_rrule(Distributions.EachVariate{1}, xs)
test_rrule(Distributions.EachVariate{2}, xs)
test_rrule(Distributions.EachVariate{3}, xs)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ const tests = [
"univariate/discrete/discreteuniform",
"univariate/continuous/tdist",
"multivariate/product",
"eachvariate",

### missing files compared to /src:
# "common",
Expand Down

2 comments on commit a31ebc4

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/69907

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.25.76 -m "<description of version>" a31ebc4de29a491971587cf159b184349d6a24e9
git push origin v0.25.76

Please sign in to comment.