diff --git a/src/Flux.jl b/src/Flux.jl index 0cacbd419a..c07e4e9e61 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -9,7 +9,9 @@ using MacroTools: @forward using MLUtils import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions -using Zygote, ChainRulesCore +using ChainRulesCore +import ChainRulesCore: rrule, RuleConfig, HasReverseMode +using Zygote using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 961f653f68..f010676ead 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -21,7 +21,7 @@ julia> x = rand(10, 32); julia> m(x) == m[2](m[1](x)) true -julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10 => 5, tanh)), +julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10 => 5, tanh)), dec = Dense(5 => 2)); julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x) @@ -65,6 +65,47 @@ function applychain(layers::AbstractVector, x) # type-unstable path, helps comp x end +_push!!(xs::Vector{T}, x::T) where T = push!(xs, x) +_push!!(@nospecialize(xs::Vector), @nospecialize(x)) = vcat(xs, [x]) + +_unpack_2tup(::Type{Tuple{A, B}}) where {A, B} = (A, B) + +function _pullback_stack_eltype(config::C, ::Type{F}, x::X) where {C, F, X} + Y1, P1 = _unpack_2tup(Core.Compiler.return_type(rrule_via_ad, Tuple{C, F, X})) + Y2, P2 = _unpack_2tup(Core.Compiler.return_type(rrule_via_ad, Tuple{C, F, Y1})) + return promote_type(Y1, Y2), promote_type(P1, P2) +end + +struct VectorChainPullback{P} + pullbacks::P +end +function (pb::VectorChainPullback)(dy) + # @show pullbacks + dlayers = Union{}[] + dx = dy + # @show dy + for pb in reverse(pb.pullbacks) + dlayer, dx = pb(dx) + # @show dx + dlayers = _push!!(dlayers, dlayer) + end + # @show dlayers + return (NoTangent(), dlayers, dx) +end + +function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(applychain), layers::AbstractVector, x) + YT, ST = _pullback_stack_eltype(config, eltype(layers), x) + pullbacks = ST[] + y::YT = x + for l in layers + y, pb = rrule_via_ad(config, l, y::YT) + # @show y, l + pullbacks = _push!!(pullbacks, pb) + end + + return y::YT, VectorChainPullback(pullbacks) +end + Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]) Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) = Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) @@ -159,7 +200,7 @@ function (a::Dense)(x::AbstractVecOrMat) return σ.(a.weight * x .+ a.bias) end -(a::Dense)(x::AbstractArray) = +(a::Dense)(x::AbstractArray) = reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...) function Base.show(io::IO, l::Dense) @@ -178,9 +219,9 @@ Create an element-wise layer, whose forward pass is given by: y = σ.(scale .* x .+ bias) This uses `.*` instead of matrix multiplication `*` of [`Dense`](@ref). - + The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`, -with `init=ones32` by default. You may specify the function `init`, +with `init=ones32` by default. You may specify the function `init`, turn off trainable bias with `bias=false`, or provide the array(s) explicitly. Used by [`LayerNorm`](@ref) with `affine=true`. @@ -248,7 +289,7 @@ Instead of defining layers individually, you can provide a zero-argument functio which constructs them, and the number to construct. Maxout over linear dense layers satisfies the univeral approximation theorem. -See Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks" +See Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks" [https://arxiv.org/abs/1302.4389](1302.4389). See also [`Parallel`](@ref) to reduce with other operators. @@ -501,7 +542,7 @@ end function _parallel_check(layers, xs) nl = length(layers) - nx = length(xs) + nx = length(xs) if (nl != nx) throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs")) end @@ -531,14 +572,14 @@ end ## Arguments -- `connection`: A function taking 2 inputs and combining them into a single output +- `connection`: A function taking 2 inputs and combining them into a single output - `layers`: The layers whose outputs are combined ## Inputs This layer behaves differently based on input type: -1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`, +1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`, then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`. Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))` may be drawn as: @@ -633,12 +674,12 @@ end """ Embedding(in => out; init=randn) -A lookup table that stores embeddings of dimension `out` +A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`. -This layer is often used to store word embeddings and retrieve them using indices. +This layer is often used to store word embeddings and retrieve them using indices. The input to the layer can be either a vector of indexes -or the corresponding [onehot encoding](@ref Flux.OneHotArray). +or the corresponding [onehot encoding](@ref Flux.OneHotArray). # Examples ```jldoctest