diff --git a/Project.toml b/Project.toml index 7d7e129..9bb2749 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.6.0" +version = "5.6.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/src/sample.jl b/src/sample.jl index 2324604..32aca7d 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -391,13 +391,16 @@ function mcmcsample( # Copy the random number generator, model, and sample for each thread nchunks = min(nchains, Threads.nthreads()) - chunksize = cld(nchains, nchunks) interval = 1:nchunks # `copy` instead of `deepcopy` for RNGs: https://github.com/JuliaLang/julia/issues/42899 rngs = [copy(rng) for _ in interval] models = [deepcopy(model) for _ in interval] samplers = [deepcopy(sampler) for _ in interval] + # If nchains/nchunks = m with remainder n, then the first n chunks will + # have m + 1 chains, and the rest will have m chains. + m, n = divrem(nchains, nchunks) + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -437,12 +440,17 @@ function mcmcsample( Distributed.@async begin try Distributed.@sync for (i, _rng, _model, _sampler) in - zip(1:nchunks, rngs, models, samplers) - chainidxs = if i == nchunks - ((i - 1) * chunksize + 1):nchains + zip(interval, rngs, models, samplers) + if i <= n + chainidx_hi = i * (m + 1) + nchains_chunk = m + 1 else - ((i - 1) * chunksize + 1):(i * chunksize) + chainidx_hi = i * m + n # n * (m + 1) + (i - n) * m + nchains_chunk = m end + chainidx_lo = chainidx_hi - nchains_chunk + 1 + chainidxs = chainidx_lo:chainidx_hi + Threads.@spawn for chainidx in chainidxs # Seed the chunk-specific random number generator with the pre-made seed. Random.seed!(_rng, seeds[chainidx])