Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TeLU activation functions telu and telu_fast #622

Open
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

zengmao
Copy link

@zengmao zengmao commented Jan 7, 2025

This PR adds the TeLU activation function advocated by a recent paper, following discussions on Julia Discourse.

telu and telu_fast have been added to activation.jl. The latter is slightly faster and uses tanh_fast while sacrificing accuracy a bit. The hard-coded derivatives are deriv_telu and deriv_telu_fast, respectively. The accuracy gap between the derivative functions is more significant, as deriv_telu re-organizes the expression to avoid numerical instabilities.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Some quick comments from a first pass...

telu_fast(x)

This is faster but less accruate version of `telu`. This function is associated with a hard-coded derivative,
`deriv_telu_fast`, which is faster but less accurate that `deriv_telu`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this less meaningfully less accurate? In the tests there should be some functions for measuring error, countepsfrom and friends.

My guess is that for NN purposes, we will only want the fast version, and probably @fastmath x * tanh_fast(exp(x)) to speed up exp too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my gist, but translated to this notation -- there is hardly any accuracy change:

julia> worst_eps(telu_fast, telu, -5:0.01f0:5)  # comparing to bigfloat
3

julia> worst_eps(telu, telu, -5:0.01f0:5)
2

src/activations.jl Outdated Show resolved Hide resolved
@zengmao
Copy link
Author

zengmao commented Jan 7, 2025

Thanks for the comments! I've updated the code to reuse telu(x) to compute the derivative, using tanh(exp(x)) = telu(x) / x. This is problematic for x close to zero, so when abs(x) falls below a cutoff value, I switch to a Taylor expansion for telu'(x) around x=0 with two terms. I've also added @fastmath annotation for telu_fast. Please see test/activation.jl for accuracy tests with functions like countepsfrom, e.g. under the line

@testset "tanh_fast, sigmoid_fast, telu_fast & deriv_telu_fast: Float64" begin

The fast derivative deriv_telu_fast suffers from up to 66 eps deviations due to the numerical instabilities at moderately large x for Float32. Here's one of the worst cases: at x=2.052801f0, the fast derivative gives 1.0000185f0, while the accurate result is 1.0000106f0. I suppose the practical consequence of this error is not significant.

@zengmao
Copy link
Author

zengmao commented Jan 7, 2025

Unfortunately, the update broke the test, since I have a type-dependent small-x cutoff which is different for Float16/32/64. Is it possible to rewrite something like small_x_cutoff(::Float32) = 4f0 to make it friendly to dual number arguments (which broke the tests)? Schematically, I think I need small_x_cutoff(::Union{Float32, Dual{Float32}}) = 4f0. How can this be done? Or is it OK to hard-code a constant cutoff which is only appropriate for Float32, since this is what most people use with NNlib?

P.S. Never mind, the AD trouble is gone after I use eps to control the cutoff:

function deriv_telu_fast(x::T, Ω) where T
    ifelse(abs(x) < sqrt(eps(T)), _deriv_telu_taylor_expansion(x), # if x is close to 0, return linear-order Taylor expansion
           ifelse(x >= -log(sqrt(eps(T))), one(x), _deriv_telu_fast(x, Ω))) # cut off large x to prevent `exp(x)` overflow.
end

@mcabbott
Copy link
Member

mcabbott commented Jan 7, 2025

I timed everything and tried to simplify a bit here:

https://gist.github.com/mcabbott/8fb03f175ee4e0c29ef4a7044dc19a85

Since then you've simplified this too, good going.

I still wish it were shorter!

  • We can just keep telu_fast, the error is like 3eps instead of 2eps worse case, which is totally fine, no need for two options.
  • Can't we just hardcode cutoffs like 0.01 and 4 for this taylor / const switch? When I test that, the least accurate point is always near x=-1, i.e. unrelated to the cutoffs.
  • I'm a bit scared to trust that this taylor expansion will always work out all the log etc. at compile-time, and would be happier just to have explicit coefficients.

Here's what I think the entire derivative code could look like. (I've inlined the "main path" just to have fewer symbols with confusingly similar names around.)

