Skip to content

Commit

Permalink
config(nyz): add lunarlander ppo config and example
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Aug 23, 2023
1 parent 3fa8a01 commit e9d8194
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 1 deletion.
3 changes: 2 additions & 1 deletion ding/bonus/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def get_instance_config(env: str, algorithm: str) -> EasyDict:
if algorithm == 'PPO':
cfg = PPOFPolicy.default_config()
if env == 'lunarlander_discrete':
cfg.n_sample = 400
cfg.n_sample = 512
cfg.entropy_weight = 1e-3
elif env == 'lunarlander_continuous':
cfg.action_space = 'continuous'
cfg.n_sample = 400
Expand Down
45 changes: 45 additions & 0 deletions ding/example/ppo_lunarlander.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import gym
from ditk import logging
from ding.model import VAC
from ding.policy import PPOPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task, ding_init
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
gae_estimator, online_logger
from ding.utils import set_pkg_seed
from dizoo.box2d.lunarlander.config.lunarlander_ppo_config import main_config, create_config


def main():
logging.getLogger().setLevel(logging.INFO)
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
ding_init(cfg)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = VAC(**cfg.policy.model)
policy = PPOPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(gae_estimator(cfg, policy.collect_mode))
task.use(multistep_trainer(policy.learn_mode, log_freq=50))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(online_logger(train_show_freq=3))
task.run()


if __name__ == "__main__":
main()
50 changes: 50 additions & 0 deletions dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from easydict import EasyDict

lunarlander_ppo_config = dict(
exp_name='lunarlander_ppo_seed0',
env=dict(
collector_env_num=8,
evaluator_env_num=5,
env_id='LunarLander-v2',
n_evaluator_episode=5,
stop_value=200,
),
policy=dict(
recompute_adv=True,
cuda=True,
action_space='discrete',
model=dict(
obs_shape=8,
action_shape=4,
action_space='discrete',
),
learn=dict(
epoch_per_collect=10,
batch_size=64,
learning_rate=3e-4,
entropy_weight=0.01,
adv_norm=True,
value_norm=True,
),
collect=dict(
n_sample=512,
discount_factor=0.99,
),
),
)
lunarlander_ppo_config = EasyDict(lunarlander_ppo_config)
main_config = lunarlander_ppo_config
lunarlander_ppo_create_config = dict(
env=dict(
type='lunarlander',
import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='ppo'),
)
lunarlander_ppo_create_config = EasyDict(lunarlander_ppo_create_config)
create_config = lunarlander_ppo_create_config

if __name__ == "__main__":
from ding.entry import serial_pipeline_onpolicy
serial_pipeline_onpolicy([main_config, create_config], seed=0)

0 comments on commit e9d8194

Please sign in to comment.