-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathLightDark.jl
138 lines (108 loc) · 3.35 KB
/
LightDark.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# A one-dimensional light-dark problem, originally used to test MCVI
# A very simple POMDP with continuous state and observation spaces.
# maintained by @zsunberg
import Base: ==, +, *, -
"""
LightDark1DState
## Fields
- `y`: position
- `status`: 0 = normal, negative = terminal
"""
struct LightDark1DState
status::Int64
y::Float64
end
*(n::Number, s::LightDark1DState) = LightDark1DState(s.status, n*s.y)
"""
LightDark1D
A one-dimensional light dark problem. The goal is to be near 0. Observations are noisy measurements of the position.
Model
-----
-3-2-1 0 1 2 3
...| | | | | | | | ...
G S
Here G is the goal. S is the starting location
"""
mutable struct LightDark1D{F<:Function} <: POMDPs.POMDP{LightDark1DState,Int,Float64}
discount_factor::Float64
correct_r::Float64
incorrect_r::Float64
step_size::Float64
movement_cost::Float64
sigma::F
end
default_sigma(x::Float64) = abs(x - 5)/sqrt(2) + 1e-2
LightDark1D() = LightDark1D(0.9, 10.0, -10.0, 1.0, 0.0, default_sigma)
discount(p::LightDark1D) = p.discount_factor
isterminal(::LightDark1D, act::Int64) = act == 0
isterminal(::LightDark1D, s::LightDark1DState) = s.status < 0
actions(::LightDark1D) = -1:1
struct LDNormalStateDist
mean::Float64
std::Float64
end
sampletype(::Type{LDNormalStateDist}) = LightDark1DState
rand(rng::AbstractRNG, d::LDNormalStateDist) = LightDark1DState(0, d.mean + randn(rng)*d.std)
initialstate(pomdp::LightDark1D) = LDNormalStateDist(2, 3)
initialobs(m::LightDark1D, s) = observation(m, s)
observation(p::LightDark1D, sp::LightDark1DState) = Normal(sp.y, p.sigma(sp.y))
function transition(p::LightDark1D, s::LightDark1DState, a::Int)
if a == 0
return Deterministic(LightDark1DState(-1, s.y+a*p.step_size))
else
return Deterministic(LightDark1DState(s.status, s.y+a*p.step_size))
end
end
function reward(p::LightDark1D, s::LightDark1DState, a::Int)
if s.status < 0
return 0.0
elseif a == 0
if abs(s.y) < 1
return p.correct_r
else
return p.incorrect_r
end
else
return -p.movement_cost*a
end
end
convert_s(::Type{A}, s::LightDark1DState, p::LightDark1D) where A<:AbstractArray = eltype(A)[s.status, s.y]
convert_s(::Type{LightDark1DState}, s::A, p::LightDark1D) where A<:AbstractArray = LightDark1DState(Int64(s[1]), s[2])
# Define some simple policies based on particle belief
mutable struct DummyHeuristic1DPolicy <: POMDPs.Policy
thres::Float64
end
DummyHeuristic1DPolicy() = DummyHeuristic1DPolicy(0.1)
mutable struct SmartHeuristic1DPolicy <: POMDPs.Policy
thres::Float64
end
SmartHeuristic1DPolicy() = SmartHeuristic1DPolicy(0.1)
function action(p::DummyHeuristic1DPolicy, b::B) where {B}
target = 0.0
μ = mean(b)
σ = std(b, μ)
if σ.y < p.thres && -0.5 < μ.y < 0.5
a = 0
elseif μ.y < target
a = 1 # Right
elseif μ.y > target
a = -1 # Left
end
return a
end
function action(p::SmartHeuristic1DPolicy, b::B) where {B}
μ = mean(b)
σ = std(b, μ)
target = 0.0
if σ.y > p.thres
target = 5.0
end
if σ.y < p.thres && -0.5 < μ.y < 0.5
a = 0
elseif μ.y < target
a = 1 # Right
elseif μ.y > target
a = -1 # Left
end
return a
end