Skip to content

Commit

Permalink
did some more work
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jul 17, 2024
1 parent 1467e74 commit c857fb9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 109 deletions.
133 changes: 25 additions & 108 deletions src/basic.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
### Basic Particle Filter ###
# implements the POMDPs.jl Updater interface
"""
BasicParticleFilter(predict_model, reweight_model, resampler, n_init::Integer, rng::AbstractRNG)
BasicParticleFilter(model, resampler, n_init::Integer, rng::AbstractRNG)
Construct a basic particle filter with three steps: predict, reweight, and resample.
BasicParticleFilter(resample, predict, reweight, [propose], [rng::AbstractRNG])
In the second constructor, `model` is used for both the prediction and reweighting.
"""
Expand All @@ -13,123 +10,43 @@ mutable struct BasicParticleFilter{RS,PR,RW,PR,RNG<:AbstractRNG,PMEM} <: Updater
predict::PRE
reweight::RW
propose::PRO
n_init::Int
check_belief::Bool
rng::RNG
end

## Constructors ##
function BasicParticleFilter(model, resampler, n::Integer, rng::AbstractRNG=Random.GLOBAL_RNG)
return BasicParticleFilter(model, model, resampler, n, rng)
function BasicParticleFilter(resample, predict, reweight)
return BasicParticleFilter(resample, predict, reweight, propose, Random.TaskLocalRNG())
end

function BasicParticleFilter(pmodel, rmodel, resampler, n::Integer, rng::AbstractRNG=Random.GLOBAL_RNG)
return BasicParticleFilter(pmodel,
rmodel,
resampler,
n,
rng,
particle_memory(pmodel),
Float64[]
)
end
function update(up::BasicParticleFilter, b::AbstractParticleBelief, a, o)
b_resampled = up.resample(b, a, o, up.rng)
particles = up.predict(b_resampled, a, o, up.rng)
weights = up.reweight(b_resampled, particles, a, o)
bp = WeightedParticleBelief(particles, weights)
new_belief = up.propose(bp, b, a, o, up.rng)

"""
particle_memory(m)
if up.check_belief
check_belief(new_belief)
end

Return a suitable container for particles produced by prediction model `m`.
This should usually be an empty `Vector{S}` where `S` is the type of the state for prediction model `m`. Size does not matter because `resize!` will be called appropriately within `update`.
"""
function particle_memory end
return new_belief
end

function update(up::BasicParticleFilter, b::ParticleCollection, a, o)
b = up.resample(b, a, o, up.rng)
particles = up.predict(b, a, o, up.rng)
weights = up.reweight(b, a, pm, o)
bp = up.propose(particles, weights, b, a, o, up.rng)
return bp
function check_belief(b::AbstractParticleBelief)
if length(particles(b)) != length(weights(b))
@warn "Number of particles and weights do not match" length(particles(b)) length(weights(b))
end
if weight_sum(b) <= 0.0
@warn "Sum of particle filter weights is not greater than zero." weight_sum(b)
end
if sum(weights(b)) ! weight_sum(b)
@warn "Sum of particle filter weights does not match weight_sum." sum(weights(b)) weight_sum(b)
end
end

function Random.seed!(f::BasicParticleFilter, seed)
Random.seed!(f.rng, seed)
return f
end

"""
predict!(pm, m, b, u, rng)
predict!(pm, m, b, u, y, rng)
Fill `pm` with predicted particles for the next time step.
A method of this function should be implemented by prediction models to be used in a [`BasicParticleFilter`](@ref). `pm` should be a correctly-sized vector created by [`particle_memory`](@ref) to hold a one-step-propagated particle for each particle in `b`.
Normally the observation `y` is not needed, so most prediction models should implement the first version, but the second is available for heuristics that use `y`.
# Arguments
- `pm::Vector`: memory for holding the propagated particles; created by [`particle_memory`](@ref) and resized to `n_particles(b)`.
- `m`: prediction model, the "owner" of this function
- `b::ParticleCollection`: current belief; each particle in this belief should be propagated one step and inserted into `pm`.
- `u`: control or action
- `rng::AbstractRNG`: random number generator; should be used for any randomness in propagation for reproducibility.
- `y`: measuerement/observation (usually not needed)
"""
function predict! end

"""
reweight!(wm, m, b, a, pm, y)
reweight!(wm, m, b, a, pm, y, rng)
Fill `wm` likelihood weights for each particle in `pm`.
A method of this function should be implemented by reweighting models to be used in a [`BasicParticleFilter`](@ref). `wm` should be a correctly-sized vector to hold weights for each particle in pm.
Normally `rng` is not needed, so most reweighting models should implement the first version, but the second is available for heuristics that use random numbers.
# Arguments
- `wm::Vector{Float64}`: memory for holding likelihood weights.
- `m`: reweighting model, the "owner" of this function
- `b::ParticleCollection`: previous belief; `pm` should contain a propagated particle for each particle in this belief
- `u`: control or action
- `pm::Vector`: memory for holding current particles; these particle have been propagated by `predict!`.
- `y`: measurement/observation
- `rng::AbstractRNG`: random number generator; should be used for any randomness for reproducibility.
"""
function reweight! end

predict!(pm, m, b, a, o, rng) = predict!(pm, m, b, a, rng)
reweight!(wm, m, b, a, pm, o, rng) = reweight!(wm, m, b, a, pm, o)

"""
predict(m, b, u, rng)
Simulate each of the particles in `b` forward one time step using model `m` and contol input `u` returning a vector of states. Calls [`predict!`](@ref) internally - see that function for documentation.
This function is provided for convenience only. New models should implement `predict!`.
"""
function predict end

function predict(m, b, args...)
pm = particle_memory(m)
resize!(pm, n_particles(b))
predict!(pm, m, b, args...)
return pm
end
predict(f::BasicParticleFilter, args...) = predict(f.predict_model, args...)

"""
reweight(m, b, u, pm, y)
Return a vector of likelihood weights for each particle in `pm` given observation `y`.

`pm` can be generated with `predict(m, b, u, rng)`.
This function is provided for convenience only - new reweighting models should implement `reweight!`.
"""
function reweight end

function reweight(m, b, args...)
wm = Vector{Float64}(undef, n_particles(b))
reweight!(wm, m, b, args...)
return wm
end
reweight(f::BasicParticleFilter, args...) = reweight(f.reweight_model, args...)
1 change: 0 additions & 1 deletion src/beliefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ Random.gentype(::Type{B}) where B<:AbstractParticleBelief{T} where T = T
### Belief interface ###
########################

# also rand(), pdf(), and mode() from POMDPs.jl are part of the belief interface.
"""
n_particles(b::AbstractParticleBelief)
Expand Down

0 comments on commit c857fb9

Please sign in to comment.