Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Dec 2, 2023
1 parent 2b6a788 commit 53a74dc
Show file tree
Hide file tree
Showing 27 changed files with 478 additions and 80 deletions.
2 changes: 1 addition & 1 deletion algos/A2C/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
Expand Down
2 changes: 1 addition & 1 deletion algos/DDPG/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
2 changes: 1 addition & 1 deletion algos/DQN/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
Expand Down
5 changes: 3 additions & 2 deletions algos/DQN/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self,cfg) -> None:
self.epsilon_end = cfg.epsilon_end
self.epsilon_decay = cfg.epsilon_decay
self.target_update = cfg.target_update
self.update_step = 0
self.create_graph() # create graph and optimizer
self.create_summary() # create summary
self.to(self.device)
Expand Down Expand Up @@ -52,7 +53,6 @@ def learn(self, **kwargs):
''' learn policy
'''
states, actions, next_states, rewards, dones = kwargs.get('states'), kwargs.get('actions'), kwargs.get('next_states'), kwargs.get('rewards'), kwargs.get('dones')
update_step = kwargs.get('update_step')
# convert numpy to tensor
states = torch.tensor(states, device=self.device, dtype=torch.float32)
actions = torch.tensor(actions, device=self.device, dtype=torch.int64).unsqueeze(dim=1)
Expand All @@ -74,7 +74,8 @@ def learn(self, **kwargs):
param.grad.data.clamp_(-1, 1)
self.optimizer.step()
# update target net every C steps
if update_step % self.target_update == 0:
if self.update_step % self.target_update == 0:
self.target_net.load_state_dict(self.policy_net.state_dict())
self.update_step += 1
self.update_summary() # update summary

2 changes: 1 addition & 1 deletion algos/DoubleDQN/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
2 changes: 1 addition & 1 deletion algos/DuelingDQN/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Discription:
'''

from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
2 changes: 1 addition & 1 deletion algos/NoisyDQN/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
LastEditTime: 2023-05-18 13:31:14
Discription:
'''
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
Expand Down
2 changes: 1 addition & 1 deletion algos/PER_DQN/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Discription:
'''
import numpy as np
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
Expand Down
2 changes: 1 addition & 1 deletion algos/PPO/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
Expand Down
2 changes: 1 addition & 1 deletion algos/QLearning/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
from algos.base.exps import Exp
class DataHandler(BaseDataHandler):
def __init__(self,cfg) -> None:
Expand Down
2 changes: 1 addition & 1 deletion algos/SAC/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
Expand Down
2 changes: 1 addition & 1 deletion algos/Sarsa/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Discription:
'''
import numpy as np
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
from algos.base.exps import Exp
class DataHandler(BaseDataHandler):
def __init__(self,cfg) -> None:
Expand Down
2 changes: 1 addition & 1 deletion algos/SoftQ/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
2 changes: 1 addition & 1 deletion algos/TD3/data_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from algos.base.data_handlers import BaseDataHandler
from algos.base.data_handler import BaseDataHandler

class DataHandler(BaseDataHandler):
def __init__(self, cfg):
Expand Down
File renamed without changes.
64 changes: 63 additions & 1 deletion framework/collector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,71 @@
import ray
from ray.util.queue import Queue, Empty, Full
import threading
from framework.message import Msg, MsgType
from config.general_config import MergedConfig
from algos.base.data_handler import BaseDataHandler

@ray.remote(num_cpus=0)
class Collector:
def __init__(self, cfg: MergedConfig, data_handler: BaseDataHandler) -> None:
self.cfg = cfg
self.data_handler = data_handler
self.training_data_queue = Queue(maxsize = 128)
self._t_sample_training_data = threading.Thread(target=self._sample_training_data)

def pub_msg(self, msg: Msg):
''' publish message
'''
msg_type, msg_data = msg.type, msg.data
if msg_type == MsgType.COLLECTOR_PUT_EXPS:
exps = msg_data
self._put_exps(exps)
elif msg_type == MsgType.COLLECTOR_GET_TRAINING_DATA:
if self.training_data_queue.empty(): return None
return self.training_data_queue.get()
return self._get_training_data()
elif msg_type == MsgType.COLLECTOR_GET_BUFFER_LENGTH:
return self.get_buffer_length()
else:
raise NotImplementedError

def run(self):
''' start
'''
self._t_sample_training_data.start()

