Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use LogExpFunctions for losses #1866

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Expand All @@ -25,6 +26,7 @@ Adapt = "3.0"
ArrayInterface = "3.1, 4"
CUDA = "3"
Functors = "0.2.1"
LogExpFunctions = "0.3"
MacroTools = "0.5"
NNlib = "0.8"
NNlibCUDA = "0.2"
Expand Down
1 change: 1 addition & 0 deletions src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Zygote
using Zygote: @adjoint
using ..Flux: ofeltype, epseltype
using CUDA
using LogExpFunctions: xlogx, xlogy
using NNlib: logsoftmax, logσ
import Base.Broadcast: broadcasted

Expand Down
27 changes: 1 addition & 26 deletions src/losses/utils.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,6 @@
"""
xlogx(x)

Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get zero.
"""
function xlogx(x)
result = x * log(x)
ifelse(iszero(x), zero(result), result)
end

"""
xlogy(x, y)

Return `x * log(y)` for `y > 0`, and zero when `x == 0`.
"""
function xlogy(x, y)
result = x * log(y)
ifelse(iszero(x), zero(result), result)
end
Comment on lines -16 to -19
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question is whether there are any performance differences, and whether we care. IIRC the replacements have if else instead of ifelse, but perhaps the compiler sorts it out?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found JuliaStats/LogExpFunctions.jl#26. GPU is the big question mark, but if #1791 is any indication there may not be a difference there either.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty comparable:

using Flux.Losses: xlogx as f_xlogx, xlogy as f_xlogy
using LogExpFunctions: xlogx as l_xlogx, xlogy as l_xlogy
using BenchmarkTools, CUDA

x, y, out = ntuple(_ -> rand(Float32, 100_000), 3);
cx, cy, cout = ntuple(_ -> CUDA.rand(Float32, 100_000), 3);

julia> @btime $out .= f_xlogx.($x);
  580.412 μs (0 allocations: 0 bytes)

julia> @btime $out .= l_xlogx.($x);
  580.883 μs (0 allocations: 0 bytes)

julia> @btime $out .= f_xlogy.($x, $y);
  622.826 μs (0 allocations: 0 bytes)

julia> @btime $out .= l_xlogy.($x, $y);
  657.381 μs (0 allocations: 0 bytes)

julia> @btime CUDA.@sync $cout .= f_xlogx.($cx);
  5.896 μs (7 allocations: 480 bytes)

julia> @btime CUDA.@sync $cout .= l_xlogx.($cx);
  5.832 μs (7 allocations: 480 bytes)

julia> @btime CUDA.@sync $cout .= f_xlogy.($cx, $cy);
  7.555 μs (23 allocations: 1.61 KiB)

julia> @btime CUDA.@sync $cout .= l_xlogy.($cx, $cy);
  7.114 μs (23 allocations: 1.61 KiB)

I did a couple of runs and there was a not insignificant amount of variability, but at least the relative times aren't too far off.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gret similar numbers.


@adjoint function broadcasted(::typeof(xlogy), x::Zygote.Numeric, y::Zygote.Numeric)
res = xlogy.(x, y)
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
end

# This can be made an error in Flux v0.13, for now just a warning
function _check_sizes(ŷ::AbstractArray, y::AbstractArray)
for d in 1:max(ndims(ŷ), ndims(y))
for d in 1:max(ndims(ŷ), ndims(y))
if size(ŷ,d) != size(y,d)
@warn "Size mismatch in loss function! In future this will be an error. In Flux <= 0.12 broadcasting accepts this, but may not give sensible results" summary(ŷ) summary(y) maxlog=3 _id=hash(size(y))
end
Expand Down
17 changes: 0 additions & 17 deletions test/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ using Test
using Flux: onehotbatch, σ

using Flux.Losses: mse, label_smoothing, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy
using Flux.Losses: xlogx, xlogy

# group here all losses, used in tests
const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
Expand All @@ -17,22 +16,6 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss]


@testset "xlogx & xlogy" begin
@test iszero(xlogx(0))
@test isnan(xlogx(NaN))
@test xlogx(2) ≈ 2.0 * log(2.0)
@inferred xlogx(2)
@inferred xlogx(0)

@test iszero(xlogy(0, 1))
@test isnan(xlogy(NaN, 1))
@test isnan(xlogy(1, NaN))
@test isnan(xlogy(NaN, NaN))
@test xlogy(2, 3) ≈ 2.0 * log(3.0)
@inferred xlogy(2, 3)
@inferred xlogy(0, 1)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these tests pass?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested locally before removing and they do. https://github.com/JuliaStats/LogExpFunctions.jl/blob/master/test/basicfuns.jl also looks like a strict superset of the Flux tests.


# First, regression-style y's
y = [1, 1, 0, 0]
ŷ = [.9, .1, .1, .9]
Expand Down