-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nested AD failure with logσ
after https://github.com/JuliaDiff/ChainRules.jl/pull/644
#432
Comments
The CR rule for Edit, MWE is this: julia> using Zygote, NNlib, ChainRules, Test
julia> ENV["JULIA_DEBUG"] = ChainRules;
julia> x = [-0.9, -0.2, 0.1, 0.3, 1.2];
julia> H1 = Zygote.hessian_dual(x -> sum(abs2, relu.(x .+ 0.1)), x);
julia> @test H1 ≈ Zygote.hessian_reverse(x -> sum(abs2, relu.(x .+ 0.1)), x) # ok
Test Passed
julia> H2 = Zygote.hessian_dual(x -> sum(abs2, logsigmoid.(x .+ 0.1)), x);
julia> @test H2 ≈ Zygote.hessian_reverse(x -> sum(abs2, logsigmoid.(x .+ 0.1)), x)
┌ Debug: broadcasting: minus 1
└ @ ChainRules ~/.julia/packages/ChainRules/iTLxh/src/rulesets/Base/broadcast.jl:179
Error During Test at REPL[35]:1
Test threw exception
Expression: H2 ≈ Zygote.hessian_reverse((x->begin
sum(abs2, logsigmoid.(x .+ 0.1))
end), x)
MethodError: no method matching unbroadcast(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}, ::Vector{Float64})
Closest candidates are:
unbroadcast(::Number, ::Any) at ~/.julia/packages/Zygote/H6vD3/src/lib/broadcast.jl:57
unbroadcast(::Base.RefValue, ::Any) at ~/.julia/packages/Zygote/H6vD3/src/lib/broadcast.jl:59
unbroadcast(::Tuple{Any}, ::Any) at ~/.julia/packages/Zygote/H6vD3/src/lib/broadcast.jl:58
...
Stacktrace: [1] map(f::typeof(Zygote.unbroadcast), t::Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}}, s::Tuple{Vector{Float64}})
@ Base ./tuple.jl:246
[2] (::Zygote.var"#∇broadcasted#1106"{Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}}, Vector{Tuple{Float64, Zygote.ZBack{NNlib.var"#sigmoid_fast_pullback#142"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Val{2}})(ȳ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/broadcast.jl:198
[3] (::Zygote.var"#4012#back#1109"{Zygote.var"#∇broadcasted#1106"{Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}}, Vector{Tuple{Float64, Zygote.ZBack{NNlib.var"#sigmoid_fast_pullback#142"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Val{2}}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[4] (::Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4012#back#1109"{Zygote.var"#∇broadcasted#1106"{Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}}, Vector{Tuple{Float64, Zygote.ZBack{NNlib.var"#sigmoid_fast_pullback#142"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Val{2}}}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:203
[5] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4012#back#1109"{Zygote.var"#∇broadcasted#1106"{Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}}, Vector{Tuple{Float64, Zygote.ZBack{NNlib.var"#sigmoid_fast_pullback#142"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Val{2}}}}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[6] Pullback
@ ./broadcast.jl:1298 [inlined]
[7] Pullback
@ ~/.julia/packages/NNlib/TAcqa/src/activations.jl:886 [inlined]
[8] (::typeof(∂(λ)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[9] Pullback
@ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:199 [inlined]
[10] Pullback
@ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:232 [inlined]
[11] (::typeof(∂(unthunk)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[12] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:104 [inlined]
[13] (::Zygote.var"#561#566")(::Tuple{Vector{Float64}, typeof(∂(wrap_chainrules_output))}, δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:202
[14] map
@ ./tuple.jl:247 [inlined]
[15] map
@ ./tuple.jl:250 [inlined]
[16] (::Zygote.var"#map_back#565"{typeof(Zygote.wrap_chainrules_output), 1, Tuple{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}}}, NNlib.var"#47#50"{Vector{Float64}, Vector{Float64}}}}}, Tuple{Val{3}}, Tuple{Tuple{Nothing, typeof(∂(wrap_chainrules_output))}, Tuple{Nothing, typeof(∂(wrap_chainrules_output))}, Tuple{Vector{Float64}, typeof(∂(wrap_chainrules_output))}}})(Δ::Tuple{Nothing, Nothing, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:202
[17] (::Zygote.var"#2593#back#569"{Zygote.var"#map_back#565"{typeof(Zygote.wrap_chainrules_output), 1, Tuple{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}}}, NNlib.var"#47#50"{Vector{Float64}, Vector{Float64}}}}}, Tuple{Val{3}}, Tuple{Tuple{Nothing, typeof(∂(wrap_chainrules_output))}, Tuple{Nothing, typeof(∂(wrap_chainrules_output))}, Tuple{Vector{Float64}, typeof(∂(wrap_chainrules_output))}}}})(Δ::Tuple{Nothing, Nothing, Vector{Float64}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[18] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:105 [inlined]
[19] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:204 [inlined]
[20] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Nothing, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[21] Pullback
@ ./none:0 [inlined]
[22] (::typeof(∂(λ)))(Δ::Tuple{Nothing, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[23] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41 [inlined]
[24] (::typeof(∂(λ)))(Δ::Tuple{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[25] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76 [inlined]
[26] (::typeof(∂(gradient)))(Δ::Tuple{Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[27] Pullback
@ ~/.julia/packages/Zygote/H6vD3/src/lib/grad.jl:87 [inlined]
[28] (::typeof(∂(#100)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[29] #212
@ ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:203 [inlined]
[30] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing}}, typeof(∂(#100))}})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[31] Pullback
@ ./operators.jl:1030 [inlined]
[32] (::typeof(∂(#_#95)))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[33] (::Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof(∂(#_#95))})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:203
[34] #1750#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[35] Pullback
@ ./operators.jl:1030 [inlined]
[36] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), Zygote.var"#100#101"{var"#25#26"}}(Zygote._jvec, Zygote.var"#100#101"{var"#25#26"}(var"#25#26"())))))(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
[37] (::Zygote.var"#56#57"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), Zygote.var"#100#101"{var"#25#26"}}(Zygote._jvec, Zygote.var"#100#101"{var"#25#26"}(var"#25#26"()))))})(Δ::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
[38] withjacobian(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/grad.jl:162
[39] jacobian
@ ~/.julia/packages/Zygote/H6vD3/src/lib/grad.jl:140 [inlined]
[40] hessian_reverse(f::Function, x::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/grad.jl:87
[41] top-level scope
@ /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:464
[42] eval
@ ./boot.jl:368 [inlined]
[43] eval_user_input(ast::Any, backend::REPL.REPLBackend)
@ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
[44] repl_backend_loop(backend::REPL.REPLBackend)
@ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
[45] start_repl_backend(backend::REPL.REPLBackend, consumer::Any)
@ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
[46] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool)
@ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
[47] run_repl(repl::REPL.AbstractREPL, consumer::Any)
@ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:355
[48] (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module)
@ Base ./client.jl:419
[49] #invokelatest#2
@ ./essentials.jl:729 [inlined]
[50] invokelatest
@ ./essentials.jl:726 [inlined]
[51] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
@ Base ./client.jl:404
[52] exec_options(opts::Base.JLOptions)
@ Base ./client.jl:318
ERROR: There was an error during testing A minimal hack to fix it would be to allow Zygote's julia> Zygote.unbroadcast(x, dx) = ChainRules.unbroadcast(x, dx)
julia> @test H2 ≈ Zygote.hessian_reverse(x -> sum(abs2, logsigmoid.(x .+ 0.1)), x)
┌ Debug: broadcasting: minus 1
└ @ ChainRules ~/.julia/packages/ChainRules/iTLxh/src/rulesets/Base/broadcast.jl:179
Test Passed But really better would be to figure out why Zygote is calling the ChainRules rule in the first place. Notice from the debug output that it's the rule for one-arg Edit': One guess is this Here's the stacktrace if I insert an error into the julia> Zygote.hessian_reverse(x -> sum(abs2, logsigmoid.(x .+ 0.1)), x)
┌ Debug: broadcasting: minus 1
└ @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:179
ERROR: "rrule for broadcasted minus"
Stacktrace:
[1] rrule(#unused#::typeof(Base.Broadcast.broadcasted), #unused#::typeof(-), x::Vector{Float64})
@ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:180
[2] rrule(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::Function, ::Vector{Float64})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/rules.jl:134
[3] chain_rrule
@ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:218 [inlined]
[4] macro expansion
@ ~/.julia/dev/Zygote/src/compiler/interface2.jl:0 [inlined]
[5] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.broadcasted), ::typeof(-), ::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:9
[6] _pullback
@ ~/.julia/packages/NNlib/0QnJJ/src/activations.jl:885 [inlined]
[7] _pullback(::Zygote.Context{false}, ::NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[8] _pullback
@ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:199 [inlined]
[9] _pullback
@ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:232 [inlined]
[10] _pullback(ctx::Zygote.Context{false}, f::typeof(ChainRulesCore.unthunk), args::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}}}, NNlib.var"#47#50"{Vector{Float64}, Vector{Float64}}})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[11] _pullback
@ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:105 [inlined]
[12] (::Zygote.var"#485#489"{Zygote.Context{false}, typeof(Zygote.wrap_chainrules_output)})(args::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}}}, NNlib.var"#47#50"{Vector{Float64}, Vector{Float64}}})
@ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:193
[13] map
@ ./tuple.jl:275 [inlined]
[14] ∇map(cx::Zygote.Context{false}, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}}}, NNlib.var"#47#50"{Vector{Float64}, Vector{Float64}}}})
@ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:193
[15] adjoint
@ ~/.julia/dev/Zygote/src/lib/array.jl:219 [inlined]
[16] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[17] _pullback
@ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:106 [inlined]
[18] _pullback
@ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:206 [inlined]
[19] _pullback
@ ./REPL[10]:1 [inlined]
[20] _pullback(ctx::Zygote.Context{false}, f::typeof(∂(#11)), args::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[21] _pullback
@ ~/.julia/dev/Zygote/src/compiler/interface.jl:45 [inlined]
[22] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#60#61"{typeof(∂(#11))}, args::Float64)
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[23] _pullback
@ ~/.julia/dev/Zygote/src/compiler/interface.jl:97 [inlined]
[24] _pullback(::Zygote.Context{false}, ::typeof(gradient), ::var"#11#12", ::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[25] _pullback
@ ~/.julia/dev/Zygote/src/lib/grad.jl:75 [inlined]
[26] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#106#107"{var"#11#12"}, args::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[27] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:831
[28] adjoint
@ ~/.julia/dev/Zygote/src/lib/lib.jl:203 [inlined]
[29] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[30] _pullback
@ ./operators.jl:1020 [inlined]
[31] _pullback
@ ./operators.jl:1019 [inlined]
[32] _pullback
@ ./operators.jl:1016 [inlined]
[33] _pullback(::Zygote.Context{false}, ::Base.var"##_#95", ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ComposedFunction{typeof(Zygote._jvec), Zygote.var"#106#107"{var"#11#12"}}, ::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[34] _apply(::Function, ::Vararg{Any})
@ Core ./boot.jl:831
[35] adjoint
@ ~/.julia/dev/Zygote/src/lib/lib.jl:203 [inlined]
[36] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[37] _pullback
@ ./operators.jl:1016 [inlined]
[38] _pullback(ctx::Zygote.Context{false}, f::ComposedFunction{typeof(Zygote._jvec), Zygote.var"#106#107"{var"#11#12"}}, args::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
[39] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:44
[40] pullback
@ ~/.julia/dev/Zygote/src/compiler/interface.jl:42 [inlined]
[41] withjacobian(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/lib/grad.jl:141
[42] jacobian
@ ~/.julia/dev/Zygote/src/lib/grad.jl:128 [inlined]
[43] hessian_reverse(f::Function, x::Vector{Float64})
@ Zygote ~/.julia/dev/Zygote/src/lib/grad.jl:75
[44] top-level scope
@ REPL[10]:1 |
I think hitting the |
Ah right, sorry I assumed there was a rule, but only binary So the simple solution is to define such a rule in Zygote. And perhaps audit the other fused rules to look for others with no Zygote equivalent:
Yes, probably those should all allow for Broadcasted. |
Found while creating FluxML/Zygote.jl#1285. Doing some printf debugging near https://github.com/FluxML/Zygote.jl/blob/be5b47fad5fc9c0a3e22f239ec8517df60ffdf4c/src/lib/broadcast.jl#L196:
This indicates
NNlib.jl/src/activations.jl
Line 849 in cb30e19
cc @mcabbott
The text was updated successfully, but these errors were encountered: