From e9d8194ba6669c367b52c51df0b369adc8ad493a Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 23 Aug 2023 15:18:50 +0800 Subject: [PATCH] config(nyz): add lunarlander ppo config and example --- ding/bonus/config.py | 3 +- ding/example/ppo_lunarlander.py | 45 +++++++++++++++++ .../config/lunarlander_ppo_config.py | 50 +++++++++++++++++++ 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 ding/example/ppo_lunarlander.py create mode 100644 dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py diff --git a/ding/bonus/config.py b/ding/bonus/config.py index 041a37e653..94e14750ce 100644 --- a/ding/bonus/config.py +++ b/ding/bonus/config.py @@ -11,7 +11,8 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict: if algorithm == 'PPO': cfg = PPOFPolicy.default_config() if env == 'lunarlander_discrete': - cfg.n_sample = 400 + cfg.n_sample = 512 + cfg.entropy_weight = 1e-3 elif env == 'lunarlander_continuous': cfg.action_space = 'continuous' cfg.n_sample = 400 diff --git a/ding/example/ppo_lunarlander.py b/ding/example/ppo_lunarlander.py new file mode 100644 index 0000000000..b2e60fe7d6 --- /dev/null +++ b/ding/example/ppo_lunarlander.py @@ -0,0 +1,45 @@ +import gym +from ditk import logging +from ding.model import VAC +from ding.policy import PPOPolicy +from ding.envs import DingEnvWrapper, BaseEnvManagerV2 +from ding.data import DequeBuffer +from ding.config import compile_config +from ding.framework import task, ding_init +from ding.framework.context import OnlineRLContext +from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \ + gae_estimator, online_logger +from ding.utils import set_pkg_seed +from dizoo.box2d.lunarlander.config.lunarlander_ppo_config import main_config, create_config + + +def main(): + logging.getLogger().setLevel(logging.INFO) + cfg = compile_config(main_config, create_cfg=create_config, auto=True) + ding_init(cfg) + with task.start(async_mode=False, ctx=OnlineRLContext()): + collector_env = BaseEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], + cfg=cfg.env.manager + ) + evaluator_env = BaseEnvManagerV2( + env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], + cfg=cfg.env.manager + ) + + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + + model = VAC(**cfg.policy.model) + policy = PPOPolicy(cfg.policy, model=model) + + task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env)) + task.use(StepCollector(cfg, policy.collect_mode, collector_env)) + task.use(gae_estimator(cfg, policy.collect_mode)) + task.use(multistep_trainer(policy.learn_mode, log_freq=50)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=100)) + task.use(online_logger(train_show_freq=3)) + task.run() + + +if __name__ == "__main__": + main() diff --git a/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py b/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py new file mode 100644 index 0000000000..ad622c444d --- /dev/null +++ b/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py @@ -0,0 +1,50 @@ +from easydict import EasyDict + +lunarlander_ppo_config = dict( + exp_name='lunarlander_ppo_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=5, + env_id='LunarLander-v2', + n_evaluator_episode=5, + stop_value=200, + ), + policy=dict( + recompute_adv=True, + cuda=True, + action_space='discrete', + model=dict( + obs_shape=8, + action_shape=4, + action_space='discrete', + ), + learn=dict( + epoch_per_collect=10, + batch_size=64, + learning_rate=3e-4, + entropy_weight=0.01, + adv_norm=True, + value_norm=True, + ), + collect=dict( + n_sample=512, + discount_factor=0.99, + ), + ), +) +lunarlander_ppo_config = EasyDict(lunarlander_ppo_config) +main_config = lunarlander_ppo_config +lunarlander_ppo_create_config = dict( + env=dict( + type='lunarlander', + import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='ppo'), +) +lunarlander_ppo_create_config = EasyDict(lunarlander_ppo_create_config) +create_config = lunarlander_ppo_create_config + +if __name__ == "__main__": + from ding.entry import serial_pipeline_onpolicy + serial_pipeline_onpolicy([main_config, create_config], seed=0)