Skip to content

Commit

Permalink
feature(nyz): add MADDPG pettingzoo example (#774)
Browse files Browse the repository at this point in the history
* feature(nyz): add MADDPG pettingzoo example

* style(nyz): fix registry style bug
  • Loading branch information
PaParaZz1 authored Feb 4, 2024
1 parent e9c09f6 commit abdf68a
Show file tree
Hide file tree
Showing 18 changed files with 124 additions and 32 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 27 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 |
| 28 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [WQMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/wqmix.html)<br>[policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 |
| 29 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [CollaQ doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/collaq.html)<br>[policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
| 30 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)<br>[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ant_maddpg_config.py -s 0 |
| 30 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html)<br>[policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ptz_simple_spread_maddpg_config.py -s 0 |
| 31 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [GAIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/gail.html)<br>[reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
| 32 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [SQIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sqil.html)<br>[entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
| 33 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [DQFD doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqfd.html)<br>[policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
Expand Down
36 changes: 23 additions & 13 deletions ding/policy/policy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections import namedtuple
from easydict import EasyDict
import gym
import gymnasium
import torch

from ding.torch_utils import to_device
Expand Down Expand Up @@ -49,26 +50,35 @@ def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:

actions = {}
for env_id in data:
if not isinstance(action_space, list):
if isinstance(action_space, list):
if 'global_state' in data[env_id].keys():
# for smac
logit = torch.ones_like(data[env_id]['action_mask'])
logit[data[env_id]['action_mask'] == 0.0] = -1e8
dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit))
actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)}
else:
# for gfootball
actions[env_id] = {
'action': torch.as_tensor(
[action_space_agent.sample() for action_space_agent in action_space]
),
'logit': torch.ones([len(action_space), action_space[0].n])
}
elif isinstance(action_space, gymnasium.spaces.Dict): # pettingzoo
actions[env_id] = {
'action': torch.as_tensor(
[action_space_agent.sample() for action_space_agent in action_space.values()]
)
}
else:
if isinstance(action_space, gym.spaces.Discrete):
action = torch.LongTensor([action_space.sample()])
elif isinstance(action_space, gym.spaces.MultiDiscrete):
action = [torch.LongTensor([v]) for v in action_space.sample()]
else:
action = torch.as_tensor(action_space.sample())
actions[env_id] = {'action': action}
elif 'global_state' in data[env_id].keys():
# for smac
logit = torch.ones_like(data[env_id]['action_mask'])
logit[data[env_id]['action_mask'] == 0.0] = -1e8
dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit))
actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)}
else:
# for gfootball
actions[env_id] = {
'action': torch.as_tensor([action_space_agent.sample() for action_space_agent in action_space]),
'logit': torch.ones([len(action_space), action_space[0].n])
}
return actions

def reset(*args, **kwargs) -> None:
Expand Down
9 changes: 5 additions & 4 deletions ding/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,11 @@ def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object:
global _innest_error
if _innest_error:
argspec = inspect.getfullargspec(build_fn)
message = 'for {}(alias={})'.format(build_fn, obj_type)
message += '\nExpected args are:{}'.format(argspec)
message += '\nGiven args are:{}/{}'.format(argspec, obj_kwargs.keys())
message += '\nGiven args details are:{}/{}'.format(argspec, obj_kwargs)
message = 'Hint: for {}(alias={})'.format(build_fn, obj_type)
message += '\n\nExpected args are:\n {}\nGiven arguments keys are:\n{}\n'.format(
argspec, obj_kwargs.keys()
)
print(message)
_innest_error = False
raise e

Expand Down
2 changes: 1 addition & 1 deletion dizoo/gfootball/config/gfootball_counter_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
import_names=['dizoo.gfootball.envs.gfootball_academy_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac'),
)
gfootball_keeper_masac_default_create_config = EasyDict(gfootball_keeper_masac_default_create_config)
create_config = gfootball_keeper_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/gfootball/config/gfootball_keeper_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
import_names=['dizoo.gfootball.envs.gfootball_academy_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac'),
)
gfootball_keeper_masac_default_create_config = EasyDict(gfootball_keeper_masac_default_create_config)
create_config = gfootball_keeper_masac_default_create_config
Expand Down
81 changes: 81 additions & 0 deletions dizoo/petting_zoo/config/ptz_simple_spread_maddpg_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from easydict import EasyDict

