Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to write a type stable rrule(Chain{Vector}) #2003

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ using MacroTools: @forward
using MLUtils
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions

using Zygote, ChainRulesCore
using ChainRulesCore
import ChainRulesCore: rrule, RuleConfig, HasReverseMode
using Zygote
using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient

Expand Down
63 changes: 52 additions & 11 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ julia> x = rand(10, 32);
julia> m(x) == m[2](m[1](x))
true

julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10 => 5, tanh)),
julia> m2 = Chain(enc = Chain(Flux.flatten, Dense(10 => 5, tanh)),
dec = Dense(5 => 2));

julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
Expand Down Expand Up @@ -65,6 +65,47 @@ function applychain(layers::AbstractVector, x) # type-unstable path, helps comp
x
end

_push!!(xs::Vector{T}, x::T) where T = push!(xs, x)
_push!!(@nospecialize(xs::Vector), @nospecialize(x)) = vcat(xs, [x])

_unpack_2tup(::Type{Tuple{A, B}}) where {A, B} = (A, B)

function _pullback_stack_eltype(config::C, ::Type{F}, x::X) where {C, F, X}
Y1, P1 = _unpack_2tup(Core.Compiler.return_type(rrule_via_ad, Tuple{C, F, X}))
Y2, P2 = _unpack_2tup(Core.Compiler.return_type(rrule_via_ad, Tuple{C, F, Y1}))
return promote_type(Y1, Y2), promote_type(P1, P2)
end

struct VectorChainPullback{P}
pullbacks::P
end
function (pb::VectorChainPullback)(dy)
# @show pullbacks
dlayers = Union{}[]
dx = dy
# @show dy
for pb in reverse(pb.pullbacks)
dlayer, dx = pb(dx)
# @show dx
dlayers = _push!!(dlayers, dlayer)
end
# @show dlayers
return (NoTangent(), dlayers, dx)
end

function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(applychain), layers::AbstractVector, x)
YT, ST = _pullback_stack_eltype(config, eltype(layers), x)
pullbacks = ST[]
y::YT = x
for l in layers
y, pb = rrule_via_ad(config, l, y::YT)
# @show y, l
pullbacks = _push!!(pullbacks, pb)
end

return y::YT, VectorChainPullback(pullbacks)
end

Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))
Expand Down Expand Up @@ -159,7 +200,7 @@ function (a::Dense)(x::AbstractVecOrMat)
return σ.(a.weight * x .+ a.bias)
end

(a::Dense)(x::AbstractArray) =
(a::Dense)(x::AbstractArray) =
reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

function Base.show(io::IO, l::Dense)
Expand All @@ -178,9 +219,9 @@ Create an element-wise layer, whose forward pass is given by:
y = σ.(scale .* x .+ bias)

This uses `.*` instead of matrix multiplication `*` of [`Dense`](@ref).

The learnable scale & bias are initialised `init(size...)` and `zeros32(size...)`,
with `init=ones32` by default. You may specify the function `init`,
with `init=ones32` by default. You may specify the function `init`,
turn off trainable bias with `bias=false`, or provide the array(s) explicitly.

Used by [`LayerNorm`](@ref) with `affine=true`.
Expand Down Expand Up @@ -248,7 +289,7 @@ Instead of defining layers individually, you can provide a zero-argument functio
which constructs them, and the number to construct.

Maxout over linear dense layers satisfies the univeral approximation theorem.
See Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks"
See Goodfellow, Warde-Farley, Mirza, Courville & Bengio "Maxout Networks"
[https://arxiv.org/abs/1302.4389](1302.4389).

See also [`Parallel`](@ref) to reduce with other operators.
Expand Down Expand Up @@ -501,7 +542,7 @@ end

function _parallel_check(layers, xs)
nl = length(layers)
nx = length(xs)
nx = length(xs)
if (nl != nx)
throw(ArgumentError("Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
end
Expand Down Expand Up @@ -531,14 +572,14 @@ end

## Arguments

- `connection`: A function taking 2 inputs and combining them into a single output
- `connection`: A function taking 2 inputs and combining them into a single output
- `layers`: The layers whose outputs are combined

## Inputs

This layer behaves differently based on input type:

1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`,
1. If input `x` is a tuple of length N (or the input is `xs` with N `x`'s), matching the number of `layers`,
then each layer receives a new input `x[i]` combined with the previous output `y[i-1]` using `connection`.
Thus `(y1, y2, y3) = PairwiseFusion(connection, layer1, layer2, layer3)((x1, x2, x3))`
may be drawn as:
Expand Down Expand Up @@ -633,12 +674,12 @@ end
"""
Embedding(in => out; init=randn)

A lookup table that stores embeddings of dimension `out`
A lookup table that stores embeddings of dimension `out`
for a vocabulary of size `in`.

This layer is often used to store word embeddings and retrieve them using indices.
This layer is often used to store word embeddings and retrieve them using indices.
The input to the layer can be either a vector of indexes
or the corresponding [onehot encoding](@ref Flux.OneHotArray).
or the corresponding [onehot encoding](@ref Flux.OneHotArray).

# Examples
```jldoctest
Expand Down