def _sample_training_data(self):
''' async run
'''
while True:
training_data = self._get_training_data()
if training_data is None: continue
while not self.training_data_queue.full():
self.training_data_queue.put(training_data)
break

def _put_exps(self, exps):
''' add exps to data handler
'''
self.data_handler.add_exps(exps) # add exps to data handler

def _get_training_data(self):
training_data = self.data_handler.sample_training_data() # sample training data
return training_data

def handle_data_after_learn(self, policy_data_after_learn, *args, **kwargs):
return

def get_buffer_length(self):
return len(self.data_handler.buffer)


class BaseCollector:
def __init__(self, cfg, data_handler = None) -> None:
self.cfg = cfg
self.n_learners = cfg.n_learners
if data_handler is None: raise NotImplementedError("data_handler must be specified!")
self.data_handler = data_handler

def pub_msg(self, msg: Msg):
''' publish message
'''
Expand Down
75 changes: 74 additions & 1 deletion framework/interactor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gymnasium as gym
import ray
import copy
from typing import Tuple
from algos.base.exps import Exp
Expand Down Expand Up @@ -62,7 +63,79 @@ def _sample_data(self,*args, **kwargs):
if run_step >= self.cfg.n_sample_steps:
run_step = 0
break

@ray.remote(num_cpus = 1)
class Interactor:
def __init__(self, cfg: MergedConfig, id = 0, env = None, policy = None, *args, **kwargs) -> None:
self.cfg = cfg
self.id = id
self.env = env
self.policy = policy
self.seed = self.cfg.seed + self.id
self.exps = []
self.seed = self.cfg.seed + self.id
self.exps = [] # reset experiences
self.summary = [] # reset summary
self.ep_reward, self.ep_step = 0, 0 # reset params per episode
self.curr_obs, self.curr_info = self.env.reset(seed = self.seed) # reset env

def run(self, *args, **kwargs):
''' run in sync mode
'''
tracker = kwargs['tracker']
collector = kwargs['collector']
recorder = kwargs['recorder']
model_mgr = kwargs['model_mgr']
logger = kwargs['logger']

def start(self, *args, **kwargs):
''' start in async mode
'''
tracker = kwargs['tracker']
collector = kwargs['collector']
recorder = kwargs['recorder']
model_mgr = kwargs['model_mgr']
logger = kwargs['logger']
while True:
model_params = ray.get(model_mgr.pub_msg.remote(Msg(type = MsgType.MODEL_MGR_GET_MODEL_PARAMS))) # get model params
self.policy.put_model_params(model_params)
action = self.policy.get_action(self.curr_obs)
obs, reward, terminated, truncated, info = self.env.step(action)
interact_transition = {'interactor_id': self.id, 'state': self.curr_obs, 'action': action,'reward': reward, 'next_state': obs, 'done': terminated or truncated, 'info': info}
policy_transition = self.policy.get_policy_transition()
# create exp
self.exps.append(Exp(**interact_transition, **policy_transition))
self.curr_obs, self.curr_info = obs, info
self.ep_reward += reward
self.ep_step += 1
if len(self.exps) >= 1 or terminated or truncated or self.ep_step >= self.cfg.max_step:
collector.pub_msg.remote(Msg(type = MsgType.COLLECTOR_PUT_EXPS, data = self.exps))
self.exps = []
if terminated or truncated or self.ep_step >= self.cfg.max_step:
global_episode = ray.get(tracker.pub_msg.remote(Msg(type = MsgType.TRACKER_GET_EPISODE)))
tracker.pub_msg.remote(Msg(MsgType.TRACKER_INCREASE_EPISODE))
if global_episode % self.cfg.interact_summary_fre == 0:
logger.info.remote(f"Interactor {self.id} finished episode {global_episode} with reward {self.ep_reward:.3f} in {self.ep_step} steps")
interact_summary = {'reward':self.ep_reward,'step':self.ep_step}
self.summary.append((global_episode, interact_summary))
recorder.pub_msg.remote(Msg(type = MsgType.RECORDER_PUT_INTERACT_SUMMARY, data = self.summary)) # put summary to stats recorder
self.ep_reward, self.ep_step = 0, 0
self.curr_obs, self.curr_info = self.env.reset(seed = self.seed)

