Skip to content

Commit

Permalink
before swapping postprocess arg order
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Jan 5, 2025
1 parent f221d6f commit ed49aa7
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 31 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ repo = "https://github.com/JuliaPOMDP/ParticleFilters.jl"
version = "0.6.0"

[deps]
AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
POMDPLinter = "f3bd98c0-eb40-45e2-9eb1-f2763262d755"
POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7"
Expand All @@ -13,6 +14,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
AliasTables = "1.1.3"
Documenter = "1.8.0"
POMDPLinter = "0.1"
POMDPTools = "0.1, 1"
Expand Down
24 changes: 8 additions & 16 deletions docs/src/basic.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@
# Basic Particle Filter

The
The [`BasicParticleFilter`](@ref) type is a flexible structure for building a particle filter. It simply contains functions that carry out each of the steps of a particle filter belief update.

## Update Steps
The basic particle filtering step in ParticleFilters.jl is implemented in the [`update`](@ref) function, and consists of four steps:

The basic particle filtering step in ParticleFilters.jl is implemented in the [`update`](@ref) function, and consists of three steps:

1. Prediction (or propagation) - each state particle is simulated forward one step in time
2. Reweighting - an explicit measurement (observation) model is used to calculate a new weight
3. Resampling - a new collection of state particles is generated with particle frequencies proportional to the new weights

