Skip to content

Commit

Permalink
coma for mpe
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Oct 8, 2023
1 parent 4d962b4 commit b1d81fa
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
4 changes: 2 additions & 2 deletions xuanpolicy/common/memory_tools_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def finish_path(self, value, i_env, value_normalizer=None): # when an episode i
returns = np.zeros_like(vs)
step_nums = len(path_slice)
for t in reversed(range(step_nums)):
returns[t] = self.td_lambda * self.gamma * returns[t+1] + \
rewards[t] + (1 - self.td_lambda) * self.gamma * vs[t+1] * (1 - dones[t])
returns[t] = self.td_lambda * self.gamma * returns[t + 1] + \
rewards[t] + (1 - self.td_lambda) * self.gamma * vs[t + 1] * (1 - dones[t])
self.data['returns'][i_env, path_slice] = returns[:-1]
self.start_ids[i_env] = self.ptr
6 changes: 3 additions & 3 deletions xuanpolicy/configs/coma/mpe/simple_spread_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ parallels: 10
n_size: 25
n_epoch: 10
n_minibatch: 1
learning_rate_actor: 0.0005
learning_rate_critic: 0.0005
learning_rate_actor: 0.0007
learning_rate_critic: 0.0007

clip_grad: 10
clip_type: 1 # Gradient clip for Mindspore: 0: ms.ops.clip_by_value; 1: ms.nn.ClipByNorm()
gamma: 0.95 # discount factor
td_lambda: 0.8
td_lambda: 0.1

start_greedy: 0.5
end_greedy: 0.01
Expand Down
13 changes: 6 additions & 7 deletions xuanpolicy/torch/agents/multi_agent_rl/coma_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,20 @@ def act(self, obs_n, *rnn_hidden, avail_actions=None, test_mode=False):
batch_size = len(obs_n)
agents_id = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)
obs_in = torch.Tensor(obs_n).view([batch_size, self.n_agents, -1]).to(self.device)
epsilon = 0.0 if test_mode else self.egreedy
if self.use_recurrent:
batch_agents = batch_size * self.n_agents
hidden_state, action_probs = self.policy(obs_in.view(batch_agents, 1, -1),
agents_id.view(batch_agents, 1, -1),
*rnn_hidden,
avail_actions=avail_actions.reshape(batch_agents, 1, -1),
epsilon=self.egreedy)
epsilon=epsilon)
action_probs = action_probs.view(batch_size, self.n_agents)
else:
hidden_state, action_probs = self.policy(obs_in, agents_id, avail_actions=avail_actions)

if test_mode:
_, picked_actions = action_probs.max()
else:
picked_actions = Categorical(action_probs).sample()
hidden_state, action_probs = self.policy(obs_in, agents_id,
avail_actions=avail_actions,
epsilon=epsilon)
picked_actions = Categorical(action_probs).sample()
onehot_actions = self.learner.onehot_action(picked_actions, self.dim_act)
return hidden_state, picked_actions.detach().cpu().numpy(), onehot_actions.detach().cpu().numpy()

Expand Down

0 comments on commit b1d81fa

Please sign in to comment.