diff --git a/Project.toml b/Project.toml index fedd2a600..124a1fbaf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.61.0" +version = "1.61.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -20,6 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [compat] Adapt = "3.4.0, 4" +AxisArrays = "0.4.7" ChainRulesCore = "1.20" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" @@ -41,6 +42,7 @@ SuiteSparse = "1" julia = "1.6" [extras] +AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" @@ -50,4 +52,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] +test = ["AxisArrays", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 830571ecd..ea081c99c 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -125,21 +125,21 @@ Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2) This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`, and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what `∇getindex` does next. - -It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't -allow `eltype(dy)`, nor does it work for many structured matrices. """ -_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) -_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = + fill!(similar(x, typeof(dy)), false) +function _setindex_zero(x::AbstractArray{<:Number}, dy, inds...) + return fill!(similar(x, eltype(dy)), false) +end function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), # but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors T = Union{typeof(dy), ZeroTangent} - return fill!(similar(x, T, axes(x)), ZeroTangent()) + return fill!(similar(x, T), ZeroTangent()) end function _setindex_zero(x::AbstractArray, dy, inds...) T = Union{eltype(dy), ZeroTangent} - return fill!(similar(x, T, axes(x)), ZeroTangent()) + return fill!(similar(x, T), ZeroTangent()) end ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index e878dd061..423a7afe7 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -128,6 +128,13 @@ end @test dx23[3] == dxfix[3] end + @testset "getindex(::AxisArray{<:Number})" begin + X = randn((2, 3)) + A = AxisArray(X; row=[:a, :b], col=[:x, :y, :z]) + dA, back = rrule(getindex, A, [:a], [:x, :z]) + unthunk(back(ones(1, 2))[2]) == [1.0 0.0 1.0; 0.0 0.0 0.0] + end + @testset "second derivatives: ∇getindex" begin @eval using ChainRules: ∇getindex # Forward, scalar result diff --git a/test/runtests.jl b/test/runtests.jl index 768f7c208..81bb4ee22 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils @nospecialize using Adapt +using AxisArrays using Base.Broadcast: broadcastable using ChainRules using ChainRules: stack