From fa493fb84020b9c8f9449ffeccbc5b0d6c313831 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 30 May 2024 08:58:28 +0200 Subject: [PATCH] Clamp `cdf` and `ccdf` of `Truncated` (#1865) --- src/truncate.jl | 6 ++++-- test/truncate.jl | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/truncate.jl b/src/truncate.jl index 45709f6b5..48d62b015 100644 --- a/src/truncate.jl +++ b/src/truncate.jl @@ -166,7 +166,8 @@ function logpdf(d::Truncated, x::Real) end function cdf(d::Truncated, x::Real) - result = (cdf(d.untruncated, x) - d.lcdf) / d.tp + result = clamp((cdf(d.untruncated, x) - d.lcdf) / d.tp, 0, 1) + # Special cases for values outside of the support to avoid e.g. NaN issues with `Binomial` return if d.lower !== nothing && x < d.lower zero(result) elseif d.upper !== nothing && x >= d.upper @@ -188,7 +189,8 @@ function logcdf(d::Truncated, x::Real) end function ccdf(d::Truncated, x::Real) - result = (d.ucdf - cdf(d.untruncated, x)) / d.tp + result = clamp((d.ucdf - cdf(d.untruncated, x)) / d.tp, 0, 1) + # Special cases for values outside of the support to avoid e.g. NaN issues with `Binomial` return if d.lower !== nothing && x <= d.lower one(result) elseif d.upper !== nothing && x > d.upper diff --git a/test/truncate.jl b/test/truncate.jl index b409cba93..9c2a286d0 100644 --- a/test/truncate.jl +++ b/test/truncate.jl @@ -214,3 +214,21 @@ end @test isa(quantile(d, ForwardDiff.Dual(1.,0.)), ForwardDiff.Dual) end + +@testset "cdf outside of [0, 1] (#1854)" begin + dist = truncated(Normal(2.5, 0.2); lower=0.0) + @test @inferred(cdf(dist, 3.741058503233821e-17)) === 0.0 + @test @inferred(ccdf(dist, 3.741058503233821e-17)) === 1.0 + @test @inferred(cdf(dist, 1.4354474178676617e-18)) === 0.0 + @test @inferred(ccdf(dist, 1.4354474178676617e-18)) === 1.0 + @test @inferred(cdf(dist, 8.834854780587132e-18)) === 0.0 + @test @inferred(ccdf(dist, 8.834854780587132e-18)) === 1.0 + + dist = truncated( + Normal(2.122039143928797, 0.07327367710864985); + lower = 1.9521656132878236, + upper = 2.8274429454898398, + ) + @test @inferred(cdf(dist, 2.82)) === 1.0 + @test @inferred(ccdf(dist, 2.82)) === 0.0 +end