From baf5ce363be3cdeda9b58fdc2a5384b15df484b4 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Sat, 11 Jul 2020 21:48:31 +0800 Subject: [PATCH 1/5] use vmap for all activations --- src/activation.jl | 7 ++----- test/activation.jl | 9 +++++++++ 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/activation.jl b/src/activation.jl index 3e7dd22c4..58b473709 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -118,7 +118,7 @@ elu(x::RealOrFloatType, α = one(x)) = vifelse(x ≥ 0, x / one(x), α * (exp(x) activation function. """ function gelu(x::RealOrFloatType) - p = oftype(x / 1, π) + p = oftype(x / 1, Float64(π)) λ = oftype(x / 1, √(2 / p)) α = oftype(x / 1, 0.044715) h = oftype(x / 1, 0.5) @@ -166,7 +166,7 @@ end Continuously Differentiable Exponential Linear Units See [Continuously Differentiable Exponential Linear Units](https://arxiv.org/pdf/1704.07483.pdf). """ -celu(x::RealOrFloatType, α::Real = one(x)) = vifelse(x ≥ 0, x / one(x), α * (exp(x/α) - one(x))) +celu(x::RealOrFloatType, α::RealOrFloatType = one(x)) = vifelse(x ≥ 0, x / one(x), α * (exp(x/α) - one(x))) """ @@ -230,8 +230,5 @@ softshrink(x::RealOrFloatType, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ) for f in (:σ, :hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink) @eval $(f)(x::AbstractArray, args...) = error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.") -end - -for f in (:σ, :tanh) @eval Base.broadcasted(::typeof($f), x::Array{T, N}) where {T <: Union{Float64, Float32}, N} = vmap($f, x) end diff --git a/test/activation.jl b/test/activation.jl index 70558fc62..13336d1af 100644 --- a/test/activation.jl +++ b/test/activation.jl @@ -112,6 +112,15 @@ end end end + @testset "Broadcasting" begin + for T in (Float32, Float64) + x = rand(T, 5) + for a in ACTIVATION_FUNCTIONS + @test a.(x) ≈ map(a, x) + end + end + end + @testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin test_value_int_input_forces_float64.(filter(x -> (x != relu && x != relu6 && x != hardtanh && x != trelu), ACTIVATION_FUNCTIONS)) From 6d624cce633fbf5f63631fd220822ee835e57ba4 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Sat, 11 Jul 2020 22:19:11 +0800 Subject: [PATCH 2/5] test broadcasting gradient --- test/activation.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/activation.jl b/test/activation.jl index 13336d1af..7c180b1d2 100644 --- a/test/activation.jl +++ b/test/activation.jl @@ -117,10 +117,11 @@ end x = rand(T, 5) for a in ACTIVATION_FUNCTIONS @test a.(x) ≈ map(a, x) + @test Zygote.gradient(z -> sum(a.(z)), x)[1] == a'.(x) end end end - + @testset "Test Integer64 and Integer32 inputs will force Float64 outputs" begin test_value_int_input_forces_float64.(filter(x -> (x != relu && x != relu6 && x != hardtanh && x != trelu), ACTIVATION_FUNCTIONS)) From 375b9c2415fecedba3e84ac1fc28942fa168d044 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Sat, 11 Jul 2020 22:42:11 +0800 Subject: [PATCH 3/5] add back tanh --- src/activation.jl | 2 +- test/activation.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/activation.jl b/src/activation.jl index 58b473709..77582e326 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -227,7 +227,7 @@ See [Softshrink Activation Function](https://www.gabormelli.com/RKB/Softshrink_A softshrink(x::RealOrFloatType, λ = oftype(x/1, 0.5)) = min(max(zero(x), x - λ), x + λ) # Provide an informative error message if activation functions are called with an array -for f in (:σ, :hardσ, :logσ, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink) +for f in (:σ, :hardσ, :logσ, :tanh, :hardtanh, :relu, :leakyrelu, :relu6, :rrelu, :elu, :gelu, :swish, :lisht, :selu, :celu, :trelu, :softsign, :softplus, :logcosh, :mish, :tanhshrink, :softshrink) @eval $(f)(x::AbstractArray, args...) = error("Use broadcasting (`", $(string(f)), ".(x)`) to apply activation functions to arrays.") @eval Base.broadcasted(::typeof($f), x::Array{T, N}) where {T <: Union{Float64, Float32}, N} = vmap($f, x) diff --git a/test/activation.jl b/test/activation.jl index 7c180b1d2..75246d004 100644 --- a/test/activation.jl +++ b/test/activation.jl @@ -1,6 +1,6 @@ using NNlib, Test, Zygote -ACTIVATION_FUNCTIONS = [σ, hardσ, logσ, hardtanh, relu, leakyrelu, relu6, rrelu, elu, gelu, celu, swish, lisht, selu, trelu, softplus, softsign, logcosh, mish, tanhshrink, softshrink]; +ACTIVATION_FUNCTIONS = [σ, hardσ, logσ, tanh, hardtanh, relu, leakyrelu, relu6, rrelu, elu, gelu, celu, swish, lisht, selu, trelu, softplus, softsign, logcosh, mish, tanhshrink, softshrink]; function test_value_float_precision_preserving(a) @testset "$(a): " begin From 33854f5a0a329f9276bb4dd58614cfefb0d925ed Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Sat, 11 Jul 2020 22:53:27 +0800 Subject: [PATCH 4/5] import tanh --- src/activation.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/activation.jl b/src/activation.jl index 77582e326..12dcd691c 100644 --- a/src/activation.jl +++ b/src/activation.jl @@ -1,6 +1,7 @@ export σ, sigmoid, hardσ, hardsigmoid, hardtanh, relu, leakyrelu, relu6, rrelu, elu, gelu, swish, selu, celu, softplus, softsign, logσ, logsigmoid, logcosh, mish, tanhshrink, softshrink, thresholdrelu, trelu, lisht +import Base: tanh import LoopVectorization: vifelse using LoopVectorization.SLEEFPirates: FloatType From d8e827b92076aeebc9929f87fd242c29a3a664e1 Mon Sep 17 00:00:00 2001 From: Yao Lu Date: Sun, 12 Jul 2020 00:12:39 +0800 Subject: [PATCH 5/5] isapprox --- test/activation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/activation.jl b/test/activation.jl index 75246d004..78249c145 100644 --- a/test/activation.jl +++ b/test/activation.jl @@ -117,7 +117,7 @@ end x = rand(T, 5) for a in ACTIVATION_FUNCTIONS @test a.(x) ≈ map(a, x) - @test Zygote.gradient(z -> sum(a.(z)), x)[1] == a'.(x) + @test isapprox(gradient(z -> sum(a.(z)), x)[1], a'.(x)) end end end