Skip to content

Commit

Permalink
Remove greek-letter keyword from normalise (#2252)
Browse files Browse the repository at this point in the history
* use _greek_ascii_depwarn in normalise

* also a better example

* change use in LayerNorm
  • Loading branch information
mcabbott authored May 3, 2023
1 parent d5a0643 commit 7088682
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ function (a::LayerNorm)(x::AbstractArray)
end
end
eps = convert(float(eltype(x)), a.ϵ) # avoids promotion for Float16 data, but should ε chage too?
a.diag(normalise(x, dims=1:length(a.size), ϵ=eps))
a.diag(normalise(x; dims=1:length(a.size), eps))
end

function Base.show(io::IO, l::LayerNorm)
Expand Down
28 changes: 19 additions & 9 deletions src/layers/stateless.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,44 @@

"""
normalise(x; dims=ndims(x), ϵ=1e-5)
normalise(x; dims=ndims(x), eps=1e-5)
Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given by `dims`.
Per default, `dims` is the last dimension.
`ϵ` is a small additive factor added to the denominator for numerical stability.
`eps` is a small term added to the denominator for numerical stability.
# Examples
```jldoctest
julia> using Statistics
julia> x = [9, 10, 20, 60];
julia> x = [90, 100, 110, 130, 70];
julia> y = Flux.normalise(x);
julia> mean(x), std(x; corrected=false)
(100.0, 20.0)
julia> isapprox(std(y), 1, atol=0.2) && std(y) != std(x)
julia> y = Flux.normalise(x)
5-element Vector{Float64}:
-0.49999975000012503
0.0
0.49999975000012503
1.499999250000375
-1.499999250000375
julia> isapprox(std(y; corrected=false), 1, atol=1e-5)
true
julia> x = rand(1:100, 10, 2);
julia> x = rand(10:100, 10, 10);
julia> y = Flux.normalise(x, dims=1);
julia> isapprox(std(y, dims=1), ones(1, 2), atol=0.2) && std(y, dims=1) != std(x, dims=1)
julia> isapprox(std(y; dims=1, corrected=false), ones(1, 10), atol=1e-5)
true
```
"""
@inline function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
@inline function normalise(x::AbstractArray; dims=ndims(x), eps=ofeltype(x, 1e-5), ϵ=nothing)
ε = _greek_ascii_depwarn=> eps, :InstanceNorm, "ϵ" => "eps")
μ = mean(x, dims=dims)
σ = std(x, dims=dims, mean=μ, corrected=false)
return @. (x - μ) /+ ϵ)
return @. (x - μ) /+ ε)
end

"""
Expand Down

0 comments on commit 7088682

Please sign in to comment.