Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested AD failure with logσ after https://github.com/JuliaDiff/ChainRules.jl/pull/644 #432

Closed
ToucheSir opened this issue Aug 13, 2022 · 3 comments · Fixed by FluxML/Zygote.jl#1287

Comments

@ToucheSir
Copy link
Member

Found while creating FluxML/Zygote.jl#1285. Doing some printf debugging near https://github.com/FluxML/Zygote.jl/blob/be5b47fad5fc9c0a3e22f239ec8517df60ffdf4c/src/lib/broadcast.jl#L196:

summary(y∂b) = "5-element Vector{Tuple{Float64, Zygote.ZBack{NNlib.var\"#sigmoid_fast_pullback#142\"{Float64, ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}"
args = (Base.Broadcast.Broadcasted(-, ([-0.8, -0.1, 0.2, 0.4, 1.3],)),)

This indicates

(:logσ, :(sigmoid_fast(-x))),
is working, but for some reason the broadcast isn't being materialized before hitting the rule. Could this be something similar to dfdx/Yota.jl#121 (comment)?

cc @mcabbott

@mcabbott
Copy link
Member

mcabbott commented Aug 13, 2022

The CR rule for .+ is lazier, doesn't materialize, so that things can fuse. But the mystery is why that should be called here, instead of Zygote's own rules.

Edit, MWE is this:

julia> using Zygote, NNlib, ChainRules, Test

julia> ENV["JULIA_DEBUG"] = ChainRules;

julia> x = [-0.9, -0.2, 0.1, 0.3, 1.2];

julia> H1 = Zygote.hessian_dual(x -> sum(abs2, relu.(x .+ 0.1)), x);

julia> @test H1  Zygote.hessian_reverse(x -> sum(abs2, relu.(x .+ 0.1)), x)  # ok
Test Passed

julia> H2 = Zygote.hessian_dual(x -> sum(abs2, logsigmoid.(x .+ 0.1)), x);

