Skip to content

Commit

Permalink
Merge pull request #244 from FluxML/logsumexp
Browse files Browse the repository at this point in the history
add logsumexp
  • Loading branch information
CarloLucibello authored Dec 15, 2020
2 parents d130ba8 + 94737bc commit c2a2c84
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 256 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ Requires = "0.5, 1.0"
julia = "1.3"

[extras]
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "Zygote"]
test = ["Test", "Statistics", "Zygote"]
17 changes: 16 additions & 1 deletion src/softmax.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export softmax, softmax!, ∇softmax, ∇softmax!,
logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!
logsoftmax, logsoftmax!, ∇logsoftmax, ∇logsoftmax!,
logsumexp

"""
softmax(x; dims=1)
Expand Down Expand Up @@ -114,3 +115,17 @@ end

∇logsoftmax(Δ, xs; dims=1) = Δ .- sum(Δ, dims=dims) .* softmax(xs, dims=dims)
∇logsoftmax!(Δ, xs) = ∇logsoftmax!(Δ, Δ, xs)

"""
logsumexp(x; dims=:)
Computes `log.(sum(exp.(x); dims=dims))` in a numerically stable
way.
See also [`logsoftmax`](@ref).
"""
function logsumexp(xs::AbstractArray; dims=:)
max_ = maximum(xs, dims=dims)
log_ = log.(sum(exp.(xs .- max_), dims=dims))
return max_ .+ log_
end
Loading

0 comments on commit c2a2c84

Please sign in to comment.