Skip to content

Commit

Permalink
coma starcraft2
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Oct 9, 2023
1 parent b1d81fa commit 0940507
Show file tree
Hide file tree
Showing 11 changed files with 259 additions and 73 deletions.
4 changes: 2 additions & 2 deletions demo_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
def parse_args():
parser = argparse.ArgumentParser("Run an MARL demo.")
parser.add_argument("--method", type=str, default="coma")
parser.add_argument("--env", type=str, default="mpe")
parser.add_argument("--env-id", type=str, default="simple_spread_v3")
parser.add_argument("--env", type=str, default="sc2")
parser.add_argument("--env-id", type=str, default="3m")
parser.add_argument("--test", type=int, default=0)
parser.add_argument("--device", type=str, default="cuda:0")
return parser.parse_args()
Expand Down
50 changes: 50 additions & 0 deletions xuanpolicy/common/memory_tools_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,53 @@ def finish_path(self, value, i_env, value_normalizer=None): # when an episode i
rewards[t] + (1 - self.td_lambda) * self.gamma * vs[t + 1] * (1 - dones[t])
self.data['returns'][i_env, path_slice] = returns[:-1]
self.start_ids[i_env] = self.ptr


class COMA_Buffer_RNN(MARL_OnPolicyBuffer_RNN):
def __init__(self, n_agents, state_space, obs_space, act_space, rew_space, done_space, n_envs, n_size,
use_gae, use_advnorm, gamma, gae_lam, **kwargs):
self.td_lambda = kwargs['td_lambda']
super(COMA_Buffer_RNN, self).__init__(n_agents, state_space, obs_space, act_space, rew_space, done_space,
n_envs, n_size, use_gae, use_advnorm, gamma, gae_lam, **kwargs)

def clear(self):
self.data = {
'obs': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len + 1) + self.obs_space, np.float32),
'actions': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.act_space, np.float32),
'actions_onehot': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len, self.dim_act)).astype(np.float32),
'rewards': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.rew_space, np.float32),
'returns': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.rew_space, np.float32),
'values': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.rew_space, np.float32),
'advantages': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.rew_space, np.float32),
'log_pi_old': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len,), np.float32),
'terminals': np.zeros((self.buffer_size, self.max_eps_len) + self.done_space, np.bool),
'avail_actions': np.ones((self.buffer_size, self.n_agents, self.max_eps_len + 1, self.dim_act), np.bool),
'filled': np.zeros((self.buffer_size, self.max_eps_len, 1), np.bool)
}
if self.state_space is not None:
self.data.update({'state': np.zeros(
(self.buffer_size, self.max_eps_len + 1) + self.state_space, np.float32)})
self.ptr, self.size = 0, 0

def finish_path(self, value, i_env, episode_data=None, current_t=None, value_normalizer=None):
"""
when an episode is finished, build td-lambda targets.
"""
if current_t > self.max_eps_len:
path_slice = np.arange(0, self.max_eps_len).astype(np.int32)
else:
path_slice = np.arange(0, current_t).astype(np.int32)
# calculate advantages and returns
rewards = np.array(episode_data['rewards'][i_env, :, path_slice])
vs = np.append(np.array(episode_data['values'][i_env, :, path_slice]), [value.reshape(self.n_agents, 1)],
axis=0)
dones = np.array(episode_data['terminals'][i_env, path_slice])[:, :, None]
returns = np.zeros_like(vs)
step_nums = len(path_slice)

for t in reversed(range(step_nums)):
returns[t] = self.td_lambda * self.gamma * returns[t + 1] + \
rewards[t] + (1 - self.td_lambda) * self.gamma * vs[t + 1] * (1 - dones[t])

episode_data['returns'][i_env, :, path_slice] = returns[:-1]
self.store(episode_data, i_env)
3 changes: 1 addition & 2 deletions xuanpolicy/configs/coma/mpe/simple_spread_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ env_id: "simple_spread_v3"
continuous_action: False
policy: "Categorical_COMA_Policy"
representation: "Basic_MLP"
representation_critic: "Basic_MLP"
vectorize: "Dummy_Pettingzoo"
runner: "Pettingzoo_Runner"

render: True

use_recurrent: False
rnn:
representation_hidden_size: [128, ]
Expand Down
59 changes: 59 additions & 0 deletions xuanpolicy/configs/coma/sc2/3m.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
agent: "COMA" # the learning algorithms_marl
env_name: "StarCraft2"
env_id: "simple_spread_v3"
fps: 15
policy: "Categorical_COMA_Policy"
representation: "Basic_RNN"
vectorize: "Dummy_StarCraft2"
runner: "StarCraft2_Runner"

render: True

