From ee832801dd50f4dff1f89d891b8e9257d8efa586 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Thu, 20 Jun 2024 18:30:32 +0300 Subject: [PATCH 1/5] Add scatter --- src/rulesets/Base/indexing.jl | 5 ++++ t.jl | 55 +++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 t.jl diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 61216bda2..42aa34ff0 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -180,7 +180,12 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUni view(dx, inds...) .+= dy return dx end + function ∇getindex!(dx::AbstractGPUArray, dy, inds...) + # TODO we want this + # @atomic dx[inds...] .+= dy + # return dx + dx_cpu = adapt(Array, dx) view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy) copyto!(dx, dx_cpu) diff --git a/t.jl b/t.jl new file mode 100644 index 000000000..235f18050 --- /dev/null +++ b/t.jl @@ -0,0 +1,55 @@ +using ChainRules +using GPUArrays +using Zygote +using AMDGPU +using KernelAbstractions +using KernelAbstractions: @atomic + +function _accum!(dest, val, ids...) + # TODO support passing `op` + @atomic dest[ids...] += val +end + +@generated function _scatter!(i, dest, src, idims, Is::Vararg{Any, N}) where N + quote + is = @inbounds CartesianIndices(idims)[i] + Base.Cartesian.@nexprs $N j -> I_j = @inbounds((Is[j])[is[j]]) + dv = dest[i] + Base.Cartesian.@ncall $N _accum! src dv j -> I_j + end +end + +@kernel function scatter!(dest, src, idims, Is::Vararg{Any, N}) where N + _scatter!(@index(Global), dest, src, idims, Is...) +end + +function main() + x = ROCArray(zeros(Float32, 16, 4, 2, 3)) + y = ROCArray(ones(Float32, 6, 2, 2)) + ids = ([4, 1, 4, 3, 2, 1], 1, :, 3) + + gids = GPUArrays.to_indices(x, ids) + idims = map(length, gids) + Is = map(AMDGPU.Adapt.adapt(GPUArrays.ToGPU(y)), gids) + + kab = get_backend(x) + scatter!(kab, 256)(y, x, idims, Is...; ndrange=length(y)) + @show y + @show Array(x)[:, 1, 1, 3] + + # @show x[ids...] + # x[ids...] .+= y + # return + + # Δ = ROCArray(ones(Float32, 1)) + + # y, back = Zygote.pullback(x) do x + # # xd = x[[4, 3, 2, 1], :, 1, [3, 1]] + # xd = x[] + # sum(xd; dims=(1:ndims(xd)...,)) + # end + # println("===============") + # back(Δ) + return +end +main() From 36b5c57f2e6ab27f8ad1d7a3018130bc7ce8395c Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Fri, 21 Jun 2024 14:15:27 +0300 Subject: [PATCH 2/5] Add extension --- Project.toml | 11 ++++ .../ChainRulesKernelAbstractionsExt.jl | 45 +++++++++++++++ src/rulesets/Base/indexing.jl | 17 +----- t.jl | 55 ------------------- 4 files changed, 59 insertions(+), 69 deletions(-) create mode 100644 ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl delete mode 100644 t.jl diff --git a/Project.toml b/Project.toml index 39bda8294..c647c759b 100644 --- a/Project.toml +++ b/Project.toml @@ -18,17 +18,28 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" +[weakdeps] +Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" + +[extensions] +ChainRulesKernelAbstractionsExt = ["Atomix", "GPUArrays", "KernelAbstractions"] + [compat] Adapt = "3.4.0, 4" +Atomix = "0.1" ChainRulesCore = "1.25" ChainRulesTestUtils = "1.5" Compat = "3.46, 4.2" Distributed = "1" FiniteDifferences = "0.12.20" GPUArraysCore = "0.1.0, 0.2" +GPUArrays = "10, 11" IrrationalConstants = "0.1.1, 0.2" JLArrays = "0.1" JuliaInterpreter = "0.8,0.9" +KernelAbstractions = "0.9" LinearAlgebra = "1" Random = "1" RealDot = "0.1" diff --git a/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl b/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl new file mode 100644 index 000000000..153de32e0 --- /dev/null +++ b/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl @@ -0,0 +1,45 @@ +module ChainRulesKernelAbstractionsExt + +import Adapt +import Atomix +import ChainRules +import GPUArrays +import KernelAbstractions as KA + +using GPUArraysCore: AbstractGPUArray +using KernelAbstractions + +function ChainRules.∇getindex!(dx::AbstractGPUArray, dy, inds...) + # kab = get_backend(dx) + + # if KA.supports_atomics(kab) + # gids = GPUArrays.to_indices(dx, inds) + # idims = map(length, gids) + # Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids) + # scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy)) + # else + dx_cpu = Adapt.adapt(Array, dx) + view(dx_cpu, Adapt.adapt(Array, inds)...) .+= Adapt.adapt(Array, dy) + copyto!(dx, dx_cpu) + # end + return dx +end + +@kernel function scatter!(op, dest, src, idims, Is::Vararg{Any, N}) where N + _scatter!(@index(Global), op, dest, src, idims, Is...) +end + +@generated function _scatter!(i, op, dest, src, idims, Is::Vararg{Any, N}) where N + quote + is = @inbounds CartesianIndices(idims)[i] + Base.Cartesian.@nexprs $N j -> I_j = @inbounds((Is[j])[is[j]]) + dv = src[i] + Base.Cartesian.@ncall $N _accum! op dest dv j -> I_j + end +end + +function _accum!(op, dest, val, ids...) + Atomix.modify!(Atomix.IndexableRef(dest, (ids...,)), op, val) +end + +end diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 42aa34ff0..329d0969e 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -168,9 +168,9 @@ function rrule(::typeof(∇getindex), x, dy, inds...) return z, ∇getindex_pullback end -# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers. -# To avoid this, copy everything back to the CPU. -# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice: +# NOTE: +# Generic `∇getindex!(dx::AbstractGPUArray, dy, inds...)` +# is implemented in `ext/` with a custom kernel. function ∇getindex!(dx::AbstractGPUArray, dy, inds::Integer...) view(dx, inds...) .+= Ref(dy) @@ -181,17 +181,6 @@ function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUni return dx end -function ∇getindex!(dx::AbstractGPUArray, dy, inds...) - # TODO we want this - # @atomic dx[inds...] .+= dy - # return dx - - dx_cpu = adapt(Array, dx) - view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy) - copyto!(dx, dx_cpu) - return dx -end - ##### ##### view ##### diff --git a/t.jl b/t.jl deleted file mode 100644 index 235f18050..000000000 --- a/t.jl +++ /dev/null @@ -1,55 +0,0 @@ -using ChainRules -using GPUArrays -using Zygote -using AMDGPU -using KernelAbstractions -using KernelAbstractions: @atomic - -function _accum!(dest, val, ids...) - # TODO support passing `op` - @atomic dest[ids...] += val -end - -@generated function _scatter!(i, dest, src, idims, Is::Vararg{Any, N}) where N - quote - is = @inbounds CartesianIndices(idims)[i] - Base.Cartesian.@nexprs $N j -> I_j = @inbounds((Is[j])[is[j]]) - dv = dest[i] - Base.Cartesian.@ncall $N _accum! src dv j -> I_j - end -end - -@kernel function scatter!(dest, src, idims, Is::Vararg{Any, N}) where N - _scatter!(@index(Global), dest, src, idims, Is...) -end - -function main() - x = ROCArray(zeros(Float32, 16, 4, 2, 3)) - y = ROCArray(ones(Float32, 6, 2, 2)) - ids = ([4, 1, 4, 3, 2, 1], 1, :, 3) - - gids = GPUArrays.to_indices(x, ids) - idims = map(length, gids) - Is = map(AMDGPU.Adapt.adapt(GPUArrays.ToGPU(y)), gids) - - kab = get_backend(x) - scatter!(kab, 256)(y, x, idims, Is...; ndrange=length(y)) - @show y - @show Array(x)[:, 1, 1, 3] - - # @show x[ids...] - # x[ids...] .+= y - # return - - # Δ = ROCArray(ones(Float32, 1)) - - # y, back = Zygote.pullback(x) do x - # # xd = x[[4, 3, 2, 1], :, 1, [3, 1]] - # xd = x[] - # sum(xd; dims=(1:ndims(xd)...,)) - # end - # println("===============") - # back(Δ) - return -end -main() From 4d4c70f448e94c57a6ee2beaee3808afd18a13ca Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Fri, 21 Jun 2024 14:16:26 +0300 Subject: [PATCH 3/5] Cleanup --- .../ChainRulesKernelAbstractionsExt.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl b/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl index 153de32e0..a55837dba 100644 --- a/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl +++ b/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl @@ -10,18 +10,18 @@ using GPUArraysCore: AbstractGPUArray using KernelAbstractions function ChainRules.∇getindex!(dx::AbstractGPUArray, dy, inds...) - # kab = get_backend(dx) - - # if KA.supports_atomics(kab) - # gids = GPUArrays.to_indices(dx, inds) - # idims = map(length, gids) - # Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids) - # scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy)) - # else + kab = get_backend(dx) + + if KA.supports_atomics(kab) + gids = GPUArrays.to_indices(dx, inds) + idims = map(length, gids) + Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids) + scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy)) + else dx_cpu = Adapt.adapt(Array, dx) view(dx_cpu, Adapt.adapt(Array, inds)...) .+= Adapt.adapt(Array, dy) copyto!(dx, dx_cpu) - # end + end return dx end From 35a599e01c9ce4fe62b7ceecb14e529eba4d60d5 Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Fri, 21 Jun 2024 14:27:00 +0300 Subject: [PATCH 4/5] Minor refactor --- .../ChainRulesKernelAbstractionsExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl b/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl index a55837dba..c171d049f 100644 --- a/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl +++ b/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl @@ -32,8 +32,8 @@ end @generated function _scatter!(i, op, dest, src, idims, Is::Vararg{Any, N}) where N quote is = @inbounds CartesianIndices(idims)[i] - Base.Cartesian.@nexprs $N j -> I_j = @inbounds((Is[j])[is[j]]) dv = src[i] + Base.Cartesian.@nexprs $N j -> I_j = @inbounds((Is[j])[is[j]]) Base.Cartesian.@ncall $N _accum! op dest dv j -> I_j end end From 24cce69c5e29473bbc90c521e97287f130d8b5fa Mon Sep 17 00:00:00 2001 From: Anton Smirnov Date: Fri, 21 Jun 2024 14:29:49 +0300 Subject: [PATCH 5/5] Fix 1.6 compat --- Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index c647c759b..56e5ff128 100644 --- a/Project.toml +++ b/Project.toml @@ -52,10 +52,13 @@ SuiteSparse = "1" julia = "1.6" [extras] +Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" JuliaInterpreter = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"