From 0940507e5f8fc3e3e9e96aec439eb3ab645e47ee Mon Sep 17 00:00:00 2001 From: wenzhangliu Date: Mon, 9 Oct 2023 20:33:53 +0800 Subject: [PATCH] coma starcraft2 --- demo_marl.py | 4 +- xuanpolicy/common/memory_tools_marl.py | 50 +++++++++++ .../configs/coma/mpe/simple_spread_v3.yaml | 3 +- xuanpolicy/configs/coma/sc2/3m.yaml | 59 +++++++++++++ xuanpolicy/configs/mappo/sc2/3m.yaml | 1 - .../agents/multi_agent_rl/coma_agents.py | 26 +++--- .../learners/multi_agent_rl/coma_learner.py | 84 +++++++++++++++++++ xuanpolicy/torch/policies/categorical_marl.py | 36 ++------ xuanpolicy/torch/runners/runner_pettingzoo.py | 2 +- xuanpolicy/torch/runners/runner_sc2.py | 63 +++++++++----- xuanpolicy/torch/utils/input_reformat.py | 4 +- 11 files changed, 259 insertions(+), 73 deletions(-) create mode 100644 xuanpolicy/configs/coma/sc2/3m.yaml diff --git a/demo_marl.py b/demo_marl.py index 0940bd4e4..d228410ad 100644 --- a/demo_marl.py +++ b/demo_marl.py @@ -5,8 +5,8 @@ def parse_args(): parser = argparse.ArgumentParser("Run an MARL demo.") parser.add_argument("--method", type=str, default="coma") - parser.add_argument("--env", type=str, default="mpe") - parser.add_argument("--env-id", type=str, default="simple_spread_v3") + parser.add_argument("--env", type=str, default="sc2") + parser.add_argument("--env-id", type=str, default="3m") parser.add_argument("--test", type=int, default=0) parser.add_argument("--device", type=str, default="cuda:0") return parser.parse_args() diff --git a/xuanpolicy/common/memory_tools_marl.py b/xuanpolicy/common/memory_tools_marl.py index 3c4c8e2c6..df07267d2 100644 --- a/xuanpolicy/common/memory_tools_marl.py +++ b/xuanpolicy/common/memory_tools_marl.py @@ -536,3 +536,53 @@ def finish_path(self, value, i_env, value_normalizer=None): # when an episode i rewards[t] + (1 - self.td_lambda) * self.gamma * vs[t + 1] * (1 - dones[t]) self.data['returns'][i_env, path_slice] = returns[:-1] self.start_ids[i_env] = self.ptr + + +class COMA_Buffer_RNN(MARL_OnPolicyBuffer_RNN): + def __init__(self, n_agents, state_space, obs_space, act_space, rew_space, done_space, n_envs, n_size, + use_gae, use_advnorm, gamma, gae_lam, **kwargs): + self.td_lambda = kwargs['td_lambda'] + super(COMA_Buffer_RNN, self).__init__(n_agents, state_space, obs_space, act_space, rew_space, done_space, + n_envs, n_size, use_gae, use_advnorm, gamma, gae_lam, **kwargs) + + def clear(self): + self.data = { + 'obs': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len + 1) + self.obs_space, np.float32), + 'actions': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.act_space, np.float32), + 'actions_onehot': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len, self.dim_act)).astype(np.float32), + 'rewards': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.rew_space, np.float32), + 'returns': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.rew_space, np.float32), + 'values': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.rew_space, np.float32), + 'advantages': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.rew_space, np.float32), + 'log_pi_old': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len,), np.float32), + 'terminals': np.zeros((self.buffer_size, self.max_eps_len) + self.done_space, np.bool), + 'avail_actions': np.ones((self.buffer_size, self.n_agents, self.max_eps_len + 1, self.dim_act), np.bool), + 'filled': np.zeros((self.buffer_size, self.max_eps_len, 1), np.bool) + } + if self.state_space is not None: + self.data.update({'state': np.zeros( + (self.buffer_size, self.max_eps_len + 1) + self.state_space, np.float32)}) + self.ptr, self.size = 0, 0 + + def finish_path(self, value, i_env, episode_data=None, current_t=None, value_normalizer=None): + """ + when an episode is finished, build td-lambda targets. + """ + if current_t > self.max_eps_len: + path_slice = np.arange(0, self.max_eps_len).astype(np.int32) + else: + path_slice = np.arange(0, current_t).astype(np.int32) + # calculate advantages and returns + rewards = np.array(episode_data['rewards'][i_env, :, path_slice]) + vs = np.append(np.array(episode_data['values'][i_env, :, path_slice]), [value.reshape(self.n_agents, 1)], + axis=0) + dones = np.array(episode_data['terminals'][i_env, path_slice])[:, :, None] + returns = np.zeros_like(vs) + step_nums = len(path_slice) + + for t in reversed(range(step_nums)): + returns[t] = self.td_lambda * self.gamma * returns[t + 1] + \ + rewards[t] + (1 - self.td_lambda) * self.gamma * vs[t + 1] * (1 - dones[t]) + + episode_data['returns'][i_env, :, path_slice] = returns[:-1] + self.store(episode_data, i_env) diff --git a/xuanpolicy/configs/coma/mpe/simple_spread_v3.yaml b/xuanpolicy/configs/coma/mpe/simple_spread_v3.yaml index b72cbe11a..a89465dbb 100644 --- a/xuanpolicy/configs/coma/mpe/simple_spread_v3.yaml +++ b/xuanpolicy/configs/coma/mpe/simple_spread_v3.yaml @@ -4,11 +4,10 @@ env_id: "simple_spread_v3" continuous_action: False policy: "Categorical_COMA_Policy" representation: "Basic_MLP" +representation_critic: "Basic_MLP" vectorize: "Dummy_Pettingzoo" runner: "Pettingzoo_Runner" -render: True - use_recurrent: False rnn: representation_hidden_size: [128, ] diff --git a/xuanpolicy/configs/coma/sc2/3m.yaml b/xuanpolicy/configs/coma/sc2/3m.yaml new file mode 100644 index 000000000..f1461ac76 --- /dev/null +++ b/xuanpolicy/configs/coma/sc2/3m.yaml @@ -0,0 +1,59 @@ +agent: "COMA" # the learning algorithms_marl +env_name: "StarCraft2" +env_id: "simple_spread_v3" +fps: 15 +policy: "Categorical_COMA_Policy" +representation: "Basic_RNN" +vectorize: "Dummy_StarCraft2" +runner: "StarCraft2_Runner" + +render: True + +use_recurrent: True +rnn: "GRU" +recurrent_layer_N: 1 +fc_hidden_sizes: [64, ] +recurrent_hidden_size: 64 +N_recurrent_layers: 1 +dropout: 0 +normalize: "LayerNorm" +initialize: "orthogonal" +gain: 0.01 + +actor_hidden_size: [64, ] +critic_hidden_size: [128, 128] +activation: "ReLU" + +seed: 1 +parallels: 1 +n_size: 128 +n_epoch: 15 +n_minibatch: 1 +learning_rate_actor: 0.0007 +learning_rate_critic: 0.0007 + +clip_grad: 10 +clip_type: 1 # Gradient clip for Mindspore: 0: ms.ops.clip_by_value; 1: ms.nn.ClipByNorm() +gamma: 0.95 # discount factor +td_lambda: 0.8 + +start_greedy: 0.5 +end_greedy: 0.01 +decay_step_greedy: 2500000 +sync_frequency: 200 + +use_global_state: True # if use global state to replace merged observations +use_advnorm: True +use_gae: True +gae_lambda: 0.95 + +start_training: 1 +running_steps: 1000000 +train_per_step: True +training_frequency: 1 + +test_steps: 10000 +eval_interval: 5000 +test_episode: 10 +log_dir: "./logs/coma/" +model_dir: "./models/coma/" diff --git a/xuanpolicy/configs/mappo/sc2/3m.yaml b/xuanpolicy/configs/mappo/sc2/3m.yaml index c3c5ca688..1425f02d2 100644 --- a/xuanpolicy/configs/mappo/sc2/3m.yaml +++ b/xuanpolicy/configs/mappo/sc2/3m.yaml @@ -6,7 +6,6 @@ policy: "Categorical_MAAC_Policy" representation: "Basic_RNN" vectorize: "Dummy_StarCraft2" runner: "StarCraft2_Runner" -on_policy: True # recurrent settings for Basic_RNN representation use_recurrent: True diff --git a/xuanpolicy/torch/agents/multi_agent_rl/coma_agents.py b/xuanpolicy/torch/agents/multi_agent_rl/coma_agents.py index f0231073a..b4d667b45 100644 --- a/xuanpolicy/torch/agents/multi_agent_rl/coma_agents.py +++ b/xuanpolicy/torch/agents/multi_agent_rl/coma_agents.py @@ -19,10 +19,11 @@ def __init__(self, self.n_epoch = config.n_epoch self.n_minibatch = config.n_minibatch if config.state_space is not None: - config.dim_state, state_shape = config.state_space.shape, config.state_space.shape + config.dim_state, state_shape = config.state_space.shape[0], config.state_space.shape else: config.dim_state, state_shape = None, None + # create representation for COMA actor input_representation = get_repre_in(config) self.use_recurrent = config.use_recurrent self.use_global_state = config.use_global_state @@ -30,17 +31,14 @@ def __init__(self, "dropout": config.dropout, "rnn": config.rnn} if self.use_recurrent else {} representation = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn) - # create representation for COMA critic - input_representation[0] = config.dim_obs + config.dim_act * config.n_agents - if self.use_global_state: - input_representation[0] += config.dim_state - representation_critic = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn) # create policy - input_policy = get_policy_in_marl(config, (representation, representation_critic)) + input_policy = get_policy_in_marl(config, representation) policy = REGISTRY_Policy[config.policy](*input_policy, use_recurrent=config.use_recurrent, rnn=config.rnn, - gain=config.gain) + gain=config.gain, + use_global_state=self.use_global_state, + dim_state=config.dim_state) optimizer = [torch.optim.Adam(policy.parameters_actor, config.learning_rate_actor, eps=1e-5), torch.optim.Adam(policy.parameters_critic, config.learning_rate_critic, eps=1e-5)] scheduler = [torch.optim.lr_scheduler.LinearLR(optimizer[0], start_factor=1.0, end_factor=0.5, @@ -58,7 +56,7 @@ def __init__(self, config.dim_state, state_shape = None, None config.act_onehot_shape = config.act_shape + tuple([config.dim_act]) - buffer = MARL_OnPolicyBuffer_RNN if self.use_recurrent else COMA_Buffer + buffer = COMA_Buffer_RNN if self.use_recurrent else COMA_Buffer input_buffer = (config.n_agents, config.state_space.shape, config.obs_shape, config.act_shape, config.rew_shape, config.done_shape, envs.num_envs, config.n_size, config.use_gae, config.use_advnorm, config.gamma, config.gae_lambda) @@ -86,7 +84,7 @@ def act(self, obs_n, *rnn_hidden, avail_actions=None, test_mode=False): *rnn_hidden, avail_actions=avail_actions.reshape(batch_agents, 1, -1), epsilon=epsilon) - action_probs = action_probs.view(batch_size, self.n_agents) + action_probs = action_probs.view(batch_size, self.n_agents, self.dim_act) else: hidden_state, action_probs = self.policy(obs_in, agents_id, avail_actions=avail_actions, @@ -95,7 +93,7 @@ def act(self, obs_n, *rnn_hidden, avail_actions=None, test_mode=False): onehot_actions = self.learner.onehot_action(picked_actions, self.dim_act) return hidden_state, picked_actions.detach().cpu().numpy(), onehot_actions.detach().cpu().numpy() - def values(self, obs_n, actions_n, actions_onehot, *rnn_hidden, state=None): + def values(self, obs_n, *rnn_hidden, state=None, actions_n=None, actions_onehot=None): batch_size = len(obs_n) # build critic input obs_n = torch.Tensor(obs_n).to(self.device) @@ -111,11 +109,7 @@ def values(self, obs_n, actions_n, actions_onehot, *rnn_hidden, state=None): else: critic_in = torch.concat([obs_n, actions_in]) # get critic values - if self.use_recurrent: - hidden_state, values_n = self.policy.get_values(critic_in.unsqueeze(2), *rnn_hidden, target=True) - values_n = values_n.squeeze(2) - else: - hidden_state, values_n = self.policy.get_values(critic_in, target=True) + hidden_state, values_n = self.policy.get_values(critic_in, target=True) target_values = values_n.gather(-1, actions_n.long()) return hidden_state, target_values.detach().cpu().numpy() diff --git a/xuanpolicy/torch/learners/multi_agent_rl/coma_learner.py b/xuanpolicy/torch/learners/multi_agent_rl/coma_learner.py index f6bf225a0..01aea212c 100644 --- a/xuanpolicy/torch/learners/multi_agent_rl/coma_learner.py +++ b/xuanpolicy/torch/learners/multi_agent_rl/coma_learner.py @@ -108,3 +108,87 @@ def update(self, sample, epsilon=0.0): } return info + + def update_recurrent(self, sample, epsilon=0.0): + self.iterations += 1 + state = torch.Tensor(sample['state']).to(self.device) + obs = torch.Tensor(sample['obs']).to(self.device) + actions = torch.Tensor(sample['actions']).to(self.device) + actions_onehot = torch.Tensor(sample['actions_onehot']).to(self.device) + targets = torch.Tensor(sample['returns']).squeeze(-1).to(self.device) + avail_actions = torch.Tensor(sample['avail_actions']).float().to(self.device) + filled = torch.Tensor(sample['filled']).float().to(self.device) + batch_size = obs.shape[0] + episode_length = actions.shape[2] + IDs = torch.eye(self.n_agents).unsqueeze(1).unsqueeze(0).expand(batch_size, -1, episode_length + 1, -1).to( + self.device) + + # build critic input + actions_in = actions_onehot.transpose(1, 2).reshape(batch_size, episode_length, -1) + actions_in = actions_in.unsqueeze(1).repeat(1, self.n_agents, 1, 1) + actions_in_mask = 1 - torch.eye(self.n_agents, device=self.device) + actions_in_mask = actions_in_mask.view(-1, 1).repeat(1, self.dim_act).view(self.n_agents, -1) + actions_in_mask = actions_in_mask.unsqueeze(1).repeat(1, episode_length, 1) + actions_in = actions_in * actions_in_mask + if self.use_global_state: + state = state[:, :-1].unsqueeze(1).repeat(1, self.n_agents, 1, 1) + critic_in = torch.concat([state, obs[:, :, :-1], actions_in], dim=-1) + else: + critic_in = torch.concat([obs[:, :-1], actions_in]) + # get critic value + + _, q_eval = self.policy.get_values(critic_in) + q_eval_a = q_eval.gather(-1, actions.unsqueeze(-1).long()).squeeze(-1) + filled_n = filled.unsqueeze(1).expand(-1, self.n_agents, -1, -1).squeeze(-1) + td_errors = q_eval_a - targets.detach() + td_errors *= filled_n + loss_c = (td_errors ** 2).sum() / filled_n.sum() + self.optimizer['critic'].zero_grad() + loss_c.backward() + grad_norm_critic = torch.nn.utils.clip_grad_norm_(self.policy.parameters_critic, self.args.clip_grad) + self.optimizer['critic'].step() + if self.iterations_critic % self.sync_frequency == 0: + self.policy.copy_target() + self.iterations_critic += 1 + + if self.scheduler['critic'] is not None: + self.scheduler['critic'].step() + + # calculate baselines + rnn_hidden_actor = self.policy.representation.init_hidden(batch_size * self.n_agents) + _, pi_probs = self.policy(obs[:, :, :-1].reshape(-1, episode_length, self.dim_obs), + IDs[:, :, :-1].reshape(-1, episode_length, self.n_agents), + *rnn_hidden_actor, + avail_actions=avail_actions[:, :, :-1].reshape(-1, episode_length, self.dim_act), + epsilon=epsilon) + pi_probs = pi_probs.reshape(batch_size, self.n_agents, episode_length, self.dim_act) + baseline = (pi_probs * q_eval).sum(-1) + + pi_a = pi_probs.gather(-1, actions.unsqueeze(-1).long()).squeeze(-1) + log_pi_a = torch.log(pi_a) + advantages = (q_eval_a - baseline).detach() + loss_coma = -(advantages * log_pi_a * filled_n).sum() / filled_n.sum() + + self.optimizer['actor'].zero_grad() + loss_coma.backward() + grad_norm_actor = torch.nn.utils.clip_grad_norm_(self.policy.parameters_actor, self.args.clip_grad) + self.optimizer['actor'].step() + + if self.scheduler['actor'] is not None: + self.scheduler['actor'].step() + + # Logger + lr_a = self.optimizer['actor'].state_dict()['param_groups'][0]['lr'] + lr_c = self.optimizer['critic'].state_dict()['param_groups'][0]['lr'] + + info = { + "learning_rate_actor": lr_a, + "learning_rate_critic": lr_c, + "actor_loss": loss_coma.item(), + "critic_loss": loss_c.item(), + "advantage": advantages.mean().item(), + "actor_gradient_norm": grad_norm_actor.item(), + "critic_gradient_norm": grad_norm_critic.item() + } + + return info diff --git a/xuanpolicy/torch/policies/categorical_marl.py b/xuanpolicy/torch/policies/categorical_marl.py index 8f6e52a64..2f9b2583b 100644 --- a/xuanpolicy/torch/policies/categorical_marl.py +++ b/xuanpolicy/torch/policies/categorical_marl.py @@ -166,18 +166,19 @@ def __init__(self, self.device = device self.action_dim = action_space.n self.n_agents = n_agents - self.representation = representation[0] - self.representation_critic = representation[1] + self.representation = representation self.representation_info_shape = self.representation.output_shapes self.lstm = True if kwargs["rnn"] == "LSTM" else False self.use_rnn = True if kwargs["use_recurrent"] else False self.actor = ActorNet(self.representation.output_shapes['state'][0], self.action_dim, n_agents, actor_hidden_size, normalize, initialize, kwargs['gain'], activation, device) - self.critic = COMA_Critic(self.representation_critic.output_shapes['state'][0], self.action_dim, - critic_hidden_size, normalize, initialize, activation, device) - self.target_representation_critic = copy.deepcopy(self.representation_critic) + critic_input_dim = self.representation.input_shape[0] + self.action_dim * self.n_agents + if kwargs["use_global_state"]: + critic_input_dim += kwargs["dim_state"] + self.critic = COMA_Critic(critic_input_dim, self.action_dim, critic_hidden_size, + normalize, initialize, activation, device) self.target_critic = copy.deepcopy(self.critic) - self.parameters_critic = list(self.representation_critic.parameters()) + list(self.critic.parameters()) + self.parameters_critic = list(self.critic.parameters()) self.parameters_actor = list(self.representation.parameters()) + list(self.actor.parameters()) self.pi_dist = CategoricalDistribution(self.action_dim) @@ -199,34 +200,13 @@ def forward(self, observation: torch.Tensor, agent_ids: torch.Tensor, return rnn_hidden, act_probs def get_values(self, critic_in: torch.Tensor, *rnn_hidden: torch.Tensor, target=False): - shape_in = critic_in.shape - # get representation features - if self.use_rnn: - batch_size, n_agent, episode_length, dim_critic_in = tuple(shape_in) - if target: - outputs = self.target_representation_critic(critic_in.reshape(-1, episode_length, dim_critic_in), *rnn_hidden) - else: - outputs = self.representation_critic(critic_in.reshape(-1, episode_length, dim_critic_in), *rnn_hidden) - outputs['state'] = outputs['state'].view(batch_size, n_agent, episode_length, -1) - rnn_hidden = (outputs['rnn_hidden'], outputs['rnn_cell']) - else: - batch_size, n_agent, dim_critic_in = tuple(shape_in) - if target: - outputs = self.target_representation_critic(critic_in.reshape(-1, dim_critic_in)) - else: - outputs = self.representation_critic(critic_in.reshape(-1, dim_critic_in)) - outputs['state'] = outputs['state'].view(batch_size, n_agent, -1) - rnn_hidden = None # get critic values - critic_in = outputs['state'] v = self.target_critic(critic_in) if target else self.critic(critic_in) - return rnn_hidden, v + return [None, None], v def copy_target(self): for ep, tp in zip(self.critic.parameters(), self.target_critic.parameters()): tp.data.copy_(ep) - for ep, tp in zip(self.representation_critic.parameters(), self.target_representation_critic.parameters()): - tp.data.copy_(ep) class MeanFieldActorCriticPolicy(nn.Module): diff --git a/xuanpolicy/torch/runners/runner_pettingzoo.py b/xuanpolicy/torch/runners/runner_pettingzoo.py index 531cf5011..49b711a7b 100644 --- a/xuanpolicy/torch/runners/runner_pettingzoo.py +++ b/xuanpolicy/torch/runners/runner_pettingzoo.py @@ -172,7 +172,7 @@ def get_actions(self, obs_n, test_mode, act_mean_last, agent_mask, state): log_pi_n.append(None) elif self.marl_names[h] in ["COMA"]: _, a, a_onehot = mas_group.act(obs_n[h], test_mode) - _, values = mas_group.values(obs_n[h], a, a_onehot, state=state) + _, values = mas_group.values(obs_n[h], state=state, actions_n=a, actions_onehot=a_onehot) actions_n_onehot.append(a_onehot) values_n.append(values) else: diff --git a/xuanpolicy/torch/runners/runner_sc2.py b/xuanpolicy/torch/runners/runner_sc2.py index 11868971c..ebffe339e 100644 --- a/xuanpolicy/torch/runners/runner_sc2.py +++ b/xuanpolicy/torch/runners/runner_sc2.py @@ -51,7 +51,6 @@ def __init__(self, args): else: raise "No logger is implemented." - self.on_policy = self.args.on_policy self.running_steps = args.running_steps self.training_frequency = args.training_frequency self.current_step = 0 @@ -63,36 +62,44 @@ def __init__(self, args): args.n_agents = self.num_agents self.dim_obs, self.dim_act, self.dim_state = self.envs.dim_obs, self.envs.dim_act, self.envs.dim_state args.dim_obs, args.dim_act = self.dim_obs, self.dim_act - args.obs_shape, args.act_shape = (self.dim_obs, ), () - args.rew_shape = args.done_shape = (1, ) + args.obs_shape, args.act_shape = (self.dim_obs,), () + args.rew_shape = args.done_shape = (1,) args.action_space = self.envs.action_space args.state_space = self.envs.state_space + + # environment details, representations, policies, optimizers, and agents. + self.agents = REGISTRY_Agent[args.agent](args, self.envs, args.device) + self.on_policy = self.agents.on_policy self.episode_buffer = { 'obs': np.zeros((self.n_envs, self.num_agents, self.episode_length + 1) + args.obs_shape, dtype=np.float32), 'actions': np.zeros((self.n_envs, self.num_agents, self.episode_length) + args.act_shape, dtype=np.float32), 'state': np.zeros((self.n_envs, self.episode_length + 1) + args.state_space.shape, dtype=np.float32), 'rewards': np.zeros((self.n_envs, self.num_agents, self.episode_length) + args.rew_shape, dtype=np.float32), 'terminals': np.zeros((self.n_envs, self.episode_length) + args.done_shape, dtype=np.bool), - 'avail_actions': np.ones((self.n_envs, self.num_agents, self.episode_length + 1, self.dim_act), dtype=np.bool), + 'avail_actions': np.ones((self.n_envs, self.num_agents, self.episode_length + 1, self.dim_act), + dtype=np.bool), 'filled': np.zeros((self.n_envs, self.episode_length, 1), dtype=np.bool), } if self.on_policy: self.episode_buffer.update({ 'values': np.zeros((self.n_envs, self.num_agents, self.episode_length) + args.rew_shape, np.float32), 'returns': np.zeros((self.n_envs, self.num_agents, self.episode_length) + args.rew_shape, np.float32), - 'advantages': np.zeros((self.n_envs, self.num_agents, self.episode_length) + args.rew_shape, np.float32), - 'log_pi_old': np.zeros((self.n_envs, self.num_agents, self.episode_length,), np.float32) + 'advantages': np.zeros((self.n_envs, self.num_agents, self.episode_length) + args.rew_shape, + np.float32), + 'log_pi_old': np.zeros((self.n_envs, self.num_agents, self.episode_length,), np.float32), }) + if self.args.agent == "COMA": + self.episode_buffer.update({ + 'actions_onehot': np.zeros((self.n_envs, self.num_agents, self.episode_length, self.dim_act), + dtype=np.float32)}) self.env_ptr = range(self.n_envs) - # environment details, representations, policies, optimizers, and agents. - self.agents = REGISTRY_Agent[args.agent](args, self.envs, args.device) # initialize hidden units for RNN. self.rnn_hidden = self.agents.policy.representation.init_hidden(self.n_envs * self.num_agents) - if self.on_policy: + if self.on_policy and self.args.agent != "COMA": self.rnn_hidden_critic = self.agents.policy.representation_critic.init_hidden(self.n_envs * self.num_agents) else: - self.rnn_hidden_critic = None + self.rnn_hidden_critic = [None, None] def get_agent_num(self): self.num_agents, self.num_enemies = self.envs.num_agents, self.envs.num_enemies @@ -124,14 +131,19 @@ def get_actions(self, obs_n, avail_actions, *rnn_hidden, state=None, test_mode=F log_pi_n, values_n, actions_n_onehot = None, None, None rnn_hidden_policy, rnn_hidden_critic = rnn_hidden[0], rnn_hidden[1] if self.on_policy: - rnn_hidden_next, actions_n, log_pi_n = self.agents.act(obs_n, *rnn_hidden_policy, - avail_actions=avail_actions, - test_mode=test_mode) + if self.args.agent == "COMA": + rnn_hidden_next, actions_n, actions_n_onehot = self.agents.act(obs_n, *rnn_hidden_policy, + avail_actions=avail_actions, + test_mode=test_mode) + else: + rnn_hidden_next, actions_n, log_pi_n = self.agents.act(obs_n, *rnn_hidden_policy, + avail_actions=avail_actions, + test_mode=test_mode) if test_mode: rnn_hidden_critic_next, values_n = None, 0 else: - rnn_hidden_critic_next, values_n = self.agents.values(obs_n, *rnn_hidden_critic, - state=state) + kwargs = {"state": state, "actions_n": actions_n, "actions_onehot": actions_n_onehot} + rnn_hidden_critic_next, values_n = self.agents.values(obs_n, *rnn_hidden_critic, **kwargs) else: rnn_hidden_next, actions_n = self.agents.act(obs_n, *rnn_hidden_policy, avail_actions=avail_actions, test_mode=test_mode) @@ -150,6 +162,8 @@ def store_data(self, t_envs, obs_n, actions_dict, state, rewards, terminated, av if self.on_policy: self.episode_buffer['values'][self.env_ptr, :, t_envs] = actions_dict['values'] self.episode_buffer['log_pi_old'][self.env_ptr, :, t_envs] = actions_dict['log_pi'] + if self.args.agent == "COMA": + self.episode_buffer['actions_onehot'][self.env_ptr, :, t_envs] = actions_dict['act_n_onehot'] def store_terminal_data(self, i_env, t_env, obs_n, state, last_avail_actions, filled): self.episode_buffer['obs'][i_env, :, t_env] = obs_n[i_env] @@ -189,9 +203,15 @@ def train_episode(self, n_episodes): else: rnn_h_critic_i = self.agents.policy.representation_critic.get_hidden_item(batch_select, *rnn_hidden_critic) - _, values_next = self.agents.values([obs_n[i_env]], *rnn_h_critic_i, state=[state[i_env]]) - rnn_hidden_critic = self.agents.policy.representation_critic.init_hidden_item(batch_select, - *rnn_hidden_critic) + kwargs = {"state": [state[i_env]], + "actions_n": actions_dict['actions_n'][i_env], + "actions_onehot": actions_dict['act_n_onehot'][i_env]} + _, values_next = self.agents.values([obs_n[i_env]], *rnn_h_critic_i, **kwargs) + if self.args.agent != "COMA": + rnn_hidden_critic = self.agents.policy.representation_critic.init_hidden_item(batch_select, + *rnn_hidden_critic) + else: + rnn_hidden_critic = [None, None] self.agents.memory.finish_path(values_next, i_env, episode_data=self.episode_buffer, current_t=self.envs_step[i_env], value_normalizer=self.agents.learner.value_normalizer) @@ -211,7 +231,8 @@ def train_episode(self, n_episodes): step_info["Train-Episode-Rewards/env-%d" % i_env] = info[i_env]["episode_score"] else: step_info["Train-Results/Episode-Steps"] = {"env-%d" % i_env: info[i_env]["episode_step"]} - step_info["Train-Results/Episode-Rewards"] = {"env-%d" % i_env: info[i_env]["episode_score"]} + step_info["Train-Results/Episode-Rewards"] = { + "env-%d" % i_env: info[i_env]["episode_score"]} self.log_infos(step_info, self.current_step) self.current_step += self.n_envs @@ -255,7 +276,8 @@ def test_episode(self, n_episodes): for step in range(self.episode_length): available_actions = self.test_envs.get_avail_actions() actions_dict = self.get_actions(obs_n, available_actions, rnn_hidden, None, test_mode=True) - next_obs_n, next_state, rewards, terminated, truncated, info = self.test_envs.step(actions_dict['actions_n']) + next_obs_n, next_state, rewards, terminated, truncated, info = self.test_envs.step( + actions_dict['actions_n']) if self.args.render_mode == "rgb_array" and self.render: images = self.test_envs.render(self.args.render_mode) for idx, img in enumerate(images): @@ -374,4 +396,3 @@ def benchmark(self): wandb.finish() else: self.writer.close() - diff --git a/xuanpolicy/torch/utils/input_reformat.py b/xuanpolicy/torch/utils/input_reformat.py index 347b9b953..7229a6fe5 100644 --- a/xuanpolicy/torch/utils/input_reformat.py +++ b/xuanpolicy/torch/utils/input_reformat.py @@ -7,8 +7,8 @@ import torch -def get_repre_in(args): - representation_name = args.representation +def get_repre_in(args, name=None): + representation_name = args.representation if name is None else name input_dict = deepcopy(Representation_Inputs_All) if args.env_name in ["StarCraft2", "Football", "MAgent2"]: input_dict["input_shape"] = (args.dim_obs, )