From aac380c045ca7e07ed17b7264463fad92f7c814e Mon Sep 17 00:00:00 2001 From: Frames White Date: Thu, 18 May 2023 14:04:16 +0800 Subject: [PATCH] = --- src/rulesets/Base/indexing.jl | 2 +- test/rulesets/Base/indexing.jl | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 0ca102143..3871831f3 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -101,7 +101,7 @@ and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't allow `eltype(dy)`, nor does it work for many structured matrices. """ -_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false) +_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, eltype(dy), axes(x)), false) _setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false) function _setindex_zero(x::AbstractArray, dy, inds::Integer...) # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent), diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 3dbcd0bc9..c76a81bfc 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -166,6 +166,26 @@ @test Array(y3) == Array(x_23_gpu)[1, [1,1,2]] @test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0]) end + + # https://github.com/JuliaDiff/ChainRules.jl/issues/697 + @testset "pulling back mixes AbstractZero and co" begin + _, back = rrule(getindex, [1], 1) + _, gs = back(@not_implemented("test")) + @test unthunk(gs[2]) isa NotImplemented + + _, back2 = rrule(getindex, [1], 1) + gs2 = back2([NoTangent()]) + @test unthunk(gs2[2]) isa NoTangent + + # Above are not realistic since they should be solved by the AD not calling with that + # but more realistic is ended up with a tangent that has a mixture of things + _, back3 = rrule(getindex, [10, 0, -1], :) + gs3 = back3([2.0, NoTangent(), (@not_implemented "test2")]) + num, notan, not_imp = unthunk(gs3[2]) + @test num isa Real + @test iszero(notan) # We don't care if this gets converted to a 0.0 + @test not_imp isa NotImplemented + end end @testset "first & tail" begin