Skip to content

Commit

Permalink
update framwork
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Nov 30, 2023
1 parent b0ff0f4 commit 8b29b05
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 78 deletions.
10 changes: 8 additions & 2 deletions algos/base/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,17 @@ def get_state_action_size(self):
return self.state_size, self.action_size
def create_optimizer(self):
self.optimizer = optim.Adam(self.parameters(), lr=self.cfg.lr)

def get_model_params(self):
model_params = self.state_dict()
return model_params
''' get model params
'''
return self.state_dict()

def put_model_params(self, model_params):
''' put model params
'''
self.load_state_dict(model_params)

def get_optimizer_params(self):
return self.optimizer.state_dict()
def set_optimizer_params(self, optim_params_dict):
Expand Down
3 changes: 2 additions & 1 deletion framework/dataserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def pub_msg(self, msg: Msg):
elif msg_type == MsgType.DATASERVER_INCREASE_UPDATE_STEP:
update_step_delta = 1 if msg_data is None else msg_data
self._increase_update_step(i = update_step_delta)

elif msg_type == MsgType.DATASERVER_CHECK_TASK_END:
self._check_task_end()
return self._check_task_end()
else:
raise NotImplementedError

Expand Down
38 changes: 11 additions & 27 deletions framework/interactor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gymnasium as gym
import copy
from typing import Tuple
from algos.base.exps import Exp
from framework.message import Msg, MsgType
Expand All @@ -7,36 +8,28 @@
class BaseInteractor:
''' Interactor for gym env to support sample n-steps or n-episodes traning data
'''
def __init__(self, cfg: MergedConfig, id = 0, policy = None, *args, **kwargs) -> None:
def __init__(self, cfg: MergedConfig, id = 0, env = None, policy = None, *args, **kwargs) -> None:
self.cfg = cfg
self.id = id
self.policy = policy
self.env = gym.make(self.cfg.env_cfg.id)
self.env = env
self.seed = self.cfg.seed + self.id
self.data = None
self.exps = [] # reset experiences
self.summary = [] # reset summary
self.ep_reward, self.ep_step = 0, 0 # reset params per episode
self.curr_obs, self.curr_info = self.env.reset(seed = self.seed) # reset env

def pub_msg(self, msg: Msg):
msg_type, msg_data = msg.type, msg.data
if msg_type == MsgType.INTERACTOR_SAMPLE:
model_params = msg_data
self._put_model_params(model_params)
self._sample_data()
elif msg_type == MsgType.INTERACTOR_GET_SAMPLE_DATA:
return self._get_sample_data()

def run(self, model_params, *args, **kwargs):
def run(self, *args, **kwargs):
collector = kwargs['collector']
stats_recorder = kwargs['stats_recorder']
model_mgr = kwargs['model_mgr']
model_params = model_mgr.pub_msg(Msg(type = MsgType.MODEL_MGR_GET_MODEL_PARAMS)) # get model params
self.policy.put_model_params(model_params)
self._sample_data(*args, **kwargs)
collector.pub_msg(Msg(type = MsgType.COLLECTOR_PUT_EXPS, data = self.exps)) # put exps to collector
self.exps = [] # reset exps
if len(self.summary) > 0:
stats_recorder.pub_msg(Msg(type = MsgType.STATS_RECORDER_PUT_INTERACT_SUMMARY, data = self.summary)) # put summary to stats recorder
stats_recorder.pub_msg(Msg(type = MsgType.RECORDER_PUT_INTERACT_SUMMARY, data = self.summary)) # put summary to stats recorder
self.summary = [] # reset summary

def _sample_data(self,*args, **kwargs):
Expand Down Expand Up @@ -77,24 +70,15 @@ class BaseVecInteractor:
def __init__(self, cfg: MergedConfig, policy = None, *args, **kwargs) -> None:
self.cfg = cfg
self.n_envs = cfg.n_workers
self.reset_interact_outputs()
def reset_interact_outputs(self):
self.interact_outputs = []

class DummyVecInteractor(BaseVecInteractor):
def __init__(self, cfg: MergedConfig, policy = None, *args, **kwargs) -> None:
super().__init__(cfg, policy = policy, *args, **kwargs)
self.interactors = [BaseInteractor(cfg, id = i, policy = policy, *args, **kwargs) for i in range(self.n_envs)]
def __init__(self, cfg: MergedConfig, env = None, policy = None, *args, **kwargs) -> None:
super().__init__(cfg, env = env, policy = policy, *args, **kwargs)
self.interactors = [BaseInteractor(cfg, id = i, env = copy.deepcopy(env), policy = copy.deepcopy(policy), *args, **kwargs) for i in range(self.n_envs)]

