diff --git a/Project.toml b/Project.toml index 0c208ae89..ab2c0f01f 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" version = "1.41.0" [deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -15,6 +16,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] +Adapt = "3.4.0" ChainRulesCore = "1.15.3" ChainRulesTestUtils = "1.5" Compat = "3.42.0, 4" @@ -28,7 +30,6 @@ StaticArrays = "1.2" julia = "1.6" [extras] -Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" @@ -38,4 +39,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] +test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index e323f7b6d..78ef7cd97 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -1,9 +1,11 @@ module ChainRules +using Adapt: adapt using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable using ChainRulesCore using Compat using Distributed +using GPUArraysCore: AbstractGPUArray using IrrationalConstants: logtwo, logten using LinearAlgebra using LinearAlgebra.BLAS diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 6bea6e06c..7b8df5b6c 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -241,3 +241,23 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...) return y, map_pullback end + +##### +##### `task_local_storage` +##### + +# Called by `@allowscalar` from GPUArrays + +ChainRules.@non_differentiable task_local_storage(key::Any) +ChainRules.@non_differentiable task_local_storage(key::Any, value::Any) + +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage), body::Function, key, value) + y, back = task_local_storage(key, value) do + rrule_via_ad(config, body) + end + function task_local_storage_pullback(dy) + dbody = only(back(dy)) + return (NoTangent(), dbody, NoTangent(), NoTangent()) + end + return y, task_local_storage_pullback +end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index a3d25ca82..d34c91c23 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -113,8 +113,6 @@ function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds::Integer...) end function ∇getindex!(dx::AbstractArray, x::AbstractArray, dy, inds...) view(dx, inds...) .+= dy - # For GPU arrays, `inds::Union{Integer, Base.Slice}...` is fine, but any other AbstractArray risks overwriting. - # Those should call `NNlib.scatter!`, alla https://github.com/FluxML/Zygote.jl/pull/1131 return dx end @@ -134,6 +132,25 @@ function rrule(::typeof(∇getindex), x, dy, inds...) return z, ∇getindex_pullback end +# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers. +# To avoid this, copy everything back to the CPU. +# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice: + +function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Integer...) + view(dx, inds...) .+= Ref(dy) + return dx +end +function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...) + view(dx, inds...) .+= dy + return dx +end +function ∇getindex!(dx::AbstractGPUArray, x::AbstractArray, dy, inds...) + dx_cpu = adapt(Array, dx) + view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy) + copyto!(dx, dx_cpu) + return dx +end + ##### ##### first, tail ##### diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index 0d5b09398..c2b9203a4 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -143,6 +143,25 @@ test_rrule(∇getindex, [rand(2) for _ in 1:3], rand(2), 3; check_inferred=false) test_rrule(∇getindex, [rand(2) for _ in 1:3], [rand(2), rand(2)], 1:2; check_inferred=false) end + + @testset "GPU" begin + x_23_gpu = jl(rand(2, 3)) + + # Scalar indexing, copied from: @macroexpand @allowscalar A[i] + # Gives an error in Pkg.test, no idea why + # y1, bk1 = rrule(CFG, Base.task_local_storage, () -> x_23_gpu[1], :ScalarIndexing, ScalarAllowed) + # @test y1 == @allowscalar x_gpu[1] + # bk1(1.0) # This is zero, because finite-differencing ignores the function + # ... but this works, and calls the rule: + # Zygote.gradient(x -> @allowscalar(x[1]), jl(rand(3)))[1] + + y2, bk2 = rrule(getindex, x_23_gpu, :, 2:3) # fast path, just broadcast .+= + @test unthunk(bk2(jl(ones(2,2)))[2]) == jl([0 1 1; 0 1 1]) + + y3, bk3 = rrule(getindex, x_23_gpu, 1, [1,1,2]) # slow path, copy to CPU + @test_skip Array(y3) == Array(x_gpu)[1, [1,1,2]] # error in Pkg.test, no idea why + @test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0]) + end end @testset "first & tail" begin @@ -178,6 +197,7 @@ end end @testset "unsafe_getindex" begin + # In real life this is called only on some AbstractRanges, but easier to test on Array: test_frule(Base.unsafe_getindex, collect(1:0.1:2), 3) test_rrule(Base.unsafe_getindex, collect(1:0.1:2), 3) end diff --git a/test/runtests.jl b/test/runtests.jl index 24c1d85b9..840c6de65 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,12 +2,15 @@ using Test, ChainRulesCore, ChainRulesTestUtils @nospecialize +using Adapt using Base.Broadcast: broadcastable using ChainRules using ChainRulesCore using ChainRulesTestUtils using ChainRulesTestUtils: rand_tangent, _fdm using FiniteDifferences +using GPUArraysCore +using JLArrays using LinearAlgebra using LinearAlgebra.BLAS using LinearAlgebra: dot