-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcrawler.py
66 lines (61 loc) · 2.94 KB
/
crawler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
from absl import app
from absl import logging
from absl import flags
from unityagents import UnityEnvironment
from agents.ppo import PPOAgent
config = flags.FLAGS
flags.DEFINE_string(name='env', default='../../deep-reinforcement-learning/p2_continuous-control/Crawler_Windows_x86_64/Crawler.exe',
help='Unity Environment to load')
flags.DEFINE_boolean(name='render', default=False, help="execute Unity Enviroment with display")
flags.DEFINE_string(name='load',default=None,
help='model file to load with path')
flags.DEFINE_bool(name='play', default=None,
help='play environment with model')
flags.DEFINE_bool(name='train', default=None,
help='train the agent')
flags.DEFINE_integer(name='episodes', default=20,
help='number of episodes to run')
flags.DEFINE_float(name='gamma',default=0.995,
help='discount factor for future rewards (0,1]')
flags.DEFINE_integer(name='trajectories',default=2048,
help='number of trajectories to sample per iteration')
flags.DEFINE_integer(name='policy_optimization_epochs', default=160,
help='number of epochs to run (K in paper)')
flags.DEFINE_float(name='policy_stopping_kl', default=0.3,
help='log KL divergence to early stop PPO improvements')
flags.DEFINE_float(name='policy_clip_range', default=0.2,
help='clipping threshold for PPO policy optimization')
flags.DEFINE_float(name='gae_lambda', default=0.85,
help='lambda coefficient for generalized advantage estimate')
flags.DEFINE_float(name='entropy_beta', default=0.002,
help='coefficient to multiply beta loss in PPO step')
flags.DEFINE_float(name='vf_coeff', default=0.05,
help='coefficient to multiply value loss in PPO step')
flags.DEFINE_integer(name='memory_batch_size',default=512,
help='batch size of memory samples per epoch')
flags.DEFINE_bool(name='tb', default=True,
help='enable tensorboard logging')
flags.DEFINE_string(name='device', default='cpu',
help="Device to use for torch")
def main(argv):
del argv
if config.log_dir != '':
if not os.path.exists(config.log_dir):
os.makedirs(config.log_dir)
logging.get_absl_handler().use_absl_log_file()
logging.set_verbosity('info')
env = UnityEnvironment(file_name=config.env, worker_id = 2, no_graphics=config.render)
model = PPOAgent(device=config.device,env=env)
if config.load is not None:
model.load_model(load_model = config.load)
if config.play is not None:
model.play(episodes= config.episodes)
if config.train is not None:
model.train(
training_iterations=config.training_iterations,
log_dir=config.log_dir,
render=config.render)
env.close()
if __name__ == '__main__':
app.run(main)