Skip to content

Commit

Permalink
Allow Parallel(+, f)(x, y, z) to work like broadcasting, and enable…
Browse files Browse the repository at this point in the history
… `Chain(identity, Parallel(+, f))(x, y, z)` (#2393)

* let Parallel(+, f)(x, y, z) work like broadcasting

* add (::Chain)(xs...) method

* more examples

* correction

* change implementation to dispatch

* nicer errors when called on zero inputs

* disallow zero layers, let's try this out
  • Loading branch information
mcabbott authored Nov 5, 2024
1 parent 7525499 commit 7be1ca7
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 16 deletions.
70 changes: 60 additions & 10 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
true
```
A chain may be called with multiple arguments, which is equivalent to calling it
with one tuple of these arguments. Such a tuple is understood by [`Parallel`](@ref)
to mean the same as several arguments:
```jldoctest
julia> Chain(println, println)(1, 2, 3) # three arguments become a tuple
(1, 2, 3)
nothing
julia> Chain(x->@show(x), Parallel(+, inv, abs2))(4, 5) # returns 1/4 + 5^2
x = (4, 5)
25.25
```
For large models, there is a special type-unstable path which can reduce compilation
times. This can be used by supplying a vector of layers `Chain([layer1, layer2, ...])`.
This feature is somewhat experimental, beware!
Expand All @@ -46,9 +60,10 @@ end
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex, Base.keys, Base.firstindex

@layer :expand Chain # the + opts-in to container-style pretty-printing
@layer :expand Chain # the option :expand opts-in to container-style pretty-printing

(c::Chain)(x) = _applychain(c.layers, x)
(c::Chain)(x, ys...) = _applychain(c.layers, (x, ys...))

@generated function _applychain(layers::Tuple{Vararg{Any,N}}, x) where {N}
symbols = vcat(:x, [gensym() for _ in 1:N])
Expand All @@ -68,6 +83,7 @@ end
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i])
Base.getindex(c::Chain{<:NamedTuple}, i::AbstractArray) =
Chain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i]))

function Base.show(io::IO, c::Chain)
print(io, "Chain(")
_show_layers(io, c.layers)
Expand Down Expand Up @@ -475,8 +491,11 @@ end
Create a layer which passes an input array to each path in
`layers`, before reducing the output with `connection`.
Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
If called with multiple inputs, one is passed to each layer, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
Obeys the similar rules to broadcasting:
* Called with one input `x`, this is equivalent to `connection([l(x) for l in layers]...)`.
* With multiple `inputs` and just one layer, it is instead `connection([layer(x) for x in inputs]...)`.
* With multiple inputs and multiple layers, one input is passed to each layer,
thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`.
Like [`Chain`](@ref), its sub-layers may be given names using the keyword constructor.
These can be accessed by indexing: `m[1] == m[:name]` is the first layer.
Expand All @@ -486,6 +505,25 @@ and [`Maxout`](@ref) which reduces by broadcasting `max`.
# Examples
```jldoctest
julia> p = Parallel(+, abs2, sqrt);
julia> p(3, 4) # == 3^2 + √4, two functions two inputs
11.0
julia> p((3, 4)) # tuple is always splatted
11.0
julia> p(4) # == 4^2 + √4, one input used twice
18.0
julia> Parallel(hcat, inv)(1, 2, 4) # one function three inputs
1×3 Matrix{Float64}:
1.0 0.5 0.25
```
With Flux layers:
```jldoctest
julia> model = Chain(Dense(3 => 5),
Parallel(vcat, Dense(5 => 4), Chain(Dense(5 => 7), Dense(7 => 4))),
Expand Down Expand Up @@ -516,35 +554,47 @@ struct Parallel{F, T<:Union{Tuple, NamedTuple}}
layers::T
end

_ParallelONE{T} = Parallel{T, <:Union{Tuple{Any}, NamedTuple{<:Any, <:Tuple{Any}}}}

Parallel(connection, layers...) = Parallel(connection, layers)
function Parallel(connection; kw...)
layers = NamedTuple(kw)
if :layers in keys(layers) || :connection in keys(layers)
throw(ArgumentError("a Parallel layer cannot have a named sub-layer called `connection` or `layers`"))
end
isempty(layers) && return Parallel(connection, ())
Parallel(connection, layers)
end
Parallel(connection, layers::Union{Tuple{}, @NamedTuple{}}) =
throw(ArgumentError("cannot construct a Parallel layer with no sub-layers"))

@layer :expand Parallel

(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
(m::Parallel)(xs::Tuple) = m(xs...)
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...) # one argument

function _parallel_check(layers, xs)
nl = length(layers)
nx = length(xs)
@assert nl > 1 # dispatch handles nl==1 cases
nx = length(xs)
if (nl != nx)
throw(ArgumentError(lazy"Parallel with $nl sub-layers can take one input or $nl inputs, but got $nx inputs"))
throw(ArgumentError(lazy"Parallel with $nl > 1 sub-layers can take one input or $nl inputs, but got $nx inputs"))
end
end
ChainRulesCore.@non_differentiable _parallel_check(nl, nx)

function (m::Parallel)(xs...)
function (m::Parallel)(x, ys...)
xs = (x, ys...)
_parallel_check(m.layers, xs)
m.connection(map(|>, xs, Tuple(m.layers))...)
m.connection(map(|>, xs, Tuple(m.layers))...) # multiple arguments & multiple layers
end

(m::_ParallelONE)(x, ys...) =
m.connection(map(z -> only(m.layers)(z), (x, ys...))...) # multiple arguments, one layer

(m::Parallel)(xs::Tuple) = m(xs...) # tuple is always splatted
(m::_ParallelONE)(xs::Tuple) = m(xs...) # solves an ambiguity

(m::Parallel)() = throw(ArgumentError("Parallel layer cannot take 0 inputs"))

Base.getindex(m::Parallel, i) = m.layers[i]
Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i])
Base.getindex(m::Parallel{<:Any, <:NamedTuple}, i::AbstractVector) =
Expand Down
23 changes: 17 additions & 6 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using Flux: activations
c = Chain(Dense(10, 5, σ), Dense(5, 2), Dense(2, 1, relu))
@test c[1] == c[begin]
@test c[3] == c[end]

@test Chain(identity)(1,2,3) == (1,2,3) # multiple args become a tuple
end

@testset "Activations" begin
Expand Down Expand Up @@ -228,17 +230,20 @@ using Flux: activations
end

@testset "concat size" begin
input = randn(10, 2)
input = randn32(10, 2)
@test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4)
@test size(Parallel(hcat, one = Dense(10, 10), two = identity)(input)) == (10, 4)
end

@testset "vararg input" begin
inputs = randn(10), randn(5), randn(4)
inputs = randn32(10), randn32(5), randn32(4)
@test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,)
@test size(Parallel(+; a = Dense(10, 2), b = Dense(5, 2), c = Dense(4, 2))(inputs)) == (2,)
@test_throws ArgumentError Parallel(+, sin, cos)(1,2,3) # wrong number of inputs
@test Parallel(+, sin, cos)(pi/2) 1
@test Parallel(+, sin, cos)(pi/2) 1 # one input, several layers
@test Parallel(/, abs)(3, -4) 3/4 # one layer, several inputs
@test Parallel(/, abs)((3, -4)) 3/4
@test Parallel(/; f=abs)(3, -4) 3/4
end

@testset "named access" begin
Expand All @@ -256,9 +261,13 @@ using Flux: activations
end

@testset "trivial cases" begin
@test Parallel(hcat) isa Parallel{typeof(hcat), Tuple{}} # not a NamedTuple
@test Parallel(hcat)(1) == hcat()
@test Parallel(hcat, inv)(2) == hcat(1/2) # still calls connection once.
# zero inputs, always an error
@test_throws ArgumentError Parallel(hcat)()
@test_throws ArgumentError Parallel(hcat, inv)()
@test_throws ArgumentError Parallel(hcat, inv, sqrt)()

# zero layers -- not useful... now made an error
@test_throws ArgumentError Parallel(hcat)
end

@testset "connection is called once" begin
Expand All @@ -270,6 +279,8 @@ using Flux: activations
@test CNT[] == 2
Parallel(f_cnt, sin)(1)
@test CNT[] == 3
Parallel(f_cnt, sin)(1,2,3)
@test CNT[] == 4
end

# Ref https://github.com/FluxML/Flux.jl/issues/1673
Expand Down

0 comments on commit 7be1ca7

Please sign in to comment.