From de9315d69f1fb231cbfbb4002386205c7947221c Mon Sep 17 00:00:00 2001 From: wenzhangliu Date: Thu, 19 Oct 2023 10:21:20 +0800 Subject: [PATCH] ippo --- .../agents/multi_agent_rl/ippo_agents.py | 118 ++++++++++ .../learners/multi_agent_rl/ippo_learner.py | 213 ++++++++++++++++++ 2 files changed, 331 insertions(+) create mode 100644 xuanpolicy/torch/agents/multi_agent_rl/ippo_agents.py create mode 100644 xuanpolicy/torch/learners/multi_agent_rl/ippo_learner.py diff --git a/xuanpolicy/torch/agents/multi_agent_rl/ippo_agents.py b/xuanpolicy/torch/agents/multi_agent_rl/ippo_agents.py new file mode 100644 index 000000000..c7fa01691 --- /dev/null +++ b/xuanpolicy/torch/agents/multi_agent_rl/ippo_agents.py @@ -0,0 +1,118 @@ +import torch + +from xuanpolicy.torch.agents import * + + +class MAPPO_Agents(MARLAgents): + def __init__(self, + config: Namespace, + envs: DummyVecEnv_Pettingzoo, + device: Optional[Union[int, str, torch.device]] = None): + self.gamma = config.gamma + self.n_envs = envs.num_envs + self.n_size = config.n_size + self.n_epoch = config.n_epoch + self.n_minibatch = config.n_minibatch + if config.state_space is not None: + config.dim_state, state_shape = config.state_space.shape[0], config.state_space.shape + else: + config.dim_state, state_shape = None, None + + input_representation = get_repre_in(config) + self.use_recurrent = config.use_recurrent + self.use_global_state = config.use_global_state + # create representation for actor + kwargs_rnn = {"N_recurrent_layers": config.N_recurrent_layers, + "dropout": config.dropout, + "rnn": config.rnn} if self.use_recurrent else {} + representation = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn) + # create representation for critic + input_representation[0] = (config.dim_state,) if self.use_global_state else (config.dim_obs * config.n_agents,) + representation_critic = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn) + # create policy + input_policy = get_policy_in_marl(config, (representation, representation_critic)) + policy = REGISTRY_Policy[config.policy](*input_policy, + use_recurrent=config.use_recurrent, + rnn=config.rnn, + gain=config.gain) + optimizer = torch.optim.Adam(policy.parameters(), + lr=config.learning_rate, eps=1e-5, + weight_decay=config.weight_decay) + self.observation_space = envs.observation_space + self.action_space = envs.action_space + self.auxiliary_info_shape = {} + + buffer = MARL_OnPolicyBuffer_RNN if self.use_recurrent else MARL_OnPolicyBuffer + input_buffer = (config.n_agents, config.state_space.shape, config.obs_shape, config.act_shape, config.rew_shape, + config.done_shape, envs.num_envs, config.n_size, + config.use_gae, config.use_advnorm, config.gamma, config.gae_lambda) + memory = buffer(*input_buffer, max_episode_length=envs.max_episode_length, dim_act=config.dim_act) + self.buffer_size = memory.buffer_size + self.batch_size = self.buffer_size // self.n_minibatch + + learner = MAPPO_Clip_Learner(config, policy, optimizer, None, + config.device, config.model_dir, config.gamma) + super(MAPPO_Agents, self).__init__(config, envs, policy, memory, learner, device, + config.log_dir, config.model_dir) + self.share_values = True if config.rew_shape[0] == 1 else False + self.on_policy = True + + def act(self, obs_n, *rnn_hidden, avail_actions=None, state=None, test_mode=False): + batch_size = len(obs_n) + agents_id = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device) + obs_in = torch.Tensor(obs_n).view([batch_size, self.n_agents, -1]).to(self.device) + if self.use_recurrent: + batch_agents = batch_size * self.n_agents + hidden_state, dists = self.policy(obs_in.view(batch_agents, 1, -1), + agents_id.view(batch_agents, 1, -1), + *rnn_hidden, + avail_actions=avail_actions.reshape(batch_agents, 1, -1)) + actions = dists.stochastic_sample() + log_pi_a = dists.log_prob(actions).reshape(batch_size, self.n_agents) + actions = actions.reshape(batch_size, self.n_agents) + else: + hidden_state, dists = self.policy(obs_in, agents_id, avail_actions=avail_actions) + actions = dists.stochastic_sample() + log_pi_a = dists.log_prob(actions) + return hidden_state, actions.detach().cpu().numpy(), log_pi_a.detach().cpu().numpy() + + def values(self, obs_n, *rnn_hidden, state=None): + batch_size = len(obs_n) + agents_id = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device) + # build critic input + if self.use_global_state: + state = torch.Tensor(state).unsqueeze(1).to(self.device) + critic_in = state.expand(-1, self.n_agents, -1) + else: + critic_in = torch.Tensor(obs_n).view([batch_size, 1, -1]).to(self.device) + critic_in = critic_in.expand(-1, self.n_agents, -1) + # get critic values + if self.use_recurrent: + hidden_state, values_n = self.policy.get_values(critic_in.unsqueeze(2), # add a sequence length axis. + agents_id.unsqueeze(2), + *rnn_hidden) + values_n = values_n.squeeze(2) + else: + hidden_state, values_n = self.policy.get_values(critic_in, agents_id) + + return hidden_state, values_n.detach().cpu().numpy() + + def train(self, i_step): + if self.memory.full: + info_train = {} + indexes = np.arange(self.buffer_size) + for _ in range(self.n_epoch): + np.random.shuffle(indexes) + for start in range(0, self.buffer_size, self.batch_size): + end = start + self.batch_size + sample_idx = indexes[start:end] + sample = self.memory.sample(sample_idx) + if self.use_recurrent: + info_train = self.learner.update_recurrent(sample) + else: + info_train = self.learner.update(sample) + self.learner.lr_decay(i_step) + self.memory.clear() + return info_train + else: + return {} diff --git a/xuanpolicy/torch/learners/multi_agent_rl/ippo_learner.py b/xuanpolicy/torch/learners/multi_agent_rl/ippo_learner.py new file mode 100644 index 000000000..7ad5353b2 --- /dev/null +++ b/xuanpolicy/torch/learners/multi_agent_rl/ippo_learner.py @@ -0,0 +1,213 @@ +""" +Multi-Agent Proximal Policy Optimization (MAPPO) +Paper link: +https://arxiv.org/pdf/2103.01955.pdf +Implementation: Pytorch +""" +from xuanpolicy.torch.learners import * +from xuanpolicy.torch.utils.value_norm import ValueNorm +from xuanpolicy.torch.utils.operations import update_linear_decay + + +class MAPPO_Clip_Learner(LearnerMAS): + def __init__(self, + config: Namespace, + policy: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + device: Optional[Union[int, str, torch.device]] = None, + model_dir: str = "./", + gamma: float = 0.99, + ): + self.gamma = gamma + self.clip_range = config.clip_range + self.use_linear_lr_decay = config.use_linear_lr_decay + self.use_grad_norm, self.max_grad_norm = config.use_grad_norm, config.max_grad_norm + self.use_value_clip, self.value_clip_range = config.use_value_clip, config.value_clip_range + self.use_huber_loss, self.huber_delta = config.use_huber_loss, config.huber_delta + self.use_value_norm = config.use_value_norm + self.use_global_state = config.use_global_state + self.vf_coef, self.ent_coef = config.vf_coef, config.ent_coef + self.mse_loss = nn.MSELoss() + self.huber_loss = nn.HuberLoss(reduction="none", delta=self.huber_delta) + super(MAPPO_Clip_Learner, self).__init__(config, policy, optimizer, scheduler, device, model_dir) + if self.use_value_norm: + self.value_normalizer = ValueNorm(1).to(device) + else: + self.value_normalizer = None + self.lr = config.learning_rate + self.end_factor_lr_decay = config.end_factor_lr_decay + + def lr_decay(self, i_step): + if self.use_linear_lr_decay: + update_linear_decay(self.optimizer, i_step, self.running_steps, self.lr, self.end_factor_lr_decay) + + def update(self, sample): + info = {} + self.iterations += 1 + state = torch.Tensor(sample['state']).to(self.device) + obs = torch.Tensor(sample['obs']).to(self.device) + actions = torch.Tensor(sample['actions']).to(self.device) + values = torch.Tensor(sample['values']).to(self.device) + returns = torch.Tensor(sample['returns']).to(self.device) + advantages = torch.Tensor(sample['advantages']).to(self.device) + log_pi_old = torch.Tensor(sample['log_pi_old']).to(self.device) + agent_mask = torch.Tensor(sample['agent_mask']).float().reshape(-1, self.n_agents, 1).to(self.device) + batch_size = obs.shape[0] + IDs = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device) + + # actor loss + _, pi_dist = self.policy(obs, IDs) + log_pi = pi_dist.log_prob(actions) + ratio = torch.exp(log_pi - log_pi_old).reshape(batch_size, self.n_agents, 1) + advantages_mask = advantages.detach() * agent_mask + surrogate1 = ratio * advantages_mask + surrogate2 = torch.clip(ratio, 1 - self.clip_range, 1 + self.clip_range) * advantages_mask + loss_a = -torch.sum(torch.min(surrogate1, surrogate2), dim=-2, keepdim=True).mean() + + # entropy loss + entropy = pi_dist.entropy().reshape(agent_mask.shape) * agent_mask + loss_e = entropy.mean() + + # critic loss + critic_in = torch.Tensor(obs).reshape([batch_size, 1, -1]).to(self.device) + critic_in = critic_in.expand(-1, self.n_agents, -1) + _, value_pred = self.policy.get_values(critic_in, IDs) + value_pred = value_pred + value_target = returns + if self.use_value_clip: + value_clipped = values + (value_pred - values).clamp(-self.value_clip_range, self.value_clip_range) + if self.use_huber_loss: + loss_v = self.huber_loss(value_pred, value_target) + loss_v_clipped = self.huber_loss(value_clipped, value_target) + else: + loss_v = (value_pred - value_target) ** 2 + loss_v_clipped = (value_clipped - value_target) ** 2 + loss_c = torch.max(loss_v, loss_v_clipped) * agent_mask + loss_c = loss_c.sum() / agent_mask.sum() + else: + if self.use_huber_loss: + loss_v = self.huber_loss(value_pred, value_target) * agent_mask + else: + loss_v = ((value_pred - value_target) ** 2) * agent_mask + loss_c = loss_v.sum() / agent_mask.sum() + + loss = loss_a + self.vf_coef * loss_c - self.ent_coef * loss_e + self.optimizer.zero_grad() + loss.backward() + if self.use_grad_norm: + grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + info["gradient_norm"] = grad_norm.item() + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + + # Logger + lr = self.optimizer.state_dict()['param_groups'][0]['lr'] + + info.update({ + "learning_rate": lr, + "actor_loss": loss_a.item(), + "critic_loss": loss_c.item(), + "entropy": loss_e.item(), + "loss": loss.item(), + "predict_value": value_pred.mean().item() + }) + + return info + + def update_recurrent(self, sample): + info = {} + self.iterations += 1 + state = torch.Tensor(sample['state']).to(self.device) + if self.use_global_state: + state = state.unsqueeze(1).expand(-1, self.n_agents, -1, -1) + obs = torch.Tensor(sample['obs']).to(self.device) + actions = torch.Tensor(sample['actions']).to(self.device) + values = torch.Tensor(sample['values']).to(self.device) + returns = torch.Tensor(sample['returns']).to(self.device) + advantages = torch.Tensor(sample['advantages']).to(self.device) + log_pi_old = torch.Tensor(sample['log_pi_old']).to(self.device) + avail_actions = torch.Tensor(sample['avail_actions']).float().to(self.device) + filled = torch.Tensor(sample['filled']).float().to(self.device) + batch_size = obs.shape[0] + episode_length = actions.shape[2] + IDs = torch.eye(self.n_agents).unsqueeze(1).unsqueeze(0).expand(batch_size, -1, episode_length + 1, -1).to( + self.device) + + # actor loss + rnn_hidden_actor = self.policy.representation.init_hidden(batch_size * self.n_agents) + _, pi_dist = self.policy(obs[:, :, :-1].reshape(-1, episode_length, self.dim_obs), + IDs[:, :, :-1].reshape(-1, episode_length, self.n_agents), + *rnn_hidden_actor, + avail_actions=avail_actions[:, :, :-1].reshape(-1, episode_length, self.dim_act)) + log_pi = pi_dist.log_prob(actions.reshape(-1, episode_length)).reshape(batch_size, self.n_agents, episode_length) + ratio = torch.exp(log_pi - log_pi_old).unsqueeze(-1) + filled_n = filled.unsqueeze(1).expand(batch_size, self.n_agents, episode_length, 1) + surrogate1 = ratio * advantages + surrogate2 = torch.clip(ratio, 1 - self.clip_range, 1 + self.clip_range) * advantages + loss_a = -(torch.min(surrogate1, surrogate2) * filled_n).sum() / filled_n.sum() + + # entropy loss + entropy = pi_dist.entropy().reshape(batch_size, self.n_agents, episode_length, 1) + entropy = entropy * filled_n + loss_e = entropy.sum() / filled_n.sum() + + # critic loss + rnn_hidden_critic = self.policy.representation_critic.init_hidden(batch_size * self.n_agents) + if self.use_global_state: + _, value_pred = self.policy.get_values(state[:, :, :-1], IDs[:, :, :-1], *rnn_hidden_critic) + else: + critic_in = obs[:, :, :-1].transpose(1, 2).reshape(batch_size, episode_length, -1) + critic_in = critic_in.unsqueeze(1).expand(-1, self.n_agents, -1, -1) + _, value_pred = self.policy.get_values(critic_in, IDs[:, :, :-1], *rnn_hidden_critic) + value_target = returns.reshape(-1, 1) + values = values.reshape(-1, 1) + value_pred = value_pred.reshape(-1, 1) + filled_all = filled_n.reshape(-1, 1) + if self.use_value_clip: + value_clipped = values + (value_pred - values).clamp(-self.value_clip_range, self.value_clip_range) + if self.use_value_norm: + self.value_normalizer.update(value_target) + value_target = self.value_normalizer.normalize(value_target) + if self.use_huber_loss: + loss_v = self.huber_loss(value_pred, value_target) + loss_v_clipped = self.huber_loss(value_clipped, value_target) + else: + loss_v = (value_pred - value_target) ** 2 + loss_v_clipped = (value_clipped - value_target) ** 2 + loss_c = torch.max(loss_v, loss_v_clipped) * filled_all + loss_c = loss_c.sum() / filled_all.sum() + else: + if self.use_value_norm: + self.value_normalizer.update(value_target) + value_pred = self.value_normalizer.normalize(value_pred) + if self.use_huber_loss: + loss_v = self.huber_loss(value_pred, value_target) + else: + loss_v = (value_pred - value_target) ** 2 + loss_c = (loss_v * filled_all).sum() / filled_all.sum() + + loss = loss_a + self.vf_coef * loss_c - self.ent_coef * loss_e + self.optimizer.zero_grad() + loss.backward() + if self.use_grad_norm: + grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + info["gradient_norm"] = grad_norm.item() + self.optimizer.step() + if self.scheduler is not None: + self.scheduler.step() + + # Logger + lr = self.optimizer.state_dict()['param_groups'][0]['lr'] + + info.update({ + "learning_rate": lr, + "actor_loss": loss_a.item(), + "critic_loss": loss_c.item(), + "entropy": loss_e.item(), + "loss": loss.item(), + "predict_value": value_pred.mean().item() + }) + + return info