From 0a1ea938c573a85b2d74bfe969986bc8f512fc04 Mon Sep 17 00:00:00 2001 From: Nikola Date: Sun, 5 Feb 2023 04:47:30 -0500 Subject: [PATCH 1/2] ssim, ssim_loss, ssim_loss_fast --- Project.toml | 5 ++- src/losses/Losses.jl | 7 ++- src/losses/functions.jl | 96 +++++++++++++++++++++++++++++++++++++++++ src/losses/utils.jl | 36 ++++++++++++++++ test/cuda/losses.jl | 30 ++++++++++--- test/losses.jl | 83 +++++++++++++++++++++++++++++++---- 6 files changed, 239 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index 8292de97e2..6470132f16 100644 --- a/Project.toml +++ b/Project.toml @@ -48,6 +48,9 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" +ImageQualityIndexes = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9" +TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Images", "ImageQualityIndexes", "TestImages"] diff --git a/src/losses/Losses.jl b/src/losses/Losses.jl index 3d8f6f8149..7fcd342b63 100644 --- a/src/losses/Losses.jl +++ b/src/losses/Losses.jl @@ -6,7 +6,9 @@ using Zygote: @adjoint using ChainRulesCore using ..Flux: ofeltype, epseltype using CUDA -using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss +using Adapt +using MLUtils: ones_like +using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss, conv, pad_symmetric import Base.Broadcast: broadcasted export mse, mae, msle, @@ -19,7 +21,8 @@ export mse, mae, msle, dice_coeff_loss, poisson_loss, hinge_loss, squared_hinge_loss, - binary_focal_loss, focal_loss, siamese_contrastive_loss + binary_focal_loss, focal_loss, siamese_contrastive_loss, + ssim, ssim_loss, ssim_loss_fast include("utils.jl") include("functions.jl") diff --git a/src/losses/functions.jl b/src/losses/functions.jl index c40d4dcd76..0d1bbf78b9 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -630,6 +630,102 @@ function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1) return agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2) end +""" + ssim(x, y, kernel=ssim_kernel(x); peakval=1, crop=true, dims=:) + +Return the [structural similarity index +measure](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM) between +two signals. SSIM is computed via the mean of a sliding window of +statistics computed between the two signals. By default, the sliding window is +a Gaussian with side-length 11 in each signal dimension and σ=1.5. `crop=false` will pad `x` and `y` +such that the sliding window computes statistics centered at every pixel of the input (via same-size convolution). +`ssim` computes statistics independently over channel and batch dimensions. +`x` and `y` may be 3D/4D/5D tensors with channel and batch-dimensions. + +`peakval=1` is the standard for image comparisons, but in practice should be +set to the maximum value of your signal type. + +`dims` determines which dimensions to average the computed statistics over. If +`dims=1:ndims(x)-1`, SSIM will be computed for each batch-element separately. + +The results of `ssim` are matched against those of +[ImageQualityIndexes](https://github.com/JuliaImages/ImageQualityIndexes.jl) +for grayscale and RGB images (i.e. x, y both of size (N1, N2, 1, B) and (N1, N2, 3, B) for grayscale and color images, resp.). + +See also [`ssim_loss`](@ref), [`ssim_loss_fast`](@ref). +""" +function ssim(x::AbstractArray{T,N}, y::AbstractArray{T,N}, kernel=ssim_kernel(x); peakval=T(1.0), crop=true, dims=:) where {T,N} + _check_sizes(x, y) + + # apply same kernel on each channel dimension separately via groups=in_channels + groups = size(x, N-1) + kernel = repeat(kernel, ones(Int, N-1)..., groups) + + # constants to avoid division by zero + SSIM_K = (0.01, 0.03) + C₁, C₂ = @. T(peakval * SSIM_K)^2 + + # crop==true -> valid-sized conv (do nothing), + # otherwise, pad for same-sized conv + if !crop + # from src/layers/conv.jl (calc_padding) + padding = Tuple(mapfoldl(i -> [cld(i, 2), fld(i,2)], vcat, size(kernel)[1:N-2] .- 1)) + x = pad_symmetric(x, padding) + y = pad_symmetric(y, padding) + end + + μx = conv(x, kernel; groups=groups) + μy = conv(y, kernel; groups=groups) + μx² = μx.^2 + μy² = μy.^2 + μxy = μx.*μy + σx² = conv(x.^2, kernel; groups=groups) .- μx² + σy² = conv(y.^2, kernel; groups=groups) .- μy² + σxy = conv(x.*y, kernel; groups=groups) .- μxy + + ssim_map = @. (2μxy + C₁)*(2σxy + C₂)/((μx² + μy² + C₁)*(σx² + σy² + C₂)) + return mean(ssim_map, dims=dims) +end + +""" + ssim_loss(ŷ, y, kernel=ssim_kernel(x); peakval=1, crop=true, dims=:) + +Computes the 1-ssim(ŷ,y), suitable for use as a loss function with gradient descent. +For faster training, it is recommended to store a kernel and reuse it, ex., +```julia +kernel = Flux.Losses.ssim_kernel(Float32, 2) |> gpu +# or alternatively for faster computation +# kernel = ones(Float32, 5, 5, 1, num_channels) |> gpu + +for (x, y) in dataloader + x, y = (x, y) .|> gpu + grads = Flux.gradient(model) do m + ŷ = m(x) + Flux.ssim_loss(ŷ, y, kernel) + end + # update the model ... +end +``` +See [`ssim`](@ref) for a detailed description of SSIM and the above arguments. +See also [`ssim_loss_fast`](@ref). +""" +ssim_loss(x, args...; kws...) = one(eltype(x)) - ssim(x, args...; kws...) + +""" + ssim_loss_fast(ŷ, y; kernel_length=5, peakval=1, crop=true, dims=:) + +Computes `ssim_loss` with an averaging kernel instead of a large Gaussian +kernel for faster computation. `kernel_length` specifies the averaging kernel +side-length in each signal dimension of ŷ, y. See [`ssim`](@ref) for a +detailed description of SSIM and the above arguments. + +See also [`ssim_loss`](@ref). +""" +function ssim_loss_fast(ŷ, y; kernel_length=5, kws...) + kernel = ones_like(y, (kernel_length*ones(Int, ndims(y)-2)..., 1, 1)) + return ssim_loss(ŷ, y, kernel; kws...) +end + ```@meta DocTestFilters = nothing ``` diff --git a/src/losses/utils.jl b/src/losses/utils.jl index e42bdfbe2e..11a61faafb 100644 --- a/src/losses/utils.jl +++ b/src/losses/utils.jl @@ -36,3 +36,39 @@ end _check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1 ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any) + +# Gaussian kernel std=1.5, length=11 +const SSIM_KERNEL = + [0.00102838008447911, + 0.007598758135239185, + 0.03600077212843083, + 0.10936068950970002, + 0.2130055377112537, + 0.26601172486179436, + 0.2130055377112537, + 0.10936068950970002, + 0.03600077212843083, + 0.007598758135239185, + 0.00102838008447911] + +""" + ssim_kernel(T, N) + +Return Gaussian kernel with σ=1.5 and side-length 11 for use in [`ssim`](@ref). +""" +function ssim_kernel(T::Type, N::Integer) + if N-2 == 1 + kernel = SSIM_KERNEL + elseif N-2 == 2 + kernel = SSIM_KERNEL*SSIM_KERNEL' + elseif N-2 == 3 + ks = length(SSIM_KERNEL) + kernel = reshape(SSIM_KERNEL*SSIM_KERNEL', 1, ks, ks).*SSIM_KERNEL + else + throw("SSIM is only implemented for 3D/4D/5D inputs, dimension=$N provided.") + end + return reshape(T.(kernel), size(kernel)..., 1, 1) +end +ssim_kernel(x::Array{T, N}) where {T, N} = ssim_kernel(T, N) +ssim_kernel(x::AnyCuArray{T, N}) where {T, N} = cu(ssim_kernel(T, N)) + diff --git a/test/cuda/losses.jl b/test/cuda/losses.jl index a0f7f47d80..2e9aefb314 100644 --- a/test/cuda/losses.jl +++ b/test/cuda/losses.jl @@ -1,6 +1,5 @@ using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy, binary_focal_loss, focal_loss - @testset "Losses" begin x = [1.,2.,3.] @@ -26,13 +25,32 @@ y = [1 0 0 0 1 0 0 1 0 0] @test focal_loss(x, y) ≈ focal_loss(gpu(x), gpu(y)) +@testset "GPU ssim tests" begin + for N=1:3 + @testset "num_dims=$N" begin + x = rand(Float32, 16*ones(Int, N)..., 2, 2) + y = rand(Float32, 16*ones(Int, N)..., 2, 2) + + for f in (Flux.ssim, Flux.ssim_loss, Flux.ssim_loss_fast) + @test f(x, y) ≈ f(gpu(x), gpu(y)) + gpu_autodiff_test(f, x, y) + end + + x = gpu(x) + @test Flux.ssim(x, x) ≈ 1 + @test Flux.ssim_loss(x, x) ≈ 0 + @test Flux.ssim_loss_fast(x, x) ≈ 0 + end + end +end + @testset "GPU grad tests" begin - x = rand(Float32, 3,3) - y = rand(Float32, 3,3) + x = rand(Float32, 3,3) + y = rand(Float32, 3,3) - for loss in ALL_LOSSES - gpu_autodiff_test(loss, x, y) - end + for loss in ALL_LOSSES + gpu_autodiff_test(loss, x, y) + end end end #testset diff --git a/test/losses.jl b/test/losses.jl index 2ca697a657..30ad2460d3 100644 --- a/test/losses.jl +++ b/test/losses.jl @@ -1,9 +1,12 @@ using Test using Flux: onehotbatch, σ - using Flux.Losses: mse, label_smoothing, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy using Flux.Losses: xlogx, xlogy +# for ssim_loss +using Images, TestImages, ImageQualityIndexes +using MLUtils: unsqueeze + # group here all losses, used in tests const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle, Flux.Losses.crossentropy, Flux.Losses.logitcrossentropy, @@ -169,15 +172,15 @@ end end @testset "no spurious promotions" begin - for T in (Float32, Float64) - y = rand(T, 2) - ŷ = rand(T, 2) - for f in ALL_LOSSES - fwd, back = Flux.pullback(f, ŷ, y) - @test fwd isa T - @test eltype(back(one(T))[1]) == T + for T in (Float32, Float64) + y = rand(T, 2) + ŷ = rand(T, 2) + for f in ALL_LOSSES + fwd, back = Flux.pullback(f, ŷ, y) + @test fwd isa T + @test eltype(back(one(T))[1]) == T + end end - end end @testset "binary_focal_loss" begin @@ -248,3 +251,65 @@ end @test_throws DomainError(-0.5, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ1, y1, margin = -0.5) @test_throws DomainError(-1, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ, y, margin = -1) end + +# monarch_color_256 and fabio_color_256 testimages +# used to obtain below numbers. +# true/false denote `assess_ssim(...; crop=true/false)` +const iqi_rgb_true = 0.1299260389807608 +const iqi_gry_true = 0.13380159790218638 +const iqi_rgb_false = 0.13683875886675542 +const iqi_gry_false = 0.14181793989104552 + +@testset "ssim_loss" begin + # color-image testing + # ssim values for monarch-fabio + @test Flux.Losses.SSIM_KERNEL == ImageQualityIndexes.SSIM_KERNEL.parent + + # get reference images + imx_rgb = testimage("monarch_color_256") + imy_rgb = testimage("fabio_color_256") + imx_gry = Gray.(imx_rgb) + imy_gry = Gray.(imy_rgb) + x_rgb = permutedims(channelview(imx_rgb), (2, 3, 1)) .|> Float64 |> unsqueeze(dims=4) + y_rgb = permutedims(channelview(imy_rgb), (2, 3, 1)) .|> Float64 |> unsqueeze(dims=4) + x_gry = imx_gry .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4) + y_gry = imy_gry .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4) + + # 8 tests enumerating rgb/gray, crop/nocrop, iqi/flux vs. ref + for (ssim_iqi, crop) in + zip(((iqi_rgb_true, iqi_gry_true), (iqi_rgb_false, iqi_gry_false)), (true, false)) + for (imx, imy, x, y, ssim_ref) in + zip((imx_rgb, imx_gry), (imy_rgb, imy_gry), (x_rgb, x_gry), (y_rgb, y_gry), ssim_iqi) + + color = eltype(imx) <: RGB ? "RGB" : "Gray" + @testset "crop=$crop, color=$color" begin + # make sure IQI is same + @test assess_ssim(imx, imy; crop=crop) ≈ ssim_ref + # test flux against IQI on Image Array + @test Flux.ssim(x, y; crop=crop) ≈ ssim_ref atol=1e-6 + end + end + end + + for N=1:3 + x = rand(Float32, 16*ones(Int, N)..., 2, 2) + @testset "num_dims=$N" begin + @test Flux.ssim(x, x) ≈ 1 + @test Flux.ssim_loss(x, x) ≈ 0 + @test Flux.ssim_loss_fast(x, x) ≈ 0 + end + end + + @testset "no spurious promotions" begin + for T in (Float32, Float64) + y = rand(T, 15, 1, 1) + ŷ = rand(T, 15, 1, 1) + for f in (Flux.ssim, Flux.ssim_loss, Flux.ssim_loss_fast) + fwd, back = Flux.pullback(f, ŷ, y) + @test fwd isa T + @test eltype(back(one(T))[1]) == T + end + end + end +end + From 247cf221c781fd2c5ee645fafea233a6692e3e1d Mon Sep 17 00:00:00 2001 From: Nikola Date: Thu, 16 Feb 2023 00:31:18 -0500 Subject: [PATCH 2/2] test fixes + fast loss NaN fix --- src/losses/functions.jl | 1 + test/cuda/losses.jl | 35 ++++++++++++++++++++--------------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/losses/functions.jl b/src/losses/functions.jl index 0166b6a228..80f1161f5e 100644 --- a/src/losses/functions.jl +++ b/src/losses/functions.jl @@ -766,6 +766,7 @@ See also [`ssim_loss`](@ref). """ function ssim_loss_fast(ŷ, y; kernel_length=5, kws...) kernel = ones_like(y, (kernel_length*ones(Int, ndims(y)-2)..., 1, 1)) + kernel = kernel ./ sum(kernel) return ssim_loss(ŷ, y, kernel; kws...) end diff --git a/test/cuda/losses.jl b/test/cuda/losses.jl index 9b1237970a..969d9562b0 100644 --- a/test/cuda/losses.jl +++ b/test/cuda/losses.jl @@ -26,25 +26,30 @@ y = [1 0 0 0 1 @test focal_loss(x, y) ≈ focal_loss(gpu(x), gpu(y)) @testset "GPU ssim tests" begin - for N=1:3 - @testset "num_dims=$N" begin - x = rand(Float32, 16*ones(Int, N)..., 2, 2) - y = rand(Float32, 16*ones(Int, N)..., 2, 2) - - for loss in (Flux.ssim, Flux.ssim_loss, Flux.ssim_loss_fast) - @test loss(x, y) ≈ loss(gpu(x), gpu(y)) - gpu_autodiff_test(loss, x, y) - - # Float16 tests - @test loss(f16(x), f16(y)) ≈ loss(gpu(f16(x)), gpu(f16(y))) - @test loss(f16(x), f16(y)) ≈ Float16(loss(x, y)) rtol=0.1 # no GPU in fact + @testset "num_dims=$N" for N=1:3 + x = rand(Float32, 16*ones(Int, N)..., 2, 2) + y = rand(Float32, 16*ones(Int, N)..., 2, 2) + + @testset "$loss" for loss in (Flux.ssim, Flux.ssim_loss, Flux.ssim_loss_fast) + @testset "cpu-gpu" begin loss(x, y) ≈ loss(gpu(x), gpu(y)) end + @testset "autodiff" begin gpu_autodiff_test(loss, x, y) end + # Float16 tests + @testset "f16 cpu-gpu" begin + @test isapprox(loss(f16(x), f16(y)), loss(gpu(f16(x)), gpu(f16(y))), rtol=0.1) broken=(N==3) + end + @testset "f16 cpu-cpu" begin + isapprox(loss(f16(x), f16(y)), Float16(loss(x, y)); rtol=0.1) + end + @testset "f16 grad" begin g16 = gradient(loss, f16(x), f16(y))[1] - @test g16 ≈ cpu(gradient(loss, f16(gpu(x)), f16(gpu(y)))[1]) + @test isapprox(g16, cpu(gradient(loss, f16(gpu(x)), f16(gpu(y)))[1]), rtol=0.1) broken=true end + end - # sanity checks - x = gpu(x) + # sanity checks + x = gpu(x) + @testset "sanity check" begin @test Flux.ssim(x, x) ≈ 1 @test Flux.ssim_loss(x, x) ≈ 0 @test Flux.ssim_loss_fast(x, x) ≈ 0