Skip to content

Commit

Permalink
Merge pull request #2 from FredericWantiez/feature/split
Browse files Browse the repository at this point in the history
Reset rng
  • Loading branch information
FredericWantiez authored Aug 17, 2021
2 parents 22fa456 + 5b01908 commit 8134751
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 78 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
AbstractMCMC = "2, 3"
Distributions = "0.23, 0.24, 0.25"
Libtask = "0.5.3"
Random123 = "1.3"
StatsFuns = "0.9"
julia = "1.3"
47 changes: 24 additions & 23 deletions src/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,21 @@ function Trace(f, rng::TracedRNG)
end

function Trace(f, ctask::Libtask.CTask)
rng = TracedRNG()
return Trace(f, ctask, rng)
return Trace(f, ctask, TracedRNG())
end

# Copy task and RNG
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask))
# Copy task
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask), deepcopy(trace.rng))

# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
advance!(t::Trace) = Libtask.consume(t.ctask)
function advance!(t::Trace, isref::Bool)
isref ? reset_rng!(t.rng) : save_state!(t.rng)
inc_count!(t.rng)

# Move to next step
return Libtask.consume(t.ctask)
end

# reset log probability
reset_logprob!(t::Trace) = nothing
Expand All @@ -55,6 +60,7 @@ end
# Create new task and copy randomness
function forkr(trace::Trace)
newf = reset_model(trace.f)
set_count!(trace.rng, 1)

