-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 0834cb7
Showing
21 changed files
with
1,422 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
.DS_Store | ||
.idea | ||
.pytest_cache | ||
checkpoints/* | ||
logs/* | ||
figs/* | ||
**/*.png | ||
**/*.pyc | ||
**/*.egg-info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
WIP |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import gym | ||
import numpy as np | ||
from gym.wrappers import Monitor | ||
|
||
from playground.policies import QlearningPolicy, DqnPolicy, ReinforcePolicy | ||
from playground.utils.misc import plot_from_monitor_results | ||
from playground.utils.wrappers import DigitizedObservationWrapper | ||
|
||
|
||
def mountain_car_qlearning(model_name='mountain-car-qlearning'): | ||
"""Mountain car, reward is -1; only when done the reward is different. | ||
""" | ||
env = gym.make('MountainCar-v0') | ||
env = DigitizedObservationWrapper(env, n_bins=10) | ||
print(env.action_space, env.observation_space) | ||
|
||
policy = QlearningPolicy(env, model_name, alpha=0.5, gamma=0.99, epsilon=0.1, | ||
epsilon_decay=0.98, alpha_decay=0.97) | ||
policy.build() | ||
policy.train(100000, every_step=1000, with_monitor=True) | ||
|
||
env.close() | ||
plot_from_monitor_results('/tmp/' + model_name) | ||
|
||
|
||
def cartpole_qlearning(model_name='cartpole-qlearning'): | ||
""" | ||
In [22]: env.observation_space.low, env.observation_space.high | ||
Out[22]: | ||
(array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38], dtype=float32), | ||
array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38], dtype=float32)) | ||
""" | ||
env = gym.make('CartPole-v1') | ||
env = DigitizedObservationWrapper( | ||
env, n_bins=10, | ||
low=np.array([-2.4, -2., -0.42, -3.5]), | ||
high=np.array([2.4, 2., 0.42, 3.5]), | ||
) | ||
env = Monitor(env, '/tmp/' + model_name, force=True) | ||
|
||
policy = QlearningPolicy(env, model_name, | ||
alpha=0.5, gamma=0.9, | ||
epsilon=0.1, epsilon_decay=0.98) | ||
policy.build() | ||
policy.train(100000, every_step=5000, done_reward=-100, with_monitor=True) | ||
|
||
env.close() | ||
plot_from_monitor_results('/tmp/' + model_name) | ||
|
||
|
||
def cartpole_dqn(model_name='cartpole-dqn'): | ||
env = gym.make('CartPole-v1') | ||
env = Monitor(env, '/tmp/' + model_name, force=True) | ||
|
||
policy = DqnPolicy(env, model_name, | ||
lr=0.001, epsilon=1.0, epsilon_final=0.02, batch_size=32, | ||
# q_model_type='rnn', q_model_params={'step_size': 16, 'lstm_size': 32}, | ||
q_model_type='cnn', q_model_params={'layer_sizes': [64]}, | ||
target_update_type='hard') | ||
policy.build() | ||
policy.train(500, annealing_episodes=450, every_episode=5) | ||
|
||
env.close() | ||
plot_from_monitor_results('/tmp/' + model_name) | ||
|
||
|
||
def cartpole_reinforce(model_name='cartpole-reinforce'): | ||
env = gym.make('CartPole-v1') | ||
env = Monitor(env, '/tmp/' + model_name, force=True) | ||
|
||
policy = ReinforcePolicy(env, model_name, lr=0.002, lr_decay=0.999, | ||
batch_size=32, layer_sizes=[32, 32]) | ||
policy.build() | ||
policy.train(750, every_episode=10) | ||
|
||
env.close() | ||
plot_from_monitor_results('/tmp/' + model_name) | ||
|
||
|
||
def test_cartpole_dqn(model_name='cartpole-dqn'): | ||
env = gym.make('CartPole-v1') | ||
policy = DqnPolicy(env, model_name, training=False, | ||
lr=0.001, epsilon=1.0, epsilon_final=0.02, batch_size=32, | ||
q_model_type='mlp', q_model_params={'layer_sizes': [32, 32]}, | ||
target_update_type='hard') | ||
policy.build() | ||
assert policy.load_model(), "Failed to load a trained model." | ||
policy.test(10) | ||
|
||
|
||
def pacman_dqn(model_name='pacman-dqn'): | ||
env = gym.make('MsPacman-v0') | ||
env = Monitor(env, '/tmp/' + model_name, force=True) | ||
|
||
policy = DqnPolicy(env, model_name, | ||
lr=0.001, epsilon=1.0, epsilon_final=0.02, batch_size=32, | ||
q_model_type='cnn', | ||
target_update_type='hard', target_update_params={'every_step': 500}) | ||
policy.build() | ||
policy.train(1000, annealing_episodes=900, every_episode=10) | ||
|
||
env.close() | ||
plot_from_monitor_results('/tmp/' + model_name) | ||
|
||
|
||
if __name__ == '__main__': | ||
# cartpole_dqn('cartpole-dqn-hard-rnn') | ||
cartpole_reinforce('cartpole-reinforce') | ||
# pacman_dqn('pacman-dqn-hard') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from playground.policies.dqn import DqnPolicy | ||
from playground.policies.qlearning import QlearningPolicy | ||
from playground.policies.reinforce import ReinforcePolicy |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import os | ||
import time | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from gym.utils import colorize | ||
|
||
from playground.utils.misc import REPO_ROOT | ||
|
||
|
||
class Policy: | ||
def __init__(self, env, name, training=True, gamma=0.99): | ||
self.env = env | ||
self.gamma = gamma | ||
self.training = training | ||
self.name = name | ||
|
||
np.random.seed(int(time.time())) | ||
|
||
def act(self, state, **kwargs): | ||
pass | ||
|
||
def build(self): | ||
pass | ||
|
||
def train(self, *args, **kwargs): | ||
pass | ||
|
||
def evaluate(self, n_episodes): | ||
reward_history = [] | ||
reward = 0. | ||
|
||
for i in range(n_episodes): | ||
ob = self.env.reset() | ||
done = False | ||
while not done: | ||
a = self.act(ob) | ||
new_ob, r, done, _ = self.env.step(a) | ||
self.env.render() | ||
reward += r | ||
ob = new_ob | ||
|
||
reward_history.append(reward) | ||
reward = 0. | ||
|
||
print("Avg. reward over {} episodes: {:.4f}".format(n_episodes, np.mean(reward_history))) | ||
|
||
|
||
class BaseTFModelMixin(object): | ||
"""Abstract object representing an Reader model. | ||
Code borrowed from: https://github.com/devsisters/DQN-tensorflow/blob/master/dqn/base.py | ||
with some modifications. | ||
""" | ||
|
||
def __init__(self, model_name, saver_max_to_keep=5): | ||
self._saver = None | ||
self._saver_max_to_keep = saver_max_to_keep | ||
self._writer = None | ||
self._model_name = model_name | ||
self._sess = None | ||
|
||
# for attr in self._attrs: | ||
# name = attr if not attr.startswith('_') else attr[1:] | ||
# setattr(self, name, getattr(self.config, attr)) | ||
|
||
def save_model(self, step=None): | ||
print(colorize(" [*] Saving checkpoints...", "green")) | ||
ckpt_file = os.path.join(self.checkpoint_dir, self.model_name) | ||
self.saver.save(self.sess, ckpt_file, global_step=step) | ||
|
||
def load_model(self): | ||
print(colorize(" [*] Loading checkpoints...", "green")) | ||
|
||
ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) | ||
print(self.checkpoint_dir, ckpt) | ||
if ckpt and ckpt.model_checkpoint_path: | ||
ckpt_name = os.path.basename(ckpt.model_checkpoint_path) | ||
print(ckpt_name) | ||
fname = os.path.join(self.checkpoint_dir, ckpt_name) | ||
print(fname) | ||
self.saver.restore(self.sess, fname) | ||
print(colorize(" [*] Load SUCCESS: %s" % fname, "green")) | ||
return True | ||
else: | ||
print(colorize(" [!] Load FAILED: %s" % self.checkpoint_dir, "red")) | ||
return False | ||
|
||
@property | ||
def checkpoint_dir(self): | ||
ckpt_path = os.path.join(REPO_ROOT, 'checkpoints', self.model_name) | ||
os.makedirs(ckpt_path, exist_ok=True) | ||
return ckpt_path | ||
|
||
@property | ||
def model_name(self): | ||
assert self._model_name, "Not a valid model name." | ||
return self._model_name | ||
|
||
@property | ||
def saver(self): | ||
if self._saver is None: | ||
self._saver = tf.train.Saver(max_to_keep=self._saver_max_to_keep) | ||
return self._saver | ||
|
||
@property | ||
def writer(self): | ||
if self._writer is None: | ||
writer_path = os.path.join(REPO_ROOT, "logs", self.model_name) | ||
os.makedirs(writer_path, exist_ok=True) | ||
self._writer = tf.summary.FileWriter(writer_path, self.sess.graph) | ||
return self._writer | ||
|
||
@property | ||
def sess(self): | ||
if self._sess is None: | ||
config = tf.ConfigProto() | ||
|
||
config.intra_op_parallelism_threads = 2 | ||
config.inter_op_parallelism_threads = 2 | ||
self._sess = tf.Session(config=config) | ||
|
||
return self._sess |
Empty file.
Oops, something went wrong.