Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssim, ssim_loss, ssim_loss_fast #2178

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
ImageQualityIndexes = "2996bd0c-7a13-11e9-2da2-2f5ce47296a9"
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"

[targets]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Images", "ImageQualityIndexes", "TestImages"]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may be a mistake on my part, including Images etc. in [targets]

7 changes: 5 additions & 2 deletions src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ using Zygote: @adjoint
using ChainRulesCore
using ..Flux: ofeltype, epseltype
using CUDA
using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss
using Adapt
using MLUtils: ones_like
using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss, conv, pad_symmetric
import Base.Broadcast: broadcasted

export mse, mae, msle,
Expand All @@ -19,7 +21,8 @@ export mse, mae, msle,
dice_coeff_loss,
poisson_loss,
hinge_loss, squared_hinge_loss,
binary_focal_loss, focal_loss, siamese_contrastive_loss
binary_focal_loss, focal_loss, siamese_contrastive_loss,
ssim, ssim_loss, ssim_loss_fast
Copy link
Contributor

@johnnychen94 johnnychen94 Feb 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we shouldn't export ssim here if the main goal is to use ssim loss to train the network?

My own experience here is that people might end up using only ssim_loss_fast even if you provide multiple choices...


include("utils.jl")
include("functions.jl")
Expand Down
96 changes: 96 additions & 0 deletions src/losses/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,102 @@ function siamese_contrastive_loss(ŷ, y; agg = mean, margin::Real = 1)
return agg(@. (1 - y) * ŷ^2 + y * max(0, margin - ŷ)^2)
end

"""
ssim(x, y, kernel=ssim_kernel(x); peakval=1, crop=true, dims=:)

Return the [structural similarity index
measure](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM) between
two signals. SSIM is computed via the mean of a sliding window of
statistics computed between the two signals. By default, the sliding window is
a Gaussian with side-length 11 in each signal dimension and σ=1.5. `crop=false` will pad `x` and `y`
such that the sliding window computes statistics centered at every pixel of the input (via same-size convolution).
`ssim` computes statistics independently over channel and batch dimensions.
`x` and `y` may be 3D/4D/5D tensors with channel and batch-dimensions.

`peakval=1` is the standard for image comparisons, but in practice should be
set to the maximum value of your signal type.

`dims` determines which dimensions to average the computed statistics over. If
`dims=1:ndims(x)-1`, SSIM will be computed for each batch-element separately.

The results of `ssim` are matched against those of
[ImageQualityIndexes](https://github.com/JuliaImages/ImageQualityIndexes.jl)
for grayscale and RGB images (i.e. x, y both of size (N1, N2, 1, B) and (N1, N2, 3, B) for grayscale and color images, resp.).

See also [`ssim_loss`](@ref), [`ssim_loss_fast`](@ref).
"""
function ssim(x::AbstractArray{T,N}, y::AbstractArray{T,N}, kernel=ssim_kernel(x); peakval=T(1.0), crop=true, dims=:) where {T,N}
_check_sizes(x, y)

# apply same kernel on each channel dimension separately via groups=in_channels
groups = size(x, N-1)
kernel = repeat(kernel, ones(Int, N-1)..., groups)

# constants to avoid division by zero
SSIM_K = (0.01, 0.03)
C₁, C₂ = @. T(peakval * SSIM_K)^2

# crop==true -> valid-sized conv (do nothing),
# otherwise, pad for same-sized conv
if !crop
# from src/layers/conv.jl (calc_padding)
padding = Tuple(mapfoldl(i -> [cld(i, 2), fld(i,2)], vcat, size(kernel)[1:N-2] .- 1))
x = pad_symmetric(x, padding)
y = pad_symmetric(y, padding)
end

μx = conv(x, kernel; groups=groups)
μy = conv(y, kernel; groups=groups)
μx² = μx.^2
μy² = μy.^2
μxy = μx.*μy
σx² = conv(x.^2, kernel; groups=groups) .- μx²
σy² = conv(y.^2, kernel; groups=groups) .- μy²
σxy = conv(x.*y, kernel; groups=groups) .- μxy

ssim_map = @. (2μxy + C₁)*(2σxy + C₂)/((μx² + μy² + C₁)*(σx² + σy² + C₂))
return mean(ssim_map, dims=dims)
end