ctask = let f = trace.ctask.task.code
Libtask.CTask() do
Expand Down Expand Up @@ -89,23 +95,15 @@ Data structure for particle filters
- normalise!(pc::ParticleContainer)
- consume(pc::ParticleContainer): return incremental likelihood
"""
mutable struct ParticleContainer{T<:Particle,R<:Random.AbstractRNG}
mutable struct ParticleContainer{T<:Particle}
"Particles."
vals::Vector{T}
"Unnormalized logarithmic weights."
logWs::Vector{Float64}
"TracedRNG to track the resampling step"
rng::TracedRNG{R}
end

function ParticleContainer(particles::Vector{<:Particle})
return ParticleContainer(particles, zeros(length(particles)), TracedRNG())
end

function ParticleContainer(
particles::Vector{<:Particle}, rng::T
) where {T<:Random.AbstractRNG}
return ParticleContainer(particles, zeros(length(particles)), TracedRNG(rng))
return ParticleContainer(particles, zeros(length(particles)))
end

Base.collect(pc::ParticleContainer) = pc.vals
Expand All @@ -132,7 +130,7 @@ function Base.copy(pc::ParticleContainer)
# copy weights
logWs = copy(pc.logWs)

return ParticleContainer(vals, logWs, pc.rng)
return ParticleContainer(vals, logWs)
end

"""
Expand Down Expand Up @@ -231,9 +229,12 @@ function resample_propagate!(
p = isref ? fork(pi, isref) : pi
children[j += 1] = p

seeds = split(pi.rng, ni)
# fork additional children
for _ in 2:ni
children[j += 1] = fork(p, isref)
for k in 2:ni
part = fork(p, isref)
seed!(part.rng, seeds[k])
children[j += 1] = part
end
end
end
Expand Down Expand Up @@ -274,7 +275,7 @@ end
Check if the final time step is reached, and otherwise reweight the particles by
considering the next observation.
"""
function reweight!(pc::ParticleContainer)
function reweight!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing)
n = length(pc)

particles = collect(pc)
Expand All @@ -286,7 +287,8 @@ function reweight!(pc::ParticleContainer)
# the execution of the model is finished.
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and
# ``θᵢ`` are variables of other samplers.
score = advance!(p)
isref = p === ref
score = advance!(p, isref)

if score === nothing
numdone += 1
Expand Down Expand Up @@ -337,7 +339,6 @@ function sweep!(
ref::Union{Particle,Nothing}=nothing,
)
# Initial step:

# Resample and propagate particles.
resample_propagate!(rng, pc, resampler, ref)

Expand All @@ -349,7 +350,7 @@ function sweep!(
logZ0 = logZ(pc)

# Reweight the particles by including the first observation ``y₁``.
isdone = reweight!(pc)
isdone = reweight!(pc, ref)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand All @@ -367,7 +368,7 @@ function sweep!(
logZ0 = logZ(pc)

# Reweight the particles by including the next observation ``yₜ``.
isdone = reweight!(pc)
isdone = reweight!(pc, ref)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand Down
100 changes: 74 additions & 26 deletions src/rng.jl
Original file line number Diff line number Diff line change
@@ -1,40 +1,88 @@
using Random123
using Random
using Distributions

import Base.rand
import Random.seed!

# Use Philox2x for now
BASE_RNG = Philox2x

"""
Data structure to keep track of the history of the random stream
produced by RNG.
TracedRNG{R,T}
Wrapped random number generator from Random123 to keep track of random streams during model evaluation
"""
mutable struct TracedRNG{T} <: Random.AbstractRNG where {T<:Random.AbstractRNG}
count::Base.RefValue{Int}
mutable struct TracedRNG{T} <:
Random.AbstractRNG where {T<:(Random123.AbstractR123{R} where {R})}
count::Int
rng::T
seed::Array
states::Array{T}
keys
counters
end

# Set seed manually, for init ?
function Random.seed!(rng::TracedRNG, seed)
rng.rng.seed = seed
return Random.seed!(rng.rng, seed)
function TracedRNG(r::Random123.AbstractR123)
return TracedRNG(1, r, typeof(r.key)[], typeof(r.ctr1)[])
end

# Reset the rng to the initial seed
Random.seed!(rng::TracedRNG) = Random.seed!(rng.rng, rng.seed)
"""
TracedRNG()
TracedRNG() = TracedRNG(Random.MersenneTwister()) # Pick up an explicit RNG from Random
TracedRNG(rng::Random.AbstractRNG) = TracedRNG(Ref(0), rng, rng.seed, [rng])
TracedRNG(rng::Random._GLOBAL_RNG) = TracedRNG(Random.default_rng())
Create a default TracedRNG
"""
function TracedRNG()
r = BASE_RNG()
return TracedRNG(r)
end

# Intercept rand
# https://github.com/JuliaLang/julia/issues/30732
Random.rng_native_52(r::TracedRNG) = UInt64
# Plug into Random
Random.rng_native_52(rng::TracedRNG{U}) where {U} = Random.rng_native_52(rng.rng)
Base.rand(rng::TracedRNG{U}, ::Type{T}) where {U,T} = Base.rand(rng.rng, T)

function Base.rand(rng::TracedRNG, ::Type{T}) where {T}
res = Base.rand(rng.rng, T)
inc_count!(rng, length(res))
push!(rng.states, copy(rng.rng))
return res
"""
split(r::TracedRNG, n::Integer)
Split keys of the internal Philox2x into n distinct seeds
"""
function split(r::TracedRNG{T}, n::Integer) where {T}
n == 1 && return [r.rng.key]
return map(i -> hash(r.rng.key, convert(UInt, r.rng.ctr1 + i)), 1:n)
end

inc_count!(rng::TracedRNG) = inc_count!(rng, 1)
"""
update_rng!(r::TracedRNG, seed::Number)
inc_count!(rng::TracedRNG, n::Int) = rng.count[] += n
Set the key of the wrapped Philox2x rng
"""
function seed!(r::TracedRNG{T}, seed) where {T}
return seed!(r.rng, seed)
end

"""
reset_rng(r::TracedRNG, seed)
Reset the rng to the running model step
"""
function reset_rng!(rng::TracedRNG{T}) where {T}
key = rng.keys[rng.count]
ctr = rng.counters[rng.count]
Random.seed!(rng.rng, key)
return set_counter!(rng.rng, ctr)
end

function save_state!(r::TracedRNG{T}) where {T}
push!(r.keys, r.rng.key)
return push!(r.counters, r.rng.ctr1)
end

Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.keys))

"""
set_count!(r::TracedRNG, n::Integer)
Set the counter of the TracedRNG, used to keep track of the current model step
"""
set_count!(r::TracedRNG, n::Integer) = r.count = n

curr_count(t::TracedRNG) = t.count[]
inc_count!(r::TracedRNG, n::Integer) = r.count += n
inc_count!(r::TracedRNG) = inc_count!(r, 1)
22 changes: 10 additions & 12 deletions src/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ function AbstractMCMC.sample(
end

# Create a set of particles.
particles = ParticleContainer(
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], rng
)
particles = ParticleContainer([
Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)
])

# Perform particle sweep.
logevidence = sweep!(particles.rng, particles, sampler.resampler)
logevidence = sweep!(rng, particles, sampler.resampler)

return SMCSample(collect(particles), getweights(particles), logevidence)
end
Expand Down Expand Up @@ -85,12 +85,12 @@ function AbstractMCMC.step(
rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, sampler::PG; kwargs...
)
# Create a new set of particles.
particles = ParticleContainer(
[Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], rng
)
particles = ParticleContainer([
Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)
])

# Perform a particle sweep.
logevidence = sweep!(particles.rng, particles, sampler.resampler)
logevidence = sweep!(rng, particles, sampler.resampler)

# Pick a particle to be retained.
trajectory = rand(rng, particles)
Expand All @@ -115,12 +115,10 @@ function AbstractMCMC.step(
Trace(model, TracedRNG())
end
end
particles = ParticleContainer(x, rng)
particles = ParticleContainer(x)

# Perform a particle sweep.
logevidence = sweep!(
particles.rng, particles, sampler.resampler, particles.vals[nparticles]
)
logevidence = sweep!(rng, particles, sampler.resampler, particles.vals[nparticles])

# Pick a particle to be retained.
newtrajectory = rand(rng, particles)
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
AbstractMCMC = "2, 3"
Distributions = "0.24, 0.25"
Libtask = "0.5"
julia = "1.3"
Random123 = "1.3"
20 changes: 3 additions & 17 deletions test/rng.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,12 @@
@testset "sample distribution" begin
rng = AdvancedPS.TracedRNG()
vns = rand(rng, Distributions.Normal())

@test AdvancedPS.curr_count(rng) === 1
AdvancedPS.save_state!(rng)

rand(rng, Distributions.Normal())
Random.seed!(rng)

AdvancedPS.reset_rng!(rng)
new_vns = rand(rng, Distributions.Normal())
@test new_vns vns
end

@testset "inc count" begin
rng = AdvancedPS.TracedRNG()
AdvancedPS.inc_count!(rng)
@test AdvancedPS.curr_count(rng) == 1

AdvancedPS.inc_count!(rng, 2)
@test AdvancedPS.curr_count(rng) == 3
end

@testset "curr count" begin
rng = AdvancedPS.TracedRNG()
@test AdvancedPS.curr_count(rng) == 0
end
end

0 comments on commit 8134751

Please sign in to comment.