Skip to content

Commit

Permalink
Return -Inf in logpdf of LKJCholesky when out of support (#1610)
Browse files Browse the repository at this point in the history
* ArgumentError to DomainError

* add test

* return -Inf

* Update test/cholesky/lkjcholesky.jl

Co-authored-by: David Widmann <[email protected]>

* fix and test Float32 support

* remove redundant test

* Update src/matrix/lkj.jl

Co-authored-by: David Widmann <[email protected]>

* Update test/cholesky/lkjcholesky.jl

Co-authored-by: David Widmann <[email protected]>

* fix sumlogs

* simplify condition

* Update src/matrix/lkj.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/matrix/lkj.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/matrix/lkj.jl

Co-authored-by: David Widmann <[email protected]>

* Remove static type parameter

* Update src/matrix/lkj.jl

* fix tests

* more type generalisations

* Update test/cholesky/lkjcholesky.jl

Co-authored-by: David Widmann <[email protected]>

* Update test/cholesky/lkjcholesky.jl

Co-authored-by: David Widmann <[email protected]>

* address comments

Co-authored-by: David Widmann <[email protected]>
Co-authored-by: mohamed82008 <[email protected]>
  • Loading branch information
3 people authored Sep 1, 2022
1 parent 2dc764e commit ca13aa7
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 24 deletions.
4 changes: 2 additions & 2 deletions src/cholesky/lkjcholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ function logkernel(d::LKJCholesky, R::LinearAlgebra.Cholesky)
end

function logpdf(d::LKJCholesky, R::LinearAlgebra.Cholesky)
insupport(d, R) || throw(ArgumentError("provided point is not in the support"))
return _logpdf(d, R)
lp = _logpdf(d, R)
return insupport(d, R) ? lp : oftype(lp, -Inf)
end

_logpdf(d::LKJCholesky, R::LinearAlgebra.Cholesky) = logkernel(d, R) + d.logc0
Expand Down
37 changes: 20 additions & 17 deletions src/matrix/lkj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ params(d::LKJ) = (d.d, d.η)
# -----------------------------------------------------------------------------

function lkj_logc0(d::Integer, η::Real)
T = float(Base.promote_typeof(d, η))
d > 1 || return zero(η)
if isone(η)
if iseven(d)
logc0 = -lkj_onion_loginvconst_uniform_even(d)
logc0 = -lkj_onion_loginvconst_uniform_even(d, T)
else
logc0 = -lkj_onion_loginvconst_uniform_odd(d)
logc0 = -lkj_onion_loginvconst_uniform_odd(d, T)
end
else
logc0 = -lkj_onion_loginvconst(d, η)
Expand Down Expand Up @@ -188,32 +189,34 @@ end

function lkj_onion_loginvconst(d::Integer, η::Real)
# Equation (17) in LKJ (2009 JMA)
sumlogs = zero(η)
for k in 2:d - 1
sumlogs += 0.5k*logπ + loggamma+ 0.5(d - 1 - k))
T = float(Base.promote_typeof(d, η))
h = T(1//2)
α = η + h * d - 1
loginvconst = (2*η + d - 3)*T(logtwo) + (T(logπ) / 4) * (d * (d - 1) - 2) + logbeta(α, α) - (d - 2) * loggamma+ h * (d - 1))
for k in 2:(d - 1)
loginvconst += loggamma+ h * (d - 1 - k))
end
α = η + 0.5d - 1
loginvconst = (2η + d - 3)*logtwo + logbeta(α, α) + sumlogs - (d - 2) * loggamma+ 0.5(d - 1))
return loginvconst
end

