diff --git a/Project.toml b/Project.toml index 4b88d68de..48dbf0249 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" diff --git a/src/lib/array.jl b/src/lib/array.jl index 7734ad5ca..aff0883b4 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -1,7 +1,8 @@ using Random, FillArrays, AbstractFFTs using FillArrays: AbstractFill, getindex_value using Base.Broadcast: broadcasted, broadcast_shape -using Distributed: pmap, AbstractWorkerPool +using Distributed: pmap, AbstractWorkerPool, workers, nworkers, myid, remotecall_fetch +using DistributedArrays @adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,) @adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,) @@ -192,7 +193,7 @@ _restore(dx, ::Val{N}) where {N} = length(dx) < N ? ntuple(i -> get(dx,i,nothing last_or_nothing(::Nothing) = nothing last_or_nothing(x) = last(x) -for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)] +for (mapfunc,∇mapfunc) in [(:map,:∇map)] @eval function $∇mapfunc(cx, f::F, args::Vararg{Any, N}) where {F, N} ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...) ys = map(first, ys_and_backs) @@ -224,15 +225,58 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)] end end -@adjoint function pmap(f, wp::AbstractWorkerPool, args...; kwargs...) - ys_backs = pmap((x...) -> _pullback(__context__, f, x...), wp, args...; kwargs...) - ys, backs = unzip(ys_backs) - ys, function (Δ) - res = pmap((df,d) -> df(d), wp, backs, Δ; kwargs...) - Δf_and_args = unzip(res) - Δf = reduce(accum, Δf_and_args[1]) - (Δf, nothing, Δf_and_args[2:end]..., nothing, nothing) - end +# Now that there is a backwards rule for zip, +# it should be fine to deal with only a single collection X +@adjoint function pmap(f, p::AbstractWorkerPool, X; kwargs...) + darr = dfill([], (nworkers(p) + 1,), vcat(myid(), workers(p))) # Include own proc to handle empty worker pool + + function forw(x) + y, back = _pullback(__context__, f, x) + push!(darr[:L][1], back) + return y, myid(), length(darr[:L][1]) + end + + ys_IDs_indices = pmap(forw, p, X; kwargs...) + ys = getindex.(ys_IDs_indices, 1) # the primal values + IDs = getindex.(ys_IDs_indices, 2) # remember which processors handled which elements of X + indices = getindex.(ys_IDs_indices, 3) # remember the index of the pullback in the array on each processor + output_axes = axes(ys) + + # create a list of positions in X handled by each processor + unique_IDs = sort!(unique(IDs)) + T = eltype(eachindex(ys_IDs_indices)) + positions = [Vector{T}() for _ in 1:length(unique_IDs)] + for i in eachindex(ys_IDs_indices) + push!(positions[searchsortedfirst(unique_IDs, IDs[i])], i) + end + + function pmap_pullback(Δ) + # runs the pullback for each position handled by proc ID in forward pass + function run_backs(ID, positions) + Δ_batch = Δ[positions] + indices_batch = indices[positions] + res_batch = remotecall_fetch(ID) do + asyncmap((Δy, i) -> darr[:L][1][i](Δy), Δ_batch, indices_batch) # run all the backs in a local asyncmap + end + return res_batch + end + + # combine the results from each proc into res + + res_batches = asyncmap(run_backs, unique_IDs, positions) + res = similar(Array{Any}, output_axes) + + for (positions, res_batch) in zip(positions, res_batches) + res[positions] = res_batch + end + + # extract f̄ and X̄ + Δf_and_args = unzip(res) + Δf = reduce(accum, Δf_and_args[1]) + return (Δf, nothing, Δf_and_args[2:end]..., nothing, nothing) + end + + return ys, pmap_pullback end for t in subtypes(AbstractWorkerPool) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 90f0a4b4a..d1ea79121 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -2,7 +2,7 @@ using Zygote, Test, Random, LinearAlgebra, Statistics, FillArrays, AbstractFFTs, FFTW, Distances using Zygote: gradient using Base.Broadcast: broadcast_shape -using Distributed: pmap, CachingPool, workers +using Distributed: pmap, CachingPool, workers, nworkers import FiniteDifferences function ngradient(f, xs::AbstractArray...) @@ -279,7 +279,7 @@ end for mapfunc in [map,pmap] @testset "$mapfunc" begin @test gradtest(xs -> sum(mapfunc(x -> x^2, xs)), rand(2,3)) - @test gradtest((xss...) -> sum(mapfunc((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...) + @test gradtest((xss...) -> sum(mapfunc((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...) # test multiple collections function foo(y) bar = (x) -> x*y sum(mapfunc(bar, 1:5)) @@ -287,49 +287,49 @@ for mapfunc in [map,pmap] @test gradtest(foo, 3) @test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],) end +end - @testset "Tuple adjoint" begin - x = randn(3) - _, pb = Zygote.pullback(x -> map(abs2, x), x) - Δy = randn(3) - @test first(pb((Δy..., ))) ≈ first(pb(Δy)) - end +@testset "Tuple adjoint" begin + x = randn(3) + _, pb = Zygote.pullback(x -> map(abs2, x), x) + Δy = randn(3) + @test first(pb((Δy..., ))) ≈ first(pb(Δy)) +end - @testset "empty tuples" begin - out, pb = Zygote.pullback(map, -, ()) - @test pb(out) === (nothing, ()) +@testset "empty tuples" begin + out, pb = Zygote.pullback(map, -, ()) + @test pb(out) === (nothing, ()) - out, pb = Zygote.pullback(map, +, (), ()) - @test pb(()) === (nothing, (), ()) + out, pb = Zygote.pullback(map, +, (), ()) + @test pb(()) === (nothing, (), ()) - function build_foo(z) - foo(x) = x * z - return foo - end - out, pb = Zygote.pullback(map, build_foo(5.0), ()) - @test pb(()) === (nothing, ()) + function build_foo(z) + foo(x) = x * z + return foo end + out, pb = Zygote.pullback(map, build_foo(5.0), ()) + @test pb(()) === (nothing, ()) +end - @testset "Vector{Nothing} cotangent" begin - Δ = Vector{Nothing}(nothing, 5) +@testset "Vector{Nothing} cotangent" begin + Δ = Vector{Nothing}(nothing, 5) - # Unary stateless - out, pb = Zygote.pullback(map, -, randn(5)) - @test pb(Δ)[2] isa Vector{Nothing} + # Unary stateless + out, pb = Zygote.pullback(map, -, randn(5)) + @test pb(Δ)[2] isa Vector{Nothing} - # Binary stateless - out, pb = Zygote.pullback(map, +, randn(5), randn(5)) - @test pb(Δ)[2] isa Vector{Nothing} - @test pb(Δ)[3] isa Vector{Nothing} + # Binary stateless + out, pb = Zygote.pullback(map, +, randn(5), randn(5)) + @test pb(Δ)[2] isa Vector{Nothing} + @test pb(Δ)[3] isa Vector{Nothing} - # Stateful - function build_foo(z) - foo(x) = x * z - return foo - end - out, pb = Zygote.pullback(map, build_foo(5.0), randn(5)) - @test pb(Δ)[2] isa Vector{Nothing} + # Stateful + function build_foo(z) + foo(x) = x * z + return foo end + out, pb = Zygote.pullback(map, build_foo(5.0), randn(5)) + @test pb(Δ)[2] isa Vector{Nothing} end # Check that map infers correctly. pmap still doesn't infer. @@ -364,7 +364,7 @@ end @test gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, nothing),) end -@testset "Alternative Pmap Dispatch" begin +@testset "pmap with caching pool" begin cache_and_map(f,xs...) = pmap(f, CachingPool(workers()), xs...; batch_size = 1) @test gradtest(xs -> sum(cache_and_map(x -> x^2, xs)), rand(2,3)) @test gradtest((xss...) -> sum(cache_and_map((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...) @@ -376,6 +376,28 @@ end @test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],) end +# more elaborate tests of pmap rule +@testset "multiple pmaps" begin + function sequential(xs) + ys = pmap(x -> x^2, xs) + sum(pmap(y -> y^3, ys)) + end + @test gradtest(sequential, rand(2,3)) + + function nested(xs) + X = [xs[:, i] for i in 1:size(xs)[2]] + inner(arr) = sum(pmap(x -> x^2, arr)) + sum(pmap(inner, X)) + end + xs = rand(10, clamp(nworkers() - 1, 1, 2)) # only set outer iterations > 1 if we won't exhaust worker pool + @test gradtest(nested, xs) +end + +@testset "pmap kwargs" begin + @test gradtest(xs -> pmap(x -> x^2, xs, batch_size=2), rand(4)) # batch_size > 1 + @test gradtest(xs -> pmap(x -> x^2, xs, distributed=false), rand(4)) # distributed = false +end + @testset "Stateful Map" begin s = 0 f(x) = (s += x)