diff --git a/src/losses/functions.jl b/src/losses/functions.jl index ffda2ff99a..457cfae9a4 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -603,14 +603,54 @@ function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims)) end +""" + logit_focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps(ŷ)) + +Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) +which can be used in classification tasks with highly imbalanced classes. +It down-weights well-classified examples and focuses on hard examples. +The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) output). + +The modulating factor, `γ`, controls the down-weighting strength. +For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref). + +# Example +```jldoctest +julia> y = [1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0] +3×5 Matrix{Int64}: + 1 0 0 0 1 + 0 1 0 1 0 + 0 0 1 0 0 + +julia> ŷ = reshape(-7:7, 3, 5) .* 1f0 +3×5 Matrix{Float32}: + 0.0900306 0.0900306 0.0900306 0.0900306 0.0900306 + 0.244728 0.244728 0.244728 0.244728 0.244728 + 0.665241 0.665241 0.665241 0.665241 0.665241 + +julia> Flux.logit_focal_loss(ŷ, y) ≈ 1.1277571935622628 +true +``` + +See also: [`Losses.focal_loss`](@ref) + +""" +function logit_focal_loss(ŷ, y; γ=2.0f0, agg=mean, dims=1, ϵ=epseltype(ŷ)) + _check_sizes(ŷ, y) + logpt = logsoftmax(ŷ; dims) + agg(sum(@. -y * (1 - exp(logpt + ϵ))^γ * (logpt + ϵ); dims)) +end + """ siamese_contrastive_loss(ŷ, y; margin = 1, agg = mean) - + Return the [contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf) which can be useful for training Siamese Networks. It is given by - - agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) - + + agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) + Specify `margin` to set the baseline for distance at which pairs are dissimilar. # Example diff --git a/test/losses.jl b/test/losses.jl index 2ca697a657..d1db2390eb 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -14,7 +14,7 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, Flux.Losses.dice_coeff_loss, Flux.Losses.poisson_loss, Flux.Losses.hinge_loss, Flux.Losses.squared_hinge_loss, - Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss] + Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss, Flux.Losses.siamese_contrastive_loss, Flux.Losses.logit_focal_loss] @testset "xlogx & xlogy" begin @@ -210,7 +210,20 @@ end @test Flux.focal_loss(ŷ1, y1) ≈ 0.45990566879720157 @test Flux.focal_loss(ŷ, y; γ=0.0) ≈ Flux.crossentropy(ŷ, y) end - + +@testset "logit_focal_loss" begin + rng = Random.seed!(Random.default_rng(), 5) + y = rand(rng, Float32, 6, 40, 2) + yhat = rand(rng, Float32, 6, 40, 2) + + @test logit_focal_loss(yhat, y; γ=0) ≈ + Flux.Losses.logitcrossentropy(yhat, y) + + + @test logit_focal_loss(yhat, y; γ=2) == + Flux.Losses.focal_loss(Flux.softmax(yhat; dims=1), y; γ=2) +end + @testset "siamese_contrastive_loss" begin y = [1 0 0 0