Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Less memory/communication-intensive pmap adjoint #1188

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DistributedArrays = "aaf54ef3-cdf8-58ed-94cc-d582ad619b94"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
Expand Down
66 changes: 55 additions & 11 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
using Random, FillArrays, AbstractFFTs
using FillArrays: AbstractFill, getindex_value
using Base.Broadcast: broadcasted, broadcast_shape
using Distributed: pmap, AbstractWorkerPool
using Distributed: pmap, AbstractWorkerPool, workers, nworkers, myid, remotecall_fetch
using DistributedArrays

@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)
Expand Down Expand Up @@ -192,7 +193,7 @@ _restore(dx, ::Val{N}) where {N} = length(dx) < N ? ntuple(i -> get(dx,i,nothing
last_or_nothing(::Nothing) = nothing
last_or_nothing(x) = last(x)

for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
for (mapfunc,∇mapfunc) in [(:map,:∇map)]
@eval function $∇mapfunc(cx, f::F, args::Vararg{Any, N}) where {F, N}
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
ys = map(first, ys_and_backs)
Expand Down Expand Up @@ -224,15 +225,58 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
end
end

@adjoint function pmap(f, wp::AbstractWorkerPool, args...; kwargs...)
ys_backs = pmap((x...) -> _pullback(__context__, f, x...), wp, args...; kwargs...)
ys, backs = unzip(ys_backs)
ys, function (Δ)
res = pmap((df,d) -> df(d), wp, backs, Δ; kwargs...)
Δf_and_args = unzip(res)
Δf = reduce(accum, Δf_and_args[1])
(Δf, nothing, Δf_and_args[2:end]..., nothing, nothing)
end
# Now that there is a backwards rule for zip,
# it should be fine to deal with only a single collection X
@adjoint function pmap(f, p::AbstractWorkerPool, X; kwargs...)
darr = dfill([], (nworkers(p) + 1,), vcat(myid(), workers(p))) # Include own proc to handle empty worker pool

function forw(x)
y, back = _pullback(__context__, f, x)
push!(darr[:L][1], back)
return y, myid(), length(darr[:L][1])
end

ys_IDs_indices = pmap(forw, p, X; kwargs...)
ys = getindex.(ys_IDs_indices, 1) # the primal values
IDs = getindex.(ys_IDs_indices, 2) # remember which processors handled which elements of X
indices = getindex.(ys_IDs_indices, 3) # remember the index of the pullback in the array on each processor
output_axes = axes(ys)

# create a list of positions in X handled by each processor
unique_IDs = sort!(unique(IDs))
T = eltype(eachindex(ys_IDs_indices))
positions = [Vector{T}() for _ in 1:length(unique_IDs)]
for i in eachindex(ys_IDs_indices)
push!(positions[searchsortedfirst(unique_IDs, IDs[i])], i)
end

function pmap_pullback(Δ)
# runs the pullback for each position handled by proc ID in forward pass
function run_backs(ID, positions)
Δ_batch = Δ[positions]
indices_batch = indices[positions]
res_batch = remotecall_fetch(ID) do
asyncmap((Δy, i) -> darr[:L][1][i](Δy), Δ_batch, indices_batch) # run all the backs in a local asyncmap
end
return res_batch
end

# combine the results from each proc into res

res_batches = asyncmap(run_backs, unique_IDs, positions)
res = similar(Array{Any}, output_axes)

for (positions, res_batch) in zip(positions, res_batches)
res[positions] = res_batch
end

# extract f̄ and X̄
Δf_and_args = unzip(res)
Δf = reduce(accum, Δf_and_args[1])
return (Δf, nothing, Δf_and_args[2:end]..., nothing, nothing)
end

return ys, pmap_pullback
end

for t in subtypes(AbstractWorkerPool)
Expand Down
94 changes: 58 additions & 36 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Zygote, Test, Random, LinearAlgebra, Statistics, FillArrays,
AbstractFFTs, FFTW, Distances
using Zygote: gradient
using Base.Broadcast: broadcast_shape
using Distributed: pmap, CachingPool, workers
using Distributed: pmap, CachingPool, workers, nworkers
import FiniteDifferences

function ngradient(f, xs::AbstractArray...)
Expand Down Expand Up @@ -279,57 +279,57 @@ end
for mapfunc in [map,pmap]
@testset "$mapfunc" begin
@test gradtest(xs -> sum(mapfunc(x -> x^2, xs)), rand(2,3))
@test gradtest((xss...) -> sum(mapfunc((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...)
@test gradtest((xss...) -> sum(mapfunc((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...) # test multiple collections
function foo(y)
bar = (x) -> x*y
sum(mapfunc(bar, 1:5))
end
@test gradtest(foo, 3)
@test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],)
end
end

@testset "Tuple adjoint" begin
x = randn(3)
_, pb = Zygote.pullback(x -> map(abs2, x), x)
Δy = randn(3)
@test first(pb((Δy..., ))) ≈ first(pb(Δy))
end
@testset "Tuple adjoint" begin
x = randn(3)
_, pb = Zygote.pullback(x -> map(abs2, x), x)
Δy = randn(3)
@test first(pb((Δy..., ))) ≈ first(pb(Δy))
end

@testset "empty tuples" begin
out, pb = Zygote.pullback(map, -, ())
@test pb(out) === (nothing, ())
@testset "empty tuples" begin
out, pb = Zygote.pullback(map, -, ())
@test pb(out) === (nothing, ())

out, pb = Zygote.pullback(map, +, (), ())
@test pb(()) === (nothing, (), ())
out, pb = Zygote.pullback(map, +, (), ())
@test pb(()) === (nothing, (), ())

function build_foo(z)
foo(x) = x * z
return foo
end
out, pb = Zygote.pullback(map, build_foo(5.0), ())
@test pb(()) === (nothing, ())
function build_foo(z)
foo(x) = x * z
return foo
end
out, pb = Zygote.pullback(map, build_foo(5.0), ())
@test pb(()) === (nothing, ())
end

@testset "Vector{Nothing} cotangent" begin
Δ = Vector{Nothing}(nothing, 5)
@testset "Vector{Nothing} cotangent" begin
Δ = Vector{Nothing}(nothing, 5)

# Unary stateless
out, pb = Zygote.pullback(map, -, randn(5))
@test pb(Δ)[2] isa Vector{Nothing}
# Unary stateless
out, pb = Zygote.pullback(map, -, randn(5))
@test pb(Δ)[2] isa Vector{Nothing}

# Binary stateless
out, pb = Zygote.pullback(map, +, randn(5), randn(5))
@test pb(Δ)[2] isa Vector{Nothing}
@test pb(Δ)[3] isa Vector{Nothing}
# Binary stateless
out, pb = Zygote.pullback(map, +, randn(5), randn(5))
@test pb(Δ)[2] isa Vector{Nothing}
@test pb(Δ)[3] isa Vector{Nothing}

# Stateful
function build_foo(z)
foo(x) = x * z
return foo
end
out, pb = Zygote.pullback(map, build_foo(5.0), randn(5))
@test pb(Δ)[2] isa Vector{Nothing}
# Stateful
function build_foo(z)
foo(x) = x * z
return foo
end
out, pb = Zygote.pullback(map, build_foo(5.0), randn(5))
@test pb(Δ)[2] isa Vector{Nothing}
end

# Check that map infers correctly. pmap still doesn't infer.
Expand Down Expand Up @@ -364,7 +364,7 @@ end
@test gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, nothing),)
end

@testset "Alternative Pmap Dispatch" begin
@testset "pmap with caching pool" begin
cache_and_map(f,xs...) = pmap(f, CachingPool(workers()), xs...; batch_size = 1)
@test gradtest(xs -> sum(cache_and_map(x -> x^2, xs)), rand(2,3))
@test gradtest((xss...) -> sum(cache_and_map((xs...) -> sqrt(sum(xs.^2)), xss...)), [rand(5) for _ in 1:6]...)
Expand All @@ -376,6 +376,28 @@ end
@test gradient(v -> sum([x for x in v]), [1.1,2.2,3.3]) == ([1, 1, 1],)
end

# more elaborate tests of pmap rule
@testset "multiple pmaps" begin
function sequential(xs)
ys = pmap(x -> x^2, xs)
sum(pmap(y -> y^3, ys))
end
@test gradtest(sequential, rand(2,3))

function nested(xs)
X = [xs[:, i] for i in 1:size(xs)[2]]
inner(arr) = sum(pmap(x -> x^2, arr))
sum(pmap(inner, X))
end
xs = rand(10, clamp(nworkers() - 1, 1, 2)) # only set outer iterations > 1 if we won't exhaust worker pool
@test gradtest(nested, xs)
end

@testset "pmap kwargs" begin
@test gradtest(xs -> pmap(x -> x^2, xs, batch_size=2), rand(4)) # batch_size > 1
@test gradtest(xs -> pmap(x -> x^2, xs, distributed=false), rand(4)) # distributed = false
end

@testset "Stateful Map" begin
s = 0
f(x) = (s += x)
Expand Down