-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils_drl.py
109 lines (87 loc) · 3.56 KB
/
utils_drl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from typing import (
Optional,
)
import random
import torch
import torch.nn.functional as F
import torch.optim as optim
from utils_types import (
TensorStack4,
TorchDevice,
)
from utils_memory import Prioritized_ReplayMemory
from utils_model import DQN
class Agent(object):
def __init__(
self,
action_dim: int,
device: TorchDevice,
gamma: float,
belta_start: float,
beta_increment_per_sampling: float,
seed: int,
eps_start: float,
eps_final: float,
eps_decay: float,
restore: Optional[str] = None,
) -> None:
self.__action_dim = action_dim
self.__device = device
self.__gamma = gamma
self.__belta = belta_start
self.__beta_increment_per_sampling = beta_increment_per_sampling
self.__eps_start = eps_start
self.__eps_final = eps_final
self.__eps_decay = eps_decay
self.__eps = eps_start
self.__r = random.Random()
self.__r.seed(seed)
self.__policy = DQN(action_dim, device).to(device)
self.__target = DQN(action_dim, device).to(device)
if restore is None:
self.__policy.apply(DQN.init_weights)
else:
self.__policy.load_state_dict(torch.load(restore))
self.__target.load_state_dict(self.__policy.state_dict())
self.__optimizer = optim.Adam(
self.__policy.parameters(),
lr=0.000015625,
eps=1.5e-4,
)
self.__target.eval()
def run(self, state: TensorStack4, training: bool = False) -> int:
"""run suggests an action for the given state."""
if training:
self.__eps -= \
(self.__eps_start - self.__eps_final) / self.__eps_decay
self.__eps = max(self.__eps, self.__eps_final)
if self.__r.random() > self.__eps:
with torch.no_grad():
return self.__policy(state).max(1).indices.item()
return self.__r.randint(0, self.__action_dim - 1)
def learn(self, memory: Prioritized_ReplayMemory, batch_size: int) -> float:
"""learn trains the value network via TD-learning."""
self.__belta += self.__beta_increment_per_sampling
self.__belta = min(self.__belta, 1.0)
state_batch, action_batch, reward_batch, next_batch, done_batch, indice_batch, weight_batch = \
memory.sample(batch_size, self.__belta)
values = self.__policy(state_batch.float()).gather(1, action_batch)
values_next = self.__target(next_batch.float()).max(1).values.detach()
expected = (self.__gamma * values_next.unsqueeze(1)) * (1. - done_batch) + reward_batch
loss = F.smooth_l1_loss(values, expected, reduction = 'none')
priorites = loss + 0.01
memory.update_priorities(indice_batch, priorites)
loss = (weight_batch * loss).mean()
self.__optimizer.zero_grad()
loss.backward()
for param in self.__policy.parameters():
param.grad.data.clamp_(-1, 1)
self.__optimizer.step()
return loss.item()
def sync(self) -> None:
"""sync synchronizes the weights from the policy network to the target
network."""
self.__target.load_state_dict(self.__policy.state_dict())
def save(self, path: str) -> None:
"""save saves the state dict of the policy network."""
torch.save(self.__policy.state_dict(), path)