Skip to content

Commit

Permalink
tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
zsunberg committed Sep 28, 2024
1 parent 1fa3354 commit b8231b7
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 160 deletions.
2 changes: 1 addition & 1 deletion notebooks/Using-a-Particle-Filter-with-POMDPs-jl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pomdp = LightDark1D();
N = 5000;

# ╔═╡ 07460420-f278-11ea-0e96-e9881f22016e
up = BootstrapFilter(pomdp, N, rng);
up = BootstrapFilter(pomdp, N, rng=rng);

# ╔═╡ 074f4968-f278-11ea-026f-c5ceaf12774e
policy = FunctionPolicy(b->1);
Expand Down
34 changes: 22 additions & 12 deletions src/ParticleFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using POMDPTools.ModelTools: weighted_iterator
import Random: rand, gentype
import Statistics: mean, cov, var

# TODO cleanup export

export
AbstractParticleBelief,
ParticleCollection,
Expand All @@ -27,17 +29,17 @@ export
LowVarianceResampler,
UnweightedParticleFilter,
ParticleFilterModel,
PredictModel,
BootstrapFilter,
ReweightModel
# PredictModel,
BootstrapFilter
# ReweightModel

export
resample,
predict,
predict!,
reweight,
reweight!,
particle_memory
# export
# resample,
# predict,
# predict!,
# reweight,
# reweight!,
# particle_memory

export
n_particles,
Expand All @@ -47,8 +49,10 @@ export
weight,
particle,
weights,
obs_weight,
n_init_samples,
obs_weight
# n_init_samples,

export
runfilter

export
Expand All @@ -64,9 +68,15 @@ export

include("beliefs.jl")
include("basic.jl")

export
low_variance_resample

include("resamplers.jl")

include("unweighted.jl")
include("models.jl")
include("postprocessing.jl")
include("bootstrap.jl")
include("pomdps.jl")
include("policies.jl")
Expand Down
18 changes: 10 additions & 8 deletions src/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,27 @@
BasicParticleFilter
"""
struct BasicParticleFilter{F1,F2,F3,F4,F5,RNG<:AbstractRNG} <: Updater
initialize::F1
preprocess::F2
predict::F3
reweight::F4
postprocess::F5
preprocess::F1
predict::F2
reweight::F3
postprocess::F4
initialize::F5
rng::RNG
end



function BasicParticleFilter(initialize, preprocess, predict, reweight, postprocess)
return BasicParticleFilter(preprocess, predict, reweight, postprocess, Random.TaskLocalRNG())
function BasicParticleFilter(preprocess, predict, reweight, postprocess;
rng=Random.default_rng(),
initialize=(b,rng)->b)
return BasicParticleFilter(preprocess, predict, reweight, postprocess, initialize, rng)
end

function update(up::BasicParticleFilter, b::AbstractParticleBelief, a, o)
bb = up.preprocess(b, a, o, up.rng)
particles = up.predict(bb, a, o, up.rng)
weights = up.reweight(bb, a, particles, o)
bp = WeightedParticleBelief(ps, ws)
bp = WeightedParticleBelief(particles, weights)
return up.postprocess(bp, b, a, o, up.rng)
end

Expand Down
12 changes: 11 additions & 1 deletion src/beliefs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ For ParticleCollection and WeightedParticleBelief, the result is cached for effi
"""
function probdict end

# TODO: document, export, and test
function effective_sample_size(b::AbstractParticleBelief)
ws = weight_sum(b)
return 1.0 / sum(w->(w/ws)^2, weights(b))
end

#############################
### Concrete Belief types ###
Expand Down Expand Up @@ -181,7 +186,12 @@ mutable struct WeightedParticleBelief{T} <: AbstractParticleBelief{T}
weight_sum::Float64
_probs::Union{Nothing, Dict{T,Float64}}
end
WeightedParticleBelief(particles::AbstractVector{T}, weights::AbstractVector, weight_sum=sum(weights)) where {T} = WeightedParticleBelief{T}(particles, weights, weight_sum, nothing)

function WeightedParticleBelief(particles::AbstractVector{T},
weights::AbstractVector=ones(length(particles)),
weight_sum=sum(weights)) where {T}
return WeightedParticleBelief{T}(particles, weights, weight_sum, nothing)
end

