From f51166d9c7fc1b7e4c9e3f101f2f85b33638e9ee Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 17:41:00 +0100 Subject: [PATCH 1/6] Add StatsBase as a dependency --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 9d9a654..023a8ad 100644 --- a/Project.toml +++ b/Project.toml @@ -11,9 +11,11 @@ DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] AbstractMCMC = "2, 3, 4" DensityInterface = "0.4" Setfield = "0.8.2, 1" +StatsBase = "0.32, 0.33" julia = "~1.6.6, 1.7.3" From f4d7823d43c59170aadd42bf6bf9f689e554a0b1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 25 Feb 2023 17:41:16 +0100 Subject: [PATCH 2/6] Implement StatsBase.predict --- src/abstractprobprog.jl | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/abstractprobprog.jl b/src/abstractprobprog.jl index 30b8b35..b5dfbe4 100644 --- a/src/abstractprobprog.jl +++ b/src/abstractprobprog.jl @@ -1,6 +1,7 @@ using AbstractMCMC using DensityInterface using Random +using StatsBase """ @@ -80,3 +81,29 @@ end function Base.rand(model::AbstractProbabilisticProgram) return rand(Random.default_rng(), NamedTuple, model) end + +""" + predict( + [rng::AbstractRNG=Random.default_rng(),] + [T=NamedTuple,] + model::AbstractProbabilisticProgram, + params, + ) -> T + +Draw a sample from the joint distribution specified by `model` conditioned on the values in +`params`. + +The sample will be returned as format specified by `T`. +""" +function StatsBase.predict(rand::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params) + return rand(rng, T, condition(model, params)) +end +function StatsBase.predict(T::Type, model::AbstractProbabilisticProgram, params) + return StatsBase.predict(Random.default_rng(), T, model, params) +end +function StatsBase.predict(model::AbstractProbabilisticProgram, params) + return StatsBase.predict(NamedTuple, model, params) +end +function StatsBase.predict(rng::AbstractRNG, params) + return StatsBase.predict(rng, NamedTuple, model, params) +end From b72a96368dc247549b9bbe10e13282ccd003517c Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 28 Oct 2024 05:30:03 +0000 Subject: [PATCH 3/6] use `fix` and fix some errors --- src/abstractprobprog.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/abstractprobprog.jl b/src/abstractprobprog.jl index c1b137f..98b37ea 100644 --- a/src/abstractprobprog.jl +++ b/src/abstractprobprog.jl @@ -131,13 +131,12 @@ end params, ) -> T -Draw a sample from the joint distribution specified by `model` conditioned on the values in -`params`. +Draw a sample from the predictive distribution specified by `model` with its parameters fixed to `params`. The sample will be returned as format specified by `T`. """ -function StatsBase.predict(rand::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params) - return rand(rng, T, condition(model, params)) +function StatsBase.predict(rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params) + return rand(rng, T, fix(model, params)) end function StatsBase.predict(T::Type, model::AbstractProbabilisticProgram, params) return StatsBase.predict(Random.default_rng(), T, model, params) @@ -145,6 +144,6 @@ end function StatsBase.predict(model::AbstractProbabilisticProgram, params) return StatsBase.predict(NamedTuple, model, params) end -function StatsBase.predict(rng::AbstractRNG, params) +function StatsBase.predict(rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params) return StatsBase.predict(rng, NamedTuple, model, params) end From 076a7a54ad5e4992ef6f75cbe7f8bb137d65fd5b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 6 Dec 2024 21:53:44 +0000 Subject: [PATCH 4/6] Format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/abstractprobprog.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/abstractprobprog.jl b/src/abstractprobprog.jl index 4209d57..3e8dd3b 100644 --- a/src/abstractprobprog.jl +++ b/src/abstractprobprog.jl @@ -130,7 +130,9 @@ Draw a sample from the predictive distribution specified by `model` with its par The sample will be returned as format specified by `T`. """ -function StatsBase.predict(rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params) +function StatsBase.predict( + rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params +) return rand(rng, T, fix(model, params)) end function StatsBase.predict(T::Type, model::AbstractProbabilisticProgram, params) @@ -139,6 +141,8 @@ end function StatsBase.predict(model::AbstractProbabilisticProgram, params) return StatsBase.predict(NamedTuple, model, params) end -function StatsBase.predict(rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params) +function StatsBase.predict( + rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params +) return StatsBase.predict(rng, NamedTuple, model, params) end From 58a7931a8c8a7838a0bdbfa1f11e4b504b32cecb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 6 Dec 2024 21:58:10 +0000 Subject: [PATCH 5/6] Bump StatsBase compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 6252e32..5304405 100644 --- a/Project.toml +++ b/Project.toml @@ -20,5 +20,5 @@ Accessors = "0.1" DensityInterface = "0.4" JSON = "0.19 - 0.21" Random = "1.6" -StatsBase = "0.32, 0.33" +StatsBase = "0.32, 0.33, 0.34" julia = "~1.6.6, 1.7.3" From 66bf79e640e66bb519a27e1ca385525e55c0c426 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 16 Dec 2024 09:57:11 -0800 Subject: [PATCH 6/6] slim down implementations --- Project.toml | 1 - src/abstractprobprog.jl | 20 ++------------------ 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index 5304405..df46559 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] diff --git a/src/abstractprobprog.jl b/src/abstractprobprog.jl index 3e8dd3b..32e125d 100644 --- a/src/abstractprobprog.jl +++ b/src/abstractprobprog.jl @@ -121,28 +121,12 @@ end """ predict( [rng::AbstractRNG=Random.default_rng(),] - [T=NamedTuple,] model::AbstractProbabilisticProgram, params, - ) -> T + ) Draw a sample from the predictive distribution specified by `model` with its parameters fixed to `params`. - -The sample will be returned as format specified by `T`. """ -function StatsBase.predict( - rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params -) - return rand(rng, T, fix(model, params)) -end -function StatsBase.predict(T::Type, model::AbstractProbabilisticProgram, params) - return StatsBase.predict(Random.default_rng(), T, model, params) -end function StatsBase.predict(model::AbstractProbabilisticProgram, params) - return StatsBase.predict(NamedTuple, model, params) -end -function StatsBase.predict( - rng::AbstractRNG, T::Type, model::AbstractProbabilisticProgram, params -) - return StatsBase.predict(rng, NamedTuple, model, params) + return predict(Random.default_rng(), NamedTuple, model, params) end