n_agent = 3
n_landmark = n_agent
collector_env_num = 8
evaluator_env_num = 8
main_config = dict(
exp_name='ptz_simple_spread_maddpg_seed0',
env=dict(
env_family='mpe',
env_id='simple_spread_v2',
n_agent=n_agent,
n_landmark=n_landmark,
max_cycles=25,
agent_obs_only=False,
agent_specific_global_state=True,
continuous_actions=True, # ddpg only support continuous action space
act_scale=True, # necessary for continuous action space
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
stop_value=0,
),
policy=dict(
cuda=True,
multi_agent=True,
random_collect_size=5000,
model=dict(
agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) +
n_landmark * 2 + n_agent * (n_agent - 1) * 2,
action_shape=5,
action_space='regression',
twin_critic=False,
),
learn=dict(
update_per_collect=50,
batch_size=320,
# learning_rates
learning_rate_q=5e-4,
learning_rate_policy=5e-4,
target_theta=0.005,
discount_factor=0.99,
),
collect=dict(
n_sample=1600,
env_num=collector_env_num,
),
eval=dict(
env_num=evaluator_env_num,
evaluator=dict(eval_freq=500, ),
),
other=dict(
eps=dict(
type='linear',
start=1,
end=0.05,
decay=100000,
),
replay_buffer=dict(replay_buffer_size=int(1e6), )
),
),
)

main_config = EasyDict(main_config)
create_config = dict(
env=dict(
import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
type='petting_zoo',
),
env_manager=dict(type='subprocess'),
policy=dict(type='ddpg'),
)
create_config = EasyDict(create_config)
ptz_simple_spread_maddpg_config = main_config
ptz_simple_spread_maddpg_create_config = create_config

if __name__ == '__main__':
# or you can enter `ding -m serial_entry -c ptz_simple_spread_maddpg_config.py -s 0`
from ding.entry import serial_pipeline
serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e6))
2 changes: 1 addition & 1 deletion dizoo/petting_zoo/config/ptz_simple_spread_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
type='petting_zoo',
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete'),
policy=dict(type='discrete_sac'),
)
create_config = EasyDict(create_config)
ptz_simple_spread_masac_config = main_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_10m11m_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_10m11m_masac_default_create_config = EasyDict(SMAC_10m11m_masac_default_create_config)
create_config = SMAC_10m11m_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_25m_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_25m_masac_default_create_config = EasyDict(SMAC_25m_masac_default_create_config)
create_config = SMAC_25m_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_2c64zg_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_2c64zg_masac_default_create_config = EasyDict(SMAC_2c64zg_masac_default_create_config)
create_config = SMAC_2c64zg_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_3m_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_3m_masac_default_create_config = EasyDict(SMAC_3m_masac_default_create_config)
create_config = SMAC_3m_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_3s5z_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
smac_3s5z_masac_default_create_config = EasyDict(smac_3s5z_masac_default_create_config)
create_config = smac_3s5z_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_3s5zvs3s6z_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
smac_3s5zvs3s6z_masac_default_create_config = EasyDict(smac_3s5zvs3s6z_masac_default_create_config)
create_config = smac_3s5zvs3s6z_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_5m6m_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_5m6m_masac_default_create_config = EasyDict(SMAC_5m6m_masac_default_create_config)
create_config = SMAC_5m6m_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_8m9m_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_8m9m_masac_default_create_config = EasyDict(SMAC_8m9m_masac_default_create_config)
create_config = SMAC_8m9m_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_MMM2_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
SMAC_MMM2_masac_default_create_config = EasyDict(SMAC_MMM2_masac_default_create_config)
create_config = SMAC_MMM2_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_MMM_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='base'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
MMM_masac_default_create_config = EasyDict(MMM_masac_default_create_config)
create_config = MMM_masac_default_create_config
Expand Down
2 changes: 1 addition & 1 deletion dizoo/smac/config/smac_corridor_masac_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
import_names=['dizoo.smac.envs.smac_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(type='sac_discrete', ),
policy=dict(type='discrete_sac', ),
)
smac_corridor_masac_default_create_config = EasyDict(smac_corridor_masac_default_create_config)
create_config = smac_corridor_masac_default_create_config
Expand Down

0 comments on commit abdf68a

Please sign in to comment.