def run(self, *args, **kwargs):
model_mgr = kwargs['model_mgr']
model_params = model_mgr.pub_msg(Msg(type = MsgType.MODEL_MGR_GET_MODEL_PARAMS)) # get model params
for i in range(self.n_envs):
self.interactors[i].run(model_params, *args, **kwargs)

def close_envs(self):
for i in range(self.n_envs):
self.interactors[i].close_env()
self.interactors[i].run(*args, **kwargs)

class RayVecInteractor(BaseVecInteractor):
def __init__(self, cfg) -> None:
Expand Down
3 changes: 2 additions & 1 deletion framework/learner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ray
import copy
from queue import Queue
from typing import Tuple
from framework.message import Msg, MsgType
Expand All @@ -7,7 +8,7 @@ class BaseLearner:
def __init__(self, cfg, id = 0, policy = None, *args, **kwargs) -> None:
self.cfg = cfg
self.id = id
self.policy = policy
self.policy = copy.deepcopy(policy)
self.collector = kwargs['collector']
self.dataserver = kwargs['dataserver']
self.updated_model_params_queue = Queue(maxsize = 128)
Expand Down
6 changes: 4 additions & 2 deletions framework/message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import Enum
from enum import Enum, unique
from typing import Optional, Any
from dataclasses import dataclass

@unique
class MsgType(Enum):
# dataserver
DATASERVER_GET_EPISODE = 0
Expand All @@ -24,11 +25,12 @@ class MsgType(Enum):
COLLECTOR_GET_BUFFER_LENGTH = 32

# recorder
STATS_RECORDER_PUT_INTERACT_SUMMARY = 40
RECORDER_PUT_INTERACT_SUMMARY = 40

# model_mgr
MODEL_MGR_PUT_MODEL_PARAMS = 70
MODEL_MGR_GET_MODEL_PARAMS = 71

@dataclass
class Msg(object):
type: MsgType
Expand Down
2 changes: 1 addition & 1 deletion framework/policy_mgr.py → framework/model_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, cfg, model_params, **kwargs) -> None:
self._saved_policy_bundles: Dict[int, int] = {}
self._saved_policy_queue = Queue(maxsize = 128)
self._thread_save_policy = threading.Thread(target=self._save_policy)
# self._thread_save_policy.setDaemon(True)
self._thread_save_policy.setDaemon(True)
self.start()

def pub_msg(self, msg: Msg):
Expand Down
3 changes: 2 additions & 1 deletion framework/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ def pub_msg(self, msg: Msg):
''' publish message
'''
msg_type, msg_data = msg.type, msg.data
if msg_type == MsgType.STATS_RECORDER_PUT_INTERACT_SUMMARY:
if msg_type == MsgType.RECORDER_PUT_INTERACT_SUMMARY:
interact_summary_list = msg_data
self._add_summary(interact_summary_list, writter_type = 'interact')
else:
raise NotImplementedError

def _init_writter(self):
self.writters = {}
self.writter_types = ['interact','policy']
Expand Down
93 changes: 60 additions & 33 deletions framework/tester.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import ray
import torch
import time
import copy
import os
import threading
class BaseTester:
''' Base class for online tester
'''
def __init__(self, cfg, env = None) -> None:
def __init__(self, cfg, env = None, policy = None, *args, **kwargs) -> None:
self.cfg = cfg
self.env = env
self.env = copy.deepcopy(env)
self.policy = copy.deepcopy(policy)
self.logger = kwargs['logger']
self.best_eval_reward = -float('inf')