@ray.remote(num_cpus = 0)
class InteractorMgr:
def __init__(self, cfg: MergedConfig, env = None , policy = None, *args, **kwargs) -> None:
if env is None: raise NotImplementedError("env must be specified!")
if policy is None: raise NotImplementedError("policy must be specified!")
self.cfg = cfg
self.n_envs = cfg.n_workers
self.interactors = [Interactor.remote(cfg, id = i, env = copy.deepcopy(env), policy = copy.deepcopy(policy), *args, **kwargs) for i in range(self.n_envs)]

def start(self, *args, **kwargs):
for i in range(self.n_envs):
self.interactors[i].start.remote(*args, **kwargs)




class BaseWorker:
def __init__(self, cfg: MergedConfig, policy = None, *args, **kwargs) -> None:
Expand Down
60 changes: 59 additions & 1 deletion framework/learner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import ray
import copy
import time
from queue import Queue
from typing import Tuple
from framework.message import Msg, MsgType
from config.general_config import MergedConfig

class BaseLearner:
def __init__(self, cfg, id = 0, policy = None, *args, **kwargs) -> None:
Expand Down Expand Up @@ -72,4 +74,60 @@ def run(self, *args, **kwargs):
if curr_update_step % self.cfg.policy_summary_fre == 0:
policy_summary = [(curr_update_step,self.policy.get_summary())]
recorder = kwargs['recorder']
recorder.pub_msg(Msg(type = MsgType.RECORDER_PUT_POLICY_SUMMARY, data = policy_summary))
recorder.pub_msg(Msg(type = MsgType.RECORDER_PUT_POLICY_SUMMARY, data = policy_summary))

@ray.remote
class Learner:
def __init__(self, cfg : MergedConfig, id = 0, policy = None, *args, **kwargs) -> None:
self.cfg = cfg
self.id = id
self.policy = policy

def pub_msg(self, msg: Msg):
msg_type, msg_data = msg.type, msg.data
if msg_type == MsgType.LEARNER_UPDATE_POLICY:
model_params = msg_data
self._put_model_params(model_params)
self._update_policy()
elif msg_type == MsgType.LEARNER_GET_UPDATED_MODEL_PARAMS_QUEUE:
return self._get_updated_model_params_queue()
else:
raise NotImplementedError

def run(self, *args, **kwargs):
model_mgr = kwargs['model_mgr']
collector = kwargs['collector']
tracker = kwargs['tracker']
logger = kwargs['logger']
recorder = kwargs['recorder']
while True:
training_data = ray.get(collector.pub_msg.remote(Msg(type = MsgType.COLLECTOR_GET_TRAINING_DATA)))
if training_data is None: continue
s_t = time.time()
model_params = ray.get(model_mgr.pub_msg.remote(Msg(type = MsgType.MODEL_MGR_GET_MODEL_PARAMS)))
e_t = time.time()
# logger.info.remote(f"Get model params finished in {e_t - s_t:.3f} s")
self.policy.put_model_params(model_params)
self.policy.learn(**training_data)
global_update_step = ray.get(tracker.pub_msg.remote(Msg(type = MsgType.TRACKER_GET_UPDATE_STEP)))
tracker.pub_msg.remote(Msg(type = MsgType.TRACKER_INCREASE_UPDATE_STEP))
# put updated model params to model_mgr
model_params = self.policy.get_model_params()
model_mgr.pub_msg.remote(Msg(type = MsgType.MODEL_MGR_PUT_MODEL_PARAMS, data = (global_update_step, model_params)))
# put policy summary to recorder
if global_update_step % self.cfg.policy_summary_fre == 0:
policy_summary = [(global_update_step,self.policy.get_summary())]
recorder.pub_msg.remote(Msg(type = MsgType.RECORDER_PUT_POLICY_SUMMARY, data = policy_summary))
# logger.info.remote(f"Update step {global_update_step} finished in {e_t - s_t:.3f} s")
def _get_id(self):
return self.id

@ray.remote(num_cpus = 0)
class LearnerMgr:
def __init__(self, cfg : MergedConfig, policy = None, *args, **kwargs) -> None:
if policy is None: raise NotImplementedError("[LearnerMgr] policy must be specified!")
self.cfg = cfg
self.learner = Learner.remote(cfg = cfg, policy = copy.deepcopy(policy))

def run(self, *args, **kwargs):
self.learner.run.remote(*args, **kwargs)
Loading

0 comments on commit 53a74dc

Please sign in to comment.