diff --git a/Project.toml b/Project.toml index 39bda8294..56e5ff128 100644 --- a/Project.toml +++ b/Project.toml @@ -18,17 +18,28 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +[weakdeps] +Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" + +[extensions] +ChainRulesKernelAbstractionsExt = ["Atomix", "GPUArrays", "KernelAbstractions"] + [compat] Adapt = "3.4.0, 4" +Atomix = "0.1" ChainRulesCore = "1.25" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" Distributed = "1" FiniteDifferences = "0.12.20" GPUArraysCore = "0.1.0, 0.2" +GPUArrays = "10, 11" IrrationalConstants = "0.1.1, 0.2" JLArrays = "0.1" JuliaInterpreter = "0.8,0.9" +KernelAbstractions = "0.9" LinearAlgebra = "1" Random = "1" RealDot = "0.1" @@ -41,10 +52,13 @@ SuiteSparse = "1" julia = "1.6" [extras] +Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl b/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl new file mode 100644 index 000000000..c171d049f --- /dev/null +++ b/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl @@ -0,0 +1,45 @@ +module ChainRulesKernelAbstractionsExt + +import Adapt +import Atomix +import ChainRules +import GPUArrays +import KernelAbstractions as KA + +using GPUArraysCore: AbstractGPUArray +using KernelAbstractions + +function ChainRules.∇getindex!(dx::AbstractGPUArray, dy, inds...) + kab = get_backend(dx) + + if KA.supports_atomics(kab) + gids = GPUArrays.to_indices(dx, inds) + idims = map(length, gids) + Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids) + scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy)) + else + dx_cpu = Adapt.adapt(Array, dx) + view(dx_cpu, Adapt.adapt(Array, inds)...) .+= Adapt.adapt(Array, dy) + copyto!(dx, dx_cpu) + end + return dx +end + +@kernel function scatter!(op, dest, src, idims, Is::Vararg{Any, N}) where N + _scatter!(@index(Global), op, dest, src, idims, Is...) +end + +@generated function _scatter!(i, op, dest, src, idims, Is::Vararg{Any, N}) where N + quote + is = @inbounds CartesianIndices(idims)[i] + dv = src[i] + Base.Cartesian.@nexprs $N j -> I_j = @inbounds((Is[j])[is[j]]) + Base.Cartesian.@ncall $N _accum! op dest dv j -> I_j + end +end + +function _accum!(op, dest, val, ids...) + Atomix.modify!(Atomix.IndexableRef(dest, (ids...,)), op, val) +end + +end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 61216bda2..329d0969e 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -168,9 +168,9 @@ function rrule(::typeof(∇getindex), x, dy, inds...) return z, ∇getindex_pullback end -# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers. -# To avoid this, copy everything back to the CPU. -# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice: +# NOTE: +# Generic `∇getindex!(dx::AbstractGPUArray, dy, inds...)` +# is implemented in `ext/` with a custom kernel. function ∇getindex!(dx::AbstractGPUArray, dy, inds::Integer...) view(dx, inds...) .+= Ref(dy) @@ -180,12 +180,6 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUni view(dx, inds...) .+= dy return dx end -function ∇getindex!(dx::AbstractGPUArray, dy, inds...) - dx_cpu = adapt(Array, dx) - view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy) - copyto!(dx, dx_cpu) - return dx -end ##### ##### view