From 8b29b05ed95148cb8418a4c2c4540b69e36dd703 Mon Sep 17 00:00:00 2001 From: johnjim0816 Date: Thu, 30 Nov 2023 21:39:40 +0800 Subject: [PATCH] update framwork --- algos/base/policies.py | 10 ++- framework/dataserver.py | 3 +- framework/interactor.py | 38 +++------ framework/learner.py | 3 +- framework/message.py | 6 +- framework/{policy_mgr.py => model_mgr.py} | 2 +- framework/recorder.py | 3 +- framework/tester.py | 93 +++++++++++++++-------- main.py | 25 +++--- 9 files changed, 105 insertions(+), 78 deletions(-) rename framework/{policy_mgr.py => model_mgr.py} (97%) diff --git a/algos/base/policies.py b/algos/base/policies.py index fcf4770..068696f 100644 --- a/algos/base/policies.py +++ b/algos/base/policies.py @@ -38,11 +38,17 @@ def get_state_action_size(self): return self.state_size, self.action_size def create_optimizer(self): self.optimizer = optim.Adam(self.parameters(), lr=self.cfg.lr) + def get_model_params(self): - model_params = self.state_dict() - return model_params + ''' get model params + ''' + return self.state_dict() + def put_model_params(self, model_params): + ''' put model params + ''' self.load_state_dict(model_params) + def get_optimizer_params(self): return self.optimizer.state_dict() def set_optimizer_params(self, optim_params_dict): diff --git a/framework/dataserver.py b/framework/dataserver.py index 126d2e1..d54ac4c 100644 --- a/framework/dataserver.py +++ b/framework/dataserver.py @@ -36,8 +36,9 @@ def pub_msg(self, msg: Msg): elif msg_type == MsgType.DATASERVER_INCREASE_UPDATE_STEP: update_step_delta = 1 if msg_data is None else msg_data self._increase_update_step(i = update_step_delta) + elif msg_type == MsgType.DATASERVER_CHECK_TASK_END: - self._check_task_end() + return self._check_task_end() else: raise NotImplementedError diff --git a/framework/interactor.py b/framework/interactor.py index b3110c5..75265bd 100644 --- a/framework/interactor.py +++ b/framework/interactor.py @@ -1,4 +1,5 @@ import gymnasium as gym +import copy from typing import Tuple from algos.base.exps import Exp from framework.message import Msg, MsgType @@ -7,36 +8,28 @@ class BaseInteractor: ''' Interactor for gym env to support sample n-steps or n-episodes traning data ''' - def __init__(self, cfg: MergedConfig, id = 0, policy = None, *args, **kwargs) -> None: + def __init__(self, cfg: MergedConfig, id = 0, env = None, policy = None, *args, **kwargs) -> None: self.cfg = cfg self.id = id self.policy = policy - self.env = gym.make(self.cfg.env_cfg.id) + self.env = env self.seed = self.cfg.seed + self.id - self.data = None self.exps = [] # reset experiences self.summary = [] # reset summary self.ep_reward, self.ep_step = 0, 0 # reset params per episode self.curr_obs, self.curr_info = self.env.reset(seed = self.seed) # reset env - def pub_msg(self, msg: Msg): - msg_type, msg_data = msg.type, msg.data - if msg_type == MsgType.INTERACTOR_SAMPLE: - model_params = msg_data - self._put_model_params(model_params) - self._sample_data() - elif msg_type == MsgType.INTERACTOR_GET_SAMPLE_DATA: - return self._get_sample_data() - - def run(self, model_params, *args, **kwargs): + def run(self, *args, **kwargs): collector = kwargs['collector'] stats_recorder = kwargs['stats_recorder'] + model_mgr = kwargs['model_mgr'] + model_params = model_mgr.pub_msg(Msg(type = MsgType.MODEL_MGR_GET_MODEL_PARAMS)) # get model params self.policy.put_model_params(model_params) self._sample_data(*args, **kwargs) collector.pub_msg(Msg(type = MsgType.COLLECTOR_PUT_EXPS, data = self.exps)) # put exps to collector self.exps = [] # reset exps if len(self.summary) > 0: - stats_recorder.pub_msg(Msg(type = MsgType.STATS_RECORDER_PUT_INTERACT_SUMMARY, data = self.summary)) # put summary to stats recorder + stats_recorder.pub_msg(Msg(type = MsgType.RECORDER_PUT_INTERACT_SUMMARY, data = self.summary)) # put summary to stats recorder self.summary = [] # reset summary def _sample_data(self,*args, **kwargs): @@ -77,24 +70,15 @@ class BaseVecInteractor: def __init__(self, cfg: MergedConfig, policy = None, *args, **kwargs) -> None: self.cfg = cfg self.n_envs = cfg.n_workers - self.reset_interact_outputs() - def reset_interact_outputs(self): - self.interact_outputs = [] class DummyVecInteractor(BaseVecInteractor): - def __init__(self, cfg: MergedConfig, policy = None, *args, **kwargs) -> None: - super().__init__(cfg, policy = policy, *args, **kwargs) - self.interactors = [BaseInteractor(cfg, id = i, policy = policy, *args, **kwargs) for i in range(self.n_envs)] + def __init__(self, cfg: MergedConfig, env = None, policy = None, *args, **kwargs) -> None: + super().__init__(cfg, env = env, policy = policy, *args, **kwargs) + self.interactors = [BaseInteractor(cfg, id = i, env = copy.deepcopy(env), policy = copy.deepcopy(policy), *args, **kwargs) for i in range(self.n_envs)] def run(self, *args, **kwargs): - model_mgr = kwargs['model_mgr'] - model_params = model_mgr.pub_msg(Msg(type = MsgType.MODEL_MGR_GET_MODEL_PARAMS)) # get model params - for i in range(self.n_envs): - self.interactors[i].run(model_params, *args, **kwargs) - - def close_envs(self): for i in range(self.n_envs): - self.interactors[i].close_env() + self.interactors[i].run(*args, **kwargs) class RayVecInteractor(BaseVecInteractor): def __init__(self, cfg) -> None: diff --git a/framework/learner.py b/framework/learner.py index a852fc4..a0998cd 100644 --- a/framework/learner.py +++ b/framework/learner.py @@ -1,4 +1,5 @@ import ray +import copy from queue import Queue from typing import Tuple from framework.message import Msg, MsgType @@ -7,7 +8,7 @@ class BaseLearner: def __init__(self, cfg, id = 0, policy = None, *args, **kwargs) -> None: self.cfg = cfg self.id = id - self.policy = policy + self.policy = copy.deepcopy(policy) self.collector = kwargs['collector'] self.dataserver = kwargs['dataserver'] self.updated_model_params_queue = Queue(maxsize = 128) diff --git a/framework/message.py b/framework/message.py index fafa750..fce0fde 100644 --- a/framework/message.py +++ b/framework/message.py @@ -1,7 +1,8 @@ -from enum import Enum +from enum import Enum, unique from typing import Optional, Any from dataclasses import dataclass +@unique class MsgType(Enum): # dataserver DATASERVER_GET_EPISODE = 0 @@ -24,11 +25,12 @@ class MsgType(Enum): COLLECTOR_GET_BUFFER_LENGTH = 32 # recorder - STATS_RECORDER_PUT_INTERACT_SUMMARY = 40 + RECORDER_PUT_INTERACT_SUMMARY = 40 # model_mgr MODEL_MGR_PUT_MODEL_PARAMS = 70 MODEL_MGR_GET_MODEL_PARAMS = 71 + @dataclass class Msg(object): type: MsgType diff --git a/framework/policy_mgr.py b/framework/model_mgr.py similarity index 97% rename from framework/policy_mgr.py rename to framework/model_mgr.py index f0c218d..af9df5e 100644 --- a/framework/policy_mgr.py +++ b/framework/model_mgr.py @@ -15,7 +15,7 @@ def __init__(self, cfg, model_params, **kwargs) -> None: self._saved_policy_bundles: Dict[int, int] = {} self._saved_policy_queue = Queue(maxsize = 128) self._thread_save_policy = threading.Thread(target=self._save_policy) - # self._thread_save_policy.setDaemon(True) + self._thread_save_policy.setDaemon(True) self.start() def pub_msg(self, msg: Msg): diff --git a/framework/recorder.py b/framework/recorder.py index eac075c..4a47ba7 100644 --- a/framework/recorder.py +++ b/framework/recorder.py @@ -24,11 +24,12 @@ def pub_msg(self, msg: Msg): ''' publish message ''' msg_type, msg_data = msg.type, msg.data - if msg_type == MsgType.STATS_RECORDER_PUT_INTERACT_SUMMARY: + if msg_type == MsgType.RECORDER_PUT_INTERACT_SUMMARY: interact_summary_list = msg_data self._add_summary(interact_summary_list, writter_type = 'interact') else: raise NotImplementedError + def _init_writter(self): self.writters = {} self.writter_types = ['interact','policy'] diff --git a/framework/tester.py b/framework/tester.py index 3b9b74f..7ef605b 100644 --- a/framework/tester.py +++ b/framework/tester.py @@ -1,11 +1,19 @@ import ray +import torch +import time +import copy +import os +import threading class BaseTester: ''' Base class for online tester ''' - def __init__(self, cfg, env = None) -> None: + def __init__(self, cfg, env = None, policy = None, *args, **kwargs) -> None: self.cfg = cfg - self.env = env + self.env = copy.deepcopy(env) + self.policy = copy.deepcopy(policy) + self.logger = kwargs['logger'] self.best_eval_reward = -float('inf') + def run(self, policy, *args, **kwargs): ''' Run online tester ''' @@ -15,38 +23,57 @@ def run(self, policy, *args, **kwargs): class SimpleTester(BaseTester): ''' Simple online tester ''' - def __init__(self, cfg, env = None) -> None: - super().__init__(cfg, env) - def eval(self, policy, global_update_step = 0, logger = None): - sum_eval_reward = 0 - for _ in range(self.cfg.online_eval_episode): - state, info = self.env.reset(seed = self.cfg.seed) - ep_reward, ep_step = 0, 0 # reward per episode, step per episode - while True: - action = policy.get_action(state, mode = 'predict') - next_state, reward, terminated, truncated, info = self.env.step(action) - state = next_state - ep_reward += reward - ep_step += 1 - if terminated or (0<= self.cfg.max_step <= ep_step): - sum_eval_reward += ep_reward - break - mean_eval_reward = sum_eval_reward / self.cfg.online_eval_episode - logger.info(f"update_step: {global_update_step}, online_eval_reward: {mean_eval_reward:.3f}") - if mean_eval_reward >= self.best_eval_reward: - logger.info(f"current update step obtain a better online_eval_reward: {mean_eval_reward:.3f}, save the best model!") - policy.save_model(f"{self.cfg.model_dir}/best") - self.best_eval_reward = mean_eval_reward - summary_data = [(global_update_step,{"online_eval_reward": mean_eval_reward})] - output = {"summary":summary_data} - return output - def run(self, policy, *args, **kwargs): - ''' Run online tester + def __init__(self, cfg, env = None, policy = None, *args, **kwargs) -> None: + super().__init__(cfg, env, policy, *args, **kwargs) + self.curr_test_step = -1 + self._thread_eval_policy = threading.Thread(target=self._eval_policy) + self._thread_eval_policy.setDaemon(True) + self.start() + + def _check_updated_model(self): + + model_step_list = os.listdir(self.cfg.model_dir) + model_step_list = [int(model_step) for model_step in model_step_list if model_step.isdigit()] + model_step_list.sort() + if len(model_step_list) == 0: + return False, -1 + elif model_step_list[-1] == self.curr_test_step: + return False, -1 + elif model_step_list[-1] > self.curr_test_step: + return True, model_step_list[-1] + + def start(self): + self._thread_eval_policy.start() + + def _eval_policy(self): + ''' Evaluate policy ''' - dataserver, logger = kwargs['dataserver'], kwargs['logger'] - global_update_step = dataserver.get_update_step() # get global update step - if global_update_step % self.cfg.model_save_fre == 0 and self.cfg.online_eval == True: - return self.eval(policy, global_update_step = global_update_step, logger = logger) + while True: + updated, model_step = self._check_updated_model() + if updated: + self.curr_test_step = model_step + model_params = torch.load(f"{self.cfg.model_dir}/{self.curr_test_step}") + self.policy.put_model_params(model_params) + sum_eval_reward = 0 + for _ in range(self.cfg.online_eval_episode): + state, info = self.env.reset() + ep_reward, ep_step = 0, 0 + while True: + action = self.policy.get_action(state, mode = 'predict') + next_state, reward, terminated, truncated, info = self.env.step(action) + state = next_state + ep_reward += reward + ep_step += 1 + if terminated or truncated or (0<= self.cfg.max_step <= ep_step): + sum_eval_reward += ep_reward + break + mean_eval_reward = sum_eval_reward / self.cfg.online_eval_episode + self.logger.info(f"test_step: {self.curr_test_step}, online_eval_reward: {mean_eval_reward:.3f}") + if mean_eval_reward >= self.best_eval_reward: + self.logger.info(f"current test step obtain a better online_eval_reward: {mean_eval_reward:.3f}, save the best model!") + torch.save(model_params, f"{self.cfg.model_dir}/best") + self.best_eval_reward = mean_eval_reward + time.sleep(1) @ray.remote class RayTester(BaseTester): diff --git a/main.py b/main.py index db1bfa0..b971099 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,6 @@ # parent_path = os.path.dirname(curr_path) # parent path # sys.path.append(parent_path) # add path to system path import sys,os -import copy import argparse,datetime,importlib,yaml,time import gymnasium as gym import torch.multiprocessing as mp @@ -123,7 +122,7 @@ def config_dir(dir,name = None): for k,v in dirs_dic.items(): config_dir(v,name=k) - def create_single_env(self): + def env_config(self): ''' create single env ''' env_cfg_dic = self.env_cfg.__dict__ @@ -180,24 +179,30 @@ def check_sample_length(self,cfg): setattr(self.cfg, 'n_sample_episodes', n_sample_episodes) def run(self) -> None: - test_env = self.create_single_env() # create single env + env = self.env_config() # create single env policy, data_handler = self.policy_config(self.cfg) # configure policy and data_handler dataserver = SimpleDataServer(self.cfg) logger = SimpleLogger(self.cfg.log_dir) collector = SimpleCollector(self.cfg, data_handler = data_handler) vec_interactor = DummyVecInteractor(self.cfg, - policy = copy.deepcopy(policy), + env = env, + policy = policy, ) learner = SimpleLearner(self.cfg, - policy = copy.deepcopy(policy), + policy = policy, dataserver = dataserver, collector = collector ) - online_tester = SimpleTester(self.cfg, test_env) # create online tester - model_mgr = ModelMgr(self.cfg, policy.get_model_params(), - dataserver = dataserver, - logger = logger - ) + online_tester = SimpleTester(self.cfg, + env = env, + policy = policy, + logger = logger + ) # create online tester + model_mgr = ModelMgr(self.cfg, + model_params = policy.get_model_params(), + dataserver = dataserver, + logger = logger + ) stats_recorder = SimpleStatsRecorder(self.cfg) # create stats recorder self.print_cfgs(logger = logger) # print config trainer = SimpleTrainer(self.cfg,