-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_not_working_yet.py
92 lines (82 loc) · 3.53 KB
/
train_not_working_yet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.monitor import Monitor
import os
import shadow_gym
import gymnasium
"""
Created by Ethan Cheam
Much more advanced training code. Still in development.
"""
# SETTINGS
# RecurrentPPO or PPO
recurrent = False
vectorized_env = False
normalized_env = False
start_from_existing = False
existing_model_file = "" # no need .zip extension
# Run name should have model, unique number, and optionally a description
run_name = "PPO" + "-" + "17" + "-" + "shadowgym"
saving_timesteps_interval = 500_000
start_saving = 1_000_000
# Set up folders to store models and logs
models_dir = os.path.join(os.path.dirname(__file__), 'models')
logs_dir = os.path.join(os.path.dirname(__file__), 'logs')
normalize_stats = os.path.join(os.path.dirname(__file__), 'normalize_stats')
if not start_from_existing and os.path.exists(f"{models_dir}/{run_name}"):
raise Exception("Error: model folder already exists. Change run_name to prevent overriding existing model folder")
if not start_from_existing and os.path.exists(f"{logs_dir}/{run_name}"):
raise Exception("Error: log folder already exists. Change run_name to prevent overriding existing log folder")
if not start_from_existing and os.path.exists(f"{normalize_stats}/{run_name}"):
raise Exception("Error: normalize_stats folder already exists. Change run_name")
if normalized_env:
os.makedirs(f"{normalize_stats}/{run_name}")
class TensorboardCallback(BaseCallback):
def __init__(self, verbose=1):
super(TensorboardCallback, self).__init__(verbose)
self.episode_length = 0
self.episode_reward = 0
def _on_rollout_end(self) -> None:
# Error: _on_rollout and episodes are different time periods.
self.episode_length = self.training_env.get_attr("num_steps")[0]
self.logger.record("rollout/ep_len_mean",self.episode_length)
self.logger.record("rollout/ep_rew_mean", self.episode_reward / self.episode_length)
# Reset vars
self.episode_length = 0
self.episode_reward = 0
def _on_step(self) -> bool:
self.episode_reward += self.training_env.get_attr("reward")[0]
return True
rewards_callback = None
if vectorized_env:
env = DummyVecEnv([lambda: gymnasium.make("ShadowEnv-v0", GUI=False)])
if normalized_env:
env = VecNormalize(env)
rewards_callback = TensorboardCallback()
else:
env = gymnasium.make("ShadowEnv-v0", GUI=False)
env = Monitor(env)
full_model_path = None
if start_from_existing:
full_model_path = os.path.join(models_dir, run_name, existing_model_file)
if recurrent:
from sb3_contrib import RecurrentPPO
if start_from_existing:
model = RecurrentPPO.load(full_model_path, env)
else:
model = RecurrentPPO(policy="MlpLstmPolicy", env=env, tensorboard_log=logs_dir, verbose=1)
else:
if start_from_existing:
model = PPO.load(full_model_path, env)
else:
model = PPO(policy="MlpPolicy", env=env, tensorboard_log=logs_dir, verbose=1)
timesteps = 0
while True:
model.learn(saving_timesteps_interval, tb_log_name=run_name, reset_num_timesteps=False)
timesteps += saving_timesteps_interval
if timesteps >= start_saving:
model.save(f"{models_dir}/{run_name}/{timesteps}")
if vectorized_env and normalized_env:
normalize_stats_path = os.path.join(normalize_stats, run_name, str(timesteps) + '.pkl')
env.save(normalize_stats_path)