From ef3b05b7aa425573b736d32efca4186be7b8fb98 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 17 May 2024 15:50:38 -0400 Subject: [PATCH] Ensure mixmax rng is not initialized with all zeros --- src/stan/services/util/create_rng.hpp | 5 +- .../unit/mcmc/hmc/nuts/base_nuts_test.cpp | 4 +- .../unit/mcmc/hmc/nuts/softabs_nuts_test.cpp | 14 +++--- .../unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp | 14 +++--- .../derived_static_uniform_test.cpp | 48 +++++++++---------- .../unit/mcmc/hmc/xhmc/base_xhmc_test.cpp | 4 +- .../unit/mcmc/hmc/xhmc/softabs_xhmc_test.cpp | 16 +++---- .../unit/mcmc/hmc/xhmc/unit_e_xhmc_test.cpp | 16 +++---- 8 files changed, 62 insertions(+), 59 deletions(-) diff --git a/src/stan/services/util/create_rng.hpp b/src/stan/services/util/create_rng.hpp index 231b57170b..a88d162695 100644 --- a/src/stan/services/util/create_rng.hpp +++ b/src/stan/services/util/create_rng.hpp @@ -26,7 +26,10 @@ namespace util { * @return an stan::rng_t instance */ inline rng_t create_rng(unsigned int seed, unsigned int chain) { - rng_t rng(seed + chain); + // RNG state is 128 bits, but user only provides 64 total bits + // Additionally, there are issues if all 128 bits are 0, hence + // the 1 as the second argument + rng_t rng(0, 1, seed, chain); return rng; } diff --git a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp index 3f230d01a3..f58ccf0d60 100644 --- a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp @@ -362,7 +362,7 @@ TEST(McmcNutsBaseNuts, transition) { EXPECT_EQ((2 << (sampler.get_max_depth() - 1)) - 1, sampler.n_leapfrog_); EXPECT_FALSE(sampler.divergent_); - EXPECT_EQ(-31 * init_momentum, s.cont_params()(0)); + EXPECT_EQ(23 * init_momentum, s.cont_params()(0)); EXPECT_EQ(0, s.log_prob()); EXPECT_EQ(1, s.accept_stat()); EXPECT_EQ("", debug.str()); @@ -373,7 +373,7 @@ TEST(McmcNutsBaseNuts, transition) { } TEST(McmcNutsBaseNuts, transition_egde_momenta) { - stan::rng_t base_rng = stan::services::util::create_rng(424243, 0); + stan::rng_t base_rng = stan::services::util::create_rng(42424253, 0); int model_size = 1; double init_momentum = 1.5; diff --git a/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp index e9d64cd9fa..8eaa84df6b 100644 --- a/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp @@ -338,15 +338,15 @@ TEST(McmcSoftAbsNuts, transition_test) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_EQ(5, sampler.depth_); - EXPECT_EQ((2 << 4) - 1, sampler.n_leapfrog_); + EXPECT_EQ(3, sampler.depth_); + EXPECT_EQ((2 << 3) - 1, sampler.n_leapfrog_); EXPECT_FALSE(sampler.divergent_); - EXPECT_FLOAT_EQ(-1.7373296, s.cont_params()(0)); - EXPECT_FLOAT_EQ(1.0898665, s.cont_params()(1)); - EXPECT_FLOAT_EQ(-0.38303182, s.cont_params()(2)); - EXPECT_FLOAT_EQ(-2.1764181, s.log_prob()); - EXPECT_FLOAT_EQ(0.9993856, s.accept_stat()); + EXPECT_FLOAT_EQ(0.74693149, s.cont_params()(0)); + EXPECT_FLOAT_EQ(-0.74414188, s.cont_params()(1)); + EXPECT_FLOAT_EQ(0.60859376, s.cont_params()(2)); + EXPECT_FLOAT_EQ(-0.74102008, s.log_prob()); + EXPECT_FLOAT_EQ(0.99934167, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); diff --git a/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp index 43f6ecb04d..6e5bade65a 100644 --- a/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp @@ -338,15 +338,15 @@ TEST(McmcUnitENuts, transition_test) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_EQ(5, sampler.depth_); - EXPECT_EQ((2 << 4) - 1, sampler.n_leapfrog_); + EXPECT_EQ(3, sampler.depth_); + EXPECT_EQ((2 << 3) - 1, sampler.n_leapfrog_); EXPECT_FALSE(sampler.divergent_); - EXPECT_FLOAT_EQ(-1.7890506, s.cont_params()(0)); - EXPECT_FLOAT_EQ(1.2320533, s.cont_params()(1)); - EXPECT_FLOAT_EQ(-0.62397981, s.cont_params()(2)); - EXPECT_FLOAT_EQ(-2.554004, s.log_prob()); - EXPECT_FLOAT_EQ(0.99910343, s.accept_stat()); + EXPECT_FLOAT_EQ(0.70149082, s.cont_params()(0)); + EXPECT_FLOAT_EQ(-0.69831347, s.cont_params()(1)); + EXPECT_FLOAT_EQ(0.54392564, s.cont_params()(2)); + EXPECT_FLOAT_EQ(-0.63779306, s.log_prob()); + EXPECT_FLOAT_EQ(0.99912512, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); diff --git a/src/test/unit/mcmc/hmc/static_uniform/derived_static_uniform_test.cpp b/src/test/unit/mcmc/hmc/static_uniform/derived_static_uniform_test.cpp index bd01ec70d2..15527eaa8b 100644 --- a/src/test/unit/mcmc/hmc/static_uniform/derived_static_uniform_test.cpp +++ b/src/test/unit/mcmc/hmc/static_uniform/derived_static_uniform_test.cpp @@ -41,9 +41,9 @@ TEST(McmcStaticUniform, unit_e_transition) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0)); - EXPECT_FLOAT_EQ(-1.2635686, s.log_prob()); - EXPECT_FLOAT_EQ(0.9994188, s.accept_stat()); + EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0)); + EXPECT_FLOAT_EQ(-0.59627211, s.log_prob()); + EXPECT_FLOAT_EQ(0.99985325, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); @@ -78,9 +78,9 @@ TEST(McmcStaticUniform, diag_e_transition) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0)); - EXPECT_FLOAT_EQ(-1.2635686, s.log_prob()); - EXPECT_FLOAT_EQ(0.9994188, s.accept_stat()); + EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0)); + EXPECT_FLOAT_EQ(-0.59627211, s.log_prob()); + EXPECT_FLOAT_EQ(0.99985325, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); @@ -115,9 +115,9 @@ TEST(McmcStaticUniform, dense_e_transition) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0)); - EXPECT_FLOAT_EQ(-1.2635686, s.log_prob()); - EXPECT_FLOAT_EQ(0.9994188, s.accept_stat()); + EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0)); + EXPECT_FLOAT_EQ(-0.59627211, s.log_prob()); + EXPECT_FLOAT_EQ(0.99985325, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); @@ -152,9 +152,9 @@ TEST(McmcStaticUniform, softabs_transition) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_FLOAT_EQ(1.5338461, s.cont_params()(0)); - EXPECT_FLOAT_EQ(-1.176342, s.log_prob()); - EXPECT_FLOAT_EQ(0.9996115, s.accept_stat()); + EXPECT_FLOAT_EQ(1.0826443, s.cont_params()(0)); + EXPECT_FLOAT_EQ(-0.58605933, s.log_prob()); + EXPECT_FLOAT_EQ(0.99989599, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); @@ -189,9 +189,9 @@ TEST(McmcStaticUniform, adapt_unit_e_transition) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0)); - EXPECT_FLOAT_EQ(-1.2635686, s.log_prob()); - EXPECT_FLOAT_EQ(0.9994188, s.accept_stat()); + EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0)); + EXPECT_FLOAT_EQ(-0.59627211, s.log_prob()); + EXPECT_FLOAT_EQ(0.99985325, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); @@ -226,9 +226,9 @@ TEST(McmcStaticUniform, adapt_diag_e_transition) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0)); - EXPECT_FLOAT_EQ(-1.2635686, s.log_prob()); - EXPECT_FLOAT_EQ(0.9994188, s.accept_stat()); + EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0)); + EXPECT_FLOAT_EQ(-0.59627211, s.log_prob()); + EXPECT_FLOAT_EQ(0.99985325, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); @@ -263,9 +263,9 @@ TEST(McmcStaticUniform, adapt_dense_e_transition) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_FLOAT_EQ(1.5896972, s.cont_params()(0)); - EXPECT_FLOAT_EQ(-1.2635686, s.log_prob()); - EXPECT_FLOAT_EQ(0.9994188, s.accept_stat()); + EXPECT_FLOAT_EQ(1.0920367, s.cont_params()(0)); + EXPECT_FLOAT_EQ(-0.59627211, s.log_prob()); + EXPECT_FLOAT_EQ(0.99985325, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); @@ -300,9 +300,9 @@ TEST(McmcStaticUniform, adapt_softabs_e_transition) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_FLOAT_EQ(1.5338461, s.cont_params()(0)); - EXPECT_FLOAT_EQ(-1.176342, s.log_prob()); - EXPECT_FLOAT_EQ(0.9996115, s.accept_stat()); + EXPECT_FLOAT_EQ(1.0826443, s.cont_params()(0)); + EXPECT_FLOAT_EQ(-0.58605933, s.log_prob()); + EXPECT_FLOAT_EQ(0.99989599, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str()); diff --git a/src/test/unit/mcmc/hmc/xhmc/base_xhmc_test.cpp b/src/test/unit/mcmc/hmc/xhmc/base_xhmc_test.cpp index 619f291a6e..70f9598136 100644 --- a/src/test/unit/mcmc/hmc/xhmc/base_xhmc_test.cpp +++ b/src/test/unit/mcmc/hmc/xhmc/base_xhmc_test.cpp @@ -221,7 +221,7 @@ TEST(McmcXHMCBaseXHMC, divergence_test) { } TEST(McmcXHMCBaseXHMC, transition) { - stan::rng_t base_rng = stan::services::util::create_rng(0, 0); + stan::rng_t base_rng = stan::services::util::create_rng(1234, 0); int model_size = 1; double init_momentum = 1.5; @@ -245,7 +245,7 @@ TEST(McmcXHMCBaseXHMC, transition) { stan::mcmc::sample s = sampler.transition(init_sample, logger); - EXPECT_EQ(-31 * init_momentum, s.cont_params()(0)); + EXPECT_EQ(5 * init_momentum, s.cont_params()(0)); EXPECT_EQ(0, s.log_prob()); EXPECT_EQ(1, s.accept_stat()); EXPECT_EQ("", debug.str()); diff --git a/src/test/unit/mcmc/hmc/xhmc/softabs_xhmc_test.cpp b/src/test/unit/mcmc/hmc/xhmc/softabs_xhmc_test.cpp index f5db914a35..faa9620053 100644 --- a/src/test/unit/mcmc/hmc/xhmc/softabs_xhmc_test.cpp +++ b/src/test/unit/mcmc/hmc/xhmc/softabs_xhmc_test.cpp @@ -58,13 +58,13 @@ TEST(McmcUnitEXHMC, build_tree) { EXPECT_FLOAT_EQ(1.5019561, sampler.z().p(1)); EXPECT_FLOAT_EQ(-1.5019561, sampler.z().p(2)); - EXPECT_FLOAT_EQ(0.42903179, z_propose.q(0)); - EXPECT_FLOAT_EQ(-0.42903179, z_propose.q(1)); - EXPECT_FLOAT_EQ(0.42903179, z_propose.q(2)); + EXPECT_FLOAT_EQ(0.8330583, z_propose.q(0)); + EXPECT_FLOAT_EQ(-0.8330583, z_propose.q(1)); + EXPECT_FLOAT_EQ(0.8330583, z_propose.q(2)); - EXPECT_FLOAT_EQ(-1.4385087, z_propose.p(0)); - EXPECT_FLOAT_EQ(1.4385087, z_propose.p(1)); - EXPECT_FLOAT_EQ(-1.4385087, z_propose.p(2)); + EXPECT_FLOAT_EQ(-1.1836562, z_propose.p(0)); + EXPECT_FLOAT_EQ(1.1836562, z_propose.p(1)); + EXPECT_FLOAT_EQ(-1.1836562, z_propose.p(2)); EXPECT_EQ(8, n_leapfrog); EXPECT_FLOAT_EQ(3.7645235, ave); @@ -79,7 +79,7 @@ TEST(McmcUnitEXHMC, build_tree) { } TEST(McmcUnitEXHMC, transition) { - stan::rng_t base_rng = stan::services::util::create_rng(483294, 0); + stan::rng_t base_rng = stan::services::util::create_rng(4832942, 0); stan::mcmc::softabs_point z_init(3); z_init.q(0) = 1; @@ -112,7 +112,7 @@ TEST(McmcUnitEXHMC, transition) { EXPECT_FLOAT_EQ(-1, s.cont_params()(1)); EXPECT_FLOAT_EQ(1, s.cont_params()(2)); EXPECT_FLOAT_EQ(-1.5, s.log_prob()); - EXPECT_FLOAT_EQ(0.99993229, s.accept_stat()); + EXPECT_FLOAT_EQ(0.99870497, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); diff --git a/src/test/unit/mcmc/hmc/xhmc/unit_e_xhmc_test.cpp b/src/test/unit/mcmc/hmc/xhmc/unit_e_xhmc_test.cpp index 0f9dcbb8dd..61639c3247 100644 --- a/src/test/unit/mcmc/hmc/xhmc/unit_e_xhmc_test.cpp +++ b/src/test/unit/mcmc/hmc/xhmc/unit_e_xhmc_test.cpp @@ -58,13 +58,13 @@ TEST(McmcUnitEXHMC, build_tree) { EXPECT_FLOAT_EQ(1.4131583, sampler.z().p(1)); EXPECT_FLOAT_EQ(-1.4131583, sampler.z().p(2)); - EXPECT_FLOAT_EQ(0.65928948, z_propose.q(0)); - EXPECT_FLOAT_EQ(-0.65928948, z_propose.q(1)); - EXPECT_FLOAT_EQ(0.65928948, z_propose.q(2)); + EXPECT_FLOAT_EQ(0.11940599, z_propose.q(0)); + EXPECT_FLOAT_EQ(-0.11940599, z_propose.q(1)); + EXPECT_FLOAT_EQ(0.11940599, z_propose.q(2)); - EXPECT_FLOAT_EQ(-1.2505695, z_propose.p(0)); - EXPECT_FLOAT_EQ(1.2505695, z_propose.p(1)); - EXPECT_FLOAT_EQ(-1.2505695, z_propose.p(2)); + EXPECT_FLOAT_EQ(-1.408289, z_propose.p(0)); + EXPECT_FLOAT_EQ(1.408289, z_propose.p(1)); + EXPECT_FLOAT_EQ(-1.408289, z_propose.p(2)); EXPECT_EQ(8, n_leapfrog); EXPECT_FLOAT_EQ(4.2207355, ave); @@ -79,7 +79,7 @@ TEST(McmcUnitEXHMC, build_tree) { } TEST(McmcUnitEXHMC, transition) { - stan::rng_t base_rng = stan::services::util::create_rng(483294, 0); + stan::rng_t base_rng = stan::services::util::create_rng(4832942, 0); stan::mcmc::unit_e_point z_init(3); z_init.q(0) = 1; @@ -112,7 +112,7 @@ TEST(McmcUnitEXHMC, transition) { EXPECT_FLOAT_EQ(-1, s.cont_params()(1)); EXPECT_FLOAT_EQ(1, s.cont_params()(2)); EXPECT_FLOAT_EQ(-1.5, s.log_prob()); - EXPECT_FLOAT_EQ(0.99994934, s.accept_stat()); + EXPECT_FLOAT_EQ(0.99870926, s.accept_stat()); EXPECT_EQ("", debug.str()); EXPECT_EQ("", info.str()); EXPECT_EQ("", warn.str());