julia> @test H2  Zygote.hessian_reverse(x -> sum(abs2, logsigmoid.(x .+ 0.1)), x)
┌ Debug: broadcasting: minus 1
└ @ ChainRules ~/.julia/packages/ChainRules/iTLxh/src/rulesets/Base/broadcast.jl:179
Error During Test at REPL[35]:1
  Test threw exception
  Expression: H2  Zygote.hessian_reverse((x->begin
                sum(abs2, logsigmoid.(x .+ 0.1))
            end), x)
  MethodError: no method matching unbroadcast(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}, ::Vector{Float64})
  Closest candidates are:
    unbroadcast(::Number, ::Any) at ~/.julia/packages/Zygote/H6vD3/src/lib/broadcast.jl:57
    unbroadcast(::Base.RefValue, ::Any) at ~/.julia/packages/Zygote/H6vD3/src/lib/broadcast.jl:59
    unbroadcast(::Tuple{Any}, ::Any) at ~/.julia/packages/Zygote/H6vD3/src/lib/broadcast.jl:58
    ...
  Stacktrace:
    [1] map(f::typeof(Zygote.unbroadcast), t::Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}}, s::Tuple{Vector{Float64}})
      @ Base ./tuple.jl:246
    [2] (::Zygote.var"#∇broadcasted#1106"{Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}}, Vector{Tuple{Float64, Zygote.ZBack{NNlib.var"#sigmoid_fast_pullback#142"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Val{2}})(ȳ::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/broadcast.jl:198
    [3] (::Zygote.var"#4012#back#1109"{Zygote.var"#∇broadcasted#1106"{Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}}, Vector{Tuple{Float64, Zygote.ZBack{NNlib.var"#sigmoid_fast_pullback#142"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Val{2}}})(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
    [4] (::Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4012#back#1109"{Zygote.var"#∇broadcasted#1106"{Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}}, Vector{Tuple{Float64, Zygote.ZBack{NNlib.var"#sigmoid_fast_pullback#142"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Val{2}}}})(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:203
    [5] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{}}, Zygote.var"#4012#back#1109"{Zygote.var"#∇broadcasted#1106"{Tuple{Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(-), Tuple{Vector{Float64}}}}, Vector{Tuple{Float64, Zygote.ZBack{NNlib.var"#sigmoid_fast_pullback#142"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}, Val{2}}}}})(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
    [6] Pullback
      @ ./broadcast.jl:1298 [inlined]
    [7] Pullback
      @ ~/.julia/packages/NNlib/TAcqa/src/activations.jl:886 [inlined]
    [8] (::typeof((λ)))(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
    [9] Pullback
      @ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:199 [inlined]
   [10] Pullback
      @ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:232 [inlined]
   [11] (::typeof((unthunk)))(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
   [12] Pullback
      @ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:104 [inlined]
   [13] (::Zygote.var"#561#566")(::Tuple{Vector{Float64}, typeof(∂(wrap_chainrules_output))}, δ::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:202
   [14] map
      @ ./tuple.jl:247 [inlined]
   [15] map
      @ ./tuple.jl:250 [inlined]
   [16] (::Zygote.var"#map_back#565"{typeof(Zygote.wrap_chainrules_output), 1, Tuple{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}}}, NNlib.var"#47#50"{Vector{Float64}, Vector{Float64}}}}}, Tuple{Val{3}}, Tuple{Tuple{Nothing, typeof((wrap_chainrules_output))}, Tuple{Nothing, typeof((wrap_chainrules_output))}, Tuple{Vector{Float64}, typeof((wrap_chainrules_output))}}})(Δ::Tuple{Nothing, Nothing, Vector{Float64}})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/array.jl:202
   [17] (::Zygote.var"#2593#back#569"{Zygote.var"#map_back#565"{typeof(Zygote.wrap_chainrules_output), 1, Tuple{Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}}}, NNlib.var"#47#50"{Vector{Float64}, Vector{Float64}}}}}, Tuple{Val{3}}, Tuple{Tuple{Nothing, typeof((wrap_chainrules_output))}, Tuple{Nothing, typeof((wrap_chainrules_output))}, Tuple{Vector{Float64}, typeof((wrap_chainrules_output))}}}})(Δ::Tuple{Nothing, Nothing, Vector{Float64}})
      @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
   [18] Pullback
      @ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:105 [inlined]
   [19] Pullback
      @ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:204 [inlined]
   [20] (::typeof((λ)))(Δ::Tuple{Nothing, Nothing, Vector{Float64}})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
   [21] Pullback
      @ ./none:0 [inlined]
   [22] (::typeof((λ)))(Δ::Tuple{Nothing, Vector{Float64}})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
   [23] Pullback
      @ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41 [inlined]
   [24] (::typeof((λ)))(Δ::Tuple{Vector{Float64}})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
   [25] Pullback
      @ ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:76 [inlined]
   [26] (::typeof((gradient)))(Δ::Tuple{Vector{Float64}})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
   [27] Pullback
      @ ~/.julia/packages/Zygote/H6vD3/src/lib/grad.jl:87 [inlined]
   [28] (::typeof((#100)))(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
   [29] #212
      @ ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:203 [inlined]
   [30] (::Zygote.var"#1750#back#214"{Zygote.var"#212#213"{Tuple{Tuple{Nothing}}, typeof((#100))}})(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
   [31] Pullback
      @ ./operators.jl:1030 [inlined]
   [32] (::typeof((#_#95)))(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
   [33] (::Zygote.var"#212#213"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing}}, typeof((#_#95))})(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/lib.jl:203
   [34] #1750#back
      @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
   [35] Pullback
      @ ./operators.jl:1030 [inlined]
   [36] (::typeof((ComposedFunction{typeof(Zygote._jvec), Zygote.var"#100#101"{var"#25#26"}}(Zygote._jvec, Zygote.var"#100#101"{var"#25#26"}(var"#25#26"())))))(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface2.jl:0
   [37] (::Zygote.var"#56#57"{typeof((ComposedFunction{typeof(Zygote._jvec), Zygote.var"#100#101"{var"#25#26"}}(Zygote._jvec, Zygote.var"#100#101"{var"#25#26"}(var"#25#26"()))))})(Δ::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:41
   [38] withjacobian(f::Function, args::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/grad.jl:162
   [39] jacobian
      @ ~/.julia/packages/Zygote/H6vD3/src/lib/grad.jl:140 [inlined]
   [40] hessian_reverse(f::Function, x::Vector{Float64})
      @ Zygote ~/.julia/packages/Zygote/H6vD3/src/lib/grad.jl:87
   [41] top-level scope
      @ /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/Test/src/Test.jl:464
   [42] eval
      @ ./boot.jl:368 [inlined]
   [43] eval_user_input(ast::Any, backend::REPL.REPLBackend)
      @ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
   [44] repl_backend_loop(backend::REPL.REPLBackend)
      @ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
   [45] start_repl_backend(backend::REPL.REPLBackend, consumer::Any)
      @ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
   [46] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool)
      @ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
   [47] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL /Applications/Julia-1.8.app/Contents/Resources/julia/share/julia/stdlib/v1.8/REPL/src/REPL.jl:355
   [48] (::Base.var"#966#968"{Bool, Bool, Bool})(REPL::Module)
      @ Base ./client.jl:419
   [49] #invokelatest#2
      @ ./essentials.jl:729 [inlined]
   [50] invokelatest
      @ ./essentials.jl:726 [inlined]
   [51] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:404
   [52] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:318
ERROR: There was an error during testing

A minimal hack to fix it would be to allow Zygote's unbroadcast to fall back to the one in ChainRules, which does allow such input:

julia> Zygote.unbroadcast(x, dx) = ChainRules.unbroadcast(x, dx)

julia> @test H2  Zygote.hessian_reverse(x -> sum(abs2, logsigmoid.(x .+ 0.1)), x)
┌ Debug: broadcasting: minus 1
└ @ ChainRules ~/.julia/packages/ChainRules/iTLxh/src/rulesets/Base/broadcast.jl:179
Test Passed

But really better would be to figure out why Zygote is calling the ChainRules rule in the first place. Notice from the debug output that it's the rule for one-arg -, presumably from the minus inside the definition of logsigmoid's own broadcasting rule.

Edit': One guess is this rrule_via_ad shortcut, which I think doesn't favour the @adjoint over an rrule the way Zygote does elsewhere. But this seems not to be the issue here:
https://github.com/FluxML/Zygote.jl/blob/99d5a38b14dc842643acfa624b6f0f89061efbbf/src/compiler/chainrules.jl#L243-L246

Here's the stacktrace if I insert an error into the rrule being called unexpectedly:

julia> Zygote.hessian_reverse(x -> sum(abs2, logsigmoid.(x .+ 0.1)), x)
┌ Debug: broadcasting: minus 1
└ @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:179
ERROR: "rrule for broadcasted minus"
Stacktrace:
  [1] rrule(#unused#::typeof(Base.Broadcast.broadcasted), #unused#::typeof(-), x::Vector{Float64})
    @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:180
  [2] rrule(::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::Function, ::Vector{Float64})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ctmSK/src/rules.jl:134
  [3] chain_rrule
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:218 [inlined]
  [4] macro expansion
    @ ~/.julia/dev/Zygote/src/compiler/interface2.jl:0 [inlined]
  [5] _pullback(::Zygote.Context{false}, ::typeof(Base.Broadcast.broadcasted), ::typeof(-), ::Vector{Float64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:9
  [6] _pullback
    @ ~/.julia/packages/NNlib/0QnJJ/src/activations.jl:885 [inlined]
  [7] _pullback(::Zygote.Context{false}, ::NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
  [8] _pullback
    @ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:199 [inlined]
  [9] _pullback
    @ ~/.julia/packages/ChainRulesCore/ctmSK/src/tangent_types/thunks.jl:232 [inlined]
 [10] _pullback(ctx::Zygote.Context{false}, f::typeof(ChainRulesCore.unthunk), args::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}}}, NNlib.var"#47#50"{Vector{Float64}, Vector{Float64}}})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [11] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:105 [inlined]
 [12] (::Zygote.var"#485#489"{Zygote.Context{false}, typeof(Zygote.wrap_chainrules_output)})(args::ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}}}, NNlib.var"#47#50"{Vector{Float64}, Vector{Float64}}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:193
 [13] map
    @ ./tuple.jl:275 [inlined]
 [14] ∇map(cx::Zygote.Context{false}, f::typeof(Zygote.wrap_chainrules_output), args::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, ChainRulesCore.InplaceableThunk{ChainRulesCore.Thunk{NNlib.var"#48#51"{Vector{Float64}, Vector{Float64}}}, NNlib.var"#47#50"{Vector{Float64}, Vector{Float64}}}})
    @ Zygote ~/.julia/dev/Zygote/src/lib/array.jl:193
 [15] adjoint
    @ ~/.julia/dev/Zygote/src/lib/array.jl:219 [inlined]
 [16] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [17] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:106 [inlined]
 [18] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/chainrules.jl:206 [inlined]
 [19] _pullback
    @ ./REPL[10]:1 [inlined]
 [20] _pullback(ctx::Zygote.Context{false}, f::typeof((#11)), args::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [21] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:45 [inlined]
 [22] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#60#61"{typeof((#11))}, args::Float64)
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [23] _pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:97 [inlined]
 [24] _pullback(::Zygote.Context{false}, ::typeof(gradient), ::var"#11#12", ::Vector{Float64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [25] _pullback
    @ ~/.julia/dev/Zygote/src/lib/grad.jl:75 [inlined]
 [26] _pullback(ctx::Zygote.Context{false}, f::Zygote.var"#106#107"{var"#11#12"}, args::Vector{Float64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [27] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:831
 [28] adjoint
    @ ~/.julia/dev/Zygote/src/lib/lib.jl:203 [inlined]
 [29] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [30] _pullback
    @ ./operators.jl:1020 [inlined]
 [31] _pullback
    @ ./operators.jl:1019 [inlined]
 [32] _pullback
    @ ./operators.jl:1016 [inlined]
 [33] _pullback(::Zygote.Context{false}, ::Base.var"##_#95", ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::ComposedFunction{typeof(Zygote._jvec), Zygote.var"#106#107"{var"#11#12"}}, ::Vector{Float64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [34] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:831
 [35] adjoint
    @ ~/.julia/dev/Zygote/src/lib/lib.jl:203 [inlined]
 [36] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [37] _pullback
    @ ./operators.jl:1016 [inlined]
 [38] _pullback(ctx::Zygote.Context{false}, f::ComposedFunction{typeof(Zygote._jvec), Zygote.var"#106#107"{var"#11#12"}}, args::Vector{Float64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface2.jl:0
 [39] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote ~/.julia/dev/Zygote/src/compiler/interface.jl:44
 [40] pullback
    @ ~/.julia/dev/Zygote/src/compiler/interface.jl:42 [inlined]
 [41] withjacobian(f::Function, args::Vector{Float64})
    @ Zygote ~/.julia/dev/Zygote/src/lib/grad.jl:141
 [42] jacobian
    @ ~/.julia/dev/Zygote/src/lib/grad.jl:128 [inlined]
 [43] hessian_reverse(f::Function, x::Vector{Float64})
    @ Zygote ~/.julia/dev/Zygote/src/lib/grad.jl:75
 [44] top-level scope
    @ REPL[10]:1

@ToucheSir
Copy link
Member Author

ToucheSir commented Aug 14, 2022

I think hitting the rrule first for unary - makes sense, since the call chain goes broadcasted(-, arg) -> broadcasted(style, -, arg) and Zygote only has a rule for the latter in this case. What confused me at first was why the subsequent broadcast(sigmoid_fast, ...) is hitting Zygote, but that also makes sense because the first broadcast returns a Broadcasted and https://github.com/FluxML/NNlib.jl/blob/v0.8.9/src/activations.jl#L880 only dispatches on arrays. Should we change that?

@mcabbott
Copy link
Member

mcabbott commented Aug 15, 2022

Ah right, sorry I assumed there was a rule, but only binary - has one:
https://github.com/FluxML/Zygote.jl/blob/master/src/lib/broadcast.jl#L81-L82

So the simple solution is to define such a rule in Zygote. And perhaps audit the other fused rules to look for others with no Zygote equivalent:
https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/broadcast.jl#L149-L231

NNlib.jl/.../src/activations.jl#L880 only dispatches on arrays. Should we change that?

Yes, probably those should all allow for Broadcasted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants