Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssim, ssim_loss, ssim_loss_fast #2178

Closed
wants to merge 5 commits into from
Closed

Conversation

nikopj
Copy link

@nikopj nikopj commented Feb 5, 2023

Implements structural similarity index measurement (SSIM) loss, addressing #2165.
Main function is ssim, which is used by ssim_loss and ssim_loss_fast. The fast version uses a smaller averaging kernel instead of the standard 11x11 gaussian kernel.

1D/2D/3D inputs (3D/4D/5D channel+batch tensors) and GPU are accounted for. The corresponding kernel is by default a 1D/2D/3D gaussian kernel (see ssim_kernel).

ssim is tested against ImageQualityIndexes's assess_ssim for grayscale and RGB images.

This doesn't have the full IQI capabilities (such as alpha, beta, gamma params), but it seems those are kept as default 1 in other libraries that implement ssim loss (ex https://github.com/VainF/pytorch-msssim). I believe we would require that if we extended the implementation to have a MS-SSIM loss.

Feedback is greatly appreciated.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@codecov-commenter
Copy link

codecov-commenter commented Feb 5, 2023

Codecov Report

Base: 86.02% // Head: 83.59% // Decreases project coverage by -2.44% ⚠️

Coverage data is based on head (0a1ea93) compared to base (c5a691a).
Patch coverage: 94.44% of modified lines in pull request are covered.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files
@@            Coverage Diff             @@
##           master    FluxML/Flux.jl#2178      +/-   ##
==========================================
- Coverage   86.02%   83.59%   -2.44%     
==========================================
  Files          19       19              
  Lines        1460     1493      +33     
==========================================
- Hits         1256     1248       -8     
- Misses        204      245      +41     
Impacted Files Coverage Δ
src/losses/utils.jl 88.88% <83.33%> (-4.45%) ⬇️
src/losses/functions.jl 98.97% <100.00%> (+0.33%) ⬆️
src/cuda/cudnn.jl 0.00% <0.00%> (-90.91%) ⬇️
src/functor.jl 45.90% <0.00%> (-44.58%) ⬇️
src/layers/normalise.jl 86.71% <0.00%> (-1.40%) ⬇️
src/layers/conv.jl 87.93% <0.00%> (-0.07%) ⬇️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

end
for loss in ALL_LOSSES
gpu_autodiff_test(loss, x, y)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please stick to flux's indentation style in old and new

return reshape(T.(kernel), size(kernel)..., 1, 1)
end
ssim_kernel(x::Array{T, N}) where {T, N} = ssim_kernel(T, N)
ssim_kernel(x::AnyCuArray{T, N}) where {T, N} = cu(ssim_kernel(T, N))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since these functions are not of general utility they better go close to the ssim the losses

Project.toml Outdated

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Images", "ImageQualityIndexes", "TestImages"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may be a mistake on my part, including Images etc. in [targets]

@nikopj
Copy link
Author

nikopj commented Feb 16, 2023

Some tests in the latest commit are marked as broken for two reasons, to do with F16 #2184 (bug reported in FluxML/NNlib.jl#505):

  • Float16 conv is broken for 5D tensors
  • Float16 conv gradient is broken for 3D/4D/5D tensors, and is also observed in test/cuda/layers.jl.

Other than those issues, the ssim loss seems ready to me. I've made the ssim cuda test very verbose to make clear what the issues are exactly, but let me know if its prefered to be a bit more concise.

Copy link
Contributor

@johnnychen94 johnnychen94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approval for the implementation -- quite straightforward to me.

@@ -19,7 +21,8 @@ export mse, mae, msle,
dice_coeff_loss,
poisson_loss,
hinge_loss, squared_hinge_loss,
binary_focal_loss, focal_loss, siamese_contrastive_loss
binary_focal_loss, focal_loss, siamese_contrastive_loss,
ssim, ssim_loss, ssim_loss_fast
Copy link
Contributor

@johnnychen94 johnnychen94 Feb 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we shouldn't export ssim here if the main goal is to use ssim loss to train the network?

My own experience here is that people might end up using only ssim_loss_fast even if you provide multiple choices...

@ashwanirathee
Copy link
Contributor

ashwanirathee commented May 3, 2023

@nikopj can we please get this going? I had been thinking of more image quality indexes in JuliaImages's ImageQualityIndexes.jl and siphon that through to the DL frameworks

@nikopj
Copy link
Author

nikopj commented May 5, 2023

The GPU tests on Julia 1.6 failed because some of my @tests used the kwarg skip=cond, as some Float16 convolution stuff is still broken as per FluxML/NNlib.jl#505.

My SSIM loss had one failure on Julia 1.8 because of the relative tolerance on a Float16 test being too small:

f16 cpu-cpu: Test Failed at /var/lib/buildkite-agent/builds/gpuci-2/julialang/flux-dot-jl/test/cuda/losses.jl:45
--
  | Expression: isapprox(loss(f16(x), f16(y)), Float16(loss(x, y)); rtol = 0.1)
  | Evaluated: isapprox(Float16(0.01132), Float16(-0.02692); rtol = 0.1)
  • Do we avoid using @tests kwargs for 1.6 compatibility? should I remove them?
  • Is this rtol test sensible? I took it from existing F16 tests... though 0.1 seems a bit arbitrary. For reference, the numbers above are small to begin with as they are the SSIM of two random images. The test only fails for ssim and not ssim_loss or ssim_loss_fast because ssim_loss is 1-ssim and as such the relative tolerance is then satisfied. I can perhaps change the ssim case to check for F16 matching on a known test image instead of random images.

@ToucheSir
Copy link
Member

  • Do we avoid using @tests kwargs for 1.6 compatibility? should I remove them?

atol and rtol work in 1.6. What doesn't IIRC is using e.g. @test broken=true instead of @test_broken.

  • Is this rtol test sensible? ...

I'm less worried about the tolerance, but that discrepancy looks pretty concerning?

@nikopj
Copy link
Author

nikopj commented May 5, 2023

@ToucheSir It seem's like theres some numerical instability in the Float16 case. Adding some prints into the code (see below), the instability really shows up in the ssim_map calculation of ssim. My guess is that the C1 and C2 constants were not tuned for F16 arithmetic. Any thoughts @johnnychen94?

julia> using Flux, Images, TestImages, ImageView

julia> img_x = testimage("monarch_color_256") .|> Gray .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4);

julia> img_y = testimage("fabio_color_256") .|> Gray .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4);

julia> Flux.ssim(Float32.(img_x), Float32.(img_y))
C₁ = 0.0001f0
C₂ = 0.0009f0
extrema(μx) = (0.12141457f0, 0.87607676f0)
extrema(μy) = (0.098093934f0, 0.91049606f0)
extrema(μxy) = (0.019157613f0, 0.71647954f0)
extrema(μx²) = (0.014741498f0, 0.7675105f0)
extrema(μy²) = (0.00962242f0, 0.8290031f0)
extrema(σx²) = (2.9206276f-6, 0.103816226f0)
extrema(σy²) = (3.086403f-6, 0.06908727f0)
extrema(σxy) = (-0.035758138f0, 0.04259558f0)
extrema(abs, σx²) = (2.9206276f-6, 0.103816226f0)
extrema(abs, σy²) = (3.086403f-6, 0.06908727f0)
extrema(abs, σxy) = (0.0f0, 0.04259558f0)
extrema((2μxy .+ C₁) .* (2σxy .+ C₂)) = (-0.040532943f0, 0.043345165f0)
extrema(abs, (2μxy .+ C₁) .* (2σxy .+ C₂)) = (3.2075288f-8, 0.043345165f0)
extrema(((μx² .+ μy²) .+ C₁) .* ((σx² .+ σy²) .+ C₂)) = (4.8610345f-5, 0.093692295f0)
extrema(abs, ((μx² .+ μy²) .+ C₁) .* ((σx² .+ σy²) .+ C₂)) = (4.8610345f-5, 0.093692295f0)
extrema(ssim_map) = (-0.8139806f0, 0.9414297f0)
0.13379417f0

julia> Flux.ssim(Float16.(img_x), Float16.(img_y))
C₁ = Float16(0.0001)
C₂ = Float16(0.0009)
extrema(μx) = (Float16(0.12134), Float16(0.876))
extrema(μy) = (Float16(0.09796), Float16(0.909))
extrema(μxy) = (Float16(0.01909), Float16(0.713))
extrema(μx²) = (Float16(0.014725), Float16(0.7676))
extrema(μy²) = (Float16(0.0096), Float16(0.8267))
extrema(σx²) = (Float16(-0.002197), Float16(0.10645))
extrema(σy²) = (Float16(-0.004883), Float16(0.06934))
extrema(σxy) = (Float16(-0.03528), Float16(0.04224))
extrema(abs, σx²) = (Float16(0.0), Float16(0.10645))
extrema(abs, σy²) = (Float16(0.0), Float16(0.06934))
extrema(abs, σxy) = (Float16(0.0), Float16(0.04224))
extrema((2μxy .+ C₁) .* (2σxy .+ C₂)) = (Float16(-0.03915), Float16(0.04926))
extrema(abs, (2μxy .+ C₁) .* (2σxy .+ C₂)) = (Float16(9.0e-7), Float16(0.04926))
extrema(((μx² .+ μy²) .+ C₁) .* ((σx² .+ σy²) .+ C₂)) = (Float16(-0.001884), Float16(0.09894))
extrema(abs, ((μx² .+ μy²) .+ C₁) .* ((σx² .+ σy²) .+ C₂)) = (Float16(2.91e-5), Float16(0.09894))
extrema(ssim_map) = (Float16(-12.11), Float16(6.137))
Float16(0.1942)

Plotting the ssim_map from each call you can see the difference easily. ex,

julia> using Plots

julia> p1 = heatmap(map32[:,:,1,1]; cbar=true, title="ssim_map Float32");

julia> p2 = heatmap(map16[:,:,1,1]; cbar=true, title="ssim_map Float16");

julia> p3 = heatmap(clamp.(map16[:,:,1,1], -1, 1); cbar=true, title="clamp(ssim_map, 0, 1) Float16");

julia> plot(p1, p2, p3; layout=(1,3), size=(1200, 400))

ssim_stability

@johnnychen94
Copy link
Contributor

johnnychen94 commented May 10, 2023

It seem's like theres some numerical instability in the Float16 case.

I don't think this matters for DL tasks. The SSIM implementation in ImageQualityIndexes is carefully written to keep numeric compatibility with widely-used ones. But for DL tasks what matters is the trend -- IMO even choosing a slightly different C1, or C2 would be okay.

@nikopj
Copy link
Author

nikopj commented Jul 12, 2023

I've moved this to its own package SSIMLoss.jl so it can be used with other frameworks more easily. Thanks @ToucheSir @CarloLucibello @johnnychen94 for the help and feedback.

@nikopj nikopj closed this Jul 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants