Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add action(policy, s) interface to exploration policies #510

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lib/POMDPTools/src/Policies/Policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ export LinearDecaySchedule,
EpsGreedyPolicy,
SoftmaxPolicy,
ExplorationPolicy,
loginfo
loginfo,
update!

include("exploration_policies.jl")

Expand Down
40 changes: 31 additions & 9 deletions lib/POMDPTools/src/Policies/exploration_policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,60 @@ The evolution of epsilon can be controlled using a schedule. This feature is use

If a function is passed for `eps`, `eps(k)` is called to compute the value of epsilon when calling `action(exploration_policy, on_policy, k, s)`.

# Fields

# Fields

- `eps::Function`
- `rng::AbstractRNG`
- `m::M` POMDPs or MDPs problem
- `on_policy::P` a policy to use for the greedy part
- `k::Int` the current training step to use for computing eps(k)
"""
struct EpsGreedyPolicy{T<:Function, R<:AbstractRNG, M<:Union{MDP,POMDP}} <: ExplorationPolicy
mutable struct EpsGreedyPolicy{P<:Union{Nothing,Policy},T<:Function,R<:AbstractRNG,M<:Union{MDP,POMDP}} <: ExplorationPolicy
on_policy::P
k::Int
eps::T
rng::R
m::M
end

function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Function;
function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Function;
rng::AbstractRNG=Random.default_rng())
return EpsGreedyPolicy(eps, rng, problem)
return EpsGreedyPolicy(nothing, 1, eps, rng, problem)
end
function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Real;
function EpsGreedyPolicy(problem::Union{MDP,POMDP}, eps::Real;
rng::AbstractRNG=Random.default_rng())
return EpsGreedyPolicy(x->eps, rng, problem)
return EpsGreedyPolicy(problem, x -> eps, rng=rng)
end
function EpsGreedyPolicy(problem::Union{MDP,POMDP}, on_policy::Policy, eps::Function;
k::Int=1, rng::AbstractRNG=Random.default_rng())
return EpsGreedyPolicy(on_policy, k, eps, rng, problem)
end
function EpsGreedyPolicy(problem::Union{MDP,POMDP}, on_policy::Policy, eps::Real;
k::Int=1, rng::AbstractRNG=Random.default_rng())
return EpsGreedyPolicy(problem, on_policy, x -> eps, k=k, rng=rng)
end


function POMDPs.action(p::EpsGreedyPolicy, on_policy::Policy, k, s)
if rand(p.rng) < p.eps(k)
return rand(p.rng, actions(p.m,s))
else
else
return action(on_policy, s)
end
end
POMDPs.action(p::EpsGreedyPolicy{<:Policy}, s) = action(p, p.on_policy, p.k, s)

loginfo(p::EpsGreedyPolicy, k) = (eps=p.eps(k),)
loginfo(p::EpsGreedyPolicy) = loginfo(p, p.k)

function update!(p::EpsGreedyPolicy, k::Int)
p.k = k
return p
end
function update!(p::EpsGreedyPolicy{P}, on_policy::P) where {P<:Policy}
p.on_policy = on_policy
return p
end

# softmax
"""
Expand Down
16 changes: 11 additions & 5 deletions lib/POMDPTools/test/policies/test_exploration_policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,24 @@ a = first(actions(problem))
@inferred action(policy, FunctionPolicy(s->a::Symbol), 1, GWPos(1,1))
policy = EpsGreedyPolicy(problem, 0.0)
@test action(policy, FunctionPolicy(s->a), 1, GWPos(1,1)) == a
policy = EpsGreedyPolicy(problem, FunctionPolicy(s->a), 0.0)
@test action(policy, GWPos(1,1)) == a

# softmax
# softmax
policy = SoftmaxPolicy(problem, 0.5)
@test loginfo(policy, 1).temperature == 0.5
on_policy = ValuePolicy(problem)
@inferred action(policy, on_policy, 1, GWPos(1,1))

# test linear schedule
policy = EpsGreedyPolicy(problem, LinearDecaySchedule(start=1.0, stop=0.0, steps=10))
for i=1:11
# test linear schedule
schedule = LinearDecaySchedule(start=1.0, stop=0.0, steps=10)
policy = EpsGreedyPolicy(problem, FunctionPolicy(s->a), schedule)
for i=1:11
action(policy, FunctionPolicy(s->a), i, GWPos(1,1))
@test policy.eps(i) < 1.0
@test policy.eps(i) < 1.0
@test loginfo(policy, i).eps == policy.eps(i)
end
@test policy.eps(11) ≈ 0.0
update!(policy, 11)
@test policy.eps(policy.k) ≈ 0.0
@test action(policy, FunctionPolicy(s->a), 11, GWPos(1,1)) == action(policy, GWPos(1,1))