-
Notifications
You must be signed in to change notification settings - Fork 3
/
demo.py
60 lines (43 loc) · 1.35 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import toml
from argparse import ArgumentParser
from os.path import join
from games.carracing import RacingNet, CarRacing
from ppo import PPO
CONFIG_FILE = "config.toml"
def load_config():
with open(CONFIG_FILE, "r") as f:
config = toml.load(f)
return config
def parse_args():
parser = ArgumentParser()
parser.add_argument("--ckpt", type=str, default="ckpt")
parser.add_argument("--num_steps", type=int, default=100_000)
parser.add_argument("--delay_ms", type=int, default=10)
return parser.parse_args()
def main():
cfg = load_config()
args = parse_args()
env = CarRacing(frame_skip=0, frame_stack=4,)
net = RacingNet(env.observation_space.shape, env.action_space.shape)
ppo = PPO(
env,
net,
lr=cfg["lr"],
gamma=cfg["gamma"],
batch_size=cfg["batch_size"],
gae_lambda=cfg["gae_lambda"],
clip=cfg["clip"],
value_coef=cfg["value_coef"],
entropy_coef=cfg["entropy_coef"],
epochs_per_step=cfg["epochs_per_step"],
num_steps=cfg["num_steps"],
horizon=cfg["horizon"],
save_dir=cfg["save_dir"],
save_interval=cfg["save_interval"],
)
ppo.load(args.ckpt)
for i in range(args.num_steps):
ppo.collect_trajectory(1, delay_ms=args.delay_ms)
env.close()
if __name__ == "__main__":
main()