"""
ssim_loss(ŷ, y, kernel=ssim_kernel(x); peakval=1, crop=true, dims=:)

Computes the 1-ssim(ŷ,y), suitable for use as a loss function with gradient descent.
For faster training, it is recommended to store a kernel and reuse it, ex.,
```julia
kernel = Flux.Losses.ssim_kernel(Float32, 2) |> gpu
# or alternatively for faster computation
# kernel = ones(Float32, 5, 5, 1, num_channels) |> gpu

for (x, y) in dataloader
x, y = (x, y) .|> gpu
grads = Flux.gradient(model) do m
ŷ = m(x)
Flux.ssim_loss(ŷ, y, kernel)
end
# update the model ...
end
```
See [`ssim`](@ref) for a detailed description of SSIM and the above arguments.
See also [`ssim_loss_fast`](@ref).
"""
ssim_loss(x, args...; kws...) = one(eltype(x)) - ssim(x, args...; kws...)

"""
ssim_loss_fast(ŷ, y; kernel_length=5, peakval=1, crop=true, dims=:)

Computes `ssim_loss` with an averaging kernel instead of a large Gaussian
kernel for faster computation. `kernel_length` specifies the averaging kernel
side-length in each signal dimension of ŷ, y. See [`ssim`](@ref) for a
detailed description of SSIM and the above arguments.

See also [`ssim_loss`](@ref).
"""
function ssim_loss_fast(ŷ, y; kernel_length=5, kws...)
kernel = ones_like(y, (kernel_length*ones(Int, ndims(y)-2)..., 1, 1))
return ssim_loss(ŷ, y, kernel; kws...)
end

```@meta
DocTestFilters = nothing
```
36 changes: 36 additions & 0 deletions src/losses/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,39 @@ end
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1

ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)

# Gaussian kernel std=1.5, length=11
const SSIM_KERNEL =
[0.00102838008447911,
0.007598758135239185,
0.03600077212843083,
0.10936068950970002,
0.2130055377112537,
0.26601172486179436,
0.2130055377112537,
0.10936068950970002,
0.03600077212843083,
0.007598758135239185,
0.00102838008447911]

