Skip to content

Commit

Permalink
removed broken compressors
Browse files Browse the repository at this point in the history
  • Loading branch information
FlyingWorkshop committed Mar 27, 2024
1 parent 72a9db7 commit b870c6d
Show file tree
Hide file tree
Showing 13 changed files with 254 additions and 110 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.0.0-DEV"

[deps]
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -21,7 +22,6 @@ POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d"
ParticleFilters = "c8b314e2-9260-5cf8-ae76-3be7461ca6d0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"

[compat]
julia = "1"
Expand Down
33 changes: 18 additions & 15 deletions src/CompressedBeliefMDPs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,33 @@ using Random


export
# Compressor interface
### Compressor Interface ###
Compressor,
fit!,
# StatsAPI Compressors
StatsCompressor,
## MultivariateStats wrappers
### MultivariateStats wrappers ###
MVSCompressor,
PCACompressor,
KernelPCACompressor,
PPCACompressor,
FactorAnalysisCompressor,
MDSCompressor,
## ManifoldLearning wrappers
IsomapCompressor,
LLECompressor,
HLLECompressor,
LEMCompressor,
LTSACompressor,
DiffMapCompressor
# MDSCompressor,
### ManifoldLearning wrappers ###
ManifoldCompressor,
IsomapCompressor
# LLECompressor,
# HLLECompressor,
# LEMCompressor,
# LTSACompressor,
# DiffMapCompressor
include("compressors/compressor.jl")
include("compressors/stats_compressors.jl")
include("compressors/mvs_compressors.jl")
include("compressors/manifold_compressors.jl")

export
sample
include("sampler.jl")
sample,
exploratory_belief_expansion
include("samplers/policy_simulation.jl")
include("samplers/belief_expansion.jl")

export
CompressedBeliefMDP
Expand Down
2 changes: 1 addition & 1 deletion src/compressors/compressor.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
abstract type Compressor end

"""
fit!(compressor::Compressor, beliefs)
fit!(compressor::Compressor, beliefs; kwargs...)
Fit the compressor to beliefs.
"""
Expand Down
45 changes: 45 additions & 0 deletions src/compressors/manifold_compressors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Wrappers for MultivariateStats.jl. See https://juliastats.org/MultivariateStats.jl/stable/.
"""

using ManifoldLearning


mutable struct ManifoldCompressor{T} <: Compressor
const maxoutdim::Integer
M
end

ManifoldCompressor(maxoutdim::Integer, T) = ManifoldCompressor{T}(maxoutdim, missing)

function (c::ManifoldCompressor)(beliefs)
return ndims(beliefs) == 2 ? ManifoldLearning.predict(c.M, beliefs')' : vec(ManifoldLearning.predict(c.M, beliefs))
end

function fit!(compressor::ManifoldCompressor{T}, beliefs; kwargs...) where T
compressor.M = ManifoldLearning.fit(T, beliefs'; maxoutdim=compressor.maxoutdim, kwargs...)
end

### ManifoldLearning.jl Wrappers ###
IsomapCompressor(maxoutdim::Integer) = ManifoldCompressor(maxoutdim, Isomap)

### Discontinued ###

# function (c::ManifoldCompressor{T})(beliefs) where T
# M = ManifoldLearning.fit(T, beliefs'; maxoutdim=c.maxoutdim, c.kwargs...)
# R = ManifoldLearning.predict(M)
# # R = ndims(beliefs) == 2 ? ManifoldLearning.predict(M, beliefs')' : vec(ManifoldLearning.predict(M, beliefs))
# return R
# end

# function fit!(compressor::ManifoldCompressor{T}, beliefs; kwargs...) where T
# compressor.kwargs = kwargs
# end

# LLECompressor(maxoutdim::Integer) = ManifoldCompressor(maxoutdim, LLE)
# HLLECompressor(maxoutdim::Integer) = ManifoldCompressor(maxoutdim, HLLE)
# LEMCompressor(maxoutdim::Integer) = ManifoldCompressor(maxoutdim, LEM)
# LTSACompressor(maxoutdim::Integer) = ManifoldCompressor(maxoutdim, LTSA)
# DiffMapCompressor(maxoutdim::Integer) = ManifoldCompressor(maxoutdim, DiffMap)


35 changes: 35 additions & 0 deletions src/compressors/mvs_compressors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Wrappers for MultivariateStats.jl. See https://juliastats.org/MultivariateStats.jl/stable/.
"""

using MultivariateStats
using MultivariateStats: predict


mutable struct MVSCompressor{T} <: Compressor
const maxoutdim::Integer
M
end

MVSCompressor(maxoutdim::Integer, T) = MVSCompressor{T}(maxoutdim, missing)

function (c::MVSCompressor)(beliefs)
return ndims(beliefs) == 2 ? predict(c.M, beliefs')' : vec(predict(c.M, beliefs))
end

function fit!(compressor::MVSCompressor{T}, beliefs; kwargs...) where T
compressor.M = fit(T, beliefs'; maxoutdim=compressor.maxoutdim, kwargs...)
end

### MultivariateStats.jl Wrappers ###
# PCA Compressors
PCACompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, PCA)
KernelPCACompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, KernelPCA)
PPCACompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, PPCA)

# Factor Analysis Compressor
FactorAnalysisCompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, FactorAnalysis)

# Multidimensional Scaling
# MDSCompressor(maxoutdim::Integer) = MVSCompressor(maxoutdim, MDS)

46 changes: 0 additions & 46 deletions src/compressors/stats_compressors.jl

This file was deleted.

15 changes: 0 additions & 15 deletions src/sampler.jl

This file was deleted.

67 changes: 67 additions & 0 deletions src/samplers/belief_expansion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
using Distances


### Utilities ###

function _make_numeric(b, pomdp::POMDP)
b = convert_s(AbstractArray{Float64}, b, pomdp)
return SVector{length(b)}(b)
end

"""
Adapted from PointBasedValueIteration.jl: https://github.com/JuliaPOMDP/PointBasedValueIteration.jl/blob/master/src/pbvi.jl
"""
function _successors(pomdp::POMDP, b, updater::Updater)
succs = []
for a in actions(pomdp, b)
for o in observations(pomdp)
s = update(updater, b, a, o)
push!(succs, s)
end
end
return unique!(succs)
end

### Body ###

"""
Effecient adaptation of algorithm 21.13 from AFDM that uses KDTree.
Only works for finite S, A, O.
"""
function exploratory_belief_expansion!(pomdp::POMDP, B::Set, B_numeric, updater::Updater; metric::NearestNeighbors.MinkowskiMetric=Euclidean())
tree = KDTree(B_numeric, metric)
B_new = typeof(B)()
for b in B
if isterminal(pomdp, b)
println("woops")
end
succs = _successors(pomdp, b, updater)
succs_numeric = map(s->_make_numeric(s, pomdp), succs)
if !isempty(succs)
_, dists = nn(tree, succs_numeric)
i = argmax(dists)
b_new = succs[i]
b_numeric_new = succs_numeric[i]
if !in(b_new, B)
push!(B_new, b_new)
push!(B_numeric, b_numeric_new)
end
end
end
union!(B, B_new)
end

"""
Wrapper for exploratory_belief_expansion!.
Creates an initial belief set and calls exploratory_belief_expansion! on POMDP n times.
"""
function exploratory_belief_expansion(pomdp::POMDP, updater::Updater; n::Integer=10, metric::NearestNeighbors.MinkowskiMetric=Euclidean())
b0 = initialize_belief(updater, initialstate(pomdp))
b0_numeric = _make_numeric(b0, pomdp)
B = Set([b0])
B_numeric = [b0_numeric]
for _ in 1:n
exploratory_belief_expansion!(pomdp, B, B_numeric, updater; metric)
end
return B
end
29 changes: 29 additions & 0 deletions src/samplers/policy_simulation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# TODO: add RNG support for seeding

function sample(pomdp::POMDP, explorer::Policy, updater::Updater, n::Integer; rng::AbstractRNG=Random.GLOBAL_RNG)
mdp = GenerativeBeliefMDP(pomdp, updater)
iter = stepthrough(mdp, explorer, "s"; max_steps=n)
B = collect(Iterators.take(Iterators.cycle(iter), n))
return unique!(B)
end

function sample(pomdp::POMDP, explorer::ExplorationPolicy, updater::Updater, n::Integer; rng::AbstractRNG=Random.GLOBAL_RNG)
samples = []
mdp = GenerativeBeliefMDP(pomdp, updater)
on_policy = RandomPolicy(mdp)
while true
b = initialstate(mdp).val
for k in 1:n
if length(samples) == n
return unique!(samples)
end

