-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TeLU activation functions telu
and telu_fast
#622
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Some quick comments from a first pass...
telu_fast(x) | ||
|
||
This is faster but less accruate version of `telu`. This function is associated with a hard-coded derivative, | ||
`deriv_telu_fast`, which is faster but less accurate that `deriv_telu`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this less meaningfully less accurate? In the tests there should be some functions for measuring error, countepsfrom
and friends.
My guess is that for NN purposes, we will only want the fast version, and probably @fastmath x * tanh_fast(exp(x))
to speed up exp too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my gist, but translated to this notation -- there is hardly any accuracy change:
julia> worst_eps(telu_fast, telu, -5:0.01f0:5) # comparing to bigfloat
3
julia> worst_eps(telu, telu, -5:0.01f0:5)
2
Thanks for the comments! I've updated the code to reuse
The fast derivative |
Unfortunately, the update broke the test, since I have a type-dependent small- P.S. Never mind, the AD trouble is gone after I use
|
I timed everything and tried to simplify a bit here: https://gist.github.com/mcabbott/8fb03f175ee4e0c29ef4a7044dc19a85 Since then you've simplified this too, good going. I still wish it were shorter!
Here's what I think the entire derivative code could look like. (I've inlined the "main path" just to have fewer symbols with confusingly similar names around.) function deriv_telu(x::Real, _)
# Adapted from the Discourse post, to avoid bad cancellations: <https://discourse.julialang.org/t/how-to-compute-tanhexp-telu-function-accurately/124464/7>
exp_x = exp(x)
tanh(exp_x) + 4x / (exp(exp_x - x/2) + exp(-exp_x - x/2))^2
end
function deriv_telu(x::T, Ω = telu(x)) where {T <: Union{Float16, Float32, Float64}}
# Main path, re-using forward pass:
tanh_exp_x = Ω / x
sech_exp_x_squared = 1 - tanh_exp_x^2
main = @fastmath tanh_exp_x + x * exp(x) * sech_exp_x_squared
# That's badly behaved at zero, switch to a taylor series:
taylor = _deriv_telu_taylor(x)
# It's also badly behaved at large x, switch to 1!
ifelse(abs(x) < T(0.01), taylor, # this works just as well
ifelse(x > 4, one(x), main)) # as does this
end
# Taylor coefficients are (tanh(1), 8*exp(1)^2 / (1+exp(1)^2)^2)
_deriv_telu_taylor(x::T) where T = convert(T, evalpoly(x, (0.7615941559557649, 0.8399486832280524)))
_deriv_telu_taylor(x::Float32) = evalpoly(x, (0.7615942f0, 0.83994865f0))
_deriv_telu_taylor(x::Float16) = evalpoly(x, (Float16(0.7617), Float16(0.84))) In fact, the whole first (The more exact formula could be kept in the tests, to compare there to this piecewise thing) |
On one hand, the P.S. maybe adding a second-order term to the Taylor expansion will guarantee sufficient accuracy for any practical NN purpose. |
Ah, sorry I see some larger errors which I missed.
I wonder a little bit if we should just take the hit & simplify to this. Counting forward+reverse it's only 30% slower (and may allow sufficiently smart AD not to keep the array containing function deriv_telu_2(x::Real)
# This version does not re-use forward pass, as doing so has 0/0 problems at x=0:
exp_x = @fastmath exp(x)
tanh_exp_x = tanh_fast(exp_x)
main = tanh_exp_x + x * exp_x * (1 - tanh_exp_x^2)
# That gives NaN at large x, where telu(x) is just relu(x) anyway:
ifelse(x > 4, one(float(x)), main)
end |
The new code which recomputes |
I can also update the code and the accuracy tests ( |
This PR adds the TeLU activation function advocated by a recent paper, following discussions on Julia Discourse.
telu
andtelu_fast
have been added toactivation.jl
. The latter is slightly faster and usestanh_fast
while sacrificing accuracy a bit. The hard-coded derivatives arederiv_telu
andderiv_telu_fast
, respectively. The accuracy gap between the derivative functions is more significant, asderiv_telu
re-organizes the expression to avoid numerical instabilities.