From fb34703a2c57cec8cbdd23be816250dc9de17e91 Mon Sep 17 00:00:00 2001 From: Brian Chen Date: Thu, 25 Aug 2022 21:28:19 -0700 Subject: [PATCH] Handle nothing grads for Pairs.data --- src/lib/base.jl | 2 +- test/features.jl | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lib/base.jl b/src/lib/base.jl index 21ca62b1c..1a85cc56c 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -162,7 +162,7 @@ function _pullback(cx::AContext, ::typeof(literal_getindex), ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, ::Val{K}) where K val, gf_back = _pullback(cx, literal_getfield, NamedTuple(ps), Val(K)) function kwargs_literal_getindex_pullback(Δ) - dps = (data = gf_back(Δ)[2], itr = nothing) + dps = (data = gradindex(gf_back(Δ), 2), itr = nothing) return (nothing, dps, nothing) end return val, kwargs_literal_getindex_pullback diff --git a/test/features.jl b/test/features.jl index 4c16267f2..e3e0e55bd 100644 --- a/test/features.jl +++ b/test/features.jl @@ -591,6 +591,10 @@ end h(somedata) = g(; somedata...) @test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = nothing, z = 3.0),) @test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = nothing, z = 3.0, x = 2.3),) + + # for when no kwargs have grads backpropogated + no_kwarg_grad(x; kwargs...) = x[kwargs[:i]] + @test gradient(x -> no_kwarg_grad(x; i=1), [1]) == (1,) end @testset "Iterators" begin