Skip to content

Commit

Permalink
Treat Pairs(NamedTuple) as NamedTuple for indexing
Browse files Browse the repository at this point in the history
This prevents issues with double-counting when using kwargs.
  • Loading branch information
ToucheSir committed Aug 14, 2022
1 parent 99d5a38 commit 24a6111
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
28 changes: 26 additions & 2 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ end

# named tuple
@adjoint function pairs(t::NamedTuple{N}) where N

pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)

pairs_namedtuple_pullback(dx::Tuple{}) = (NamedTuple(),)

function pairs_namedtuple_pullback::Dict)
t0 = map(zero, t)
for (idx, v) in Δ
Expand All @@ -145,6 +145,30 @@ else
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...))
end

# Keyword arguments pretend to be a Dict, but are secretly wrapping a NamedTuple.
# We can treat them much the same, just with some plumbing to handle the extra `itr` field.
function _pullback(::AContext, ::typeof(getindex),
ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, k)
# So we don't close over kwarg values in the pullback
data = map(_ -> nothing, NamedTuple(ps))
function kwargs_getindex_pullback(Δ)
dps = (data = Base.setindex(data, Δ, k), itr = nothing)
return (nothing, dps, nothing)
end
return ps[k], kwargs_getindex_pullback
end

function _pullback(cx::AContext, ::typeof(literal_getindex),
ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, ::Val{K}) where K
val, gf_back = _pullback(cx, literal_getfield, NamedTuple(ps), Val(K))
function kwargs_literal_getindex_pullback(Δ)
dps = (data = gf_back(Δ)[2], itr = nothing)
return (nothing, dps, nothing)
end
return val, kwargs_literal_getindex_pullback
end

# Misc.
@adjoint function Base.getfield(p::Pair, i::Int)
function pair_getfield_pullback(Δ)
f, s = i == 1 ? (Δ, nothing) : (nothing, Δ)
Expand Down
17 changes: 14 additions & 3 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,17 @@ end
@test gradient(x -> x[].a, Ref((a=1, b=2))) == ((x = (a = 1, b = nothing),),)
@test gradient(x -> x[1][].a, [Ref((a=1, b=2)), Ref((a=3, b=4))]) == ([(x = (a = 1, b = nothing),), nothing],)
@test gradient(x -> x[1].a, [(a=1, b=2), "three"]) == ([(a = 1, b = nothing), nothing],)

@testset "indexing kwargs" begin
inner_lit_index(; kwargs...) = kwargs[:x]
outer_lit_index(; kwargs...) = inner_lit_index(; x=kwargs[:x])

inner_dyn_index(k; kwargs...) = kwargs[k]
outer_dyn_index(k; kwargs...) = inner_dyn_index(k; x=kwargs[k])

@test gradient(x -> outer_lit_index(; x), 0.0) == (1.0,)
@test gradient((x, k) -> outer_dyn_index(k; x), 0.0, :x) == (1.0, nothing)
end
end

function type_test()
Expand All @@ -562,7 +573,7 @@ end

@testset "Pairs" begin
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
@test (x->10*pairs((a=x, b=2))[2])'(100) === 0
@test (x->10*pairs((a=x, b=2))[2])'(100) === nothing
foo(;kw...) = 1
@test gradient(() -> foo(a=1,b=2.0)) === ()

Expand All @@ -578,8 +589,8 @@ end
@testset "kwarg splatting, pass in object" begin
g(; kwargs...) = kwargs[:x] * kwargs[:z]
h(somedata) = g(; somedata...)
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = 0.0, z = 3.0),)
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = 0.0, z = 3.0, x = 2.3),)
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = nothing, z = 3.0),)
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = nothing, z = 3.0, x = 2.3),)
end

@testset "Iterators" begin
Expand Down

0 comments on commit 24a6111

Please sign in to comment.