This is an example of [sequential importance resampling](https://en.wikipedia.org/wiki/Particle_filter#Sequential_Importance_Resampling_(SIR)) using the state transition distribution as the proposal distribution, and the [`BootstrapFilter`](@ref) constructor can be used to construct such a filter with a `model` that controls the prediction and reweighting steps, and a number of particles to create in the resampling phase.

A more flexible structure for building a particle filter is the [`BasicParticleFilter`](@ref). It contains three models, one for each step:

1. The `predict_model` controls prediction through [`predict!`](@ref)
2. The `reweight_model` controls reweighting through [`reweight!`](@ref)
3. The `resampler` controls resampling through [`resample`](@ref)
1. Preprocessing
2. Prediction (or propagation) - each state particle is simulated forward one step in time
3. Reweighting - an explicit measurement (observation) model is used to calculate a new weight
4. Postprocessing

!!! note
In the future

## Docstrings

```@docs
BasicParticleFilter
update
```
4 changes: 3 additions & 1 deletion src/ParticleFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using POMDPTools.ModelTools: weighted_iterator
import Random: rand, gentype
import Statistics: mean, cov, var

using AliasTables: AliasTable

# TODO cleanup export

export
Expand Down Expand Up @@ -54,7 +56,7 @@ export
set_pair!,
push_pair!,
effective_sample_size,
low_variance_sample,
low_variance_sample

# n_init_samples,

Expand Down
4 changes: 1 addition & 3 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ struct BasicParticleFilter{F1,F2,F3,F4,F5,RNG<:AbstractRNG} <: Updater
rng::RNG
end



function BasicParticleFilter(preprocess, predict, reweight, postprocess;
rng=Random.default_rng(),
initialize=(b,rng)->b)
Expand All @@ -25,7 +23,7 @@ function update(up::BasicParticleFilter, b::AbstractParticleBelief, a, o)
particles = up.predict(bb, a, o, up.rng)
weights = up.reweight(bb, a, particles, o)
bp = WeightedParticleBelief(particles, weights)
return up.postprocess(bp, b, a, o, up.rng) # TODO XXX should this also have bb as an arg?
return up.postprocess(bp, bb, a, o, up.rng) # TODO XXX should this also have bb as an arg? (bp, a, o, b, bb, up.rng)
end

function Random.seed!(f::BasicParticleFilter, seed)
Expand Down
27 changes: 23 additions & 4 deletions src/beliefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,13 @@ mutable struct WeightedParticleBelief{T} <: AbstractParticleBelief{T}
weights::Vector{Float64}
weight_sum::Float64
_probs::Union{Nothing, Dict{T,Float64}}
_alias_table::Union{Nothing, AliasTable{UInt, Int}}
end

function WeightedParticleBelief(particles::AbstractVector{T},
weights::AbstractVector=ones(length(particles)),
weight_sum=sum(weights)) where {T}
return WeightedParticleBelief{T}(particles, weights, weight_sum, nothing)
return WeightedParticleBelief{T}(particles, weights, weight_sum, nothing, nothing)
end

n_particles(b::WeightedParticleBelief) = length(b.particles)
Expand Down Expand Up @@ -240,20 +241,37 @@ function set_pair!(b::WeightedParticleBelief, i, sw)
fraction = w / weight_sum(b)
b._probs[s] = get(b._probs, s, 0.0) + fraction
end
b._alias_table = nothing # invalidate alias table
return sw
end

function push_pair!(b::WeightedParticleBelief, sw)
push!(b.particles, first(sw))
push!(b.weights, last(sw))
# XXX _probs
b._probs = nothing # invalidate _probs cache
b._probs = nothing # invalidate _probs cache XXX this should be modified to update efficiently without throwing it away.
b._alias_table = nothing # invalidate alias table
return b
end

# XXX there should be a version that uses an alias table
# Made the decision for now to have this always use alias sampling because this is more efficient when many samples are drawn, and usually many samples will be drawn.
# This should probably be upgraded to hook better into the Random API
# https://docs.julialang.org/en/v1/stdlib/Random/#An-optimized-sampler-with-pre-computed-data
# But many of the online solvers use rand(b) on every iteration. Perhaps they should be changed to use rand(b, number_of_samples) instead.
function Random.rand(rng::AbstractRNG, sampler::Random.SamplerTrivial{<:WeightedParticleBelief})
b = sampler[]
alias_single_sample(b, rng)
end

function alias_single_sample(b::WeightedParticleBelief, rng)
if b._alias_table == nothing
b._alias_table = AliasTable(b.weights)
end
at = b._alias_table::AliasTable{UInt, Int}

return b.particles[rand(rng, at)]
end

function naive_single_sample(b::WeightedParticleBelief, rng)
t = rand(rng) * weight_sum(b)
i = 1
cw = b.weights[1]
Expand All @@ -263,6 +281,7 @@ function Random.rand(rng::AbstractRNG, sampler::Random.SamplerTrivial{<:Weighted
end
return particles(b)[i]
end

Statistics.mean(b::WeightedParticleBelief{T}) where {T <: Number} = dot(b.weights, b.particles) / weight_sum(b)
Statistics.mean(b::WeightedParticleBelief{T}) where {T <: Vector} = reduce(hcat, b.particles) * b.weights / weight_sum(b)
function Statistics.cov(b::WeightedParticleBelief{T}) where {T <: Number} # uncorrected covariance
Expand Down
2 changes: 1 addition & 1 deletion src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function BootstrapFilter(m::ParticleFilterModel, n::Int; resample_threshold=0.9,
)
end

@deprecate low_variance_resample(b::AbstractParticleBelief, n::Int, rng::AbstractRNG) = low_variance_sample(b, n, rng)
@deprecate low_variance_resample(b::AbstractParticleBelief, n::Int, rng::AbstractRNG) low_variance_sample(b, n, rng)

struct LowVarianceResampler <: Function
n::Int
Expand Down
12 changes: 6 additions & 6 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,22 @@ struct ContinuousPOMDP <: POMDP{Float64,Float64,Float64} end
@inferred weighted_particles(b)
end
@testset "lowvar" begin
@inferred low_variance_resample(b, 100, Random.default_rng())
@test all(s in support(b) for s in low_variance_resample(b, 100, Random.default_rng()))
@inferred low_variance_sample(b, 100, Random.default_rng())
@test all(s in support(b) for s in low_variance_sample(b, 100, Random.default_rng()))

rs = LowVarianceResampler(1000)
@inferred rs(b, TIGER_LISTEN, true, MersenneTwister(3))

ps = particles(b)
ws = ones(length(ps))
@inferred low_variance_resample(WeightedParticleBelief(ps, ws, sum(ws)), 100, MersenneTwister(3))
@inferred low_variance_resample(WeightedParticleBelief{Bool}(ps, ws, sum(ws), nothing), 100, MersenneTwister(3))
@inferred low_variance_sample(WeightedParticleBelief(ps, ws, sum(ws)), 100, MersenneTwister(3))
@inferred low_variance_sample(WeightedParticleBelief{Bool}(ps, ws, sum(ws), nothing, nothing), 100, MersenneTwister(3))
end
# test that the special method for ParticleCollections works
@testset "collection" begin
b = ParticleCollection(1:100)
rb1 = @inferred low_variance_resample(b, 100, MersenneTwister(3))
rb2 = @inferred low_variance_resample(WeightedParticleBelief(particles(b), ones(n_particles(b))), 100, MersenneTwister(3))
rb1 = @inferred low_variance_sample(b, 100, MersenneTwister(3))
rb2 = @inferred low_variance_sample(WeightedParticleBelief(particles(b), ones(n_particles(b))), 100, MersenneTwister(3))
@test all(rb1 .== rb2)
end

Expand Down

0 comments on commit ed49aa7

Please sign in to comment.