def run(self, policy, *args, **kwargs):
''' Run online tester
'''
Expand All @@ -15,38 +23,57 @@ def run(self, policy, *args, **kwargs):
class SimpleTester(BaseTester):
''' Simple online tester
'''
def __init__(self, cfg, env = None) -> None:
super().__init__(cfg, env)
def eval(self, policy, global_update_step = 0, logger = None):
sum_eval_reward = 0
for _ in range(self.cfg.online_eval_episode):
state, info = self.env.reset(seed = self.cfg.seed)
ep_reward, ep_step = 0, 0 # reward per episode, step per episode
while True:
action = policy.get_action(state, mode = 'predict')
next_state, reward, terminated, truncated, info = self.env.step(action)
state = next_state
ep_reward += reward
ep_step += 1
if terminated or (0<= self.cfg.max_step <= ep_step):
sum_eval_reward += ep_reward
break
mean_eval_reward = sum_eval_reward / self.cfg.online_eval_episode
logger.info(f"update_step: {global_update_step}, online_eval_reward: {mean_eval_reward:.3f}")
if mean_eval_reward >= self.best_eval_reward:
logger.info(f"current update step obtain a better online_eval_reward: {mean_eval_reward:.3f}, save the best model!")
policy.save_model(f"{self.cfg.model_dir}/best")
self.best_eval_reward = mean_eval_reward
summary_data = [(global_update_step,{"online_eval_reward": mean_eval_reward})]
output = {"summary":summary_data}
return output
def run(self, policy, *args, **kwargs):
''' Run online tester
def __init__(self, cfg, env = None, policy = None, *args, **kwargs) -> None:
super().__init__(cfg, env, policy, *args, **kwargs)
self.curr_test_step = -1
self._thread_eval_policy = threading.Thread(target=self._eval_policy)
self._thread_eval_policy.setDaemon(True)
self.start()

def _check_updated_model(self):

model_step_list = os.listdir(self.cfg.model_dir)
model_step_list = [int(model_step) for model_step in model_step_list if model_step.isdigit()]
model_step_list.sort()
if len(model_step_list) == 0:
return False, -1
elif model_step_list[-1] == self.curr_test_step:
return False, -1
elif model_step_list[-1] > self.curr_test_step:
return True, model_step_list[-1]

def start(self):
self._thread_eval_policy.start()

def _eval_policy(self):
''' Evaluate policy
'''
dataserver, logger = kwargs['dataserver'], kwargs['logger']
global_update_step = dataserver.get_update_step() # get global update step
if global_update_step % self.cfg.model_save_fre == 0 and self.cfg.online_eval == True:
return self.eval(policy, global_update_step = global_update_step, logger = logger)
while True:
updated, model_step = self._check_updated_model()
if updated:
self.curr_test_step = model_step
model_params = torch.load(f"{self.cfg.model_dir}/{self.curr_test_step}")
self.policy.put_model_params(model_params)
sum_eval_reward = 0
for _ in range(self.cfg.online_eval_episode):
state, info = self.env.reset()
ep_reward, ep_step = 0, 0
while True:
action = self.policy.get_action(state, mode = 'predict')
next_state, reward, terminated, truncated, info = self.env.step(action)
state = next_state
ep_reward += reward
ep_step += 1
if terminated or truncated or (0<= self.cfg.max_step <= ep_step):
sum_eval_reward += ep_reward
break
mean_eval_reward = sum_eval_reward / self.cfg.online_eval_episode
self.logger.info(f"test_step: {self.curr_test_step}, online_eval_reward: {mean_eval_reward:.3f}")
if mean_eval_reward >= self.best_eval_reward:
self.logger.info(f"current test step obtain a better online_eval_reward: {mean_eval_reward:.3f}, save the best model!")
torch.save(model_params, f"{self.cfg.model_dir}/best")
self.best_eval_reward = mean_eval_reward
time.sleep(1)

@ray.remote
class RayTester(BaseTester):
Expand Down
25 changes: 15 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# parent_path = os.path.dirname(curr_path) # parent path
# sys.path.append(parent_path) # add path to system path
import sys,os
import copy
import argparse,datetime,importlib,yaml,time
import gymnasium as gym
import torch.multiprocessing as mp
Expand Down Expand Up @@ -123,7 +122,7 @@ def config_dir(dir,name = None):
for k,v in dirs_dic.items():
config_dir(v,name=k)

def create_single_env(self):
def env_config(self):
''' create single env
'''
env_cfg_dic = self.env_cfg.__dict__
Expand Down Expand Up @@ -180,24 +179,30 @@ def check_sample_length(self,cfg):
setattr(self.cfg, 'n_sample_episodes', n_sample_episodes)

def run(self) -> None:
test_env = self.create_single_env() # create single env
env = self.env_config() # create single env
policy, data_handler = self.policy_config(self.cfg) # configure policy and data_handler
dataserver = SimpleDataServer(self.cfg)
logger = SimpleLogger(self.cfg.log_dir)
collector = SimpleCollector(self.cfg, data_handler = data_handler)
vec_interactor = DummyVecInteractor(self.cfg,
policy = copy.deepcopy(policy),
env = env,
policy = policy,
)
learner = SimpleLearner(self.cfg,
policy = copy.deepcopy(policy),
policy = policy,
dataserver = dataserver,
collector = collector
)
online_tester = SimpleTester(self.cfg, test_env) # create online tester
model_mgr = ModelMgr(self.cfg, policy.get_model_params(),
dataserver = dataserver,
logger = logger
)
online_tester = SimpleTester(self.cfg,
env = env,
policy = policy,
logger = logger
) # create online tester
model_mgr = ModelMgr(self.cfg,
model_params = policy.get_model_params(),
dataserver = dataserver,
logger = logger
)
stats_recorder = SimpleStatsRecorder(self.cfg) # create stats recorder
self.print_cfgs(logger = logger) # print config
trainer = SimpleTrainer(self.cfg,
Expand Down

0 comments on commit 8b29b05

Please sign in to comment.