-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ssim, ssim_loss, ssim_loss_fast #2178
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,9 @@ using Zygote: @adjoint | |
using ChainRulesCore | ||
using ..Flux: ofeltype, epseltype | ||
using CUDA | ||
using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss | ||
using Adapt | ||
using MLUtils: ones_like | ||
using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss, conv, pad_symmetric | ||
import Base.Broadcast: broadcasted | ||
|
||
export mse, mae, msle, | ||
|
@@ -19,7 +21,8 @@ export mse, mae, msle, | |
dice_coeff_loss, | ||
poisson_loss, | ||
hinge_loss, squared_hinge_loss, | ||
binary_focal_loss, focal_loss, siamese_contrastive_loss | ||
binary_focal_loss, focal_loss, siamese_contrastive_loss, | ||
ssim, ssim_loss, ssim_loss_fast | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we shouldn't export My own experience here is that people might end up using only |
||
|
||
include("utils.jl") | ||
include("functions.jl") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,3 +36,39 @@ end | |
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1 | ||
|
||
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any) | ||
|
||
# Gaussian kernel std=1.5, length=11 | ||
const SSIM_KERNEL = | ||
[0.00102838008447911, | ||
0.007598758135239185, | ||
0.03600077212843083, | ||
0.10936068950970002, | ||
0.2130055377112537, | ||
0.26601172486179436, | ||
0.2130055377112537, | ||
0.10936068950970002, | ||
0.03600077212843083, | ||
0.007598758135239185, | ||
0.00102838008447911] | ||
|
||
""" | ||
ssim_kernel(T, N) | ||
|
||
Return Gaussian kernel with σ=1.5 and side-length 11 for use in [`ssim`](@ref). | ||
""" | ||
function ssim_kernel(T::Type, N::Integer) | ||
if N-2 == 1 | ||
kernel = SSIM_KERNEL | ||
elseif N-2 == 2 | ||
kernel = SSIM_KERNEL*SSIM_KERNEL' | ||
elseif N-2 == 3 | ||
ks = length(SSIM_KERNEL) | ||
kernel = reshape(SSIM_KERNEL*SSIM_KERNEL', 1, ks, ks).*SSIM_KERNEL | ||
else | ||
throw("SSIM is only implemented for 3D/4D/5D inputs, dimension=$N provided.") | ||
end | ||
return reshape(T.(kernel), size(kernel)..., 1, 1) | ||
end | ||
ssim_kernel(x::Array{T, N}) where {T, N} = ssim_kernel(T, N) | ||
ssim_kernel(x::AnyCuArray{T, N}) where {T, N} = cu(ssim_kernel(T, N)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since these functions are not of general utility they better go close to the ssim the losses |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,5 @@ | ||
using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy, binary_focal_loss, focal_loss | ||
|
||
|
||
@testset "Losses" begin | ||
|
||
x = [1.,2.,3.] | ||
|
@@ -26,13 +25,32 @@ y = [1 0 0 0 1 | |
0 0 1 0 0] | ||
@test focal_loss(x, y) ≈ focal_loss(gpu(x), gpu(y)) | ||
|
||
@testset "GPU ssim tests" begin | ||
for N=1:3 | ||
@testset "num_dims=$N" begin | ||
x = rand(Float32, 16*ones(Int, N)..., 2, 2) | ||
y = rand(Float32, 16*ones(Int, N)..., 2, 2) | ||
|
||
for f in (Flux.ssim, Flux.ssim_loss, Flux.ssim_loss_fast) | ||
@test f(x, y) ≈ f(gpu(x), gpu(y)) | ||
gpu_autodiff_test(f, x, y) | ||
end | ||
|
||
x = gpu(x) | ||
@test Flux.ssim(x, x) ≈ 1 | ||
@test Flux.ssim_loss(x, x) ≈ 0 | ||
@test Flux.ssim_loss_fast(x, x) ≈ 0 | ||
end | ||
end | ||
end | ||
|
||
@testset "GPU grad tests" begin | ||
x = rand(Float32, 3,3) | ||
y = rand(Float32, 3,3) | ||
x = rand(Float32, 3,3) | ||
y = rand(Float32, 3,3) | ||
|
||
for loss in ALL_LOSSES | ||
gpu_autodiff_test(loss, x, y) | ||
end | ||
for loss in ALL_LOSSES | ||
gpu_autodiff_test(loss, x, y) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please stick to flux's indentation style in old and new |
||
end | ||
|
||
end #testset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may be a mistake on my part, including Images etc. in [targets]