Skip to content

Commit

Permalink
Add rule for Dict iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
ToucheSir committed Aug 11, 2022
1 parent 99d5a38 commit 83cdacc
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,45 @@ end
end
end

# This rule behaves much like the getindex adjoint,
# just with an (internal) ordinal index instead of a key.
function _pullback(cx::AContext, ::typeof(iterate), d::Dict, i)
iter = iterate(d, i)
function dict_iterate_pullback(Δ)
(iter === nothing || Δ === nothing) && return
k, v = iter[1]
_, dv = Δ[1]
accum_param(cx, v, dv) === nothing && return
grad = grad_mut(cx, d)
grad[k] = accum(get(grad, k, nothing), dv)
return (nothing, grad, nothing)
end
return iter, dict_iterate_pullback
end

# ...while this one is to avoid duplicating code or differentiating skip_deleted.
# The alternative would be to write a rule for the private _iterate(::Dict, i).
function _pullback(cx::AContext, ::typeof(iterate), d::Dict)
# Calculation of i is the same used in iterate(::Dict)
return _pullback(cx, iterate, d, Base.skip_deleted(d, d.idxfloor))
end

function _pullback(cx::AContext, ::typeof(iterate), vi::Base.ValueIterator{<:Dict}, i::Int)
iter = iterate(vi, i)
function values_iterate_pullback(Δ)
(iter === nothing || Δ === nothing) && return
v, dv = iter[1], Δ[1]
accum_param(cx, v, dv) === nothing && return
# Same as vi.dict.keys[i], but without reaching into Dict internals.
# Iterating the dict instead of keys() is to hit the rules above in nested AD.
k = iterate(vi.dict, i)[1][1]
grad = grad_mut(cx, vi.dict)
grad[k] = accum(get(grad, k, nothing), dv)
return (nothing, (; dict = grad), nothing)
end
return iter, values_iterate_pullback
end

# Channels

grad_mut(ch::Channel) = Channel(ch.sz_max)
Expand Down
32 changes: 32 additions & 0 deletions test/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,36 @@

@test result1 == result2
end

@testset "Dict iteration" begin
# https://github.com/FluxML/Zygote.jl/issues/1065
function sumkv(d)
s = zero(d["c"])
for (k, v) in d
s += v
k == :b && (s += v)
end
return sum(s)
end

function sumvals(d)
s = zero(d["c"])
for v in values(d)
s += v
end
return sum(s)
end

d_num = Dict(:a => 3, :b => 4, "c" => 5)
d_arr = Dict(:a => [3], :b => [4], "c" => [5])
ps = d_arr |> values |> collect |> Params

@test gradient(sumkv, d_num)[1] == Dict(:a => 1, :b => 2, "c" => 1)
grads = gradient(() -> sumkv(d_arr), ps)
@test (grads[d_arr[:a]], grads[d_arr[:b]], grads[d_arr["c"]]) == ([1], [2], [1])

@test gradient(sumvals, d_num)[1] == Dict(:a => 1, :b => 1, "c" => 1)
grads = gradient(() -> sumvals(d_arr), ps)
@test (grads[d_arr[:a]], grads[d_arr[:b]], grads[d_arr["c"]]) == ([1], [1], [1])
end
end

0 comments on commit 83cdacc

Please sign in to comment.