-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsarsa_Q.py
263 lines (222 loc) · 9.51 KB
/
sarsa_Q.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import env_for_p1.envs.lirobot as lr
import matplotlib.pyplot as plt
import parameters as param
import numpy as np
import agent
import time
def sarsa_q(map_size, iteration_lim):
# variables to store the statics for plot
# average q value
average_q_value_list = []
average_q_value = 0
q_value_counter = 0
# average reward
average_reward_list = []
average_reward = 0
# episode
episode_list = []
episode = 0
# instantiate the robot and environment
if map_size == 10:
robot_li = agent.robot()
env = lr.LiRobot(size=10)
elif map_size == 4:
robot_li = agent.robot_size_4()
env = lr.LiRobot(size=4)
# the original map size plus the walls on two sides
map_plus_wall_size = map_size + 2
for iteration in range(0, iteration_lim):
if (iteration + 1) % 100 == 0:
print("Training episode: ", iteration + 1)
# prediction
# robot initialization
robot_li.obser, robot_li.pos = env.reset()
reward = 0
robot_li.sample_list = [[1, 1]]
done = False
first_time = True
current_action_index = 0
action_value = 0
index = 0
while not done:
# control
for i in range(1, map_plus_wall_size - 1):
for j in range(1, map_plus_wall_size - 1):
next_value_list = [robot_li.value_Q[i][j][0],
robot_li.value_Q[i][j][1],
robot_li.value_Q[i][j][2],
robot_li.value_Q[i][j][3]]
# if the robot will hit the wall after the action,
# the q value of this action under this state will be -10
# so that it will not be choose
if robot_li.obser[i][j + 1] == -2:
robot_li.value_Q[i][j][0] = -10
if robot_li.obser[i + 1][j] == -2:
robot_li.value_Q[i][j][1] = -10
if robot_li.obser[i][j - 1] == -2:
robot_li.value_Q[i][j][2] = -10
if robot_li.obser[i - 1][j] == -2:
robot_li.value_Q[i][j][3] = -10
# record the max q value direction (index) to update the possibility matrix
action_index = next_value_list.index(max(next_value_list))
# using epsilon-greedy method to update the possibilities of actions in each state and
# save possibilities in the possibility matrix
for k in range(0, 4):
if action_index == k:
robot_li.probs[i][j][k] = 1 - param.AGENT_ACTION.EPSILON + param.AGENT_ACTION.EPSILON / 4
else:
robot_li.probs[i][j][k] = param.AGENT_ACTION.EPSILON / 4
# randomly choose the action according to the possibility of each state
if first_time:
while reward == 0:
index = np.random.choice([0, 1, 2, 3], 1, p=robot_li.probs[robot_li.pos[0]][robot_li.pos[1]]).item()
# do a action and get the reward and state
done, robot_li.pos, reward = env.step(index)
first_time = False
else:
index = current_action_index
# do a action and get the reward and state
done, robot_li.pos, reward = env.step(index)
if reward != -1 and reward != 1:
reward = 0
else:
average_reward += reward
episode += 1
if episode % 20 == 0:
average_reward_list.append(average_reward / episode)
episode_list.append(episode)
# save the trajectory
robot_li.sample_list.append(list(robot_li.pos))
# robot_li.action_list.append(index)
x = robot_li.sample_list[-2][0]
y = robot_li.sample_list[-2][1]
valid_action = False
# choose a valid action in the next state
while (not valid_action) and (reward == 0):
action_value = np.random.choice(robot_li.value_Q[robot_li.pos[0]][robot_li.pos[1]], 1, p=robot_li.probs[robot_li.pos[0]][robot_li.pos[1]]).item()
for k in range(0, 4):
if action_value == robot_li.value_Q[robot_li.pos[0]][robot_li.pos[1]][k]:
current_action_index = k
break
assume_x = robot_li.pos[0] + param.AGENT_ACTION.ACTION_SPACE[current_action_index][0]
assume_y = robot_li.pos[1] + param.AGENT_ACTION.ACTION_SPACE[current_action_index][1]
if robot_li.obser[assume_x][assume_y] == -2:
valid_action = False
else:
valid_action = True
if reward == 1 or reward == -1:
action_value = 0
# update the current q value
robot_li.value_Q[x][y][index] += param.AGENT_ACTION.L_RATE * (reward +
param.AGENT_ACTION.DISCOUNT_FACTOR * action_value -
robot_li.value_Q[x][y][index])
reward = 0
# calculate the value
for i in range(0, map_plus_wall_size):
for j in range(0, map_plus_wall_size):
for k in range(0, 4):
param.ENV_SETTINGS.STATE_ACTION_VALUE[i][j][k] = robot_li.value_Q[i][j][k]
if robot_li.value_Q[i][j][k] != -10 and episode % 20 == 0:
average_q_value += robot_li.value_Q[i][j][k]
q_value_counter += 1
if episode % 20 == 0:
average_q_value_list.append(average_q_value / q_value_counter)
average_q_value = 0
q_value_counter = 0
# test
# to find a route from start point to the destination
# initialize the position
robot_li.pos = [1, 1]
# initialize the trajectory set
route = []
x = robot_li.pos[0]
y = robot_li.pos[1]
sum = 0
success = True
# find the route according to the q value
while env.world[x][y] == 0:
robot_li.pos = [x, y]
route.append((x, y))
x = robot_li.pos[0]
y = robot_li.pos[1]
direction_value = max(param.ENV_SETTINGS.STATE_ACTION_VALUE[x][y])
for i in range(0, 4):
if param.ENV_SETTINGS.STATE_ACTION_VALUE[x][y][i] == direction_value:
direction = i
break
x += param.AGENT_ACTION.ACTION_SPACE[direction][0]
y += param.AGENT_ACTION.ACTION_SPACE[direction][1]
sum += 1
# if the number of steps is out of the size of map, stop
if sum > pow(map_plus_wall_size - 2, 2):
success = False
break
route.append((map_plus_wall_size - 2, map_plus_wall_size - 2))
# if the ending is not 1, path finding failed
if env.world[x][y] != 1:
success = False
# print the world
print(env.world)
# print the result
if success:
print(route)
print("SARSA: ", success)
else:
print("SARSA: ", success)
return episode_list, average_reward_list, average_q_value_list, env.world, route, env, success
# calculate the number of success / failure at particular training iteration and map size
def successful_times_test():
test_iteration = 100
successful_num = 0
for i in range(0, test_iteration):
e_list, ar_list, ar_q_list, world, route, env, result = sarsa_q(map_size=10, iteration_lim=3000)
if result:
successful_num += 1
failed_num = test_iteration - successful_num
successful_list = [successful_num]
failed_list = [failed_num]
# plot the bar chart
bar_width = 0.8
bar_label = ("success", "failure")
successful_index = 1
failed_index = successful_index + bar_width
plt.bar(successful_index, successful_list, width=bar_width, label='success')
plt.bar(failed_index, failed_list, width=bar_width, label='failure')
plt.legend()
plt.xticks([successful_index, failed_index], bar_label)
plt.ylim(0, test_iteration)
plt.ylabel("Number")
plt.show()
if __name__ == '__main__':
map_plus_wall_size = 12
if param.TEST:
# count times of success then plot
successful_times_test()
else:
# run for the illustration
e_list, ar_list, ar_q_list, world, route, env, result = sarsa_q(map_size=10, iteration_lim=10000)
# plot the average reward
plt.plot(e_list, ar_list, label="SARSA")
plt.xlabel("Episode")
plt.ylabel("Average Reward")
plt.ylim(-1.5, 1.5)
plt.legend()
plt.show()
# plot the average q value
plt.plot(e_list, ar_q_list, label="SARSA")
plt.xlabel("Episode")
plt.ylabel("Average Q Value")
plt.ylim(-0.1, 0.1)
plt.legend()
plt.show()
# if success, render the route
if result:
for pos in route:
if pos != (map_plus_wall_size - 2, map_plus_wall_size - 2):
if map_plus_wall_size - 2 == 10:
env.render_10(pos[0], pos[1])
elif map_plus_wall_size - 2 == 4:
env.render_4(pos[0], pos[1])
time.sleep(1)
# else:
# tkinter.messagebox.showinfo(title='Note', message='Finding route failed!')