function lkj_onion_loginvconst_uniform_odd(d::Integer)
function lkj_onion_loginvconst_uniform_odd(d::Integer, ::Type{T}) where {T <: Real}
# Theorem 5 in LKJ (2009 JMA)
sumlogs = 0.0
for k in 1:div(d - 1, 2)
sumlogs += loggamma(2k)
h = T(1//2)
loginvconst = (d - 1) * ((d + 1) * (T(logπ) / 4) - (d - 1) * (T(logtwo) / 4) - loggamma(h * (d + 1)))
for k in 2:2:(d - 1)
loginvconst += loggamma(T(k))
end
loginvconst = 0.25(d^2 - 1)*logπ + sumlogs - 0.25(d - 1)^2*logtwo - (d - 1)*loggamma(0.5(d + 1))
return loginvconst
end

function lkj_onion_loginvconst_uniform_even(d::Integer)
function lkj_onion_loginvconst_uniform_even(d::Integer, ::Type{T}) where {T <: Real}
# Theorem 5 in LKJ (2009 JMA)
sumlogs = 0.0
for k in 1:div(d - 2, 2)
sumlogs += loggamma(2k)
h = T(1//2)
loginvconst = d * ((d - 2) * (T(logπ) / 4) + (3 * d - 4) * (T(logtwo) / 4) + loggamma(h * d)) - (d - 1) * loggamma(T(d))
for k in 2:2:(d - 2)
loginvconst += loggamma(k)
end
loginvconst = 0.25d*(d - 2)*logπ + 0.25(3d^2 - 4d)*logtwo + d*loggamma(0.5d) + sumlogs - (d - 1)*loggamma(d)
return loginvconst
end

function lkj_vine_loginvconst(d::Integer, η::Real)
Expand Down
7 changes: 6 additions & 1 deletion test/cholesky/lkjcholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,17 @@ using FiniteDifferences
@test m isa Cholesky{eltype(d)}
@test Matrix(m) I
end
@test_broken partype(LKJCholesky(2, 4f0)) <: Float32
for (d, η) in ((2, 4), (2, 1), (3, 1)), T in (Float32, Float64)
@test @inferred(partype(LKJCholesky(d, T(η)))) === T
end

@testset "insupport" begin
@test insupport(LKJCholesky(40, 2, 'U'), cholesky(rand(LKJ(40, 2))))
@test insupport(LKJCholesky(40, 2), cholesky(rand(LKJ(40, 2))))
@test !insupport(LKJCholesky(40, 2), cholesky(rand(LKJ(41, 2))))
for (d, η) in ((2, 4), (2, 1), (3, 1)), T in (Float32, Float64)
@test @inferred(logpdf(LKJCholesky(40, T(2)), cholesky(T.(rand(LKJ(41, 2)))))) === T(-Inf)
end
z = rand(LKJ(40, 1))
z .+= exp(Symmetric(randn(size(z)))) .* 1e-8
x = cholesky(z)
Expand Down
8 changes: 4 additions & 4 deletions test/matrixvariates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,11 @@ function test_special(dist::Type{LKJ})
η = 1.0
lkj = LKJ(d, η)
@test Distributions.lkj_vine_loginvconst(d, η) Distributions.lkj_onion_loginvconst(d, η)
@test Distributions.lkj_onion_loginvconst(d, η) Distributions.lkj_onion_loginvconst_uniform_odd(d)
@test Distributions.lkj_onion_loginvconst(d, η) Distributions.lkj_onion_loginvconst_uniform_odd(d, Float64)
@test Distributions.lkj_vine_loginvconst(d, η) Distributions.lkj_vine_loginvconst_uniform(d)
@test Distributions.lkj_onion_loginvconst(d, η) Distributions.lkj_loginvconst_alt(d, η)
@test Distributions.lkj_onion_loginvconst(d, η) Distributions.corr_logvolume(d)
@test lkj.logc0 == -Distributions.lkj_onion_loginvconst_uniform_odd(d)
@test lkj.logc0 == -Distributions.lkj_onion_loginvconst_uniform_odd(d, Float64)
# =============
# even non-uniform
# =============
Expand All @@ -475,11 +475,11 @@ function test_special(dist::Type{LKJ})
η = 1.0
lkj = LKJ(d, η)
@test Distributions.lkj_vine_loginvconst(d, η) Distributions.lkj_onion_loginvconst(d, η)
@test Distributions.lkj_onion_loginvconst(d, η) Distributions.lkj_onion_loginvconst_uniform_even(d)
@test Distributions.lkj_onion_loginvconst(d, η) Distributions.lkj_onion_loginvconst_uniform_even(d, Float64)
@test Distributions.lkj_vine_loginvconst(d, η) Distributions.lkj_vine_loginvconst_uniform(d)
@test Distributions.lkj_onion_loginvconst(d, η) Distributions.lkj_loginvconst_alt(d, η)
@test Distributions.lkj_onion_loginvconst(d, η) Distributions.corr_logvolume(d)
@test lkj.logc0 == -Distributions.lkj_onion_loginvconst_uniform_even(d)
@test lkj.logc0 == -Distributions.lkj_onion_loginvconst_uniform_even(d, Float64)
end
@testset "check integrating constant as a volume" begin
# d = 2: Lebesgue measure of the set of correlation matrices is 2.
Expand Down

0 comments on commit ca13aa7

Please sign in to comment.