Skip to content

Commit

Permalink
init backup
Browse files Browse the repository at this point in the history
  • Loading branch information
lilianweng committed Apr 30, 2018
0 parents commit 0834cb7
Show file tree
Hide file tree
Showing 21 changed files with 1,422 additions and 0 deletions.
9 changes: 9 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
.DS_Store
.idea
.pytest_cache
checkpoints/*
logs/*
figs/*
**/*.png
**/*.pyc
**/*.egg-info
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
WIP
Binary file added playground/figs/cartpole-dqn-graph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added playground/figs/cartpole-dqn-hard-monitor.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added playground/figs/cartpole-dqn-hard.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
109 changes: 109 additions & 0 deletions playground/learn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import gym
import numpy as np
from gym.wrappers import Monitor

from playground.policies import QlearningPolicy, DqnPolicy, ReinforcePolicy
from playground.utils.misc import plot_from_monitor_results
from playground.utils.wrappers import DigitizedObservationWrapper


def mountain_car_qlearning(model_name='mountain-car-qlearning'):
"""Mountain car, reward is -1; only when done the reward is different.
"""
env = gym.make('MountainCar-v0')
env = DigitizedObservationWrapper(env, n_bins=10)
print(env.action_space, env.observation_space)

policy = QlearningPolicy(env, model_name, alpha=0.5, gamma=0.99, epsilon=0.1,
epsilon_decay=0.98, alpha_decay=0.97)
policy.build()
policy.train(100000, every_step=1000, with_monitor=True)

env.close()
plot_from_monitor_results('/tmp/' + model_name)


def cartpole_qlearning(model_name='cartpole-qlearning'):
"""
In [22]: env.observation_space.low, env.observation_space.high
Out[22]:
(array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38], dtype=float32),
array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38], dtype=float32))
"""
env = gym.make('CartPole-v1')
env = DigitizedObservationWrapper(
env, n_bins=10,
low=np.array([-2.4, -2., -0.42, -3.5]),
high=np.array([2.4, 2., 0.42, 3.5]),
)
env = Monitor(env, '/tmp/' + model_name, force=True)

policy = QlearningPolicy(env, model_name,
alpha=0.5, gamma=0.9,
epsilon=0.1, epsilon_decay=0.98)
policy.build()
policy.train(100000, every_step=5000, done_reward=-100, with_monitor=True)

env.close()
plot_from_monitor_results('/tmp/' + model_name)


def cartpole_dqn(model_name='cartpole-dqn'):
env = gym.make('CartPole-v1')
env = Monitor(env, '/tmp/' + model_name, force=True)

policy = DqnPolicy(env, model_name,
lr=0.001, epsilon=1.0, epsilon_final=0.02, batch_size=32,
# q_model_type='rnn', q_model_params={'step_size': 16, 'lstm_size': 32},
q_model_type='cnn', q_model_params={'layer_sizes': [64]},
target_update_type='hard')
policy.build()
policy.train(500, annealing_episodes=450, every_episode=5)

env.close()
plot_from_monitor_results('/tmp/' + model_name)


def cartpole_reinforce(model_name='cartpole-reinforce'):
env = gym.make('CartPole-v1')
env = Monitor(env, '/tmp/' + model_name, force=True)

policy = ReinforcePolicy(env, model_name, lr=0.002, lr_decay=0.999,
batch_size=32, layer_sizes=[32, 32])
policy.build()
policy.train(750, every_episode=10)

env.close()
plot_from_monitor_results('/tmp/' + model_name)


def test_cartpole_dqn(model_name='cartpole-dqn'):
env = gym.make('CartPole-v1')
policy = DqnPolicy(env, model_name, training=False,
lr=0.001, epsilon=1.0, epsilon_final=0.02, batch_size=32,
q_model_type='mlp', q_model_params={'layer_sizes': [32, 32]},
target_update_type='hard')
policy.build()
assert policy.load_model(), "Failed to load a trained model."
policy.test(10)


def pacman_dqn(model_name='pacman-dqn'):
env = gym.make('MsPacman-v0')
env = Monitor(env, '/tmp/' + model_name, force=True)

