-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Logit focal loss #2138
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -603,14 +603,54 @@ function focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=epseltype(ŷ)) | |
agg(sum(@. -y * (1 - ŷ)^γ * log(ŷ); dims=dims)) | ||
end | ||
|
||
""" | ||
logit_focal_loss(ŷ, y; dims=1, agg=mean, γ=2, ϵ=eps(ŷ)) | ||
|
||
Return the [focal_loss](https://arxiv.org/pdf/1708.02002.pdf) | ||
which can be used in classification tasks with highly imbalanced classes. | ||
It down-weights well-classified examples and focuses on hard examples. | ||
The input, 'ŷ', is expected to be normalized (i.e. [softmax](@ref Softmax) output). | ||
|
||
The modulating factor, `γ`, controls the down-weighting strength. | ||
For `γ == 0`, the loss is mathematically equivalent to [`Losses.crossentropy`](@ref). | ||
|
||
# Example | ||
```jldoctest | ||
julia> y = [1 0 0 0 1 | ||
0 1 0 1 0 | ||
0 0 1 0 0] | ||
3×5 Matrix{Int64}: | ||
1 0 0 0 1 | ||
0 1 0 1 0 | ||
0 0 1 0 0 | ||
|
||
julia> ŷ = reshape(-7:7, 3, 5) .* 1f0 | ||
3×5 Matrix{Float32}: | ||
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306 | ||
0.244728 0.244728 0.244728 0.244728 0.244728 | ||
0.665241 0.665241 0.665241 0.665241 0.665241 | ||
|
||
julia> Flux.logit_focal_loss(ŷ, y) ≈ 1.1277571935622628 | ||
true | ||
``` | ||
|
||
See also: [`Losses.focal_loss`](@ref) | ||
|
||
""" | ||
function logit_focal_loss(ŷ, y; γ=2.0f0, agg=mean, dims=1, ϵ=epseltype(ŷ)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some have crept in & need fixing, but there should not be greek-letter keywords. These can be Also, as written, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh nice idea, can do! |
||
_check_sizes(ŷ, y) | ||
logpt = logsoftmax(ŷ; dims) | ||
agg(sum(@. -y * (1 - exp(logpt + ϵ))^γ * (logpt + ϵ); dims)) | ||
end | ||
|
||
""" | ||
siamese_contrastive_loss(ŷ, y; margin = 1, agg = mean) | ||
|
||
Return the [contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf) | ||
which can be useful for training Siamese Networks. It is given by | ||
agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) | ||
|
||
agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) | ||
|
||
Specify `margin` to set the baseline for distance at which pairs are dissimilar. | ||
|
||
# Example | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example output doesn't match what's written.
More importantly, the example is an opportunity to show exactly how this relates to
focal_loss
, i.e. where thesoftmax
goes. And perhaps (if you can think of a compact & clear way) the relation tocrossentropy
(or ratherlogitcrossentropy
?) too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah i still need to work through these tests, did not realize about the docstring tests until after already putting tests elsewhere :) Can do !