-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsac_two_outputs.py
137 lines (111 loc) · 5.35 KB
/
sac_two_outputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
from model import ValueNetwork, SoftQNetwork, PolicyNetwork
from common.replay_buffers import BasicBuffer
import numpy as np
class SACAgent:
def __init__(self, env, gamma, tau, v_lr, q_lr, policy_lr, buffer_maxlen):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.env = env
# self.action_range = [env.action_space.low, env.action_space.high]
# TODO: as a simple demo, I changed here; for the implementation, we should pass this as parameters
self.action_range = [[-1, 1], [-1, 1]]
self.obs_dim = env.observation_space.shape[0]
self.action_dim = 2
# self.action_dim = 1
# hyperparameters
self.gamma = gamma
self.tau = tau
self.update_step = 0
self.delay_step = 2
# initialize networks
self.value_net = ValueNetwork(self.obs_dim, 1).to(self.device)
self.target_value_net = ValueNetwork(self.obs_dim, 1).to(self.device)
self.q_net1 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
self.q_net2 = SoftQNetwork(self.obs_dim, self.action_dim).to(self.device)
self.policy_net = PolicyNetwork(self.obs_dim, self.action_dim).to(self.device)
# copy params to target param
for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()):
target_param.data.copy_(param)
# initialize optimizers
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=v_lr)
self.q1_optimizer = optim.Adam(self.q_net1.parameters(), lr=q_lr)
self.q2_optimizer = optim.Adam(self.q_net2.parameters(), lr=q_lr)
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=policy_lr)
self.replay_buffer = BasicBuffer(buffer_maxlen)
# pi: state -> acton
def get_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
mean, log_std = self.policy_net.forward(state)
std = log_std.exp()
normal = Normal(mean, std)
z = normal.sample()
action = torch.tanh(z)
action = action.cpu().detach().squeeze(0).numpy()
return self.rescale_action(action)
def rescale_action(self, action):
'''if action < 0.5:
return 0
else:
return 1'''
scaled_action = []
for idx, a in enumerate(action):
action_range = self.action_range[idx]
a = (action_range[1] - action_range[0]) / 2.0 + (action_range[1] + action_range[0]) / 2.0
scaled_action.append(a)
return scaled_action
def update(self, batch_size):
states, actions, rewards, next_states, dones = self.replay_buffer.sample(batch_size)
states = torch.FloatTensor(states).to(self.device)
actions = torch.FloatTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(next_states).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
dones = dones.view(dones.size(0), -1)
next_actions, next_log_pi = self.policy_net.sample(next_states)
next_q1 = self.q_net1(next_states, next_actions)
next_q2 = self.q_net2(next_states, next_actions)
next_v = self.target_value_net(next_states)
# value Loss
next_v_target = torch.min(next_q1, next_q2) - next_log_pi
curr_v = self.value_net.forward(states)
v_loss = F.mse_loss(curr_v, next_v_target.detach())
#TODO: Question: why using 2 Q-networks?
# To reduce bias in training.
# q loss
curr_q1 = self.q_net1.forward(states, actions)
curr_q2 = self.q_net2.forward(states, actions)
expected_q = rewards + (1 - dones) * self.gamma * next_v
q1_loss = F.mse_loss(curr_q1, expected_q.detach())
q2_loss = F.mse_loss(curr_q2, expected_q.detach())
# update value network and q networks
self.value_optimizer.zero_grad()
v_loss.backward()
self.value_optimizer.step()
self.q1_optimizer.zero_grad()
q1_loss.backward()
self.q1_optimizer.step()
self.q2_optimizer.zero_grad()
q2_loss.backward()
self.q2_optimizer.step()
# delayed update for policy net and target value nets
# TODO: Question: what does this part do?
# The original paper mentioned 2 methods for approximating the value function
# 1. the EMA of policy weights to update the Q network
# 2. periodical update of the policy network, which is used in this code
if self.update_step % self.delay_step == 0:
new_actions, log_pi = self.policy_net.sample(states)
min_q = torch.min(
self.q_net1.forward(states, new_actions),
self.q_net2.forward(states, new_actions)
)
policy_loss = (log_pi - min_q).mean()
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
# target networks
for target_param, param in zip(self.target_value_net.parameters(), self.value_net.parameters()):
target_param.data.copy_(self.tau * param + (1 - self.tau) * target_param)
self.update_step += 1