policy = DqnPolicy(env, model_name,
lr=0.001, epsilon=1.0, epsilon_final=0.02, batch_size=32,
q_model_type='cnn',
target_update_type='hard', target_update_params={'every_step': 500})
policy.build()
policy.train(1000, annealing_episodes=900, every_episode=10)

env.close()
plot_from_monitor_results('/tmp/' + model_name)


if __name__ == '__main__':
# cartpole_dqn('cartpole-dqn-hard-rnn')
cartpole_reinforce('cartpole-reinforce')
# pacman_dqn('pacman-dqn-hard')
3 changes: 3 additions & 0 deletions playground/policies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from playground.policies.dqn import DqnPolicy
from playground.policies.qlearning import QlearningPolicy
from playground.policies.reinforce import ReinforcePolicy
Empty file.
123 changes: 123 additions & 0 deletions playground/policies/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import os
import time

import numpy as np
import tensorflow as tf
from gym.utils import colorize

from playground.utils.misc import REPO_ROOT


class Policy:
def __init__(self, env, name, training=True, gamma=0.99):
self.env = env
self.gamma = gamma
self.training = training
self.name = name

np.random.seed(int(time.time()))

def act(self, state, **kwargs):
pass

def build(self):
pass

def train(self, *args, **kwargs):
pass

def evaluate(self, n_episodes):
reward_history = []
reward = 0.

for i in range(n_episodes):
ob = self.env.reset()
done = False
while not done:
a = self.act(ob)
new_ob, r, done, _ = self.env.step(a)
self.env.render()
reward += r
ob = new_ob

reward_history.append(reward)
reward = 0.

print("Avg. reward over {} episodes: {:.4f}".format(n_episodes, np.mean(reward_history)))


class BaseTFModelMixin(object):
"""Abstract object representing an Reader model.
Code borrowed from: https://github.com/devsisters/DQN-tensorflow/blob/master/dqn/base.py
with some modifications.
"""

def __init__(self, model_name, saver_max_to_keep=5):
self._saver = None
self._saver_max_to_keep = saver_max_to_keep
self._writer = None
self._model_name = model_name
self._sess = None

# for attr in self._attrs:
# name = attr if not attr.startswith('_') else attr[1:]
# setattr(self, name, getattr(self.config, attr))

def save_model(self, step=None):
print(colorize(" [*] Saving checkpoints...", "green"))
ckpt_file = os.path.join(self.checkpoint_dir, self.model_name)
self.saver.save(self.sess, ckpt_file, global_step=step)

def load_model(self):
print(colorize(" [*] Loading checkpoints...", "green"))

ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
print(self.checkpoint_dir, ckpt)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
print(ckpt_name)
fname = os.path.join(self.checkpoint_dir, ckpt_name)
print(fname)
self.saver.restore(self.sess, fname)
print(colorize(" [*] Load SUCCESS: %s" % fname, "green"))
return True
else:
print(colorize(" [!] Load FAILED: %s" % self.checkpoint_dir, "red"))
return False

@property
def checkpoint_dir(self):
ckpt_path = os.path.join(REPO_ROOT, 'checkpoints', self.model_name)
os.makedirs(ckpt_path, exist_ok=True)
return ckpt_path

@property
def model_name(self):
assert self._model_name, "Not a valid model name."
return self._model_name

@property
def saver(self):
if self._saver is None:
self._saver = tf.train.Saver(max_to_keep=self._saver_max_to_keep)
return self._saver

@property
def writer(self):
if self._writer is None:
writer_path = os.path.join(REPO_ROOT, "logs", self.model_name)
os.makedirs(writer_path, exist_ok=True)
self._writer = tf.summary.FileWriter(writer_path, self.sess.graph)
return self._writer

@property
def sess(self):
if self._sess is None:
config = tf.ConfigProto()

config.intra_op_parallelism_threads = 2
config.inter_op_parallelism_threads = 2
self._sess = tf.Session(config=config)

return self._sess
Empty file added playground/policies/ddpg.py
Empty file.
Loading

0 comments on commit 0834cb7

Please sign in to comment.