use_recurrent: True
rnn: "GRU"
recurrent_layer_N: 1
fc_hidden_sizes: [64, ]
recurrent_hidden_size: 64
N_recurrent_layers: 1
dropout: 0
normalize: "LayerNorm"
initialize: "orthogonal"
gain: 0.01

actor_hidden_size: [64, ]
critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
n_size: 128
n_epoch: 15
n_minibatch: 1
learning_rate_actor: 0.0007
learning_rate_critic: 0.0007

clip_grad: 10
clip_type: 1 # Gradient clip for Mindspore: 0: ms.ops.clip_by_value; 1: ms.nn.ClipByNorm()
gamma: 0.95 # discount factor
td_lambda: 0.8

start_greedy: 0.5
end_greedy: 0.01
decay_step_greedy: 2500000
sync_frequency: 200

use_global_state: True # if use global state to replace merged observations
use_advnorm: True
use_gae: True
gae_lambda: 0.95

start_training: 1
running_steps: 1000000
train_per_step: True
training_frequency: 1

test_steps: 10000
eval_interval: 5000
test_episode: 10
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
1 change: 0 additions & 1 deletion xuanpolicy/configs/mappo/sc2/3m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ policy: "Categorical_MAAC_Policy"
representation: "Basic_RNN"
vectorize: "Dummy_StarCraft2"
runner: "StarCraft2_Runner"
on_policy: True

