diff --git a/src/deprecations.jl b/src/deprecations.jl index dd36e17d42..e6e7360a22 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -36,3 +36,9 @@ zeros32(::Type, dims...) = throw(ArgumentError("Flux.zeros32 is always Float32, # v0.13 deprecations +function Broadcast.broadcasted(f::Recur, args...) + # This had an explicit @adjoint rule, calling Zygote.∇map(__context__, f, args...), until v0.12 + Base.depwarn("""Broadcasting is not safe to use with RNNs, as it does not guarantee an iteration order. + Re-writing this as a comprehension would be better.""", :broadcasted) + map(f, args...) # map isn't really safe either, but +end diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index 5cc1108d90..14e3b8801e 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -435,8 +435,3 @@ julia> g(rand(Float32, 3, 10)) |> size # batch size of 10 """ GRUv3(a...; ka...) = Recur(GRUv3Cell(a...; ka...)) Recur(m::GRUv3Cell) = Recur(m, m.state0) - -# TODO move to ChainRulesCore? -@adjoint function Broadcast.broadcasted(f::Recur, args...) - Zygote.∇map(__context__, f, args...) -end diff --git a/src/losses/ctc.jl b/src/losses/ctc.jl index 32833934b6..ed0a06101e 100644 --- a/src/losses/ctc.jl +++ b/src/losses/ctc.jl @@ -134,8 +134,9 @@ for mathematical details. ctc_loss(ŷ::AbstractArray, y) = ctc_alpha(ŷ, y).loss function ChainRulesCore.rrule(::typeof(ctc_loss), ŷ, y) - ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, out), NoTangent()) - return ctc_loss(ŷ, y), ctc_loss_pullback + tmp = ctc_alpha(ŷ, y) + ctc_loss_pullback(Δ) = (NoTangent(), Δ .* ∇ctc_loss(ŷ, y, tmp), NoTangent()) + return tmp.loss, ctc_loss_pullback end diff --git a/src/losses/utils.jl b/src/losses/utils.jl index 48e49a2923..e13a3e6206 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -23,7 +23,7 @@ end res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y)) end -ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # is this good enough? +ChainRulesCore.@scalar_rule xlogy(x, y) (log(y), x/y) # should help Diffractor's broadcasting ChainRulesCore.@scalar_rule xlogx(x) (log(y) + true) # This can be made an error in Flux v0.13, for now just a warning diff --git a/src/utils.jl b/src/utils.jl index 13e3caef42..4478feda26 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -793,8 +793,10 @@ true """ modules(m) = [x for x in Functors.fcollect(m) if !isleaflike(x)] -@nograd modules -ChainRulesCore.@non_differentiable modules(::Any) # is this correct? +@nograd modules # TODO: is this correct? might fail with explicit parameters. +function ChainRulesCore.rrule(::typeof(modules), m) + modules(m), dm -> error("Flux.modules is not at present differentiable, sorry") +end isleaflike(x) = Functors.isleaf(x) isleaflike(::Tuple{Vararg{<:Number}}) = true