-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ssim, ssim_loss, ssim_loss_fast #2178
Conversation
Codecov ReportBase: 86.02% // Head: 83.59% // Decreases project coverage by
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## master FluxML/Flux.jl#2178 +/- ##
==========================================
- Coverage 86.02% 83.59% -2.44%
==========================================
Files 19 19
Lines 1460 1493 +33
==========================================
- Hits 1256 1248 -8
- Misses 204 245 +41
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
test/cuda/losses.jl
Outdated
end | ||
for loss in ALL_LOSSES | ||
gpu_autodiff_test(loss, x, y) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please stick to flux's indentation style in old and new
src/losses/utils.jl
Outdated
return reshape(T.(kernel), size(kernel)..., 1, 1) | ||
end | ||
ssim_kernel(x::Array{T, N}) where {T, N} = ssim_kernel(T, N) | ||
ssim_kernel(x::AnyCuArray{T, N}) where {T, N} = cu(ssim_kernel(T, N)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since these functions are not of general utility they better go close to the ssim the losses
Project.toml
Outdated
|
||
[targets] | ||
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] | ||
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Images", "ImageQualityIndexes", "TestImages"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this may be a mistake on my part, including Images etc. in [targets]
Some tests in the latest commit are marked as broken for two reasons, to do with F16 #2184 (bug reported in FluxML/NNlib.jl#505):
Other than those issues, the ssim loss seems ready to me. I've made the ssim cuda test very verbose to make clear what the issues are exactly, but let me know if its prefered to be a bit more concise. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approval for the implementation -- quite straightforward to me.
@@ -19,7 +21,8 @@ export mse, mae, msle, | |||
dice_coeff_loss, | |||
poisson_loss, | |||
hinge_loss, squared_hinge_loss, | |||
binary_focal_loss, focal_loss, siamese_contrastive_loss | |||
binary_focal_loss, focal_loss, siamese_contrastive_loss, | |||
ssim, ssim_loss, ssim_loss_fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we shouldn't export ssim
here if the main goal is to use ssim loss to train the network?
My own experience here is that people might end up using only ssim_loss_fast
even if you provide multiple choices...
@nikopj can we please get this going? I had been thinking of more image quality indexes in JuliaImages's ImageQualityIndexes.jl and siphon that through to the DL frameworks |
The GPU tests on Julia 1.6 failed because some of my My SSIM loss had one failure on Julia 1.8 because of the relative tolerance on a Float16 test being too small: f16 cpu-cpu: Test Failed at /var/lib/buildkite-agent/builds/gpuci-2/julialang/flux-dot-jl/test/cuda/losses.jl:45
--
| Expression: isapprox(loss(f16(x), f16(y)), Float16(loss(x, y)); rtol = 0.1)
| Evaluated: isapprox(Float16(0.01132), Float16(-0.02692); rtol = 0.1)
|
I'm less worried about the tolerance, but that discrepancy looks pretty concerning? |
@ToucheSir It seem's like theres some numerical instability in the Float16 case. Adding some prints into the code (see below), the instability really shows up in the julia> using Flux, Images, TestImages, ImageView
julia> img_x = testimage("monarch_color_256") .|> Gray .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4);
julia> img_y = testimage("fabio_color_256") .|> Gray .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4);
julia> Flux.ssim(Float32.(img_x), Float32.(img_y))
C₁ = 0.0001f0
C₂ = 0.0009f0
extrema(μx) = (0.12141457f0, 0.87607676f0)
extrema(μy) = (0.098093934f0, 0.91049606f0)
extrema(μxy) = (0.019157613f0, 0.71647954f0)
extrema(μx²) = (0.014741498f0, 0.7675105f0)
extrema(μy²) = (0.00962242f0, 0.8290031f0)
extrema(σx²) = (2.9206276f-6, 0.103816226f0)
extrema(σy²) = (3.086403f-6, 0.06908727f0)
extrema(σxy) = (-0.035758138f0, 0.04259558f0)
extrema(abs, σx²) = (2.9206276f-6, 0.103816226f0)
extrema(abs, σy²) = (3.086403f-6, 0.06908727f0)
extrema(abs, σxy) = (0.0f0, 0.04259558f0)
extrema((2μxy .+ C₁) .* (2σxy .+ C₂)) = (-0.040532943f0, 0.043345165f0)
extrema(abs, (2μxy .+ C₁) .* (2σxy .+ C₂)) = (3.2075288f-8, 0.043345165f0)
extrema(((μx² .+ μy²) .+ C₁) .* ((σx² .+ σy²) .+ C₂)) = (4.8610345f-5, 0.093692295f0)
extrema(abs, ((μx² .+ μy²) .+ C₁) .* ((σx² .+ σy²) .+ C₂)) = (4.8610345f-5, 0.093692295f0)
extrema(ssim_map) = (-0.8139806f0, 0.9414297f0)
0.13379417f0
julia> Flux.ssim(Float16.(img_x), Float16.(img_y))
C₁ = Float16(0.0001)
C₂ = Float16(0.0009)
extrema(μx) = (Float16(0.12134), Float16(0.876))
extrema(μy) = (Float16(0.09796), Float16(0.909))
extrema(μxy) = (Float16(0.01909), Float16(0.713))
extrema(μx²) = (Float16(0.014725), Float16(0.7676))
extrema(μy²) = (Float16(0.0096), Float16(0.8267))
extrema(σx²) = (Float16(-0.002197), Float16(0.10645))
extrema(σy²) = (Float16(-0.004883), Float16(0.06934))
extrema(σxy) = (Float16(-0.03528), Float16(0.04224))
extrema(abs, σx²) = (Float16(0.0), Float16(0.10645))
extrema(abs, σy²) = (Float16(0.0), Float16(0.06934))
extrema(abs, σxy) = (Float16(0.0), Float16(0.04224))
extrema((2μxy .+ C₁) .* (2σxy .+ C₂)) = (Float16(-0.03915), Float16(0.04926))
extrema(abs, (2μxy .+ C₁) .* (2σxy .+ C₂)) = (Float16(9.0e-7), Float16(0.04926))
extrema(((μx² .+ μy²) .+ C₁) .* ((σx² .+ σy²) .+ C₂)) = (Float16(-0.001884), Float16(0.09894))
extrema(abs, ((μx² .+ μy²) .+ C₁) .* ((σx² .+ σy²) .+ C₂)) = (Float16(2.91e-5), Float16(0.09894))
extrema(ssim_map) = (Float16(-12.11), Float16(6.137))
Float16(0.1942) Plotting the julia> using Plots
julia> p1 = heatmap(map32[:,:,1,1]; cbar=true, title="ssim_map Float32");
julia> p2 = heatmap(map16[:,:,1,1]; cbar=true, title="ssim_map Float16");
julia> p3 = heatmap(clamp.(map16[:,:,1,1], -1, 1); cbar=true, title="clamp(ssim_map, 0, 1) Float16");
julia> plot(p1, p2, p3; layout=(1,3), size=(1200, 400)) |
I don't think this matters for DL tasks. The SSIM implementation in ImageQualityIndexes is carefully written to keep numeric compatibility with widely-used ones. But for DL tasks what matters is the trend -- IMO even choosing a slightly different C1, or C2 would be okay. |
I've moved this to its own package SSIMLoss.jl so it can be used with other frameworks more easily. Thanks @ToucheSir @CarloLucibello @johnnychen94 for the help and feedback. |
Implements structural similarity index measurement (SSIM) loss, addressing #2165.
Main function is
ssim
, which is used byssim_loss
andssim_loss_fast
. The fast version uses a smaller averaging kernel instead of the standard 11x11 gaussian kernel.1D/2D/3D inputs (3D/4D/5D channel+batch tensors) and GPU are accounted for. The corresponding kernel is by default a 1D/2D/3D gaussian kernel (see
ssim_kernel
).ssim
is tested against ImageQualityIndexes'sassess_ssim
for grayscale and RGB images.This doesn't have the full IQI capabilities (such as alpha, beta, gamma params), but it seems those are kept as default 1 in other libraries that implement ssim loss (ex https://github.com/VainF/pytorch-msssim). I believe we would require that if we extended the implementation to have a MS-SSIM loss.
Feedback is greatly appreciated.
PR Checklist