# recurrent settings for Basic_RNN representation
use_recurrent: True
Expand Down
26 changes: 10 additions & 16 deletions xuanpolicy/torch/agents/multi_agent_rl/coma_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,26 @@ def __init__(self,
self.n_epoch = config.n_epoch
self.n_minibatch = config.n_minibatch
if config.state_space is not None:
config.dim_state, state_shape = config.state_space.shape, config.state_space.shape
config.dim_state, state_shape = config.state_space.shape[0], config.state_space.shape
else:
config.dim_state, state_shape = None, None

# create representation for COMA actor
input_representation = get_repre_in(config)
self.use_recurrent = config.use_recurrent
self.use_global_state = config.use_global_state
kwargs_rnn = {"N_recurrent_layers": config.N_recurrent_layers,
"dropout": config.dropout,
"rnn": config.rnn} if self.use_recurrent else {}
representation = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn)
# create representation for COMA critic
input_representation[0] = config.dim_obs + config.dim_act * config.n_agents
if self.use_global_state:
input_representation[0] += config.dim_state
representation_critic = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn)
# create policy
input_policy = get_policy_in_marl(config, (representation, representation_critic))
input_policy = get_policy_in_marl(config, representation)
policy = REGISTRY_Policy[config.policy](*input_policy,
use_recurrent=config.use_recurrent,
rnn=config.rnn,
gain=config.gain)
gain=config.gain,
use_global_state=self.use_global_state,
dim_state=config.dim_state)
optimizer = [torch.optim.Adam(policy.parameters_actor, config.learning_rate_actor, eps=1e-5),
torch.optim.Adam(policy.parameters_critic, config.learning_rate_critic, eps=1e-5)]
scheduler = [torch.optim.lr_scheduler.LinearLR(optimizer[0], start_factor=1.0, end_factor=0.5,
Expand All @@ -58,7 +56,7 @@ def __init__(self,
config.dim_state, state_shape = None, None
config.act_onehot_shape = config.act_shape + tuple([config.dim_act])

buffer = MARL_OnPolicyBuffer_RNN if self.use_recurrent else COMA_Buffer
buffer = COMA_Buffer_RNN if self.use_recurrent else COMA_Buffer
input_buffer = (config.n_agents, config.state_space.shape, config.obs_shape, config.act_shape, config.rew_shape,
config.done_shape, envs.num_envs, config.n_size,
config.use_gae, config.use_advnorm, config.gamma, config.gae_lambda)
Expand Down Expand Up @@ -86,7 +84,7 @@ def act(self, obs_n, *rnn_hidden, avail_actions=None, test_mode=False):
*rnn_hidden,
avail_actions=avail_actions.reshape(batch_agents, 1, -1),
epsilon=epsilon)
action_probs = action_probs.view(batch_size, self.n_agents)
action_probs = action_probs.view(batch_size, self.n_agents, self.dim_act)
else:
hidden_state, action_probs = self.policy(obs_in, agents_id,
avail_actions=avail_actions,
Expand All @@ -95,7 +93,7 @@ def act(self, obs_n, *rnn_hidden, avail_actions=None, test_mode=False):
onehot_actions = self.learner.onehot_action(picked_actions, self.dim_act)
return hidden_state, picked_actions.detach().cpu().numpy(), onehot_actions.detach().cpu().numpy()

def values(self, obs_n, actions_n, actions_onehot, *rnn_hidden, state=None):
def values(self, obs_n, *rnn_hidden, state=None, actions_n=None, actions_onehot=None):
batch_size = len(obs_n)
# build critic input
obs_n = torch.Tensor(obs_n).to(self.device)
Expand All @@ -111,11 +109,7 @@ def values(self, obs_n, actions_n, actions_onehot, *rnn_hidden, state=None):
else:
critic_in = torch.concat([obs_n, actions_in])
# get critic values
if self.use_recurrent:
hidden_state, values_n = self.policy.get_values(critic_in.unsqueeze(2), *rnn_hidden, target=True)
values_n = values_n.squeeze(2)
else:
hidden_state, values_n = self.policy.get_values(critic_in, target=True)
hidden_state, values_n = self.policy.get_values(critic_in, target=True)

target_values = values_n.gather(-1, actions_n.long())
return hidden_state, target_values.detach().cpu().numpy()
Expand Down
84 changes: 84 additions & 0 deletions xuanpolicy/torch/learners/multi_agent_rl/coma_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,87 @@ def update(self, sample, epsilon=0.0):
}

return info

def update_recurrent(self, sample, epsilon=0.0):
self.iterations += 1
state = torch.Tensor(sample['state']).to(self.device)
obs = torch.Tensor(sample['obs']).to(self.device)
actions = torch.Tensor(sample['actions']).to(self.device)
actions_onehot = torch.Tensor(sample['actions_onehot']).to(self.device)
targets = torch.Tensor(sample['returns']).squeeze(-1).to(self.device)
avail_actions = torch.Tensor(sample['avail_actions']).float().to(self.device)
filled = torch.Tensor(sample['filled']).float().to(self.device)
batch_size = obs.shape[0]
episode_length = actions.shape[2]
IDs = torch.eye(self.n_agents).unsqueeze(1).unsqueeze(0).expand(batch_size, -1, episode_length + 1, -1).to(
self.device)

# build critic input
actions_in = actions_onehot.transpose(1, 2).reshape(batch_size, episode_length, -1)
actions_in = actions_in.unsqueeze(1).repeat(1, self.n_agents, 1, 1)
actions_in_mask = 1 - torch.eye(self.n_agents, device=self.device)
actions_in_mask = actions_in_mask.view(-1, 1).repeat(1, self.dim_act).view(self.n_agents, -1)
actions_in_mask = actions_in_mask.unsqueeze(1).repeat(1, episode_length, 1)
actions_in = actions_in * actions_in_mask
if self.use_global_state:
state = state[:, :-1].unsqueeze(1).repeat(1, self.n_agents, 1, 1)
critic_in = torch.concat([state, obs[:, :, :-1], actions_in], dim=-1)
else:
critic_in = torch.concat([obs[:, :-1], actions_in])
# get critic value

_, q_eval = self.policy.get_values(critic_in)
q_eval_a = q_eval.gather(-1, actions.unsqueeze(-1).long()).squeeze(-1)
filled_n = filled.unsqueeze(1).expand(-1, self.n_agents, -1, -1).squeeze(-1)
td_errors = q_eval_a - targets.detach()
td_errors *= filled_n
loss_c = (td_errors ** 2).sum() / filled_n.sum()
self.optimizer['critic'].zero_grad()
loss_c.backward()
grad_norm_critic = torch.nn.utils.clip_grad_norm_(self.policy.parameters_critic, self.args.clip_grad)
self.optimizer['critic'].step()
if self.iterations_critic % self.sync_frequency == 0:
self.policy.copy_target()
self.iterations_critic += 1

if self.scheduler['critic'] is not None:
self.scheduler['critic'].step()

# calculate baselines
rnn_hidden_actor = self.policy.representation.init_hidden(batch_size * self.n_agents)
_, pi_probs = self.policy(obs[:, :, :-1].reshape(-1, episode_length, self.dim_obs),
IDs[:, :, :-1].reshape(-1, episode_length, self.n_agents),
*rnn_hidden_actor,
avail_actions=avail_actions[:, :, :-1].reshape(-1, episode_length, self.dim_act),
epsilon=epsilon)
pi_probs = pi_probs.reshape(batch_size, self.n_agents, episode_length, self.dim_act)
baseline = (pi_probs * q_eval).sum(-1)

pi_a = pi_probs.gather(-1, actions.unsqueeze(-1).long()).squeeze(-1)
log_pi_a = torch.log(pi_a)
advantages = (q_eval_a - baseline).detach()
loss_coma = -(advantages * log_pi_a * filled_n).sum() / filled_n.sum()

self.optimizer['actor'].zero_grad()
loss_coma.backward()
grad_norm_actor = torch.nn.utils.clip_grad_norm_(self.policy.parameters_actor, self.args.clip_grad)
self.optimizer['actor'].step()

if self.scheduler['actor'] is not None:
self.scheduler['actor'].step()

# Logger
lr_a = self.optimizer['actor'].state_dict()['param_groups'][0]['lr']
lr_c = self.optimizer['critic'].state_dict()['param_groups'][0]['lr']

info = {
"learning_rate_actor": lr_a,
"learning_rate_critic": lr_c,
"actor_loss": loss_coma.item(),
"critic_loss": loss_c.item(),
"advantage": advantages.mean().item(),
"actor_gradient_norm": grad_norm_actor.item(),
"critic_gradient_norm": grad_norm_critic.item()
}

return info
36 changes: 8 additions & 28 deletions xuanpolicy/torch/policies/categorical_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,19 @@ def __init__(self,
self.device = device
self.action_dim = action_space.n
self.n_agents = n_agents
self.representation = representation[0]
self.representation_critic = representation[1]
self.representation = representation
self.representation_info_shape = self.representation.output_shapes
self.lstm = True if kwargs["rnn"] == "LSTM" else False
self.use_rnn = True if kwargs["use_recurrent"] else False
self.actor = ActorNet(self.representation.output_shapes['state'][0], self.action_dim, n_agents,
actor_hidden_size, normalize, initialize, kwargs['gain'], activation, device)
self.critic = COMA_Critic(self.representation_critic.output_shapes['state'][0], self.action_dim,
critic_hidden_size, normalize, initialize, activation, device)
self.target_representation_critic = copy.deepcopy(self.representation_critic)
critic_input_dim = self.representation.input_shape[0] + self.action_dim * self.n_agents
if kwargs["use_global_state"]:
critic_input_dim += kwargs["dim_state"]
self.critic = COMA_Critic(critic_input_dim, self.action_dim, critic_hidden_size,
normalize, initialize, activation, device)
self.target_critic = copy.deepcopy(self.critic)
self.parameters_critic = list(self.representation_critic.parameters()) + list(self.critic.parameters())
self.parameters_critic = list(self.critic.parameters())
self.parameters_actor = list(self.representation.parameters()) + list(self.actor.parameters())
self.pi_dist = CategoricalDistribution(self.action_dim)

Expand All @@ -199,34 +200,13 @@ def forward(self, observation: torch.Tensor, agent_ids: torch.Tensor,
return rnn_hidden, act_probs

def get_values(self, critic_in: torch.Tensor, *rnn_hidden: torch.Tensor, target=False):
shape_in = critic_in.shape
# get representation features
if self.use_rnn:
batch_size, n_agent, episode_length, dim_critic_in = tuple(shape_in)
if target:
outputs = self.target_representation_critic(critic_in.reshape(-1, episode_length, dim_critic_in), *rnn_hidden)
else:
outputs = self.representation_critic(critic_in.reshape(-1, episode_length, dim_critic_in), *rnn_hidden)
outputs['state'] = outputs['state'].view(batch_size, n_agent, episode_length, -1)
rnn_hidden = (outputs['rnn_hidden'], outputs['rnn_cell'])
else:
batch_size, n_agent, dim_critic_in = tuple(shape_in)
if target:
outputs = self.target_representation_critic(critic_in.reshape(-1, dim_critic_in))
else:
outputs = self.representation_critic(critic_in.reshape(-1, dim_critic_in))
outputs['state'] = outputs['state'].view(batch_size, n_agent, -1)
rnn_hidden = None
# get critic values
critic_in = outputs['state']
v = self.target_critic(critic_in) if target else self.critic(critic_in)
return rnn_hidden, v
return [None, None], v

def copy_target(self):
for ep, tp in zip(self.critic.parameters(), self.target_critic.parameters()):
tp.data.copy_(ep)
for ep, tp in zip(self.representation_critic.parameters(), self.target_representation_critic.parameters()):
tp.data.copy_(ep)


class MeanFieldActorCriticPolicy(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion xuanpolicy/torch/runners/runner_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def get_actions(self, obs_n, test_mode, act_mean_last, agent_mask, state):
log_pi_n.append(None)
elif self.marl_names[h] in ["COMA"]:
_, a, a_onehot = mas_group.act(obs_n[h], test_mode)
_, values = mas_group.values(obs_n[h], a, a_onehot, state=state)
_, values = mas_group.values(obs_n[h], state=state, actions_n=a, actions_onehot=a_onehot)
actions_n_onehot.append(a_onehot)
values_n.append(values)
else:
Expand Down
Loading

0 comments on commit 0940507

Please sign in to comment.