"""
ssim_kernel(T, N)

Return Gaussian kernel with σ=1.5 and side-length 11 for use in [`ssim`](@ref).
"""
function ssim_kernel(T::Type, N::Integer)
if N-2 == 1
kernel = SSIM_KERNEL
elseif N-2 == 2
kernel = SSIM_KERNEL*SSIM_KERNEL'
elseif N-2 == 3
ks = length(SSIM_KERNEL)
kernel = reshape(SSIM_KERNEL*SSIM_KERNEL', 1, ks, ks).*SSIM_KERNEL
else
throw("SSIM is only implemented for 3D/4D/5D inputs, dimension=$N provided.")
end
return reshape(T.(kernel), size(kernel)..., 1, 1)
end
ssim_kernel(x::Array{T, N}) where {T, N} = ssim_kernel(T, N)
ssim_kernel(x::AnyCuArray{T, N}) where {T, N} = cu(ssim_kernel(T, N))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since these functions are not of general utility they better go close to the ssim the losses


30 changes: 24 additions & 6 deletions test/cuda/losses.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using Flux.Losses: crossentropy, binarycrossentropy, logitbinarycrossentropy, binary_focal_loss, focal_loss


@testset "Losses" begin

x = [1.,2.,3.]
Expand All @@ -26,13 +25,32 @@ y = [1 0 0 0 1
0 0 1 0 0]
@test focal_loss(x, y) ≈ focal_loss(gpu(x), gpu(y))

@testset "GPU ssim tests" begin
for N=1:3
@testset "num_dims=$N" begin
x = rand(Float32, 16*ones(Int, N)..., 2, 2)
y = rand(Float32, 16*ones(Int, N)..., 2, 2)

for f in (Flux.ssim, Flux.ssim_loss, Flux.ssim_loss_fast)
@test f(x, y) ≈ f(gpu(x), gpu(y))
gpu_autodiff_test(f, x, y)
end

x = gpu(x)
@test Flux.ssim(x, x) ≈ 1
@test Flux.ssim_loss(x, x) ≈ 0
@test Flux.ssim_loss_fast(x, x) ≈ 0
end
end
end

@testset "GPU grad tests" begin
x = rand(Float32, 3,3)
y = rand(Float32, 3,3)
x = rand(Float32, 3,3)
y = rand(Float32, 3,3)

for loss in ALL_LOSSES
gpu_autodiff_test(loss, x, y)
end
for loss in ALL_LOSSES
gpu_autodiff_test(loss, x, y)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please stick to flux's indentation style in old and new

end

end #testset
83 changes: 74 additions & 9 deletions test/losses.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
using Test
using Flux: onehotbatch, σ

using Flux.Losses: mse, label_smoothing, crossentropy, logitcrossentropy, binarycrossentropy, logitbinarycrossentropy
using Flux.Losses: xlogx, xlogy

# for ssim_loss
using Images, TestImages, ImageQualityIndexes
using MLUtils: unsqueeze

# group here all losses, used in tests
const ALL_LOSSES = [Flux.Losses.mse, Flux.Losses.mae, Flux.Losses.msle,
Flux.Losses.crossentropy, Flux.Losses.logitcrossentropy,
Expand Down Expand Up @@ -169,15 +172,15 @@ end
end

@testset "no spurious promotions" begin
for T in (Float32, Float64)
y = rand(T, 2)
ŷ = rand(T, 2)
for f in ALL_LOSSES
fwd, back = Flux.pullback(f, ŷ, y)
@test fwd isa T
@test eltype(back(one(T))[1]) == T
for T in (Float32, Float64)
y = rand(T, 2)
ŷ = rand(T, 2)
for f in ALL_LOSSES
fwd, back = Flux.pullback(f, ŷ, y)
@test fwd isa T
@test eltype(back(one(T))[1]) == T
end
end
end
end

@testset "binary_focal_loss" begin
Expand Down Expand Up @@ -248,3 +251,65 @@ end
@test_throws DomainError(-0.5, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ1, y1, margin = -0.5)
@test_throws DomainError(-1, "Margin must be non-negative") Flux.siamese_contrastive_loss(ŷ, y, margin = -1)
end

# monarch_color_256 and fabio_color_256 testimages
# used to obtain below numbers.
# true/false denote `assess_ssim(...; crop=true/false)`
const iqi_rgb_true = 0.1299260389807608
const iqi_gry_true = 0.13380159790218638
const iqi_rgb_false = 0.13683875886675542
const iqi_gry_false = 0.14181793989104552

@testset "ssim_loss" begin
# color-image testing
# ssim values for monarch-fabio
@test Flux.Losses.SSIM_KERNEL == ImageQualityIndexes.SSIM_KERNEL.parent

# get reference images
imx_rgb = testimage("monarch_color_256")
imy_rgb = testimage("fabio_color_256")
imx_gry = Gray.(imx_rgb)
imy_gry = Gray.(imy_rgb)
x_rgb = permutedims(channelview(imx_rgb), (2, 3, 1)) .|> Float64 |> unsqueeze(dims=4)
y_rgb = permutedims(channelview(imy_rgb), (2, 3, 1)) .|> Float64 |> unsqueeze(dims=4)
x_gry = imx_gry .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4)
y_gry = imy_gry .|> Float64 |> unsqueeze(dims=3) |> unsqueeze(dims=4)

# 8 tests enumerating rgb/gray, crop/nocrop, iqi/flux vs. ref
for (ssim_iqi, crop) in
zip(((iqi_rgb_true, iqi_gry_true), (iqi_rgb_false, iqi_gry_false)), (true, false))
for (imx, imy, x, y, ssim_ref) in
zip((imx_rgb, imx_gry), (imy_rgb, imy_gry), (x_rgb, x_gry), (y_rgb, y_gry), ssim_iqi)

color = eltype(imx) <: RGB ? "RGB" : "Gray"
@testset "crop=$crop, color=$color" begin
# make sure IQI is same
@test assess_ssim(imx, imy; crop=crop) ≈ ssim_ref
# test flux against IQI on Image Array
@test Flux.ssim(x, y; crop=crop) ≈ ssim_ref atol=1e-6
end
end
end

for N=1:3
x = rand(Float32, 16*ones(Int, N)..., 2, 2)
@testset "num_dims=$N" begin
@test Flux.ssim(x, x) ≈ 1
@test Flux.ssim_loss(x, x) ≈ 0
@test Flux.ssim_loss_fast(x, x) ≈ 0
end
end

@testset "no spurious promotions" begin
for T in (Float32, Float64)
y = rand(T, 15, 1, 1)
ŷ = rand(T, 15, 1, 1)
for f in (Flux.ssim, Flux.ssim_loss, Flux.ssim_loss_fast)
fwd, back = Flux.pullback(f, ŷ, y)
@test fwd isa T
@test eltype(back(one(T))[1]) == T
end
end
end
end