-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeep_utils.py
255 lines (214 loc) · 7.52 KB
/
deep_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import torch
import random
import numpy as np
from collections import namedtuple, deque
from tic_plot import plot_grid
import matplotlib.pyplot as plt
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'player'))
p2v = {'X': 1, 'O': -1}
# Replay buffer
# uses code from https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html#replay-memory
class ReplayBuffer(object):
"""
A class for the Replay Buffer
uses code from https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html#replay-memory
Attributes
----------
buffer : deque
The buffer
batch_size : int
The batch size used when sampling
Methods
-------
push(state, action, next_state, reward, player)
Save a transition to the buffer
get_batch(batch_size=None)
Get a batch of transitions from the buffer
__len__()
Get the length of the buffer
has_one_batch(batch_size=None)
Check if the buffer has at least one batch
"""
def __init__(self, buffer_size: int, batch_size: int) -> None:
self.buffer = deque([], maxlen=buffer_size)
self.batch_size = batch_size
def push(self, *args) -> None:
"""Save a transition to the buffer
Parameters
----------
*args : Transition
Transition agruments: state, action, next_state, reward, player
"""
self.buffer.append(Transition(*args))
def get_batch(self, batch_size=None) -> None:
"""Get a batch of transitions from the buffer
Parameters
----------
batch_size : int, optional
The batch size used when sampling, (the default is None, which uses the buffer's batch size)
Returns
-------
list of Transitions
"""
if batch_size is None:
batch_size = self.batch_size
return random.sample(self.buffer, batch_size)
def __len__(self) -> int:
"""Get the length of the buffer"""
return len(self.buffer)
def has_one_batch(self, batch_size=None) -> bool:
"""Check if the buffer has at least one batch"""
if batch_size is None:
batch_size = self.batch_size
return len(self) >= batch_size
def state_to_tensor(state: np.ndarray, player: int) -> torch.Tensor:
"""Convert a the state represntation of the board the corresponding tensor
Parameters
----------
state : np.ndarray
The state
player : str or int
The player
Returns
-------
torch.Tensor
The tensor representation of the state
"""
if player==-1:
opponent_player = 1
elif player==1:
opponent_player = -1
elif p2v[player]==-1:
player = p2v[player]
opponent_player = 1
elif p2v[player]==1:
player = p2v[player]
opponent_player = -1
else:
raise ValueError(f"Player should be 1 or -1, player={player}")
t = np.zeros((3, 3, 2), dtype=np.float32)
t[:, :, 0] = (state == player)
t[:, :, 1] = (state == opponent_player)
return torch.tensor(t, dtype=torch.float32)
# Policies
class DeepEpsilonGreedy:
"""Epsilon-greedy policy
Attributes
----------
net : torch.nn.Module
The neural network to use to choose action when exploiting
epsilon : float
The epsilon value used to determine whether to explore or exploit
player : str or int
The player to use for the policy, either 'X' or 'O'
Methods
-------
set_epsilon(epsilon)
Set the epsilon value
set_player(player)
Set the player to use for the policy
act(state)
Choose an action given the state of the board
"""
def __init__(self,
net: torch.nn.Module,
epsilon: float=0,
player: str='X') -> None:
self.net = net
self.epsilon = epsilon
self.player = player
def set_epsilon(self, epsilon: float) -> None:
"""Set the epsilon value"""
self.epsilon = epsilon
def set_player(self, player: str) -> None:
"""Set the player to use for the policy"""
self.player = player
def act(self, state) -> int:
"""Choose an action given the state of the board"""
# Exploit
if random.random() > self.epsilon:
state = state_to_tensor(state, self.player)
with torch.no_grad():
return torch.argmax(self.net(state)).item()
# Explore
else:
available = np.nonzero(state.flatten() == 0)
return int(random.choice(available[0]))
class DeepEpsilonGreedyDecreasingExploration(DeepEpsilonGreedy):
"""Epsilon-greedy policy with decreasing exploration rate
Attributes
----------
See DeepEpsilonGreedy
epsilon_min : float
The minimum epsilon value to use int decreasing exploration formula
epsilon_max : float
The maximum epsilon value to use int decreasing exploration formula
n_star : int
The n* parameter to use in the decreasing exploration formula
Methods
-------
See DeepEpsilonGreedy
update_epsilon(n)
Update the epsilon value using the decreasing exploration formula depending on the step n
"""
def __init__(self,
net: torch.nn.Module,
player: str='X',
epsilon_min: float= 0.1,
epsilon_max: float=0.8,
n_star: int=20000) -> None:
super().__init__(net, player=player)
self.epsilon_min = epsilon_min
self.epsilon_max = epsilon_max
self.n_star = n_star
def update_epsilon(self, n: int) -> None:
"""Update the epsilon value using the decreasing exploration formula depending on the step n"""
new_epsilon = max(self.epsilon_min, self.epsilon_max * (1 - (n / self.n_star)))
self.set_epsilon(epsilon=new_epsilon)
# Debug
def examples_output_images(model: torch.nn.Module) -> np.ndarray:
"""Generate examples of the model's output
Parameters
----------
model : torch.nn.Module
The model to produce the output images from
Returns
-------
Image of the model's outputs for the three specified examples
"""
examples = [
((0, 0, 0, 0, 0, 0, 0, 0, 0), 1),
((0, -1, -1, 0, 1, 1, 0, 0, 0), 1),
((1, -1, -1, 1, 1, 0, 0, 0, 0), -1),
]
imgs = []
for state, player in examples:
a = np.array(state).reshape((3, 3))
t = state_to_tensor(a, player)
with torch.no_grad():
out = model(t).cpu().numpy().reshape((3, 3))
plot_grid(a, out, clim=(-2,2))
fig = plt.gcf()
plt.close()
# redraw the canvas
fig.canvas.draw()
# convert canvas to image using numpy
imgs.append(np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8))
imgs = np.array(imgs)
W, H = fig.canvas.get_width_height()[::-1]
N = len(examples)
img = imgs.reshape((N*W, H, 3))
return img
def debug_table(d: dict) -> str:
"""Print a table of the given dictionary
Parameters
----------
d : dict
The nested dictionary to print, should have the following structure:
{'X': {'win': 0, 'draw': 0, 'loss': 0},
'O': {'win': 0, 'draw': 0, 'loss': 0}
Returns
-------
Formatted table of the given dictionary
"""
return ' \n'.join(['|||||', '|-|-|-|-|']+['|'.join(['', f'player: {p}']+[f"{n}: {o}" for n, o in m.items()]+['']) for p, m in d.items()])