-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use LogExpFunctions for losses #1866
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ using Test | |
using Flux: onehotbatch, σ | ||
|
||
using Flux.Losses: mse, label_smoothing, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy | ||
using Flux.Losses: xlogx, xlogy | ||
|
||
# group here all losses, used in tests | ||
const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, | ||
|
@@ -17,22 +16,6 @@ const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, | |
Flux.Losses.binary_focal_loss, Flux.Losses.focal_loss] | ||
|
||
|
||
@testset "xlogx & xlogy" begin | ||
@test iszero(xlogx(0)) | ||
@test isnan(xlogx(NaN)) | ||
@test xlogx(2) ≈ 2.0 * log(2.0) | ||
@inferred xlogx(2) | ||
@inferred xlogx(0) | ||
|
||
@test iszero(xlogy(0, 1)) | ||
@test isnan(xlogy(NaN, 1)) | ||
@test isnan(xlogy(1, NaN)) | ||
@test isnan(xlogy(NaN, NaN)) | ||
@test xlogy(2, 3) ≈ 2.0 * log(3.0) | ||
@inferred xlogy(2, 3) | ||
@inferred xlogy(0, 1) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do these tests pass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested locally before removing and they do. https://github.com/JuliaStats/LogExpFunctions.jl/blob/master/test/basicfuns.jl also looks like a strict superset of the Flux tests. |
||
|
||
# First, regression-style y's | ||
y = [1, 1, 0, 0] | ||
ŷ = [.9, .1, .1, .9] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One question is whether there are any performance differences, and whether we care. IIRC the replacements have
if else
instead ofifelse
, but perhaps the compiler sorts it out?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found JuliaStats/LogExpFunctions.jl#26. GPU is the big question mark, but if #1791 is any indication there may not be a difference there either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty comparable:
I did a couple of runs and there was a not insignificant amount of variability, but at least the relative times aren't too far off.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gret similar numbers.