From 9b5a072a0ebaa5782c419ad10e6dd4763536f513 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Jun 2023 01:32:14 +0200 Subject: [PATCH 1/6] Make Test a weak dependency --- Project.toml | 8 +- ext/KernelFunctionsTestExt.jl | 268 +++++++++++++++++++++++++++++++++ src/KernelFunctions.jl | 1 + src/TestUtils.jl | 275 ++-------------------------------- 4 files changed, 287 insertions(+), 265 deletions(-) create mode 100644 ext/KernelFunctionsTestExt.jl diff --git a/Project.toml b/Project.toml index c763a8824..f978c5567 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.55" +version = "0.10.56" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -21,6 +21,12 @@ TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" +[weakdeps] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[extensions] +KernelFunctionsTestExt = "Test" + [compat] ChainRulesCore = "1" Compat = "3.7, 4" diff --git a/ext/KernelFunctionsTestExt.jl b/ext/KernelFunctionsTestExt.jl new file mode 100644 index 000000000..843a0e9c1 --- /dev/null +++ b/ext/KernelFunctionsTestExt.jl @@ -0,0 +1,268 @@ +module KernelFunctionsTestExt + +using KernelFunctions +using KernelFunctions: TestUtils, LinearAlgebra, Random +using Test + +""" + test_interface( + k::Kernel, + x0::AbstractVector, + x1::AbstractVector, + x2::AbstractVector; + rtol=1e-6, + atol=rtol, + ) + +Run various consistency checks on `k` at the inputs `x0`, `x1`, and `x2`. +`x0` and `x1` should be of the same length with different values, while `x0` and `x2` should +be of different lengths. + +These tests are intended to pick up on really substantial issues with a kernel implementation +(e.g. substantial asymmetry in the kernel matrix, large negative eigenvalues), rather than to +test the numerics in detail, which can be kernel-specific. +""" +function TestUtils.test_interface( + k::Kernel, + x0::AbstractVector, + x1::AbstractVector, + x2::AbstractVector; + rtol=1e-6, + atol=rtol, +) + # Ensure that we have the required inputs. + @assert length(x0) == length(x1) + @assert length(x0) ≠ length(x2) + + # Check that kernelmatrix_diag basically works. + @test kernelmatrix_diag(k, x0, x1) isa AbstractVector + @test length(kernelmatrix_diag(k, x0, x1)) == length(x0) + + # Check that pairwise basically works. + @test kernelmatrix(k, x0, x2) isa AbstractMatrix + @test size(kernelmatrix(k, x0, x2)) == (length(x0), length(x2)) + + # Check that elementwise is consistent with pairwise. + @test kernelmatrix_diag(k, x0, x1) ≈ LinearAlgebra.diag(kernelmatrix(k, x0, x1)) atol = atol rtol = + rtol + + # Check additional binary elementwise properties for kernels. + @test kernelmatrix_diag(k, x0, x1) ≈ kernelmatrix_diag(k, x1, x0) + @test kernelmatrix(k, x0, x2) ≈ permutedims(kernelmatrix(k, x2, x0)) atol = atol rtol = rtol + + # Check that unary elementwise basically works. + @test kernelmatrix_diag(k, x0) isa AbstractVector + @test length(kernelmatrix_diag(k, x0)) == length(x0) + + # Check that unary pairwise basically works. + @test kernelmatrix(k, x0) isa AbstractMatrix + @test size(kernelmatrix(k, x0)) == (length(x0), length(x0)) + @test kernelmatrix(k, x0) ≈ permutedims(kernelmatrix(k, x0)) atol = atol rtol = rtol + + # Check that unary elementwise is consistent with unary pairwise. + @test kernelmatrix_diag(k, x0) ≈ LinearAlgebra.diag(kernelmatrix(k, x0)) atol = atol rtol = rtol + + # Check that unary pairwise produces a positive definite matrix (approximately). + @test LinearAlgebra.eigmin(Matrix(kernelmatrix(k, x0))) > -atol + + # Check that unary elementwise / pairwise are consistent with the binary versions. + @test kernelmatrix_diag(k, x0) ≈ kernelmatrix_diag(k, x0, x0) atol = atol rtol = rtol + @test kernelmatrix(k, x0) ≈ kernelmatrix(k, x0, x0) atol = atol rtol = rtol + + # Check that basic kernel evaluation succeeds and is consistent with `kernelmatrix`. + @test k(first(x0), first(x1)) isa Real + @test kernelmatrix(k, x0, x2) ≈ [k(xl, xr) for xl in x0, xr in x2] + + tmp = Matrix{Float64}(undef, length(x0), length(x2)) + @test kernelmatrix!(tmp, k, x0, x2) ≈ kernelmatrix(k, x0, x2) + + tmp_square = Matrix{Float64}(undef, length(x0), length(x0)) + @test kernelmatrix!(tmp_square, k, x0) ≈ kernelmatrix(k, x0) + + tmp_diag = Vector{Float64}(undef, length(x0)) + @test kernelmatrix_diag!(tmp_diag, k, x0) ≈ kernelmatrix_diag(k, x0) + @test kernelmatrix_diag!(tmp_diag, k, x0, x1) ≈ kernelmatrix_diag(k, x0, x1) +end + +""" + test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}=Float64; kwargs...) where {T} + +Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`, +`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`. + +For other input types, please provide the data manually. + +The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the +randomly generated inputs. +""" +function TestUtils.test_interface(k::Kernel, T::Type=Float64; kwargs...) + return TestUtils.test_interface(Random.default_rng(), k, T; kwargs...) +end + +function TestUtils.test_interface(rng::Random.AbstractRNG, k::Kernel, T::Type=Float64; kwargs...) + return TestUtils.test_with_type(TestUtils.test_interface, rng, k, T; kwargs...) +end + +""" + test_type_stability( + k::Kernel, + x0::AbstractVector, + x1::AbstractVector, + x2::AbstractVector, + ) + +Run type stability checks over `k(x,y)` and the different functions of the API +(`kernelmatrix`, `kernelmatrix_diag`). `x0` and `x1` should be of the same +length with different values, while `x0` and `x2` should be of different lengths. +""" +function TestUtils.test_type_stability( + k::Kernel, x0::AbstractVector, x1::AbstractVector, x2::AbstractVector +) + # Ensure that we have the required inputs. + @assert length(x0) == length(x1) + @assert length(x0) ≠ length(x2) + @test @inferred(kernelmatrix(k, x0)) isa AbstractMatrix + @test @inferred(kernelmatrix(k, x0, x2)) isa AbstractMatrix + @test @inferred(kernelmatrix_diag(k, x0)) isa AbstractVector + @test @inferred(kernelmatrix_diag(k, x0, x1)) isa AbstractVector +end + +function TestUtils.test_type_stability(k::Kernel, ::Type{T}=Float64; kwargs...) where {T} + return TestUtils.test_type_stability(Random.default_rng(), k, T; kwargs...) +end + +function TestUtils.test_type_stability(rng::Random.AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T} + return TestUtils.test_with_type(TestUtils.test_type_stability, rng, k, T; kwargs...) +end + +""" + test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T} + +Run the functions `f`, (for example [`test_interface`](@ref) or +[`test_type_stable`](@ref)) for randomly generated inputs of types `Vector{T}`, +`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`. + +For other input types, please provide the data manually. + +The keyword arguments are forwarded to the invocations of `f` with the +randomly generated inputs. +""" +function TestUtils.test_with_type(f, rng::Random.AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T} + @testset "Vector{$T}" begin + TestUtils.test_with_type(f, rng, k, Vector{T}; kwargs...) + end + @testset "ColVecs{$T}" begin + TestUtils.test_with_type(f, rng, k, ColVecs{T}; kwargs...) + end + @testset "RowVecs{$T}" begin + TestUtils.test_with_type(f, rng, k, RowVecs{T}; kwargs...) + end + @testset "Vector{Vector{$T}}" begin + TestUtils.test_with_type(f, rng, k, Vector{Vector{T}}; kwargs...) + end +end + +function TestUtils.test_with_type( + f, rng::Random.AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs... +) where {T<:Real} + return f(k, randn(rng, T, 11), randn(rng, T, 11), randn(rng, T, 13); kwargs...) +end + +function TestUtils.test_with_type( + f, rng::Random.AbstractRNG, k::MOKernel, ::Type{Vector{Tuple{T,Int}}}; dim_out=3, kwargs... +) where {T<:Real} + return f( + k, + [(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:11], + [(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:11], + [(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:13]; + kwargs..., + ) +end + +function TestUtils.test_with_type( + f, rng::Random.AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs... +) where {T<:Real} + return f( + k, + ColVecs(randn(rng, T, dim_in, 11)), + ColVecs(randn(rng, T, dim_in, 11)), + ColVecs(randn(rng, T, dim_in, 13)); + kwargs..., + ) +end + +function TestUtils.test_with_type( + f, rng::Random.AbstractRNG, k::Kernel, ::Type{<:RowVecs{T}}; dim_in=2, kwargs... +) where {T<:Real} + return f( + k, + RowVecs(randn(rng, T, 11, dim_in)), + RowVecs(randn(rng, T, 11, dim_in)), + RowVecs(randn(rng, T, 13, dim_in)); + kwargs..., + ) +end + +function TestUtils.test_with_type( + f, rng::Random.AbstractRNG, k::Kernel, ::Type{<:Vector{Vector{T}}}; dim_in=2, kwargs... +) where {T<:Real} + return f( + k, + [randn(rng, T, dim_in) for _ in 1:11], + [randn(rng, T, dim_in) for _ in 1:11], + [randn(rng, T, dim_in) for _ in 1:13]; + kwargs..., + ) +end + +function TestUtils.test_with_type(f, rng::Random.AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwargs...) + return f( + k, + [Random.randstring(rng) for _ in 1:3], + [Random.randstring(rng) for _ in 1:3], + [Random.randstring(rng) for _ in 1:4]; + kwargs..., + ) +end + +function test_with_type( + f, rng::Random.AbstractRNG, k::Kernel, ::Type{ColVecs{String}}; dim_in=2, kwargs... +) + return f( + k, + ColVecs([Random.randstring(rng) for _ in 1:dim_in, _ in 1:3]), + ColVecs([Random.randstring(rng) for _ in 1:dim_in, _ in 1:3]), + ColVecs([Random.randstring(rng) for _ in 1:dim_in, _ in 1:4]); + kwargs..., + ) +end + +function TestUtils.test_with_type(f, k::Kernel, T::Type{<:Real}; kwargs...) + return TestUtils.test_with_type(f, Random.default_rng(), k, T; kwargs...) +end + +""" + example_inputs(rng::AbstractRNG, type) + +Return a tuple of 4 inputs of type `type`. See `methods(example_inputs)` for information +around supported types. It is recommended that you utilise `StableRNGs.jl` for `rng` here +to ensure consistency across Julia versions. +""" +function TestUtils.example_inputs(rng::Random.AbstractRNG, ::Type{Vector{Float64}}) + return map(n -> randn(rng, Float64, n), (1, 2, 3, 4)) +end + +function TestUtils.example_inputs( + rng::Random.AbstractRNG, ::Type{ColVecs{Float64,Matrix{Float64}}}; dim::Int=2 +) + return map(n -> ColVecs(randn(rng, dim, n)), (1, 2, 3, 4)) +end + +function TestUtils.example_inputs( + rng::Random.AbstractRNG, ::Type{RowVecs{Float64,Matrix{Float64}}}; dim::Int=2 +) + return map(n -> RowVecs(randn(rng, n, dim)), (1, 2, 3, 4)) +end + +end # module diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 63205b5bf..76f86478d 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -53,6 +53,7 @@ using CompositionsBase using Distances using FillArrays using Functors +using Random using LinearAlgebra using Requires using SpecialFunctions: loggamma, besselk, polygamma diff --git a/src/TestUtils.jl b/src/TestUtils.jl index cd14ec718..b1c5967bc 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -1,270 +1,17 @@ module TestUtils -using Distances -using LinearAlgebra -using KernelFunctions -using Random -using Test - -""" - test_interface( - k::Kernel, - x0::AbstractVector, - x1::AbstractVector, - x2::AbstractVector; - rtol=1e-6, - atol=rtol, - ) - -Run various consistency checks on `k` at the inputs `x0`, `x1`, and `x2`. -`x0` and `x1` should be of the same length with different values, while `x0` and `x2` should -be of different lengths. - -These tests are intended to pick up on really substantial issues with a kernel implementation -(e.g. substantial asymmetry in the kernel matrix, large negative eigenvalues), rather than to -test the numerics in detail, which can be kernel-specific. -""" -function test_interface( - k::Kernel, - x0::AbstractVector, - x1::AbstractVector, - x2::AbstractVector; - rtol=1e-6, - atol=rtol, -) - # Ensure that we have the required inputs. - @assert length(x0) == length(x1) - @assert length(x0) ≠ length(x2) - - # Check that kernelmatrix_diag basically works. - @test kernelmatrix_diag(k, x0, x1) isa AbstractVector - @test length(kernelmatrix_diag(k, x0, x1)) == length(x0) - - # Check that pairwise basically works. - @test kernelmatrix(k, x0, x2) isa AbstractMatrix - @test size(kernelmatrix(k, x0, x2)) == (length(x0), length(x2)) - - # Check that elementwise is consistent with pairwise. - @test kernelmatrix_diag(k, x0, x1) ≈ diag(kernelmatrix(k, x0, x1)) atol = atol rtol = - rtol - - # Check additional binary elementwise properties for kernels. - @test kernelmatrix_diag(k, x0, x1) ≈ kernelmatrix_diag(k, x1, x0) - @test kernelmatrix(k, x0, x2) ≈ kernelmatrix(k, x2, x0)' atol = atol rtol = rtol - - # Check that unary elementwise basically works. - @test kernelmatrix_diag(k, x0) isa AbstractVector - @test length(kernelmatrix_diag(k, x0)) == length(x0) - - # Check that unary pairwise basically works. - @test kernelmatrix(k, x0) isa AbstractMatrix - @test size(kernelmatrix(k, x0)) == (length(x0), length(x0)) - @test kernelmatrix(k, x0) ≈ kernelmatrix(k, x0)' atol = atol rtol = rtol - - # Check that unary elementwise is consistent with unary pairwise. - @test kernelmatrix_diag(k, x0) ≈ diag(kernelmatrix(k, x0)) atol = atol rtol = rtol - - # Check that unary pairwise produces a positive definite matrix (approximately). - @test eigmin(Matrix(kernelmatrix(k, x0))) > -atol - - # Check that unary elementwise / pairwise are consistent with the binary versions. - @test kernelmatrix_diag(k, x0) ≈ kernelmatrix_diag(k, x0, x0) atol = atol rtol = rtol - @test kernelmatrix(k, x0) ≈ kernelmatrix(k, x0, x0) atol = atol rtol = rtol - - # Check that basic kernel evaluation succeeds and is consistent with `kernelmatrix`. - @test k(first(x0), first(x1)) isa Real - @test kernelmatrix(k, x0, x2) ≈ [k(xl, xr) for xl in x0, xr in x2] - - tmp = Matrix{Float64}(undef, length(x0), length(x2)) - @test kernelmatrix!(tmp, k, x0, x2) ≈ kernelmatrix(k, x0, x2) - - tmp_square = Matrix{Float64}(undef, length(x0), length(x0)) - @test kernelmatrix!(tmp_square, k, x0) ≈ kernelmatrix(k, x0) - - tmp_diag = Vector{Float64}(undef, length(x0)) - @test kernelmatrix_diag!(tmp_diag, k, x0) ≈ kernelmatrix_diag(k, x0) - @test kernelmatrix_diag!(tmp_diag, k, x0, x1) ≈ kernelmatrix_diag(k, x0, x1) -end - -""" - test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}=Float64; kwargs...) where {T} - -Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`, -`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`. - -For other input types, please provide the data manually. - -The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the -randomly generated inputs. -""" -function test_interface(k::Kernel, T::Type=Float64; kwargs...) - return test_interface(Random.GLOBAL_RNG, k, T; kwargs...) -end - -function test_interface(rng::AbstractRNG, k::Kernel, T::Type=Float64; kwargs...) - return test_with_type(test_interface, rng, k, T; kwargs...) -end - -""" - test_type_stability( - k::Kernel, - x0::AbstractVector, - x1::AbstractVector, - x2::AbstractVector, - ) - -Run type stability checks over `k(x,y)` and the different functions of the API -(`kernelmatrix`, `kernelmatrix_diag`). `x0` and `x1` should be of the same -length with different values, while `x0` and `x2` should be of different lengths. -""" -function test_type_stability( - k::Kernel, x0::AbstractVector, x1::AbstractVector, x2::AbstractVector -) - # Ensure that we have the required inputs. - @assert length(x0) == length(x1) - @assert length(x0) ≠ length(x2) - @test @inferred(kernelmatrix(k, x0)) isa AbstractMatrix - @test @inferred(kernelmatrix(k, x0, x2)) isa AbstractMatrix - @test @inferred(kernelmatrix_diag(k, x0)) isa AbstractVector - @test @inferred(kernelmatrix_diag(k, x0, x1)) isa AbstractVector -end - -function test_type_stability(k::Kernel, ::Type{T}=Float64; kwargs...) where {T} - return test_type_stability(Random.GLOBAL_RNG, k, T; kwargs...) -end - -function test_type_stability(rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T} - return test_with_type(test_type_stability, rng, k, T; kwargs...) -end - -""" - test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T} - -Run the functions `f`, (for example [`test_interface`](@ref) or -[`test_type_stable`](@ref)) for randomly generated inputs of types `Vector{T}`, -`Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`. - -For other input types, please provide the data manually. - -The keyword arguments are forwarded to the invocations of `f` with the -randomly generated inputs. -""" -function test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T} - @testset "Vector{$T}" begin - test_with_type(f, rng, k, Vector{T}; kwargs...) +function test_interface end +function test_with_type end +function test_type_stability end +function example_inputs end + +function __init__() + # Better error message if users forget to load Test + Base.Experimental.register_error_hint(MethodError) do io, exc, _, _ + if exc.f === test_interface || exc.f === test_with_type || exc.f === test_type_stability || exc.f === example_inputs + print(io, "\\nDid you forget to load Test?") + end end - @testset "ColVecs{$T}" begin - test_with_type(f, rng, k, ColVecs{T}; kwargs...) - end - @testset "RowVecs{$T}" begin - test_with_type(f, rng, k, RowVecs{T}; kwargs...) - end - @testset "Vector{Vector{$T}}" begin - test_with_type(f, rng, k, Vector{Vector{T}}; kwargs...) - end -end - -function test_with_type( - f, rng::AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs... -) where {T<:Real} - return f(k, randn(rng, T, 11), randn(rng, T, 11), randn(rng, T, 13); kwargs...) -end - -function test_with_type( - f, rng::AbstractRNG, k::MOKernel, ::Type{Vector{Tuple{T,Int}}}; dim_out=3, kwargs... -) where {T<:Real} - return f( - k, - [(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:11], - [(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:11], - [(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:13]; - kwargs..., - ) -end - -function test_with_type( - f, rng::AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs... -) where {T<:Real} - return f( - k, - ColVecs(randn(rng, T, dim_in, 11)), - ColVecs(randn(rng, T, dim_in, 11)), - ColVecs(randn(rng, T, dim_in, 13)); - kwargs..., - ) -end - -function test_with_type( - f, rng::AbstractRNG, k::Kernel, ::Type{<:RowVecs{T}}; dim_in=2, kwargs... -) where {T<:Real} - return f( - k, - RowVecs(randn(rng, T, 11, dim_in)), - RowVecs(randn(rng, T, 11, dim_in)), - RowVecs(randn(rng, T, 13, dim_in)); - kwargs..., - ) -end - -function test_with_type( - f, rng::AbstractRNG, k::Kernel, ::Type{<:Vector{Vector{T}}}; dim_in=2, kwargs... -) where {T<:Real} - return f( - k, - [randn(rng, T, dim_in) for _ in 1:11], - [randn(rng, T, dim_in) for _ in 1:11], - [randn(rng, T, dim_in) for _ in 1:13]; - kwargs..., - ) -end - -function test_with_type(f, rng::AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwargs...) - return f( - k, - [randstring(rng) for _ in 1:3], - [randstring(rng) for _ in 1:3], - [randstring(rng) for _ in 1:4]; - kwargs..., - ) -end - -function test_with_type( - f, rng::AbstractRNG, k::Kernel, ::Type{ColVecs{String}}; dim_in=2, kwargs... -) - return f( - k, - ColVecs([randstring(rng) for _ in 1:dim_in, _ in 1:3]), - ColVecs([randstring(rng) for _ in 1:dim_in, _ in 1:3]), - ColVecs([randstring(rng) for _ in 1:dim_in, _ in 1:4]); - kwargs..., - ) -end - -function test_with_type(f, k::Kernel, T::Type{<:Real}; kwargs...) - return test_with_type(f, Random.GLOBAL_RNG, k, T; kwargs...) -end - -""" - example_inputs(rng::AbstractRNG, type) - -Return a tuple of 4 inputs of type `type`. See `methods(example_inputs)` for information -around supported types. It is recommended that you utilise `StableRNGs.jl` for `rng` here -to ensure consistency across Julia versions. -""" -function example_inputs(rng::AbstractRNG, ::Type{Vector{Float64}}) - return map(n -> randn(rng, Float64, n), (1, 2, 3, 4)) -end - -function example_inputs( - rng::AbstractRNG, ::Type{ColVecs{Float64,Matrix{Float64}}}; dim::Int=2 -) - return map(n -> ColVecs(randn(rng, dim, n)), (1, 2, 3, 4)) -end - -function example_inputs( - rng::AbstractRNG, ::Type{RowVecs{Float64,Matrix{Float64}}}; dim::Int=2 -) - return map(n -> RowVecs(randn(rng, n, dim)), (1, 2, 3, 4)) end end # module From 616dbca28668ba576d65ef24d0f17da82d3f2d10 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Jun 2023 01:53:48 +0200 Subject: [PATCH 2/6] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/KernelFunctionsTestExt.jl | 33 ++++++++++++++++++++++++--------- src/TestUtils.jl | 5 ++++- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/ext/KernelFunctionsTestExt.jl b/ext/KernelFunctionsTestExt.jl index 843a0e9c1..cfe476d25 100644 --- a/ext/KernelFunctionsTestExt.jl +++ b/ext/KernelFunctionsTestExt.jl @@ -43,12 +43,13 @@ function TestUtils.test_interface( @test size(kernelmatrix(k, x0, x2)) == (length(x0), length(x2)) # Check that elementwise is consistent with pairwise. - @test kernelmatrix_diag(k, x0, x1) ≈ LinearAlgebra.diag(kernelmatrix(k, x0, x1)) atol = atol rtol = - rtol + @test kernelmatrix_diag(k, x0, x1) ≈ LinearAlgebra.diag(kernelmatrix(k, x0, x1)) atol = + atol rtol = rtol # Check additional binary elementwise properties for kernels. @test kernelmatrix_diag(k, x0, x1) ≈ kernelmatrix_diag(k, x1, x0) - @test kernelmatrix(k, x0, x2) ≈ permutedims(kernelmatrix(k, x2, x0)) atol = atol rtol = rtol + @test kernelmatrix(k, x0, x2) ≈ permutedims(kernelmatrix(k, x2, x0)) atol = atol rtol = + rtol # Check that unary elementwise basically works. @test kernelmatrix_diag(k, x0) isa AbstractVector @@ -60,7 +61,8 @@ function TestUtils.test_interface( @test kernelmatrix(k, x0) ≈ permutedims(kernelmatrix(k, x0)) atol = atol rtol = rtol # Check that unary elementwise is consistent with unary pairwise. - @test kernelmatrix_diag(k, x0) ≈ LinearAlgebra.diag(kernelmatrix(k, x0)) atol = atol rtol = rtol + @test kernelmatrix_diag(k, x0) ≈ LinearAlgebra.diag(kernelmatrix(k, x0)) atol = atol rtol = + rtol # Check that unary pairwise produces a positive definite matrix (approximately). @test LinearAlgebra.eigmin(Matrix(kernelmatrix(k, x0))) > -atol @@ -99,7 +101,9 @@ function TestUtils.test_interface(k::Kernel, T::Type=Float64; kwargs...) return TestUtils.test_interface(Random.default_rng(), k, T; kwargs...) end -function TestUtils.test_interface(rng::Random.AbstractRNG, k::Kernel, T::Type=Float64; kwargs...) +function TestUtils.test_interface( + rng::Random.AbstractRNG, k::Kernel, T::Type=Float64; kwargs... +) return TestUtils.test_with_type(TestUtils.test_interface, rng, k, T; kwargs...) end @@ -131,7 +135,9 @@ function TestUtils.test_type_stability(k::Kernel, ::Type{T}=Float64; kwargs...) return TestUtils.test_type_stability(Random.default_rng(), k, T; kwargs...) end -function TestUtils.test_type_stability(rng::Random.AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T} +function TestUtils.test_type_stability( + rng::Random.AbstractRNG, k::Kernel, ::Type{T}; kwargs... +) where {T} return TestUtils.test_with_type(TestUtils.test_type_stability, rng, k, T; kwargs...) end @@ -147,7 +153,9 @@ For other input types, please provide the data manually. The keyword arguments are forwarded to the invocations of `f` with the randomly generated inputs. """ -function TestUtils.test_with_type(f, rng::Random.AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T} +function TestUtils.test_with_type( + f, rng::Random.AbstractRNG, k::Kernel, ::Type{T}; kwargs... +) where {T} @testset "Vector{$T}" begin TestUtils.test_with_type(f, rng, k, Vector{T}; kwargs...) end @@ -169,7 +177,12 @@ function TestUtils.test_with_type( end function TestUtils.test_with_type( - f, rng::Random.AbstractRNG, k::MOKernel, ::Type{Vector{Tuple{T,Int}}}; dim_out=3, kwargs... + f, + rng::Random.AbstractRNG, + k::MOKernel, + ::Type{Vector{Tuple{T,Int}}}; + dim_out=3, + kwargs..., ) where {T<:Real} return f( k, @@ -216,7 +229,9 @@ function TestUtils.test_with_type( ) end -function TestUtils.test_with_type(f, rng::Random.AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwargs...) +function TestUtils.test_with_type( + f, rng::Random.AbstractRNG, k::Kernel, ::Type{Vector{String}}; kwargs... +) return f( k, [Random.randstring(rng) for _ in 1:3], diff --git a/src/TestUtils.jl b/src/TestUtils.jl index b1c5967bc..bb74471b6 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -8,7 +8,10 @@ function example_inputs end function __init__() # Better error message if users forget to load Test Base.Experimental.register_error_hint(MethodError) do io, exc, _, _ - if exc.f === test_interface || exc.f === test_with_type || exc.f === test_type_stability || exc.f === example_inputs + if exc.f === test_interface || + exc.f === test_with_type || + exc.f === test_type_stability || + exc.f === example_inputs print(io, "\\nDid you forget to load Test?") end end From e9dbd71600398b37b79df7274dff324c2421c318 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 27 Jun 2023 01:56:20 +0200 Subject: [PATCH 3/6] Fix Julia < 1.9 --- src/KernelFunctions.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 76f86478d..40658e9d6 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -126,6 +126,9 @@ include("chainrules.jl") include("zygoterules.jl") include("TestUtils.jl") +if !isdefined(Base, :get_extension) + include("../ext/KernelFunctionsTestExt.jl") +end function __init__() @require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin From ab090024082890cc1b985353555e321c64ad5239 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 20 Oct 2023 10:25:06 +0200 Subject: [PATCH 4/6] Improve error message --- src/TestUtils.jl | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/TestUtils.jl b/src/TestUtils.jl index bb74471b6..97f70650f 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -5,14 +5,17 @@ function test_with_type end function test_type_stability end function example_inputs end -function __init__() - # Better error message if users forget to load Test - Base.Experimental.register_error_hint(MethodError) do io, exc, _, _ - if exc.f === test_interface || - exc.f === test_with_type || - exc.f === test_type_stability || - exc.f === example_inputs - print(io, "\\nDid you forget to load Test?") +if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_error_hint) + function __init__() + # Better error message if users forget to load Test + Base.Experimental.register_error_hint(MethodError) do io, exc, _, _ + if (exc.f === test_interface || + exc.f === test_with_type || + exc.f === test_type_stability || + exc.f === example_inputs) && + (Base.get_extension(Distributions, :DistributionsTestExt) === nothing) + print(io, "\nDid you forget to load Test?") + end end end end From f5035670c7853535fb8d101c447831c2cef0c77d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 20 Oct 2023 10:40:18 +0200 Subject: [PATCH 5/6] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/TestUtils.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 97f70650f..2751f6192 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -9,11 +9,12 @@ if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_err function __init__() # Better error message if users forget to load Test Base.Experimental.register_error_hint(MethodError) do io, exc, _, _ - if (exc.f === test_interface || + if ( + exc.f === test_interface || exc.f === test_with_type || exc.f === test_type_stability || - exc.f === example_inputs) && - (Base.get_extension(Distributions, :DistributionsTestExt) === nothing) + exc.f === example_inputs + ) && (Base.get_extension(Distributions, :DistributionsTestExt) === nothing) print(io, "\nDid you forget to load Test?") end end From ff5a142edfba4f759889175820c336d9f3ce0e11 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 20 Oct 2023 10:41:17 +0200 Subject: [PATCH 6/6] Update TestUtils.jl --- src/TestUtils.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 2751f6192..5663fb97c 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -1,5 +1,7 @@ module TestUtils +using ..KernelFunctions: KernelFunctions + function test_interface end function test_with_type end function test_type_stability end @@ -14,7 +16,7 @@ if isdefined(Base, :get_extension) && isdefined(Base.Experimental, :register_err exc.f === test_with_type || exc.f === test_type_stability || exc.f === example_inputs - ) && (Base.get_extension(Distributions, :DistributionsTestExt) === nothing) + ) && (Base.get_extension(KernelFunctions, :KernelFunctionsTestExt) === nothing) print(io, "\nDid you forget to load Test?") end end