function deriv_telu(x::Real, _)
    # Adapted from the Discourse post, to avoid bad cancellations: <https://discourse.julialang.org/t/how-to-compute-tanhexp-telu-function-accurately/124464/7>
    exp_x = exp(x)
    tanh(exp_x) + 4x / (exp(exp_x - x/2) + exp(-exp_x - x/2))^2
end
function deriv_telu(x::T, Ω = telu(x)) where {T <: Union{Float16, Float32, Float64}}
    # Main path, re-using forward pass:
    tanh_exp_x = Ω / x
    sech_exp_x_squared = 1 - tanh_exp_x^2
    main = @fastmath tanh_exp_x + x * exp(x) * sech_exp_x_squared
    # That's badly behaved at zero, switch to a taylor series:
    taylor = _deriv_telu_taylor(x)
    # It's also badly behaved at large x, switch to 1!
    ifelse(abs(x) < T(0.01), taylor,  # this works just as well
        ifelse(x > 4, one(x), main)) # as does this
end
# Taylor coefficients are (tanh(1), 8*exp(1)^2 / (1+exp(1)^2)^2)
_deriv_telu_taylor(x::T) where T = convert(T, evalpoly(x, (0.7615941559557649, 0.8399486832280524)))
_deriv_telu_taylor(x::Float32) = evalpoly(x, (0.7615942f0, 0.83994865f0))
_deriv_telu_taylor(x::Float16) = evalpoly(x, (Float16(0.7617), Float16(0.84)))

In fact, the whole first deriv_telu(x::Real, _) method could probably be deleted, as I think the other would work fine for Dual too? But I haven't tried, and if not, it's fine to send them the slow path.

(The more exact formula could be kept in the tests, to compare there to this piecewise thing)

@zengmao
Copy link
Author

zengmao commented Jan 8, 2025

On one hand, the x>4 cutoff is fine for all types (Float16/32/64), because the gradients have saturated to 1.0 for all these float types before you reach x=4. On the other hand, the lower cutoff x<0.01 causes an error of the order $10^{-5}$ for an $O(1)$ gradient, which is quite large for Float64 but probably OK for Float32 (with about 280 eps deviation), and the latter is the dominant use case of NNlib. I'm happy if your code above is used as the final version. (I don't have experiences with gradient accuracy requirements in practical situations, so I'll trust you to make a call.)

P.S. maybe adding a second-order term to the Taylor expansion will guarantee sufficient accuracy for any practical NN purpose.

@mcabbott
Copy link
Member

mcabbott commented Jan 8, 2025

the lower cutoff x<0.01 causes an error of the order 10^-5 for an O(1) gradient

Ah, sorry I see some larger errors which I missed.

julia> find_worst(deriv_telu_fast, deriv_telu_exact, -0.1:0.0001f0:1)  # with abs(x) < 0.01
(302129306388, 0.009999997221166262)

julia> find_worst(deriv_telu_fast, deriv_telu_exact, -0.1:0.0001f0:1)  # with abs(x) < sqrt(eps(T))
(3, -0.04700000133889262)

I wonder a little bit if we should just take the hit & simplify to this. Counting forward+reverse it's only 30% slower (and may allow sufficiently smart AD not to keep the array containing y = telu.(x) around until reverse-time):

function deriv_telu_2(x::Real)
    # This version does not re-use forward pass, as doing so has 0/0 problems at x=0:
    exp_x = @fastmath exp(x)
    tanh_exp_x = tanh_fast(exp_x)
    main = tanh_exp_x + x * exp_x * (1 - tanh_exp_x^2)
    # That gives NaN at large x, where telu(x) is just relu(x) anyway:
    ifelse(x > 4, one(float(x)), main)
end

@zengmao
Copy link
Author

zengmao commented Jan 8, 2025

The new code which recomputes tanh_fast(exp(x)) looks good to me. Please feel free to commit it.

@zengmao
Copy link
Author

zengmao commented Jan 8, 2025

I can also update the code and the accuracy tests (mean_eps and worst_eps etc.) if we agree on a preferred version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants