diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 39b86901e..fb387faba 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -52,7 +52,6 @@ function rrule(::typeof(getindex), x::Tuple, ::Colon) return x, getindex_back_4 end - ##### ##### getindex(::AbstractArray) ##### @@ -174,6 +173,15 @@ function rrule(::typeof(view), x::AbstractArray, inds...) return view(x, inds...), view_pullback end +function rrule(::typeof(view), x::AbstractArray, i::Integer, jkl::Integer...) + # This case returns a zero-dim array, unlike getindex. So we fool ∇getindex: + function view_pullback_0(dy) + nots = map(Returns(NoTangent()), (i, jkl...)) + return (NoTangent(), thunked∇getindex(x, dy, i:i, jkl...), nots...) + end + return view(x, i, jkl...), view_pullback_0 +end + ##### ##### setindex! ##### diff --git a/test/rulesets/Base/array.jl b/test/rulesets/Base/array.jl index 42e3674e9..075c5e050 100644 --- a/test/rulesets/Base/array.jl +++ b/test/rulesets/Base/array.jl @@ -366,6 +366,7 @@ end @test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2]) test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent())) test_rrule(findmin, rand(3,4), fkwargs=(dims=2,)) + test_rrule(findmin, rand(3,4), fkwargs=(dims=(1,2),)) end @testset "$imum" for imum in [maximum, minimum] diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index a02026fe7..695f06010 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -169,7 +169,7 @@ end test_rrule(view, rand(3, 4), :, 1) test_rrule(view, rand(3, 4), 2, [1, 1, 2]) - @test_broken test_rrule(view, rand(3, 4), 3, 4) # This is why ∇getindex needs one more argument, dammit + test_rrule(view, rand(3, 4), 3, 4) end @testset "setindex!" begin