-
Notifications
You must be signed in to change notification settings - Fork 484
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
338 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import numpy as np | ||
|
||
class ReplayBuffer: | ||
def __init__(self, max_size, input_shape, n_actions): | ||
self.mem_size = max_size | ||
self.mem_cntr = 0 | ||
self.state_memory = np.zeros((self.mem_size, *input_shape)) | ||
self.new_state_memory = np.zeros((self.mem_size, *input_shape)) | ||
self.action_memory = np.zeros((self.mem_size, n_actions)) | ||
self.reward_memory = np.zeros(self.mem_size) | ||
self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool) | ||
|
||
def store_transition(self, state, action, reward, state_, done): | ||
index = self.mem_cntr % self.mem_size | ||
|
||
self.state_memory[index] = state | ||
self.new_state_memory[index] = state_ | ||
self.action_memory[index] = action | ||
self.reward_memory[index] = reward | ||
self.terminal_memory[index] = done | ||
|
||
self.mem_cntr += 1 | ||
|
||
def sample_buffer(self, batch_size): | ||
max_mem = min(self.mem_cntr, self.mem_size) | ||
|
||
batch = np.random.choice(max_mem, batch_size) | ||
|
||
states = self.state_memory[batch] | ||
states_ = self.new_state_memory[batch] | ||
actions = self.action_memory[batch] | ||
rewards = self.reward_memory[batch] | ||
dones = self.terminal_memory[batch] | ||
|
||
return states, actions, rewards, states_, dones |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pybullet_envs | ||
import gym | ||
import numpy as np | ||
from sac_tf2 import Agent | ||
from utils import plot_learning_curve | ||
from gym import wrappers | ||
|
||
if __name__ == '__main__': | ||
env = gym.make('InvertedPendulumBulletEnv-v0') | ||
agent = Agent(input_dims=env.observation_space.shape, env=env, | ||
n_actions=env.action_space.shape[0]) | ||
n_games = 250 | ||
# uncomment this line and do a mkdir tmp && mkdir tmp/video if you want to | ||
# record video of the agent playing the game. | ||
#env = wrappers.Monitor(env, 'tmp/video', video_callable=lambda episode_id: True, force=True) | ||
filename = 'inverted_pendulum.png' | ||
|
||
figure_file = 'plots/' + filename | ||
|
||
best_score = env.reward_range[0] | ||
score_history = [] | ||
load_checkpoint = True | ||
|
||
if load_checkpoint: | ||
agent.load_models() | ||
env.render(mode='human') | ||
|
||
for i in range(n_games): | ||
observation = env.reset() | ||
done = False | ||
score = 0 | ||
while not done: | ||
action = agent.choose_action(observation) | ||
observation_, reward, done, info = env.step(action) | ||
score += reward | ||
agent.remember(observation, action, reward, observation_, done) | ||
if not load_checkpoint: | ||
agent.learn() | ||
observation = observation_ | ||
score_history.append(score) | ||
avg_score = np.mean(score_history[-100:]) | ||
|
||
if avg_score > best_score: | ||
best_score = avg_score | ||
if not load_checkpoint: | ||
agent.save_models() | ||
|
||
print('episode ', i, 'score %.1f' % score, 'avg_score %.1f' % avg_score) | ||
|
||
if not load_checkpoint: | ||
x = [i+1 for i in range(n_games)] | ||
plot_learning_curve(x, score_history, figure_file) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import os | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow.keras as keras | ||
import tensorflow_probability as tfp | ||
from tensorflow.keras.layers import Dense | ||
|
||
class CriticNetwork(keras.Model): | ||
def __init__(self, n_actions, fc1_dims=256, fc2_dims=256, | ||
name='critic', chkpt_dir='tmp/sac'): | ||
super(CriticNetwork, self).__init__() | ||
self.fc1_dims = fc1_dims | ||
self.fc2_dims = fc2_dims | ||
self.n_actions = n_actions | ||
self.model_name = name | ||
self.checkpoint_dir = chkpt_dir | ||
self.checkpoint_file = os.path.join(self.checkpoint_dir, name+'_sac') | ||
|
||
self.fc1 = Dense(self.fc1_dims, activation='relu') | ||
self.fc2 = Dense(self.fc2_dims, activation='relu') | ||
self.q = Dense(1, activation=None) | ||
|
||
def call(self, state, action): | ||
action_value = self.fc1(tf.concat([state, action], axis=1)) | ||
action_value = self.fc2(action_value) | ||
|
||
q = self.q(action_value) | ||
|
||
return q | ||
|
||
class ValueNetwork(keras.Model): | ||
def __init__(self, fc1_dims=256, fc2_dims=256, | ||
name='value', chkpt_dir='tmp/sac'): | ||
super(ValueNetwork, self).__init__() | ||
self.fc1_dims = fc1_dims | ||
self.fc2_dims = fc2_dims | ||
self.model_name = name | ||
self.checkpoint_dir = chkpt_dir | ||
self.checkpoint_file = os.path.join(self.checkpoint_dir, name+'_sac') | ||
|
||
self.fc1 = Dense(self.fc1_dims, activation='relu') | ||
self.fc2 = Dense(fc2_dims, activation='relu') | ||
self.v = Dense(1, activation=None) | ||
|
||
def call(self, state): | ||
state_value = self.fc1(state) | ||
state_value = self.fc2(state_value) | ||
|
||
v = self.v(state_value) | ||
|
||
return v | ||
|
||
class ActorNetwork(keras.Model): | ||
def __init__(self, max_action, fc1_dims=256, | ||
fc2_dims=256, n_actions=2, name='actor', chkpt_dir='tmp/sac'): | ||
super(ActorNetwork, self).__init__() | ||
self.fc1_dims = fc1_dims | ||
self.fc2_dims = fc2_dims | ||
self.n_actions = n_actions | ||
self.model_name = name | ||
self.checkpoint_dir = chkpt_dir | ||
self.checkpoint_file = os.path.join(self.checkpoint_dir, name+'_sac') | ||
self.max_action = max_action | ||
self.noise = 1e-6 | ||
|
||
self.fc1 = Dense(self.fc1_dims, activation='relu') | ||
self.fc2 = Dense(self.fc2_dims, activation='relu') | ||
self.mu = Dense(self.n_actions, activation=None) | ||
self.sigma = Dense(self.n_actions, activation=None) | ||
|
||
def call(self, state): | ||
prob = self.fc1(state) | ||
prob = self.fc2(prob) | ||
|
||
mu = self.mu(prob) | ||
sigma = self.sigma(prob) | ||
# might want to come back and change this, perhaps tf plays more nicely with | ||
# a sigma of ~0 | ||
sigma = tf.clip_by_value(sigma, self.noise, 1) | ||
|
||
return mu, sigma | ||
|
||
def sample_normal(self, state, reparameterize=True): | ||
mu, sigma = self.call(state) | ||
probabilities = tfp.distributions.Normal(mu, sigma) | ||
|
||
if reparameterize: | ||
actions = probabilities.sample() # + something else if you want to implement | ||
else: | ||
actions = probabilities.sample() | ||
|
||
action = tf.math.tanh(actions)*self.max_action | ||
log_probs = probabilities.log_prob(actions) | ||
log_probs -= tf.math.log(1-tf.math.pow(action,2)+self.noise) | ||
log_probs = tf.math.reduce_sum(log_probs, axis=1, keepdims=True) | ||
|
||
return action, log_probs | ||
|
Binary file added
BIN
+22.7 KB
ReinforcementLearning/PolicyGradient/SAC/tf2/plots/inverted_pendulum.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
142 changes: 142 additions & 0 deletions
142
ReinforcementLearning/PolicyGradient/SAC/tf2/sac_tf2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import os | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow.keras as keras | ||
from tensorflow.keras.optimizers import Adam | ||
from buffer import ReplayBuffer | ||
from networks import ActorNetwork, CriticNetwork, ValueNetwork | ||
|
||
class Agent: | ||
def __init__(self, alpha=0.0003, beta=0.0003, input_dims=[8], | ||
env=None, gamma=0.99, n_actions=2, max_size=1000000, tau=0.005, | ||
layer1_size=256, layer2_size=256, batch_size=256, reward_scale=2): | ||
self.gamma = gamma | ||
self.tau = tau | ||
self.memory = ReplayBuffer(max_size, input_dims, n_actions) | ||
self.batch_size = batch_size | ||
self.n_actions = n_actions | ||
|
||
self.actor = ActorNetwork(n_actions=n_actions, name='actor', | ||
max_action=env.action_space.high) | ||
self.critic_1 = CriticNetwork(n_actions=n_actions, name='critic_1') | ||
self.critic_2 = CriticNetwork(n_actions=n_actions, name='critic_2') | ||
self.value = ValueNetwork(name='value') | ||
self.target_value = ValueNetwork(name='target_value') | ||
|
||
self.actor.compile(optimizer=Adam(learning_rate=alpha)) | ||
self.critic_1.compile(optimizer=Adam(learning_rate=beta)) | ||
self.critic_2.compile(optimizer=Adam(learning_rate=beta)) | ||
self.value.compile(optimizer=Adam(learning_rate=beta)) | ||
self.target_value.compile(optimizer=Adam(learning_rate=beta)) | ||
|
||
self.scale = reward_scale | ||
self.update_network_parameters(tau=1) | ||
|
||
def choose_action(self, observation): | ||
state = tf.convert_to_tensor([observation]) | ||
actions, _ = self.actor.sample_normal(state, reparameterize=False) | ||
|
||
return actions[0] | ||
|
||
def remember(self, state, action, reward, new_state, done): | ||
self.memory.store_transition(state, action, reward, new_state, done) | ||
|
||
def update_network_parameters(self, tau=None): | ||
if tau is None: | ||
tau = self.tau | ||
|
||
weights = [] | ||
targets = self.target_value.weights | ||
for i, weight in enumerate(self.value.weights): | ||
weights.append(weight * tau + targets[i]*(1-tau)) | ||
|
||
self.target_value.set_weights(weights) | ||
|
||
def save_models(self): | ||
print('... saving models ...') | ||
self.actor.save_weights(self.actor.checkpoint_file) | ||
self.critic_1.save_weights(self.critic_1.checkpoint_file) | ||
self.critic_2.save_weights(self.critic_2.checkpoint_file) | ||
self.value.save_weights(self.value.checkpoint_file) | ||
self.target_value.save_weights(self.target_value.checkpoint_file) | ||
|
||
def load_models(self): | ||
print('... loading models ...') | ||
self.actor.load_weights(self.actor.checkpoint_file) | ||
self.critic_1.load_weights(self.critic_1.checkpoint_file) | ||
self.critic_2.load_weights(self.critic_2.checkpoint_file) | ||
self.value.load_weights(self.value.checkpoint_file) | ||
self.target_value.load_weights(self.target_value.checkpoint_file) | ||
|
||
def learn(self): | ||
if self.memory.mem_cntr < self.batch_size: | ||
return | ||
|
||
state, action, reward, new_state, done = \ | ||
self.memory.sample_buffer(self.batch_size) | ||
|
||
states = tf.convert_to_tensor(state, dtype=tf.float32) | ||
states_ = tf.convert_to_tensor(new_state, dtype=tf.float32) | ||
rewards = tf.convert_to_tensor(reward, dtype=tf.float32) | ||
actions = tf.convert_to_tensor(action, dtype=tf.float32) | ||
|
||
with tf.GradientTape() as tape: | ||
value = tf.squeeze(self.value(states), 1) | ||
value_ = tf.squeeze(self.target_value(states_), 1) | ||
|
||
current_policy_actions, log_probs = self.actor.sample_normal(states, | ||
reparameterize=False) | ||
log_probs = tf.squeeze(log_probs,1) | ||
q1_new_policy = self.critic_1(states, current_policy_actions) | ||
q2_new_policy = self.critic_2(states, current_policy_actions) | ||
critic_value = tf.squeeze( | ||
tf.math.minimum(q1_new_policy, q2_new_policy), 1) | ||
|
||
value_target = critic_value - log_probs | ||
value_loss = 0.5 * keras.losses.MSE(value, value_target) | ||
|
||
value_network_gradient = tape.gradient(value_loss, | ||
self.value.trainable_variables) | ||
self.value.optimizer.apply_gradients(zip( | ||
value_network_gradient, self.value.trainable_variables)) | ||
|
||
|
||
with tf.GradientTape() as tape: | ||
# in the original paper, they reparameterize here. We don't implement | ||
# this so it's just the usual action. | ||
new_policy_actions, log_probs = self.actor.sample_normal(states, | ||
reparameterize=True) | ||
log_probs = tf.squeeze(log_probs, 1) | ||
q1_new_policy = self.critic_1(states, new_policy_actions) | ||
q2_new_policy = self.critic_2(states, new_policy_actions) | ||
critic_value = tf.squeeze(tf.math.minimum( | ||
q1_new_policy, q2_new_policy), 1) | ||
|
||
actor_loss = log_probs - critic_value | ||
actor_loss = tf.math.reduce_mean(actor_loss) | ||
|
||
actor_network_gradient = tape.gradient(actor_loss, | ||
self.actor.trainable_variables) | ||
self.actor.optimizer.apply_gradients(zip( | ||
actor_network_gradient, self.actor.trainable_variables)) | ||
|
||
|
||
with tf.GradientTape(persistent=True) as tape: | ||
# I didn't know that these context managers shared values? | ||
q_hat = self.scale*reward + self.gamma*value_*(1-done) | ||
q1_old_policy = tf.squeeze(self.critic_1(state, action), 1) | ||
q2_old_policy = tf.squeeze(self.critic_2(state, action), 1) | ||
critic_1_loss = 0.5 * keras.losses.MSE(q1_old_policy, q_hat) | ||
critic_2_loss = 0.5 * keras.losses.MSE(q2_old_policy, q_hat) | ||
|
||
critic_1_network_gradient = tape.gradient(critic_1_loss, | ||
self.critic_1.trainable_variables) | ||
critic_2_network_gradient = tape.gradient(critic_2_loss, | ||
self.critic_2.trainable_variables) | ||
|
||
self.critic_1.optimizer.apply_gradients(zip( | ||
critic_1_network_gradient, self.critic_1.trainable_variables)) | ||
self.critic_2.optimizer.apply_gradients(zip( | ||
critic_2_network_gradient, self.critic_2.trainable_variables)) | ||
|
||
self.update_network_parameters() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
def plot_learning_curve(x, scores, figure_file): | ||
running_avg = np.zeros(len(scores)) | ||
for i in range(len(running_avg)): | ||
running_avg[i] = np.mean(scores[max(0, i-100):(i+1)]) | ||
plt.plot(x, running_avg) | ||
plt.title('Running average of previous 100 scores') | ||
plt.savefig(figure_file) |