-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_experiments.py
98 lines (77 loc) · 3.25 KB
/
run_experiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
CLI app that takes a given environment and RL algorithm and:
- 1. trains the RL algorithm on the environment (trajectories discarded).
- 2. upon convergence, runs the RL algorithm in the environment and logs
the resulting trajectories.
"""
import argparse
from datetime import datetime
import logging.config
import os
import pickle
import subprocess
import ray
from pirl import config, experiments
logger = logging.getLogger('pirl.experiments.cli')
def _check_in(cats, kind):
def f(s):
if s in cats:
return s
else:
raise argparse.ArgumentTypeError("'{}' is not an {}".format(s, kind))
return f
experiment_type = _check_in(config.EXPERIMENTS.keys(), 'experiment')
def parse_args():
desc = 'Log trajectories from an RL algorithm.'
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--seed', metavar='STR', default='foobar', type=str)
parser.add_argument('--num-cpu', metavar='N', default=None, type=int)
parser.add_argument('--num-gpu', metavar='N', default=None, type=int)
parser.add_argument('--ray-server', metavar='HOST',
default=config.RAY_SERVER, type=str)
parser.add_argument('experiments', metavar='experiment',
type=experiment_type, nargs='+')
return parser.parse_args()
def git_hash():
hash = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=config.PROJECT_DIR)
return hash.decode().strip()
# We pretend we have more GPUs to workaround Ray issue #402.
# This can be overridden by specifying --num-gpu.
GPU_MULTIPLIER = 4
# Timestamp for logging
ISO_TIMESTAMP = "%Y%m%d_%H%M%S"
def node_setup(_cfg):
logging.config.dictConfig(config.LOG_CFG)
if __name__ == '__main__':
# Argument parsing
args = parse_args()
if args.ray_server is None: # run locally
num_gpu = args.num_gpu
if num_gpu is None:
num_gpu = ray.services._autodetect_num_gpus() * GPU_MULTIPLIER
ray.init(num_cpus=args.num_cpu, num_gpus=num_gpu,
redirect_worker_output=True)
elif args.ray_server == "DEBUG": # run in "Python" mode (single process)
ray.init(driver_mode=ray.worker.PYTHON_MODE)
else: # connect to existing server (could still be a single machine)
ray.init(redis_address=args.ray_server)
# Setup logging
node_setup(None)
if ray.worker.global_worker.mode != ray.worker.PYTHON_MODE:
ray.worker.global_worker.run_function_on_all_workers(node_setup)
logger.info('CLI args: %s', args)
# Experiment loop
for experiment in args.experiments:
# reseed so does not matter which order experiments are run in
timestamp = datetime.now().strftime(ISO_TIMESTAMP)
version = git_hash()
out_dir = '{}-{}-{}'.format(experiment, timestamp, version)
path = os.path.join(config.EXPERIMENTS_DIR, out_dir)
os.makedirs(path)
cfg = config.EXPERIMENTS[experiment]
res = experiments.run_experiment(cfg, path, args.seed)
logger.info('Experiment %s completed. Outcome:\n %s. Saving to %s.',
experiment, res['values'], path)
with open('{}/results.pkl'.format(path), 'wb') as f:
pickle.dump(res, f)