From 05fe34077dffe66ed2808c07e21e261496c4f089 Mon Sep 17 00:00:00 2001 From: st-- Date: Tue, 21 Dec 2021 00:49:37 +0200 Subject: [PATCH] use only() instead of first() (#403) * use only() instead of first() for 1-"vectors" that were for the benefit of Flux * fix one test that should not have worked as it was * add missing scalar Sinus constructor --- src/basekernels/constant.jl | 4 ++-- src/basekernels/exponential.jl | 4 ++-- src/basekernels/fbm.jl | 6 +++--- src/basekernels/matern.jl | 4 ++-- src/basekernels/polynomial.jl | 10 +++++----- src/basekernels/rational.jl | 16 ++++++++-------- src/distances/sinus.jl | 4 +++- src/kernels/scaledkernel.jl | 4 ++-- src/kernels/transformedkernel.jl | 4 ++-- src/transform/ardtransform.jl | 2 +- src/transform/periodic_transform.jl | 8 ++++---- src/transform/scaletransform.jl | 12 ++++++------ test/Project.toml | 1 + test/basekernels/constant.jl | 2 +- test/basekernels/exponential.jl | 2 +- test/distances/sinus.jl | 3 ++- test/runtests.jl | 1 + test/test_utils.jl | 4 ++-- 18 files changed, 48 insertions(+), 43 deletions(-) diff --git a/src/basekernels/constant.jl b/src/basekernels/constant.jl index 5996546f1..80087d758 100644 --- a/src/basekernels/constant.jl +++ b/src/basekernels/constant.jl @@ -73,8 +73,8 @@ end @functor ConstantKernel -kappa(κ::ConstantKernel, x::Real) = first(κ.c) * one(x) +kappa(κ::ConstantKernel, x::Real) = only(κ.c) * one(x) metric(::ConstantKernel) = Delta() -Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", first(κ.c), ")") +Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", only(κ.c), ")") diff --git a/src/basekernels/exponential.jl b/src/basekernels/exponential.jl index c7a788b8a..2061d40f9 100644 --- a/src/basekernels/exponential.jl +++ b/src/basekernels/exponential.jl @@ -137,7 +137,7 @@ end @functor GammaExponentialKernel -kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^first(κ.γ)) +kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^only(κ.γ)) metric(k::GammaExponentialKernel) = k.metric @@ -145,6 +145,6 @@ iskroncompatible(::GammaExponentialKernel) = true function Base.show(io::IO, κ::GammaExponentialKernel) return print( - io, "Gamma Exponential Kernel (γ = ", first(κ.γ), ", metric = ", κ.metric, ")" + io, "Gamma Exponential Kernel (γ = ", only(κ.γ), ", metric = ", κ.metric, ")" ) end diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index 213cb3c36..08c7b3695 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -28,16 +28,16 @@ function (κ::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) modX = sum(abs2, x) modY = sum(abs2, y) modXY = sqeuclidean(x, y) - h = first(κ.h) + h = only(κ.h) return (modX^h + modY^h - modXY^h) / 2 end function (κ::FBMKernel)(x::Real, y::Real) - return (abs2(x)^first(κ.h) + abs2(y)^first(κ.h) - abs2(x - y)^first(κ.h)) / 2 + return (abs2(x)^only(κ.h) + abs2(y)^only(κ.h) - abs2(x - y)^only(κ.h)) / 2 end function Base.show(io::IO, κ::FBMKernel) - return print(io, "Fractional Brownian Motion Kernel (h = ", first(κ.h), ")") + return print(io, "Fractional Brownian Motion Kernel (h = ", only(κ.h), ")") end _fbm(modX, modY, modXY, h) = (modX^h + modY^h - modXY^h) / 2 diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index a3c20efd7..a1ae4dfcc 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -34,7 +34,7 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, @functor MaternKernel @inline function kappa(κ::MaternKernel, d::Real) - result = _matern(first(κ.ν), d) + result = _matern(only(κ.ν), d) return ifelse(iszero(d), one(result), result) end @@ -46,7 +46,7 @@ end metric(k::MaternKernel) = k.metric function Base.show(io::IO, κ::MaternKernel) - return print(io, "Matern Kernel (ν = ", first(κ.ν), ", metric = ", κ.metric, ")") + return print(io, "Matern Kernel (ν = ", only(κ.ν), ", metric = ", κ.metric, ")") end ## Matern12Kernel = ExponentialKernel aliased in exponential.jl diff --git a/src/basekernels/polynomial.jl b/src/basekernels/polynomial.jl index da686e2c9..e0c0bfcb0 100644 --- a/src/basekernels/polynomial.jl +++ b/src/basekernels/polynomial.jl @@ -26,11 +26,11 @@ LinearKernel(; c::Real=0.0) = LinearKernel(c) @functor LinearKernel -kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + first(κ.c) +kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + only(κ.c) metric(::LinearKernel) = DotProduct() -Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", first(κ.c), ")") +Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", only(κ.c), ")") """ PolynomialKernel(; degree::Int=2, c::Real=0.0) @@ -53,7 +53,7 @@ struct PolynomialKernel{Tc<:Real} <: SimpleKernel function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc} @check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1") - @check_args(PolynomialKernel, c, first(c) >= zero(Tc), "c ≥ 0") + @check_args(PolynomialKernel, c, only(c) >= zero(Tc), "c ≥ 0") return new{Tc}(degree, c) end end @@ -68,10 +68,10 @@ function Functors.functor(::Type{<:PolynomialKernel}, x) return (c=x.c,), reconstruct_polynomialkernel end -kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + first(κ.c))^κ.degree +kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + only(κ.c))^κ.degree metric(::PolynomialKernel) = DotProduct() function Base.show(io::IO, κ::PolynomialKernel) - return print(io, "Polynomial Kernel (c = ", first(κ.c), ", degree = ", κ.degree, ")") + return print(io, "Polynomial Kernel (c = ", only(κ.c), ", degree = ", κ.degree, ")") end diff --git a/src/basekernels/rational.jl b/src/basekernels/rational.jl index 8ed396b51..0300b4f0d 100644 --- a/src/basekernels/rational.jl +++ b/src/basekernels/rational.jl @@ -32,13 +32,13 @@ end @functor RationalKernel function kappa(κ::RationalKernel, d::Real) - return (one(d) + d / first(κ.α))^(-first(κ.α)) + return (one(d) + d / only(κ.α))^(-only(κ.α)) end metric(k::RationalKernel) = k.metric function Base.show(io::IO, κ::RationalKernel) - return print(io, "Rational Kernel (α = ", first(κ.α), ", metric = ", κ.metric, ")") + return print(io, "Rational Kernel (α = ", only(κ.α), ", metric = ", κ.metric, ")") end """ @@ -72,10 +72,10 @@ end @functor RationalQuadraticKernel function kappa(κ::RationalQuadraticKernel, d::Real) - return (one(d) + d^2 / (2 * first(κ.α)))^(-first(κ.α)) + return (one(d) + d^2 / (2 * only(κ.α)))^(-only(κ.α)) end function kappa(κ::RationalQuadraticKernel{<:Real,<:Euclidean}, d²::Real) - return (one(d²) + d² / (2 * first(κ.α)))^(-first(κ.α)) + return (one(d²) + d² / (2 * only(κ.α)))^(-only(κ.α)) end metric(k::RationalQuadraticKernel) = k.metric @@ -83,7 +83,7 @@ metric(::RationalQuadraticKernel{<:Real,<:Euclidean}) = SqEuclidean() function Base.show(io::IO, κ::RationalQuadraticKernel) return print( - io, "Rational Quadratic Kernel (α = ", first(κ.α), ", metric = ", κ.metric, ")" + io, "Rational Quadratic Kernel (α = ", only(κ.α), ", metric = ", κ.metric, ")" ) end @@ -122,7 +122,7 @@ end @functor GammaRationalKernel function kappa(κ::GammaRationalKernel, d::Real) - return (one(d) + d^first(κ.γ) / first(κ.α))^(-first(κ.α)) + return (one(d) + d^only(κ.γ) / only(κ.α))^(-only(κ.α)) end metric(k::GammaRationalKernel) = k.metric @@ -131,9 +131,9 @@ function Base.show(io::IO, κ::GammaRationalKernel) return print( io, "Gamma Rational Kernel (α = ", - first(κ.α), + only(κ.α), ", γ = ", - first(κ.γ), + only(κ.γ), ", metric = ", κ.metric, ")", diff --git a/src/distances/sinus.jl b/src/distances/sinus.jl index 4bcf4bdf0..51d14c47d 100644 --- a/src/distances/sinus.jl +++ b/src/distances/sinus.jl @@ -2,10 +2,12 @@ struct Sinus{T} <: Distances.UnionSemiMetric r::Vector{T} end +Sinus(r::Real) = Sinus([r]) + Distances.parameters(d::Sinus) = d.r @inline Distances.eval_op(::Sinus, a::Real, b::Real, p::Real) = abs2(sinpi(a - b) / p) @inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b) -@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r)) +@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / only(dist.r)) Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T, Ta, Tb) diff --git a/src/kernels/scaledkernel.jl b/src/kernels/scaledkernel.jl index 897bdda1a..f7c6b9eae 100644 --- a/src/kernels/scaledkernel.jl +++ b/src/kernels/scaledkernel.jl @@ -23,7 +23,7 @@ end @functor ScaledKernel -(k::ScaledKernel)(x, y) = first(k.σ²) * k.kernel(x, y) +(k::ScaledKernel)(x, y) = only(k.σ²) * k.kernel(x, y) function kernelmatrix(κ::ScaledKernel, x::AbstractVector, y::AbstractVector) return κ.σ² .* kernelmatrix(κ.kernel, x, y) @@ -75,5 +75,5 @@ Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0) function printshifted(io::IO, κ::ScaledKernel, shift::Int) printshifted(io, κ.kernel, shift) - return print(io, "\n" * ("\t"^(shift + 1)) * "- σ² = $(first(κ.σ²))") + return print(io, "\n" * ("\t"^(shift + 1)) * "- σ² = $(only(κ.σ²))") end diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 94ae5c147..88e719ef1 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -28,10 +28,10 @@ function (k::TransformedKernel{<:SimpleKernel,<:ScaleTransform})( end function _scale(t::ScaleTransform, metric::Euclidean, x, y) - return first(t.s) * evaluate(metric, x, y) + return only(t.s) * evaluate(metric, x, y) end function _scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y) - return first(t.s)^2 * evaluate(metric, x, y) + return only(t.s)^2 * evaluate(metric, x, y) end _scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y)) diff --git a/src/transform/ardtransform.jl b/src/transform/ardtransform.jl index 4e71d0141..726d940ad 100644 --- a/src/transform/ardtransform.jl +++ b/src/transform/ardtransform.jl @@ -32,7 +32,7 @@ end dim(t::ARDTransform) = length(t.v) -(t::ARDTransform)(x::Real) = first(t.v) * x +(t::ARDTransform)(x::Real) = only(t.v) * x (t::ARDTransform)(x) = t.v .* x _map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x diff --git a/src/transform/periodic_transform.jl b/src/transform/periodic_transform.jl index 3430a63a1..098262309 100644 --- a/src/transform/periodic_transform.jl +++ b/src/transform/periodic_transform.jl @@ -25,16 +25,16 @@ PeriodicTransform(f::Real) = PeriodicTransform([f]) dim(t::PeriodicTransform) = 2 -(t::PeriodicTransform)(x::Real) = [sinpi(2 * first(t.f) * x), cospi(2 * first(t.f) * x)] +(t::PeriodicTransform)(x::Real) = [sinpi(2 * only(t.f) * x), cospi(2 * only(t.f) * x)] function _map(t::PeriodicTransform, x::AbstractVector{<:Real}) - return RowVecs(hcat(sinpi.((2 * first(t.f)) .* x), cospi.((2 * first(t.f)) .* x))) + return RowVecs(hcat(sinpi.((2 * only(t.f)) .* x), cospi.((2 * only(t.f)) .* x))) end function Base.isequal(t1::PeriodicTransform, t2::PeriodicTransform) - return isequal(first(t1.f), first(t2.f)) + return isequal(only(t1.f), only(t2.f)) end function Base.show(io::IO, t::PeriodicTransform) - return print(io, "Periodic Transform with frequency $(first(t.f))") + return print(io, "Periodic Transform with frequency $(only(t.f))") end diff --git a/src/transform/scaletransform.jl b/src/transform/scaletransform.jl index 4cd1c5443..164ed3b39 100644 --- a/src/transform/scaletransform.jl +++ b/src/transform/scaletransform.jl @@ -24,12 +24,12 @@ end set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ] -(t::ScaleTransform)(x) = first(t.s) * x +(t::ScaleTransform)(x) = only(t.s) * x -_map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x -_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X) -_map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X) +_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x +_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X) +_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X) -Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(first(t.s), first(t2.s)) +Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s)) -Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", first(t.s), ")") +Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", only(t.s), ")") diff --git a/test/Project.toml b/test/Project.toml index ef3a56dc2..eea862763 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" diff --git a/test/basekernels/constant.jl b/test/basekernels/constant.jl index e18df9419..f626fb297 100644 --- a/test/basekernels/constant.jl +++ b/test/basekernels/constant.jl @@ -36,6 +36,6 @@ # Standardised tests. TestUtils.test_interface(k, Float64) - test_ADs(c -> ConstantKernel(; c=first(c)), [c]) + test_ADs(c -> ConstantKernel(; c=only(c)), [c]) end end diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index a002bf29d..21586cadd 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -56,7 +56,7 @@ @test metric(k2) isa WeightedEuclidean @test k2(v1, v2) ≈ k(v1, v2) - test_ADs(γ -> GammaExponentialKernel(; gamma=first(γ)), [1 + 0.5 * rand()]) + test_ADs(γ -> GammaExponentialKernel(; gamma=only(γ)), [1 + 0.5 * rand()]) test_params(k, ([γ],)) TestUtils.test_interface(GammaExponentialKernel(; γ=1.36)) diff --git a/test/distances/sinus.jl b/test/distances/sinus.jl index 91ba1b028..d903e765d 100644 --- a/test/distances/sinus.jl +++ b/test/distances/sinus.jl @@ -5,5 +5,6 @@ d = KernelFunctions.Sinus(p) @test Distances.parameters(d) == p @test evaluate(d, A, B) == sum(abs2.(sinpi.(A - B) ./ p)) - @test d(3.0, 2.0) == abs2(sinpi(3.0 - 2.0) / first(p)) + d1 = KernelFunctions.Sinus(first(p)) + @test d1(3.0, 2.0) == abs2(sinpi(3.0 - 2.0) / first(p)) end diff --git a/test/runtests.jl b/test/runtests.jl index 63c8e8f89..c9bc932f3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,7 @@ using Zygote: Zygote using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff using FiniteDifferences: FiniteDifferences +using Compat: only using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils diff --git a/test/test_utils.jl b/test/test_utils.jl index 1871a99ca..b8c349e37 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -45,7 +45,7 @@ const FDM = FiniteDifferences.central_fdm(5, 1) gradient(f, s::Symbol, args) = gradient(f, Val(s), args) function gradient(f, ::Val{:Zygote}, args) - g = first(Zygote.gradient(f, args)) + g = only(Zygote.gradient(f, args)) if isnothing(g) if args isa AbstractArray{<:Real} return zeros(size(args)) # To respect the same output as other ADs @@ -66,7 +66,7 @@ function gradient(f, ::Val{:ReverseDiff}, args) end function gradient(f, ::Val{:FiniteDiff}, args) - return first(FiniteDifferences.grad(FDM, f, args)) + return only(FiniteDifferences.grad(FDM, f, args)) end function compare_gradient(f, ::Val{:FiniteDiff}, args)