diff --git a/Project.toml b/Project.toml index 3eff2752eb..c6e679603a 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" @@ -25,6 +26,7 @@ Adapt = "3.0" ArrayInterface = "3.1, 4" CUDA = "3" Functors = "0.2.1" +LogExpFunctions = "0.3" MacroTools = "0.5" NNlib = "0.8" NNlibCUDA = "0.2" diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 413c4ee034..332a113b49 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -5,6 +5,7 @@ using Zygote using Zygote: @adjoint using ..Flux: ofeltype, epseltype using CUDA +using LogExpFunctions: xlogx, xlogy using NNlib: logsoftmax, logσ import Base.Broadcast: broadcasted diff --git a/src/losses/utils.jl b/src/losses/utils.jl index 386cd67166..e1cf9f7826 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -1,31 +1,6 @@ -""" - xlogx(x) - -Return `x * log(x)` for `x ≥ 0`, handling `x == 0` by taking the limit from above, to get zero. -""" -function xlogx(x) - result = x * log(x) - ifelse(iszero(x), zero(result), result) -end - -""" - xlogy(x, y) - -Return `x * log(y)` for `y > 0`, and zero when `x == 0`. -""" -function xlogy(x, y) - result = x * log(y) - ifelse(iszero(x), zero(result), result) -end - -@adjoint function broadcasted(::typeof(xlogy), x::Zygote.Numeric, y::Zygote.Numeric) - res = xlogy.(x, y) - res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y)) -end - # This can be made an error in Flux v0.13, for now just a warning function _check_sizes(ŷ::AbstractArray, y::AbstractArray) - for d in 1:max(ndims(ŷ), ndims(y)) + for d in 1:max(ndims(ŷ), ndims(y)) if size(ŷ,d) != size(y,d) @warn "Size mismatch in loss function! In future this will be an error. In Flux <= 0.12 broadcasting accepts this, but may not give sensible results" summary(ŷ) summary(y) maxlog=3 _id=hash(size(y)) end diff --git a/test/losses.jl b/test/losses.jl index f95b9cefd2..c2cf4899d9 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -2,7 +2,6 @@ using Test using Flux: onehotbatch, σ using Flux.Losses: mse, label_smoothing, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy -using Flux.Losses: xlogx, xlogy # group here all losses, used in tests const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, @@ -17,22 +16,6 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss] -@testset "xlogx & xlogy" begin - @test iszero(xlogx(0)) - @test isnan(xlogx(NaN)) - @test xlogx(2) ≈ 2.0 * log(2.0) - @inferred xlogx(2) - @inferred xlogx(0) - - @test iszero(xlogy(0, 1)) - @test isnan(xlogy(NaN, 1)) - @test isnan(xlogy(1, NaN)) - @test isnan(xlogy(NaN, NaN)) - @test xlogy(2, 3) ≈ 2.0 * log(3.0) - @inferred xlogy(2, 3) - @inferred xlogy(0, 1) -end - # First, regression-style y's y = [1, 1, 0, 0] ŷ = [.9, .1, .1, .9]