if isterminal(mdp, b)
break
end
a = action(explorer, on_policy, k, b)
b = @gen(:sp)(mdp, b, a, rng)
push!(samples, b)
end
end
end
47 changes: 32 additions & 15 deletions src/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@ function _make_compressed_belief_MDP(
updater::Updater,
compressor::Compressor,
n::Integer,
fit_kwargs::Union{Nothing, Dict}=nothing
expansion::Bool,
fit_kwargs::Union{Nothing, Dict}=nothing,
metric::NearestNeighbors.MinkowskiMetric=Euclidean()
)
# sample beliefs
B = sample(pomdp, explorer, updater, n)
if expansion
B = exploratory_belief_expansion(pomdp, updater; n=n, metric=metric)
else
B = sample(pomdp, explorer, updater, n)
end

# compress beliefs and cache mapping
B_numerical = mapreduce(b->convert_s(AbstractArray{Float64}, b, pomdp), hcat, B)'
Expand All @@ -50,19 +56,25 @@ end
# TODO: add seeding
function CompressedBeliefSolver(
pomdp::POMDP;
explorer::Union{Policy, ExplorationPolicy}=RandomPolicy(pomdp),
updater::Updater=applicable(POMDPs.states, pomdp) ? DiscreteUpdater(pomdp) : BootstrapFilter(pomdp, 5000), # hack to determine default updater, may select incompatible Updater for complex custom POMDPs
compressor::Compressor=PCACompressor(1),
n::Integer=50, # max number of belief samples to compress

# sampling arguments
explorer::Union{Policy, ExplorationPolicy}=RandomPolicy(pomdp), # explorer policy; only used if expansion is false
updater::Updater=DiscreteUpdater(pomdp), # only works for discrete S
compressor::Compressor=PCACompressor(1),
expansion=true, # only works for discrete S, A, O
n::Integer=5, # if expansion, then n is the number of times we expand; otherwise, n is max number of belief samples
metric::NearestNeighbors.MinkowskiMetric=Euclidean(),
fit_kwargs::Union{Nothing, Dict}=nothing,

# base policy arguments
interp::Union{Nothing, LocalFunctionApproximator}=nothing,
k=1, # k nearest neighbors; only used if interp is nothing
verbose=false,
max_iterations=1000, # for value iteration
n_generative_samples=10, # number of steps to look ahead when calculated expected reward
belres::Float64=1e-3,
fit_kwargs::Union{Nothing, Dict}=nothing
belres::Float64=1e-3
)
m, B̃ = _make_compressed_belief_MDP(pomdp, explorer, updater, compressor, n, fit_kwargs)
m, B̃ = _make_compressed_belief_MDP(pomdp, explorer, updater, compressor, n, expansion, fit_kwargs, metric)

# define the interpolator for the solver
if isnothing(interp)
Expand All @@ -87,16 +99,21 @@ end
function CompressedBeliefSolver(
pomdp::POMDP,
base_solver::Solver;
explorer::Union{Policy, ExplorationPolicy}=RandomPolicy(pomdp),
updater::Updater=applicable(POMDPs.states, pomdp) ? DiscreteUpdater(pomdp) : BootstrapFilter(pomdp, 5000), # hack to determine default updater, may select incompatible Updater for complex custom POMDPs
compressor::Compressor=PCACompressor(1),
n::Integer=50, # max number of belief samples to compress
fit_kwargs::Union{Nothing, Dict}=nothing,

# sampling arguments
explorer::Union{Policy, ExplorationPolicy}=RandomPolicy(pomdp), # explorer policy; only used if expansion is false
updater::Updater=DiscreteUpdater(pomdp), # only works for discrete S
compressor::Compressor=PCACompressor(1),
expansion=true, # only works for discrete S, A, O
n::Integer=5, # if expansion, then n is the number of times we expand; otherwise, n is max number of belief samples
metric::NearestNeighbors.MinkowskiMetric=Euclidean(),
fit_kwargs::Union{Nothing, Dict}=nothing
)
m, _ = _make_compressed_belief_MDP(pomdp, explorer, updater, compressor, n, fit_kwargs)
m, _ = _make_compressed_belief_MDP(pomdp, explorer, updater, compressor, n, expansion, fit_kwargs, metric)
return CompressedBeliefSolver(m, base_solver)
end


function POMDPs.solve(solver::CompressedBeliefSolver, pomdp::POMDP)
if solver.m.bmdp.pomdp !== pomdp
@warn "Got $pomdp, but solver.m.bmdp.pomdp $(solver.m.bmdp.pomdp) isn't identical"
Expand Down
Loading

0 comments on commit b870c6d

Please sign in to comment.