n_particles(b::WeightedParticleBelief) = length(b.particles)
particles(p::WeightedParticleBelief) = p.particles
Expand Down
65 changes: 58 additions & 7 deletions src/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,73 @@ TODO: update with ess
For a more flexible particle filter structure see [`BasicParticleFilter`](@ref).
"""
function BootstrapFilter(m::POMDP, n::Int, resample_threshold=0.9, postprocess=(bp, args...)->bp, rng::AbstractRNG=Random.TaskLocalRNG())
function BootstrapFilter(m::POMDP, n::Int; resample_threshold=0.9, postprocess=(bp, args...)->bp, rng::AbstractRNG=Random.default_rng())
return BasicParticleFilter(
NormalizedESSConditionalResample(LowVarianceResampler(n), normalized_ess),
POMDPPredictor(m),
NormalizedESSConditionalResampler(LowVarianceResampler(n), resample_threshold),
POMDPPredicter(m),
POMDPReweighter(m),
PostprocessChain(postprocess, check_particle_belief),
rng
initialize=(d, rng)->initialize_to(WeightedParticleBelief, n, d, rng),
# initialize=(d, rng)->WeightedParticleBelief(rand(rng, d, n), fill(1.0/n, n)),
rng=rng
)
end

function BootstrapFilter(m::ParticleFilterModel, n::Int, resample_threshold=0.9, postprocess=(bp, args...)->bp, rng::AbstractRNG=Random.TaskLocalRNG())
function BootstrapFilter(m::ParticleFilterModel, n::Int; resample_threshold=0.9, postprocess=(bp, args...)->bp, rng::AbstractRNG=Random.default_rng())
return BasicParticleFilter(
NormalizedESSConditionalResample(LowVarianceResampler(n), normalized_ess),
NormalizedESSConditionalResampler(LowVarianceResampler(n), resample_threshold),
BasicPredictor(m),
BasicReweighter(m),
PostprocessChain(postprocess, check_particle_belief),
rng
initialize=(d, rng)->initialize_to(WeightedParticleBelief, n, d, rng),
rng=rng
)
end

function initialize_to(B::Type{<:AbstractParticleBelief}, n, d::AbstractParticleBelief, rng)
if isa(d, B) && n_particles(d) == n
return d
else
return B(low_variance_resample(d, n, rng))
end
end

function initialize_to(B::Type{<:AbstractParticleBelief}, n, d, rng)
return B(sample_non_particle(d, n, rng))
end

function sample_non_particle(d, n, rng)
# using weighted iterator here is more likely to be order n than just calling rand() repeatedly
# but, this implementation is problematic and may change in the future
D = typeof(d)
try
if @implemented(support(::D)) &&
@implemented(iterate(::typeof(support(d)))) &&
@implemented(pdf(::D, ::typeof(first(support(d)))))
S = typeof(first(support(d)))
particles = S[]
weights = Float64[]
for (s, w) in weighted_iterator(d)
push!(particles, s)
push!(weights, w)
end
return low_variance_resample(WeightedParticleBelief(particles, weights), n, rng)
end
catch ex
if ex isa MethodError
@warn("""
Suppressing MethodError in ParticleFilters.jl. Please file an issue here:
https://github.com/JuliaPOMDP/ParticleFilters.jl/issues/new
The error was
$(sprint(showerror, ex))
""", maxlog=1)
else
rethrow(ex)
end
end

return collect(rand(rng, d) for i in 1:n)
end
90 changes: 38 additions & 52 deletions src/models.jl
Original file line number Diff line number Diff line change
@@ -1,55 +1,3 @@
struct BasicPredictor{F<:Function} <: Function
dynamics::F
end

BasicPredictor(m::ParticleFilterModel) = BasicPredictor(m.f)

# """
# Predictor(f::Function)
#
# Create a prediction model for use in a [`BasicParticleFilter`](@ref)
#
# See [`ParticleFilterModel`](@ref)
# """
# PredictModel{S}(f::F) where {S, F<:Function} = PredictModel{S, F}(f)

# function predict(pm, m::PredictModel, b, u, rng)
(p::BasicPredictor)(b, u, y, rng) = map(x -> p.dynamics(x, u, rng), particles(b))

struct BasicReweighter{G<:Function} <: Function
reweight::G
end

BasicReweighter(m::ParticleFilterModel) = Reweighter(m.g)

function (r::BasicReweighter)(b, u, ps, y)
map(1:length(ps)) do i
x1 = particle(b, i)
x2 = ps[i]
r.reweight(x1, u, x2, y)
end
end


# """
# ReweightModel(g::Function)
#
# Create a reweighting model for us in a [`BasicParticleFilter`](@ref).
#
# See [`ParticleFilterModel`](@ref) for a description of `g`.
# """
# struct ReweightModel{G}
# g::G
# end
#
# function reweight!(wm, m::ReweightModel, b, u, pm, y)
# for i in 1:n_particles(b)
# x1 = particle(b, i)
# x2 = pm[i]
# wm[i] = m.g(x1, u, x2, y)
# end
# end

struct ParticleFilterModel{S, F, G}
f::F
g::G
Expand Down Expand Up @@ -83,3 +31,41 @@ function reweight!(wm, m::ParticleFilterModel, b, u, pm, y)
end

particle_memory(m::ParticleFilterModel{S}) where S = S[]



# =====================================


struct BasicPredictor{F<:Function} <: Function
dynamics::F
end

BasicPredictor(m::ParticleFilterModel) = BasicPredictor(m.f)

# """
# Predictor(f::Function)
#
# Create a prediction model for use in a [`BasicParticleFilter`](@ref)
#
# See [`ParticleFilterModel`](@ref)
# """
# PredictModel{S}(f::F) where {S, F<:Function} = PredictModel{S, F}(f)

# function predict(pm, m::PredictModel, b, u, rng)
(p::BasicPredictor)(b, u, y, rng) = map(x -> p.dynamics(x, u, rng), particles(b))

struct BasicReweighter{G<:Function} <: Function
reweight::G
end

BasicReweighter(m::ParticleFilterModel) = BasicReweighter(m.g)

function (r::BasicReweighter)(b, u, ps, y)
map(1:length(ps)) do i
x1 = particle(b, i)
x2 = ps[i]
r.reweight(x1, u, x2, y)
end
end

Loading

0 comments on commit b8231b7

Please sign in to comment.