From e5a784c9871a238dc69f5ac3ad0983aa235026c4 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Fri, 24 Nov 2023 21:53:51 +0800 Subject: [PATCH 1/2] Add envpool new pipeline --- ding/entry/utils.py | 8 +- ding/envs/env_manager/envpool_env_manager.py | 231 +++++++++++++++++- .../tests/test_envpool_env_manager.py | 11 +- ding/example/dqn_nstep_envpool.py | 119 +++++++++ ding/framework/context.py | 7 + ding/framework/middleware/__init__.py | 2 +- ding/framework/middleware/collector.py | 75 +++++- .../middleware/functional/__init__.py | 2 +- .../middleware/functional/collector.py | 148 +++++++++++ .../middleware/functional/data_processor.py | 12 +- .../middleware/functional/evaluator.py | 10 +- .../framework/middleware/functional/logger.py | 13 + ding/framework/middleware/functional/timer.py | 1 + ding/framework/middleware/learner.py | 7 + .../middleware/tests/test_distributer.py | 3 - ding/model/common/utils.py | 1 + ding/policy/common_utils.py | 2 +- ding/policy/dqn.py | 14 +- ding/utils/default_helper.py | 4 +- .../collector/sample_serial_collector.py | 11 +- .../serial/pong/pong_dqn_envpool_config.py | 17 +- .../spaceinvaders_dqn_envpool_config.py | 63 +++++ 22 files changed, 716 insertions(+), 45 deletions(-) create mode 100644 ding/example/dqn_nstep_envpool.py create mode 100644 dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_envpool_config.py diff --git a/ding/entry/utils.py b/ding/entry/utils.py index bbfbaa83bd..a3b66bfe70 100644 --- a/ding/entry/utils.py +++ b/ding/entry/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Callable, List, Any +from typing import Optional, Callable, List, Any, Dict from ding.policy import PolicyFactory from ding.worker import IMetric, MetricSerialEvaluator @@ -46,7 +46,8 @@ def random_collect( collector_env: 'BaseEnvManager', # noqa commander: 'BaseSerialCommander', # noqa replay_buffer: 'IBuffer', # noqa - postprocess_data_fn: Optional[Callable] = None + postprocess_data_fn: Optional[Callable] = None, + collect_kwargs: Optional[Dict] = None, ) -> None: # noqa assert policy_cfg.random_collect_size > 0 if policy_cfg.get('transition_with_policy_data', False): @@ -55,7 +56,8 @@ def random_collect( action_space = collector_env.action_space random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space) collector.reset_policy(random_policy) - collect_kwargs = commander.step() + if collect_kwargs is None: + collect_kwargs = commander.step() if policy_cfg.collect.collector.type == 'episode': new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs) else: diff --git a/ding/envs/env_manager/envpool_env_manager.py b/ding/envs/env_manager/envpool_env_manager.py index a8d1a4ae03..25618bae2e 100644 --- a/ding/envs/env_manager/envpool_env_manager.py +++ b/ding/envs/env_manager/envpool_env_manager.py @@ -2,7 +2,11 @@ from easydict import EasyDict from copy import deepcopy import numpy as np +import torch +import treetensor.torch as ttorch +import treetensor.numpy as tnp from collections import namedtuple +import enum from typing import Any, Union, List, Tuple, Dict, Callable, Optional from ditk import logging try: @@ -13,21 +17,33 @@ envpool = None from ding.envs import BaseEnvTimestep +from ding.envs.env_manager import BaseEnvManagerV2 from ding.utils import ENV_MANAGER_REGISTRY, deep_merge_dicts from ding.torch_utils import to_ndarray -@ENV_MANAGER_REGISTRY.register('env_pool') +class EnvState(enum.IntEnum): + VOID = 0 + INIT = 1 + RUN = 2 + RESET = 3 + DONE = 4 + ERROR = 5 + NEED_RESET = 6 + + +@ENV_MANAGER_REGISTRY.register('envpool') class PoolEnvManager: - ''' + """ Overview: + PoolEnvManager supports old pipeline of DI-engine. Envpool now supports Atari, Classic Control, Toy Text, ViZDoom. Here we list some commonly used env_ids as follows. For more examples, you can refer to . - Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5" - Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1" - ''' + """ @classmethod def default_config(cls) -> EasyDict: @@ -39,10 +55,17 @@ def default_config(cls) -> EasyDict: # Async mode: batch_size < env_num env_num=8, batch_size=8, + image_observation=True, + episodic_life=False, + reward_clip=False, + gray_scale=True, + stack_num=4, + frame_skip=4, ) def __init__(self, cfg: EasyDict) -> None: - self._cfg = cfg + self._cfg = self.default_config() + self._cfg.update(cfg) self._env_num = cfg.env_num self._batch_size = cfg.batch_size self._ready_obs = {} @@ -55,6 +78,7 @@ def launch(self) -> None: seed = 0 else: seed = self._seed + self._envs = envpool.make( task_id=self._cfg.env_id, env_type="gym", @@ -65,8 +89,10 @@ def launch(self) -> None: reward_clip=self._cfg.reward_clip, stack_num=self._cfg.stack_num, gray_scale=self._cfg.gray_scale, - frame_skip=self._cfg.frame_skip + frame_skip=self._cfg.frame_skip, ) + self._action_space = self._envs.action_space + self._observation_space = self._envs.observation_space self._closed = False self.reset() @@ -77,6 +103,8 @@ def reset(self) -> None: obs, _, _, info = self._envs.recv() env_id = info['env_id'] obs = obs.astype(np.float32) + if self._cfg.image_observation: + obs /= 255.0 self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs) if len(self._ready_obs) == self._env_num: break @@ -91,6 +119,8 @@ def step(self, action: dict) -> Dict[int, namedtuple]: obs, rew, done, info = self._envs.recv() obs = obs.astype(np.float32) + if self._cfg.image_observation: + obs /= 255.0 rew = rew.astype(np.float32) env_id = info['env_id'] timesteps = {} @@ -117,6 +147,10 @@ def seed(self, seed: int, dynamic_seed=False) -> None: self._seed = seed logging.warning("envpool doesn't support dynamic_seed in different episode") + @property + def closed(self) -> None: + return self._closed + @property def env_num(self) -> int: return self._env_num @@ -124,3 +158,190 @@ def env_num(self) -> int: @property def ready_obs(self) -> Dict[int, Any]: return self._ready_obs + + @property + def observation_space(self) -> 'gym.spaces.Space': # noqa + try: + return self._observation_space + except AttributeError: + self.launch() + self.close() + return self._observation_space + + @property + def action_space(self) -> 'gym.spaces.Space': # noqa + try: + return self._action_space + except AttributeError: + self.launch() + self.close() + return self._action_space + +@ENV_MANAGER_REGISTRY.register('envpool_v2') +class PoolEnvManagerV2: + """ + Overview: + PoolEnvManagerV2 supports new pipeline of DI-engine. + Envpool now supports Atari, Classic Control, Toy Text, ViZDoom. + Here we list some commonly used env_ids as follows. + For more examples, you can refer to . + + - Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5" + - Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1" + """ + + @classmethod + def default_config(cls) -> EasyDict: + return EasyDict(deepcopy(cls.config)) + + config = dict( + type='envpool_v2', + # Sync mode: batch_size == env_num + # Async mode: batch_size < env_num + env_num=8, + batch_size=8, + image_observation=True, + episodic_life=False, + reward_clip=False, + gray_scale=True, + stack_num=4, + frame_skip=4, + ) + + def __init__(self, cfg: EasyDict) -> None: + self._cfg = self.default_config() + self._cfg.update(cfg) + self._env_num = cfg.env_num + self._batch_size = cfg.batch_size + self._ready_obs = {} + self._closed = True + self._seed = None + + def launch(self) -> None: + assert self._closed, "Please first close the env manager" + if self._seed is None: + seed = 0 + else: + seed = self._seed + + self._envs = envpool.make( + task_id=self._cfg.env_id, + env_type="gym", + num_envs=self._env_num, + batch_size=self._batch_size, + seed=seed, + episodic_life=self._cfg.episodic_life, + reward_clip=self._cfg.reward_clip, + stack_num=self._cfg.stack_num, + gray_scale=self._cfg.gray_scale, + frame_skip=self._cfg.frame_skip, + ) + self._action_space = self._envs.action_space + self._observation_space = self._envs.observation_space + self._closed = False + self.reset() + + def reset(self) -> None: + self._ready_obs = {} + self._envs.async_reset() + while True: + obs, _, _, info = self._envs.recv() + env_id = info['env_id'] + obs = obs.astype(np.float32) + if self._cfg.image_observation: + obs /= 255.0 + self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs) + if len(self._ready_obs) == self._env_num: + break + self._eval_episode_return = [0. for _ in range(self._env_num)] + + def step(self, action: tnp.array) -> Dict[int, namedtuple]: + env_id = np.array(self.ready_obs_id) + action = np.array(action) + if len(action.shape) == 2: + action = action.squeeze(1) + self._envs.send(action, env_id) + + obs, rew, done, info = self._envs.recv() + obs = obs.astype(np.float32) + if self._cfg.image_observation: + obs /= 255.0 + rew = rew.astype(np.float32) + env_id = info['env_id'] + new_data = [] + + self._ready_obs = {} + for i in range(len(env_id)): + d = bool(done[i]) + r = to_ndarray([rew[i]]) + self._eval_episode_return[env_id[i]] += r + + if d: + new_data.append( + tnp.array( + { + 'obs': obs[i], + 'reward': r, + 'done': d, + 'info': { + 'env_id': i, + 'eval_episode_return': self._eval_episode_return[env_id[i]] + }, + 'env_id': i + } + ) + ) + self._eval_episode_return[env_id[i]] = 0. + else: + new_data.append(tnp.array({'obs': obs[i], 'reward': r, 'done': d, 'info': {'env_id': i}, 'env_id': i})) + + self._ready_obs[env_id[i]] = obs[i] + + return new_data + + @property + def ready_obs_id(self) -> List[int]: + # In BaseEnvManager, if env_episode_count equals episode_num, this env is done. + return list(self._ready_obs.keys()) + + @property + def ready_obs(self) -> tnp.array: + obs = list(self._ready_obs.values()) + return tnp.stack(obs) + + def close(self) -> None: + if self._closed: + return + # Envpool has no `close` API + self._closed = True + + def seed(self, seed: int, dynamic_seed=False) -> None: + # The i-th environment seed in Envpool will be set with i+seed, so we don't do extra transformation here + self._seed = seed + logging.warning("envpool doesn't support dynamic_seed in different episode") + + @property + def closed(self) -> None: + return self._closed + + @property + def env_num(self) -> int: + return self._env_num + + @property + def observation_space(self) -> 'gym.spaces.Space': # noqa + try: + return self._observation_space + except AttributeError: + self.launch() + self.close() + return self._observation_space + + @property + def action_space(self) -> 'gym.spaces.Space': # noqa + try: + return self._action_space + except AttributeError: + self.launch() + self.close() + return self._action_space diff --git a/ding/envs/env_manager/tests/test_envpool_env_manager.py b/ding/envs/env_manager/tests/test_envpool_env_manager.py index 9ac7730773..9582b3cfac 100644 --- a/ding/envs/env_manager/tests/test_envpool_env_manager.py +++ b/ding/envs/env_manager/tests/test_envpool_env_manager.py @@ -3,7 +3,7 @@ import numpy as np from easydict import EasyDict -from ..envpool_env_manager import PoolEnvManager +from ding.envs.env_manager.envpool_env_manager import PoolEnvManager env_num_args = [[16, 8], [8, 8]] @@ -30,17 +30,10 @@ def test_naive(self, env_num, batch_size): env_manager = PoolEnvManager(env_manager_cfg) assert env_manager._closed env_manager.launch() - # Test step - start_time = time.time() - for count in range(20): + for count in range(5): env_id = env_manager.ready_obs.keys() action = {i: np.random.randint(4) for i in env_id} timestep = env_manager.step(action) assert len(timestep) == env_manager_cfg.batch_size - print('Count {}'.format(count)) - print([v.info for v in timestep.values()]) - end_time = time.time() - print('total step time: {}'.format(end_time - start_time)) - # Test close env_manager.close() assert env_manager._closed diff --git a/ding/example/dqn_nstep_envpool.py b/ding/example/dqn_nstep_envpool.py new file mode 100644 index 0000000000..c5eb5faef8 --- /dev/null +++ b/ding/example/dqn_nstep_envpool.py @@ -0,0 +1,119 @@ +import datetime +from easydict import EasyDict +from ditk import logging +from ding.model import DQN +from ding.policy import DQNPolicy +from ding.envs.env_manager.envpool_env_manager import PoolEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import interaction_evaluator, data_pusher, \ + eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, online_logger, \ + termination_checker, wandb_online_logger, epoch_timer, StepCollectorAsync, OffPolicyLearner, nstep_reward_enhancer +from ding.utils import set_pkg_seed +from dizoo.atari.config.serial import pong_dqn_envpool_config + + +def main(cfg): + logging.getLogger().setLevel(logging.INFO) + cfg.exp_name = 'Pong-v5-DQN-envpool-standard-' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + + collector_env_cfg = EasyDict( + { + 'env_id': cfg.env.env_id, + 'env_num': cfg.env.collector_env_num, + 'batch_size': cfg.env.collector_batch_size, + # env wrappers + 'episodic_life': True, # collector: True + 'reward_clip': False, # collector: True + 'gray_scale': cfg.env.get('gray_scale', True), + 'stack_num': cfg.env.get('stack_num', 4), + } + ) + cfg.env["collector_env_cfg"] = collector_env_cfg + evaluator_env_cfg = EasyDict( + { + 'env_id': cfg.env.env_id, + 'env_num': cfg.env.evaluator_env_num, + 'batch_size': cfg.env.evaluator_batch_size, + # env wrappers + 'episodic_life': False, # evaluator: False + 'reward_clip': False, # evaluator: False + 'gray_scale': cfg.env.get('gray_scale', True), + 'stack_num': cfg.env.get('stack_num', 4), + } + ) + cfg.env["evaluator_env_cfg"] = evaluator_env_cfg + cfg = compile_config(cfg, PoolEnvManagerV2, DQNPolicy, save_cfg=task.router.node_id == 0) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_env = PoolEnvManagerV2(cfg.env.collector_env_cfg) + evaluator_env = PoolEnvManagerV2(cfg.env.evaluator_env_cfg) + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = DQN(**cfg.policy.model) + buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + policy = DQNPolicy(cfg.policy, model=model) + + # Consider the case with multiple processes + if task.router.is_active: + # You can use labels to distinguish between workers with different roles, + # here we use node_id to distinguish. + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.EVALUATOR) + else: + task.add_role(task.role.COLLECTOR) + + # Sync their context and model between each worker. + task.use(ContextExchanger(skip_n_iter=1)) + task.use(ModelExchanger(model)) + task.use(epoch_timer()) + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(eps_greedy_handler(cfg)) + task.use( + StepCollectorAsync( + cfg, + policy.collect_mode, + collector_env, + random_collect_size=cfg.policy.random_collect_size if hasattr(cfg.policy, 'random_collect_size') else 0, + ) + ) + task.use(nstep_reward_enhancer(cfg)) + task.use(data_pusher(cfg, buffer_)) + task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_)) + task.use(online_logger(train_show_freq=10)) + task.use( + wandb_online_logger( + metric_list=policy._monitor_vars_learn(), + model=policy._model, + exp_config=cfg, + anonymous=True, + project_name=cfg.exp_name, + wandb_sweep=False, + ) + ) + #task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000)) + task.use(termination_checker(max_env_step=10000000)) + task.run() + + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--seed", type=int, default=0, help="random seed") + parser.add_argument("--collector_env_num", type=int, default=8, help="collector env number") + parser.add_argument("--collector_batch_size", type=int, default=8, help="collector batch size") + arg = parser.parse_args() + + pong_dqn_envpool_config.env.collector_env_num = arg.collector_env_num + pong_dqn_envpool_config.env.collector_batch_size = arg.collector_batch_size + pong_dqn_envpool_config.seed = arg.seed + pong_dqn_envpool_config.policy.random_collect_size = 256 + + main(pong_dqn_envpool_config) diff --git a/ding/framework/context.py b/ding/framework/context.py index 6fb35eec13..95f82b1649 100644 --- a/ding/framework/context.py +++ b/ding/framework/context.py @@ -70,11 +70,18 @@ class OnlineRLContext(Context): # wandb wandb_url: str = "" + # timer + total_time: float = 0.0 + evaluator_time: float = 0.0 + collector_time: float = 0.0 + learner_time: float = 0.0 + def __post_init__(self): # This method is called just after __init__ method. Here, concretely speaking, # this method is called just after the object initialize its fields. # We use this method here to keep the fields needed for each iteration. self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter', 'last_eval_value', 'wandb_url') + self.keep('total_time', 'evaluator_time', 'collector_time', 'learner_time') @dataclasses.dataclass diff --git a/ding/framework/middleware/__init__.py b/ding/framework/middleware/__init__.py index b9e3c5005d..4c713e81ae 100644 --- a/ding/framework/middleware/__init__.py +++ b/ding/framework/middleware/__init__.py @@ -1,5 +1,5 @@ from .functional import * -from .collector import StepCollector, EpisodeCollector, PPOFStepCollector +from .collector import StepCollector, StepCollectorAsync, EpisodeCollector, PPOFStepCollector from .learner import OffPolicyLearner, HERLearner from .ckpt_handler import CkptSaver from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger diff --git a/ding/framework/middleware/collector.py b/ding/framework/middleware/collector.py index beb4894ad9..8734560f9e 100644 --- a/ding/framework/middleware/collector.py +++ b/ding/framework/middleware/collector.py @@ -5,11 +5,14 @@ from ding.policy import get_random_policy from ding.envs import BaseEnvManager from ding.framework import task -from .functional import inferencer, rolloutor, TransitionList +from .functional import inferencer, inferencer_async, rolloutor, rolloutor_async, TransitionList if TYPE_CHECKING: from ding.framework import OnlineRLContext +from ding.worker.collector.base_serial_collector import CachePool +import time + class StepCollector: """ @@ -68,6 +71,76 @@ def __call__(self, ctx: "OnlineRLContext") -> None: break +class StepCollectorAsync: + """ + Overview: + The class of the collector running by steps, including model inference and transition \ + process. Use the `__call__` method to execute the whole collection process. + """ + + def __new__(cls, *args, **kwargs): + if task.router.is_active and not task.has_role(task.role.COLLECTOR): + return task.void() + return super(StepCollectorAsync, cls).__new__(cls) + + def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None: + """ + Arguments: + - cfg (:obj:`EasyDict`): Config. + - policy (:obj:`Policy`): The policy to be collected. + - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ + its derivatives are supported. + - random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \ + typically used in initial runs. + """ + self.cfg = cfg + self.env = env + self.policy = policy + self.random_collect_size = random_collect_size + self._transitions = TransitionList(self.env.env_num) + + self._obs_pool = CachePool('obs', self.env.env_num, deepcopy=True) + self._policy_output_pool = CachePool('policy_output', self.env.env_num) + + self._inferencer = task.wrap(inferencer_async(cfg.seed, policy, env, self._obs_pool, self._policy_output_pool)) + self._rolloutor = task.wrap( + rolloutor_async(policy, env, self._transitions, self._obs_pool, self._policy_output_pool) + ) + + def __call__(self, ctx: "OnlineRLContext") -> None: + """ + Overview: + An encapsulation of inference and rollout middleware. Stop when completing \ + the target number of steps. + Input of ctx: + - env_step (:obj:`int`): The env steps which will increase during collection. + """ + + start_time = time.time() + + old = ctx.env_step + if self.random_collect_size > 0 and old < self.random_collect_size: + target_size = self.random_collect_size - old + random_policy = get_random_policy(self.cfg, self.policy, self.env) + current_inferencer = task.wrap( + inferencer_async(self.cfg.seed, random_policy, self.env, self._obs_pool, self._policy_output_pool) + ) + else: + # compatible with old config, a train sample = unroll_len step + target_size = self.cfg.policy.collect.n_sample * self.cfg.policy.collect.unroll_len + current_inferencer = self._inferencer + + while True: + current_inferencer(ctx) + self._rolloutor(ctx) + if ctx.env_step - old >= target_size: + ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories() + self._transitions.clear() + break + + ctx.collector_time += time.time() - start_time + + class PPOFStepCollector: """ Overview: diff --git a/ding/framework/middleware/functional/__init__.py b/ding/framework/middleware/functional/__init__.py index 8474f2626e..2355739a72 100644 --- a/ding/framework/middleware/functional/__init__.py +++ b/ding/framework/middleware/functional/__init__.py @@ -1,7 +1,7 @@ from .trainer import trainer, multistep_trainer from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, \ offline_data_fetcher_from_mem, sqil_data_pusher, buffer_saver -from .collector import inferencer, rolloutor, TransitionList +from .collector import inferencer, inferencer_async, rolloutor, rolloutor_async, TransitionList from .evaluator import interaction_evaluator, interaction_evaluator_ttorch from .termination_checker import termination_checker, ddp_termination_checker from .logger import online_logger, offline_logger, wandb_online_logger, wandb_offline_logger diff --git a/ding/framework/middleware/functional/collector.py b/ding/framework/middleware/functional/collector.py index d2fb4483b9..22636e1c47 100644 --- a/ding/framework/middleware/functional/collector.py +++ b/ding/framework/middleware/functional/collector.py @@ -86,6 +86,58 @@ def _inference(ctx: "OnlineRLContext"): return _inference +def inferencer_async( + seed: int, + policy: Policy, + env: BaseEnvManager, + obs_pool, + policy_output_pool, +) -> Callable: + """ + Overview: + The middleware that executes the inference process. + Arguments: + - seed (:obj:`int`): Random seed. + - policy (:obj:`Policy`): The policy to be inferred. + - env (:obj:`BaseEnvManager`): The env where the inference process is performed. \ + The env.ready_obs (:obj:`tnp.array`) will be used as model input. + """ + + env.seed(seed) + + def _inference(ctx: "OnlineRLContext"): + """ + Output of ctx: + - obs (:obj:`Union[torch.Tensor, Dict[torch.Tensor]]`): The input observations collected \ + from all collector environments. + - action: (:obj:`List[np.ndarray]`): The inferred actions listed by env_id. + - inference_output (:obj:`Dict[int, Dict]`): The dict of which the key is env_id (int), \ + and the value is inference result (Dict). + """ + + if env.closed: + env.launch() + + ready_obs = env.ready_obs + obs_pool.update(env._ready_obs) + inference_output = policy.forward(env._ready_obs, **ctx.collect_kwargs) + + # obs_pool.update(env._ready_obs) + # obs = ttorch.as_tensor(env.ready_obs) + # ctx.obs = obs + # obs = obs.to(dtype=ttorch.float32) + # # TODO mask necessary rollout + + # obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD + # inference_output = policy.forward(obs, **ctx.collect_kwargs) + + policy_output_pool.update(inference_output) + ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] # TBD + ctx.inference_output = inference_output + + return _inference + + def rolloutor( policy: Policy, env: BaseEnvManager, @@ -178,6 +230,102 @@ def _rollout(ctx: "OnlineRLContext"): return _rollout +def rolloutor_async( + policy: Policy, + env: BaseEnvManager, + transitions: TransitionList, + obs_pool, + policy_output_pool, + collect_print_freq=100, +) -> Callable: + """ + Overview: + The middleware that executes the transition process in the env. + Arguments: + - policy (:obj:`Policy`): The policy to be used during transition. + - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ + its derivatives are supported. + - transitions (:obj:`TransitionList`): The transition information which will be filled \ + in this process, including `obs`, `next_obs`, `action`, `logit`, `value`, `reward` \ + and `done`. + """ + + env_episode_id = [_ for _ in range(env.env_num)] + current_id = env.env_num + timer = EasyTimer() + last_train_iter = 0 + total_envstep_count = 0 + total_episode_count = 0 + total_train_sample_count = 0 + env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(env.env_num)} + episode_info = [] + + def _rollout(ctx: "OnlineRLContext"): + """ + Input of ctx: + - action: (:obj:`List[np.ndarray]`): The inferred actions from previous inference process. + - obs (:obj:`Dict[Tensor]`): The states fed into the transition dict. + - inference_output (:obj:`Dict[int, Dict]`): The inference results to be fed into the \ + transition dict. + - train_iter (:obj:`int`): The train iteration count to be fed into the transition dict. + - env_step (:obj:`int`): The count of env step, which will increase by 1 for a single \ + transition call. + - env_episode (:obj:`int`): The count of env episode, which will increase by 1 if the \ + trajectory stops. + """ + + nonlocal current_id, env_info, episode_info, timer, \ + total_episode_count, total_envstep_count, total_train_sample_count, last_train_iter + timesteps = env.step(ctx.action) + ctx.env_step += len(timesteps) + timesteps = [t.tensor() for t in timesteps] + + collected_sample = 0 + collected_step = 0 + collected_episode = 0 + interaction_duration = timer.value / len(timesteps) + for i, timestep in enumerate(timesteps): + with timer: + transition = policy.process_transition( + obs_pool[timestep.info.env_id.item()], policy_output_pool[timestep.info.env_id.item()], timestep + ) + transition = ttorch.as_tensor(transition) + transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter]) + transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.info.env_id.item()]]) + transitions.append(timestep.info.env_id.item(), transition) + + collected_step += 1 + collected_sample += len(transition.obs) + env_info[timestep.info.env_id.item()]['step'] += 1 + env_info[timestep.info.env_id.item()]['train_sample'] += len(transition.obs) + + env_info[timestep.info.env_id.item()]['time'] += timer.value + interaction_duration + if timestep.done: + info = { + 'reward': timestep.info['eval_episode_return'], + 'time': env_info[timestep.info.env_id.item()]['time'], + 'step': env_info[timestep.info.env_id.item()]['step'], + 'train_sample': env_info[timestep.info.env_id.item()]['train_sample'], + } + + episode_info.append(info) + policy.reset([timestep.env_id.item()]) + env_episode_id[timestep.env_id.item()] = current_id + collected_episode += 1 + current_id += 1 + ctx.env_episode += 1 + + total_envstep_count += collected_step + total_episode_count += collected_episode + total_train_sample_count += collected_sample + + if (ctx.train_iter - last_train_iter) >= collect_print_freq and len(episode_info) > 0: + output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) + last_train_iter = ctx.train_iter + + return _rollout + + def output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) -> None: """ Overview: diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index ab1f1a5544..eae264129f 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -1,4 +1,5 @@ import os +import torch.multiprocessing as mp from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional from easydict import EasyDict from ditk import logging @@ -11,6 +12,8 @@ if TYPE_CHECKING: from ding.framework import OnlineRLContext, OfflineRLContext +import time + def data_pusher(cfg: EasyDict, buffer_: Buffer, group_by_env: Optional[bool] = None): """ @@ -31,7 +34,6 @@ def _push(ctx: "OnlineRLContext"): - trajectories (:obj:`List[Dict]`): Trajectories. - episodes (:obj:`List[Dict]`): Episodes. """ - if ctx.trajectories is not None: # each data in buffer is a transition if group_by_env: for i, t in enumerate(ctx.trajectories): @@ -170,22 +172,20 @@ def _fetch(ctx: "OnlineRLContext"): index = [d.index for d in buffered_data] meta = [d.meta for d in buffered_data] # such as priority - if isinstance(ctx.train_output, List): - priority = ctx.train_output.pop()['priority'] + if isinstance(ctx.train_output_for_post_process, List): + priority = ctx.train_output_for_post_process.pop()['priority'] else: - priority = ctx.train_output['priority'] + priority = ctx.train_output_for_post_process['priority'] for idx, m, p in zip(index, meta, priority): m['priority'] = p buffer_.update(index=idx, data=None, meta=m) return _fetch - def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable: from threading import Thread from queue import Queue - import time stream = torch.cuda.Stream() def producer(queue, dataset, batch_size, device): diff --git a/ding/framework/middleware/functional/evaluator.py b/ding/framework/middleware/functional/evaluator.py index 611bbcdea6..f820dcb294 100644 --- a/ding/framework/middleware/functional/evaluator.py +++ b/ding/framework/middleware/functional/evaluator.py @@ -15,6 +15,8 @@ from ding.torch_utils import to_ndarray, get_shape0 from ding.utils import lists_to_dicts +import time + class IMetric(ABC): @@ -237,6 +239,8 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): - eval_value (:obj:`float`): The average reward in the current evaluation. """ + start_time = time.time() + # evaluation will be executed if the task begins or enough train_iter after last evaluation if ctx.last_eval_iter != -1 and \ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq): @@ -263,8 +267,8 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): env_id = timestep.env_id.item() if timestep.done: policy.reset([env_id]) - reward = timestep.info.eval_episode_return - eval_monitor.update_reward(env_id, reward) + return_ = timestep.info.eval_episode_return + eval_monitor.update_reward(env_id, return_) if 'episode_info' in timestep.info: eval_monitor.update_info(env_id, timestep.info.episode_info) episode_return = eval_monitor.get_episode_return() @@ -302,6 +306,8 @@ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]): if stop_flag: task.finish = True + ctx.evaluator_time += time.time() - start_time + return _evaluate diff --git a/ding/framework/middleware/functional/logger.py b/ding/framework/middleware/functional/logger.py index 9f62e2f429..419b3f53a4 100644 --- a/ding/framework/middleware/functional/logger.py +++ b/ding/framework/middleware/functional/logger.py @@ -302,6 +302,19 @@ def _plot(ctx: "OnlineRLContext"): "If you want to use wandb to visualize the result, please set plot_logger = True in the config." ) + if hasattr(ctx, "evaluator_time"): + info_for_logging.update({"evaluator_time": ctx.evaluator_time}) + if hasattr(ctx, "collector_time"): + info_for_logging.update({"collector_time": ctx.collector_time}) + if hasattr(ctx, "learner_time"): + info_for_logging.update({"learner_time": ctx.learner_time}) + if hasattr(ctx, "data_pusher_time"): + info_for_logging.update({"data_pusher_time": ctx.data_pusher_time}) + if hasattr(ctx, "nstep_time"): + info_for_logging.update({"nstep_time": ctx.nstep_time}) + if hasattr(ctx, "total_time"): + info_for_logging.update({"total_time": ctx.total_time}) + if ctx.eval_value != -np.inf: if hasattr(ctx, "eval_value_min"): info_for_logging.update({ diff --git a/ding/framework/middleware/functional/timer.py b/ding/framework/middleware/functional/timer.py index db8a2c0056..7c73b9b809 100644 --- a/ding/framework/middleware/functional/timer.py +++ b/ding/framework/middleware/functional/timer.py @@ -31,5 +31,6 @@ def _epoch_timer(ctx: "Context"): np.mean(records) * 1000 ) ) + ctx.total_time += time_cost return _epoch_timer diff --git a/ding/framework/middleware/learner.py b/ding/framework/middleware/learner.py index 9abf88e9b3..ea68e7d020 100644 --- a/ding/framework/middleware/learner.py +++ b/ding/framework/middleware/learner.py @@ -11,6 +11,8 @@ from ding.policy import Policy from ding.reward_model import BaseRewardModel +import time + class OffPolicyLearner: """ @@ -54,6 +56,9 @@ def __call__(self, ctx: "OnlineRLContext") -> None: Output of ctx: - train_output (:obj:`Deque`): The training output in deque. """ + + start_time = time.time() + train_output_queue = [] for _ in range(self.cfg.policy.learn.update_per_collect): self._fetcher(ctx) @@ -63,8 +68,10 @@ def __call__(self, ctx: "OnlineRLContext") -> None: self._reward_estimator(ctx) self._trainer(ctx) train_output_queue.append(ctx.train_output) + ctx.train_output_for_post_process = ctx.train_output ctx.train_output = train_output_queue + ctx.learner_time += time.time() - start_time class HERLearner: """ diff --git a/ding/framework/middleware/tests/test_distributer.py b/ding/framework/middleware/tests/test_distributer.py index 7651e66ec7..942bbc7621 100644 --- a/ding/framework/middleware/tests/test_distributer.py +++ b/ding/framework/middleware/tests/test_distributer.py @@ -246,18 +246,15 @@ def train(ctx): task.use(train) else: y_pred1 = policy.predict(X) - print("y_pred1: ", y_pred1) stale = 1 def pred(ctx): nonlocal stale y_pred2 = policy.predict(X) - print("y_pred2: ", y_pred2) stale += 1 assert stale <= 3 or all(y_pred1 == y_pred2) if any(y_pred1 != y_pred2): stale = 1 - sleep(0.3) task.use(pred) diff --git a/ding/model/common/utils.py b/ding/model/common/utils.py index f74a179962..fe30a1efe7 100644 --- a/ding/model/common/utils.py +++ b/ding/model/common/utils.py @@ -1,5 +1,6 @@ import copy import torch +import torch.nn as nn from easydict import EasyDict from ding.utils import import_module, MODEL_REGISTRY diff --git a/ding/policy/common_utils.py b/ding/policy/common_utils.py index fd2c7d3d61..e80b57127a 100644 --- a/ding/policy/common_utils.py +++ b/ding/policy/common_utils.py @@ -1,6 +1,6 @@ from typing import List, Any, Dict, Callable -import torch import numpy as np +import torch import treetensor.torch as ttorch from ding.utils.data import default_collate from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index d1f6fdbb49..f840f93352 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -199,14 +199,16 @@ def _init_learn(self) -> None: # use model_wrapper for specialized demands of different modes self._target_model = copy.deepcopy(self._model) - if 'target_update_freq' in self._cfg.learn: + if 'target_update_freq' in self._cfg.learn and self._cfg.learn.target_update_freq is not None \ + and self._cfg.learn.target_update_freq > 0: self._target_model = model_wrap( self._target_model, wrapper_name='target', update_type='assign', update_kwargs={'freq': self._cfg.learn.target_update_freq} ) - elif 'target_theta' in self._cfg.learn: + elif 'target_theta' in self._cfg.learn and self._cfg.learn.target_theta is not None \ + and self._cfg.learn.target_theta > 0.0: self._target_model = model_wrap( self._target_model, wrapper_name='target', @@ -248,6 +250,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: .. note:: For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``. """ + # Data preprocessing operations, such as stack data, cpu to cuda device data = default_preprocess_learn( data, @@ -256,6 +259,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: ignore_done=self._cfg.learn.ignore_done, use_nstep=True ) + if self._cuda: data = to_device(data, self._device) # Q-learning forward @@ -284,6 +288,7 @@ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]: # Postprocessing operations, such as updating target model, return logged values and priority. self._target_model.update(self._learn_model.state_dict()) + return { 'cur_lr': self._optimizer.defaults['lr'], 'total_loss': loss.item(), @@ -484,13 +489,15 @@ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]: output = default_decollate(output) return {i: d for i, d in zip(data_id, output)} + def monitor_vars(self) -> List[str]: + return ['cur_lr', 'total_loss', 'q_value'] + def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = False) -> Dict[str, Any]: """ Overview: Calculate priority for replay buffer. Arguments: - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training. - - update_target_model (:obj:`bool`): Whether to update target model. Returns: - priority (:obj:`Dict[str, Any]`): Dict type priority data, values are python scalar or a list of scalars. ArgumentsKeys: @@ -533,7 +540,6 @@ def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = F ) return {'priority': td_error_per_sample.abs().tolist()} - @POLICY_REGISTRY.register('dqn_stdim') class DQNSTDIMPolicy(DQNPolicy): """ diff --git a/ding/utils/default_helper.py b/ding/utils/default_helper.py index d76b6936f3..396a89e191 100644 --- a/ding/utils/default_helper.py +++ b/ding/utils/default_helper.py @@ -8,7 +8,7 @@ import treetensor.torch as ttorch -def get_shape0(data: Union[List, Dict, torch.Tensor, ttorch.Tensor]) -> int: +def get_shape0(data: Union[List, Dict, np.ndarray, torch.Tensor, ttorch.Tensor]) -> int: """ Overview: Get shape[0] of data's torch tensor or treetensor @@ -34,6 +34,8 @@ def fn(t): return fn(item) return fn(data.shape) + elif isinstance(data, np.ndarray): + return data.shape[0] else: raise TypeError("Error in getting shape0, not support type: {}".format(data)) diff --git a/ding/worker/collector/sample_serial_collector.py b/ding/worker/collector/sample_serial_collector.py index 26db458edb..07bac75ae5 100644 --- a/ding/worker/collector/sample_serial_collector.py +++ b/ding/worker/collector/sample_serial_collector.py @@ -25,7 +25,7 @@ class SampleSerialCollector(ISerialCollector): envstep """ - config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100) + config = dict(type='sample', deepcopy_obs=False, transform_obs=False, collect_print_freq=100) def __init__( self, @@ -34,7 +34,8 @@ def __init__( policy: namedtuple = None, tb_logger: 'SummaryWriter' = None, # noqa exp_name: Optional[str] = 'default_experiment', - instance_name: Optional[str] = 'collector' + instance_name: Optional[str] = 'collector', + timer_cuda: bool = False, ) -> None: """ Overview: @@ -44,6 +45,10 @@ def __init__( - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy - tb_logger (:obj:`SummaryWriter`): tensorboard handle + - exp_name (:obj:`Optional[str]`): name of the project folder of this experiment + - instance_name (:obj:`Optional[str]`): instance name, used to specify the saving path of log and model + - timer_cuda (:obj:`bool`): whether to use cuda timer, if True, the timer will measure the time of \ + the forward process on cuda, otherwise, the timer will measure the time of the forward process on cpu. """ self._exp_name = exp_name self._instance_name = instance_name @@ -51,7 +56,7 @@ def __init__( self._deepcopy_obs = cfg.deepcopy_obs # whether to deepcopy each data self._transform_obs = cfg.transform_obs self._cfg = cfg - self._timer = EasyTimer() + self._timer = EasyTimer(cuda=timer_cuda) self._end_flag = False self._rank = get_rank() self._world_size = get_world_size() diff --git a/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py b/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py index 0b80e41548..8a9f9c1721 100644 --- a/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py +++ b/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py @@ -8,14 +8,16 @@ evaluator_env_num=8, evaluator_batch_size=8, n_evaluator_episode=8, - stop_value=20, - env_id='PongNoFrameskip-v4', + stop_value=21, + env_id='Pong-v5', #'ALE/Pong-v5' is available. But special setting is needed after gym make. frame_stack=4, ), + nstep = 3, policy=dict( cuda=True, priority=False, + random_collect_size=50000, model=dict( obs_shape=[4, 84, 84], action_shape=6, @@ -24,10 +26,15 @@ nstep=3, discount_factor=0.99, learn=dict( - update_per_collect=10, + update_per_collect=2, batch_size=32, learning_rate=0.0001, - target_update_freq=500, + # If updating target network by replacement, \ + # target_update_freq should be larger than 0. \ + # If updating target network by changing several percentage of the origin weights, \ + # target_update_freq should be 0 and target_theta should be set. + target_update_freq=None, + target_theta=0.04, ), collect=dict(n_sample=96, ), eval=dict(evaluator=dict(eval_freq=4000, )), @@ -49,7 +56,7 @@ type='atari', import_names=['dizoo.atari.envs.atari_env'], ), - env_manager=dict(type='env_pool'), + env_manager=dict(type='envpool'), policy=dict(type='dqn'), replay_buffer=dict(type='deque'), ) diff --git a/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_envpool_config.py b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_envpool_config.py new file mode 100644 index 0000000000..da56810f0c --- /dev/null +++ b/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_envpool_config.py @@ -0,0 +1,63 @@ +from easydict import EasyDict + +spaceinvaders_dqn_envpool_config = dict( + exp_name='spaceinvaders_dqn_envpool_seed0', + env=dict( + collector_env_num=8, + collector_batch_size=8, + evaluator_env_num=8, + evaluator_batch_size=8, + n_evaluator_episode=8, + stop_value=10000000000, + env_id='SpaceInvaders-v5', + #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make. + frame_stack=4, + ), + policy=dict( + cuda=True, + priority=False, + random_collect_size=5000, + model=dict( + obs_shape=[4, 84, 84], + action_shape=6, + encoder_hidden_size_list=[128, 128, 512], + ), + nstep=3, + discount_factor=0.99, + learn=dict( + update_per_collect=10, + batch_size=32, + learning_rate=0.0001, + target_update_freq=500, + ), + collect=dict(n_sample=100, ), + eval=dict(evaluator=dict(eval_freq=4000, )), + other=dict( + eps=dict( + type='exp', + start=1., + end=0.05, + decay=1000000, + ), + replay_buffer=dict(replay_buffer_size=400000, ), + ), + ), +) +spaceinvaders_dqn_envpool_config = EasyDict(spaceinvaders_dqn_envpool_config) +main_config = spaceinvaders_dqn_envpool_config +spaceinvaders_dqn_envpool_create_config = dict( + env=dict( + type='atari', + import_names=['dizoo.atari.envs.atari_env'], + ), + env_manager=dict(type='env_pool'), + policy=dict(type='dqn'), + replay_buffer=dict(type='deque'), +) +spaceinvaders_dqn_envpool_create_config = EasyDict(spaceinvaders_dqn_envpool_create_config) +create_config = spaceinvaders_dqn_envpool_create_config + +if __name__ == '__main__': + # or you can enter `ding -m serial -c spaceinvaders_dqn_envpool_config.py -s 0` + from ding.entry import serial_pipeline + serial_pipeline((main_config, create_config), seed=0) From 7e4f3f3073289c82bbd54124935eb564ac633763 Mon Sep 17 00:00:00 2001 From: zjowowen Date: Thu, 21 Dec 2023 22:13:40 +0800 Subject: [PATCH 2/2] polish code --- ding/envs/env_manager/envpool_env_manager.py | 1 + ding/framework/middleware/functional/data_processor.py | 1 + ding/framework/middleware/learner.py | 1 + ding/policy/dqn.py | 1 + 4 files changed, 4 insertions(+) diff --git a/ding/envs/env_manager/envpool_env_manager.py b/ding/envs/env_manager/envpool_env_manager.py index 25618bae2e..e4728ac1d7 100644 --- a/ding/envs/env_manager/envpool_env_manager.py +++ b/ding/envs/env_manager/envpool_env_manager.py @@ -177,6 +177,7 @@ def action_space(self) -> 'gym.spaces.Space': # noqa self.close() return self._action_space + @ENV_MANAGER_REGISTRY.register('envpool_v2') class PoolEnvManagerV2: """ diff --git a/ding/framework/middleware/functional/data_processor.py b/ding/framework/middleware/functional/data_processor.py index 482dd6d9a9..7d4f675ea6 100644 --- a/ding/framework/middleware/functional/data_processor.py +++ b/ding/framework/middleware/functional/data_processor.py @@ -182,6 +182,7 @@ def _fetch(ctx: "OnlineRLContext"): return _fetch + def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable: from threading import Thread diff --git a/ding/framework/middleware/learner.py b/ding/framework/middleware/learner.py index ea68e7d020..5cdc855814 100644 --- a/ding/framework/middleware/learner.py +++ b/ding/framework/middleware/learner.py @@ -73,6 +73,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None: ctx.learner_time += time.time() - start_time + class HERLearner: """ Overview: diff --git a/ding/policy/dqn.py b/ding/policy/dqn.py index f840f93352..768de9884f 100644 --- a/ding/policy/dqn.py +++ b/ding/policy/dqn.py @@ -540,6 +540,7 @@ def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = F ) return {'priority': td_error_per_sample.abs().tolist()} + @POLICY_REGISTRY.register('dqn_stdim') class DQNSTDIMPolicy(DQNPolicy): """