Skip to content

Commit

Permalink
Clamp cdf and ccdf of Truncated (#1865)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored May 30, 2024
1 parent fe57164 commit fa493fb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ function logpdf(d::Truncated, x::Real)
end

function cdf(d::Truncated, x::Real)
result = (cdf(d.untruncated, x) - d.lcdf) / d.tp
result = clamp((cdf(d.untruncated, x) - d.lcdf) / d.tp, 0, 1)
# Special cases for values outside of the support to avoid e.g. NaN issues with `Binomial`
return if d.lower !== nothing && x < d.lower
zero(result)
elseif d.upper !== nothing && x >= d.upper
Expand All @@ -188,7 +189,8 @@ function logcdf(d::Truncated, x::Real)
end

function ccdf(d::Truncated, x::Real)
result = (d.ucdf - cdf(d.untruncated, x)) / d.tp
result = clamp((d.ucdf - cdf(d.untruncated, x)) / d.tp, 0, 1)
# Special cases for values outside of the support to avoid e.g. NaN issues with `Binomial`
return if d.lower !== nothing && x <= d.lower
one(result)
elseif d.upper !== nothing && x > d.upper
Expand Down
18 changes: 18 additions & 0 deletions test/truncate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,21 @@ end

@test isa(quantile(d, ForwardDiff.Dual(1.,0.)), ForwardDiff.Dual)
end

@testset "cdf outside of [0, 1] (#1854)" begin
dist = truncated(Normal(2.5, 0.2); lower=0.0)
@test @inferred(cdf(dist, 3.741058503233821e-17)) === 0.0
@test @inferred(ccdf(dist, 3.741058503233821e-17)) === 1.0
@test @inferred(cdf(dist, 1.4354474178676617e-18)) === 0.0
@test @inferred(ccdf(dist, 1.4354474178676617e-18)) === 1.0
@test @inferred(cdf(dist, 8.834854780587132e-18)) === 0.0
@test @inferred(ccdf(dist, 8.834854780587132e-18)) === 1.0

dist = truncated(
Normal(2.122039143928797, 0.07327367710864985);
lower = 1.9521656132878236,
upper = 2.8274429454898398,
)
@test @inferred(cdf(dist, 2.82)) === 1.0
@test @inferred(ccdf(dist, 2.82)) === 0.0
end

0 comments on commit fa493fb

Please sign in to comment.