Skip to content

Commit

Permalink
ippo
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Oct 19, 2023
1 parent 3b217b9 commit de9315d
Show file tree
Hide file tree
Showing 2 changed files with 331 additions and 0 deletions.
118 changes: 118 additions & 0 deletions xuanpolicy/torch/agents/multi_agent_rl/ippo_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch

from xuanpolicy.torch.agents import *


class MAPPO_Agents(MARLAgents):
def __init__(self,
config: Namespace,
envs: DummyVecEnv_Pettingzoo,
device: Optional[Union[int, str, torch.device]] = None):
self.gamma = config.gamma
self.n_envs = envs.num_envs
self.n_size = config.n_size
self.n_epoch = config.n_epoch
self.n_minibatch = config.n_minibatch
if config.state_space is not None:
config.dim_state, state_shape = config.state_space.shape[0], config.state_space.shape
else:
config.dim_state, state_shape = None, None

input_representation = get_repre_in(config)
self.use_recurrent = config.use_recurrent
self.use_global_state = config.use_global_state
# create representation for actor
kwargs_rnn = {"N_recurrent_layers": config.N_recurrent_layers,
"dropout": config.dropout,
"rnn": config.rnn} if self.use_recurrent else {}
representation = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn)
# create representation for critic
input_representation[0] = (config.dim_state,) if self.use_global_state else (config.dim_obs * config.n_agents,)
representation_critic = REGISTRY_Representation[config.representation](*input_representation, **kwargs_rnn)
# create policy
input_policy = get_policy_in_marl(config, (representation, representation_critic))
policy = REGISTRY_Policy[config.policy](*input_policy,
use_recurrent=config.use_recurrent,
rnn=config.rnn,
gain=config.gain)
optimizer = torch.optim.Adam(policy.parameters(),
lr=config.learning_rate, eps=1e-5,
weight_decay=config.weight_decay)
self.observation_space = envs.observation_space
self.action_space = envs.action_space
self.auxiliary_info_shape = {}

buffer = MARL_OnPolicyBuffer_RNN if self.use_recurrent else MARL_OnPolicyBuffer
input_buffer = (config.n_agents, config.state_space.shape, config.obs_shape, config.act_shape, config.rew_shape,
config.done_shape, envs.num_envs, config.n_size,
config.use_gae, config.use_advnorm, config.gamma, config.gae_lambda)
memory = buffer(*input_buffer, max_episode_length=envs.max_episode_length, dim_act=config.dim_act)
self.buffer_size = memory.buffer_size
self.batch_size = self.buffer_size // self.n_minibatch

learner = MAPPO_Clip_Learner(config, policy, optimizer, None,
config.device, config.model_dir, config.gamma)
super(MAPPO_Agents, self).__init__(config, envs, policy, memory, learner, device,
config.log_dir, config.model_dir)
self.share_values = True if config.rew_shape[0] == 1 else False
self.on_policy = True

def act(self, obs_n, *rnn_hidden, avail_actions=None, state=None, test_mode=False):
batch_size = len(obs_n)
agents_id = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)
obs_in = torch.Tensor(obs_n).view([batch_size, self.n_agents, -1]).to(self.device)
if self.use_recurrent:
batch_agents = batch_size * self.n_agents
hidden_state, dists = self.policy(obs_in.view(batch_agents, 1, -1),
agents_id.view(batch_agents, 1, -1),
*rnn_hidden,
avail_actions=avail_actions.reshape(batch_agents, 1, -1))
actions = dists.stochastic_sample()
log_pi_a = dists.log_prob(actions).reshape(batch_size, self.n_agents)
actions = actions.reshape(batch_size, self.n_agents)
else:
hidden_state, dists = self.policy(obs_in, agents_id, avail_actions=avail_actions)
actions = dists.stochastic_sample()
log_pi_a = dists.log_prob(actions)
return hidden_state, actions.detach().cpu().numpy(), log_pi_a.detach().cpu().numpy()

def values(self, obs_n, *rnn_hidden, state=None):
batch_size = len(obs_n)
agents_id = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)
# build critic input
if self.use_global_state:
state = torch.Tensor(state).unsqueeze(1).to(self.device)
critic_in = state.expand(-1, self.n_agents, -1)
else:
critic_in = torch.Tensor(obs_n).view([batch_size, 1, -1]).to(self.device)
critic_in = critic_in.expand(-1, self.n_agents, -1)
# get critic values
if self.use_recurrent:
hidden_state, values_n = self.policy.get_values(critic_in.unsqueeze(2), # add a sequence length axis.
agents_id.unsqueeze(2),
*rnn_hidden)
values_n = values_n.squeeze(2)
else:
hidden_state, values_n = self.policy.get_values(critic_in, agents_id)

return hidden_state, values_n.detach().cpu().numpy()

def train(self, i_step):
if self.memory.full:
info_train = {}
indexes = np.arange(self.buffer_size)
for _ in range(self.n_epoch):
np.random.shuffle(indexes)
for start in range(0, self.buffer_size, self.batch_size):
end = start + self.batch_size
sample_idx = indexes[start:end]
sample = self.memory.sample(sample_idx)
if self.use_recurrent:
info_train = self.learner.update_recurrent(sample)
else:
info_train = self.learner.update(sample)
self.learner.lr_decay(i_step)
self.memory.clear()
return info_train
else:
return {}
213 changes: 213 additions & 0 deletions xuanpolicy/torch/learners/multi_agent_rl/ippo_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""
Multi-Agent Proximal Policy Optimization (MAPPO)
Paper link:
https://arxiv.org/pdf/2103.01955.pdf
Implementation: Pytorch
"""
from xuanpolicy.torch.learners import *
from xuanpolicy.torch.utils.value_norm import ValueNorm
from xuanpolicy.torch.utils.operations import update_linear_decay


class MAPPO_Clip_Learner(LearnerMAS):
def __init__(self,
config: Namespace,
policy: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
device: Optional[Union[int, str, torch.device]] = None,
model_dir: str = "./",
gamma: float = 0.99,
):
self.gamma = gamma
self.clip_range = config.clip_range
self.use_linear_lr_decay = config.use_linear_lr_decay
self.use_grad_norm, self.max_grad_norm = config.use_grad_norm, config.max_grad_norm
self.use_value_clip, self.value_clip_range = config.use_value_clip, config.value_clip_range
self.use_huber_loss, self.huber_delta = config.use_huber_loss, config.huber_delta
self.use_value_norm = config.use_value_norm
self.use_global_state = config.use_global_state
self.vf_coef, self.ent_coef = config.vf_coef, config.ent_coef
self.mse_loss = nn.MSELoss()
self.huber_loss = nn.HuberLoss(reduction="none", delta=self.huber_delta)
super(MAPPO_Clip_Learner, self).__init__(config, policy, optimizer, scheduler, device, model_dir)
if self.use_value_norm:
self.value_normalizer = ValueNorm(1).to(device)
else:
self.value_normalizer = None
self.lr = config.learning_rate
self.end_factor_lr_decay = config.end_factor_lr_decay

def lr_decay(self, i_step):
if self.use_linear_lr_decay:
update_linear_decay(self.optimizer, i_step, self.running_steps, self.lr, self.end_factor_lr_decay)

def update(self, sample):
info = {}
self.iterations += 1
state = torch.Tensor(sample['state']).to(self.device)
obs = torch.Tensor(sample['obs']).to(self.device)
actions = torch.Tensor(sample['actions']).to(self.device)
values = torch.Tensor(sample['values']).to(self.device)
returns = torch.Tensor(sample['returns']).to(self.device)
advantages = torch.Tensor(sample['advantages']).to(self.device)
log_pi_old = torch.Tensor(sample['log_pi_old']).to(self.device)
agent_mask = torch.Tensor(sample['agent_mask']).float().reshape(-1, self.n_agents, 1).to(self.device)
batch_size = obs.shape[0]
IDs = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)

# actor loss
_, pi_dist = self.policy(obs, IDs)
log_pi = pi_dist.log_prob(actions)
ratio = torch.exp(log_pi - log_pi_old).reshape(batch_size, self.n_agents, 1)
advantages_mask = advantages.detach() * agent_mask
surrogate1 = ratio * advantages_mask
surrogate2 = torch.clip(ratio, 1 - self.clip_range, 1 + self.clip_range) * advantages_mask
loss_a = -torch.sum(torch.min(surrogate1, surrogate2), dim=-2, keepdim=True).mean()

# entropy loss
entropy = pi_dist.entropy().reshape(agent_mask.shape) * agent_mask
loss_e = entropy.mean()

# critic loss
critic_in = torch.Tensor(obs).reshape([batch_size, 1, -1]).to(self.device)
critic_in = critic_in.expand(-1, self.n_agents, -1)
_, value_pred = self.policy.get_values(critic_in, IDs)
value_pred = value_pred
value_target = returns
if self.use_value_clip:
value_clipped = values + (value_pred - values).clamp(-self.value_clip_range, self.value_clip_range)
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred, value_target)
loss_v_clipped = self.huber_loss(value_clipped, value_target)
else:
loss_v = (value_pred - value_target) ** 2
loss_v_clipped = (value_clipped - value_target) ** 2
loss_c = torch.max(loss_v, loss_v_clipped) * agent_mask
loss_c = loss_c.sum() / agent_mask.sum()
else:
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred, value_target) * agent_mask
else:
loss_v = ((value_pred - value_target) ** 2) * agent_mask
loss_c = loss_v.sum() / agent_mask.sum()

loss = loss_a + self.vf_coef * loss_c - self.ent_coef * loss_e
self.optimizer.zero_grad()
loss.backward()
if self.use_grad_norm:
grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
info["gradient_norm"] = grad_norm.item()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()

# Logger
lr = self.optimizer.state_dict()['param_groups'][0]['lr']

info.update({
"learning_rate": lr,
"actor_loss": loss_a.item(),
"critic_loss": loss_c.item(),
"entropy": loss_e.item(),
"loss": loss.item(),
"predict_value": value_pred.mean().item()
})

return info

def update_recurrent(self, sample):
info = {}
self.iterations += 1
state = torch.Tensor(sample['state']).to(self.device)
if self.use_global_state:
state = state.unsqueeze(1).expand(-1, self.n_agents, -1, -1)
obs = torch.Tensor(sample['obs']).to(self.device)
actions = torch.Tensor(sample['actions']).to(self.device)
values = torch.Tensor(sample['values']).to(self.device)
returns = torch.Tensor(sample['returns']).to(self.device)
advantages = torch.Tensor(sample['advantages']).to(self.device)
log_pi_old = torch.Tensor(sample['log_pi_old']).to(self.device)
avail_actions = torch.Tensor(sample['avail_actions']).float().to(self.device)
filled = torch.Tensor(sample['filled']).float().to(self.device)
batch_size = obs.shape[0]
episode_length = actions.shape[2]
IDs = torch.eye(self.n_agents).unsqueeze(1).unsqueeze(0).expand(batch_size, -1, episode_length + 1, -1).to(
self.device)

# actor loss
rnn_hidden_actor = self.policy.representation.init_hidden(batch_size * self.n_agents)
_, pi_dist = self.policy(obs[:, :, :-1].reshape(-1, episode_length, self.dim_obs),
IDs[:, :, :-1].reshape(-1, episode_length, self.n_agents),
*rnn_hidden_actor,
avail_actions=avail_actions[:, :, :-1].reshape(-1, episode_length, self.dim_act))
log_pi = pi_dist.log_prob(actions.reshape(-1, episode_length)).reshape(batch_size, self.n_agents, episode_length)
ratio = torch.exp(log_pi - log_pi_old).unsqueeze(-1)
filled_n = filled.unsqueeze(1).expand(batch_size, self.n_agents, episode_length, 1)
surrogate1 = ratio * advantages
surrogate2 = torch.clip(ratio, 1 - self.clip_range, 1 + self.clip_range) * advantages
loss_a = -(torch.min(surrogate1, surrogate2) * filled_n).sum() / filled_n.sum()

# entropy loss
entropy = pi_dist.entropy().reshape(batch_size, self.n_agents, episode_length, 1)
entropy = entropy * filled_n
loss_e = entropy.sum() / filled_n.sum()

# critic loss
rnn_hidden_critic = self.policy.representation_critic.init_hidden(batch_size * self.n_agents)
if self.use_global_state:
_, value_pred = self.policy.get_values(state[:, :, :-1], IDs[:, :, :-1], *rnn_hidden_critic)
else:
critic_in = obs[:, :, :-1].transpose(1, 2).reshape(batch_size, episode_length, -1)
critic_in = critic_in.unsqueeze(1).expand(-1, self.n_agents, -1, -1)
_, value_pred = self.policy.get_values(critic_in, IDs[:, :, :-1], *rnn_hidden_critic)
value_target = returns.reshape(-1, 1)
values = values.reshape(-1, 1)
value_pred = value_pred.reshape(-1, 1)
filled_all = filled_n.reshape(-1, 1)
if self.use_value_clip:
value_clipped = values + (value_pred - values).clamp(-self.value_clip_range, self.value_clip_range)
if self.use_value_norm:
self.value_normalizer.update(value_target)
value_target = self.value_normalizer.normalize(value_target)
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred, value_target)
loss_v_clipped = self.huber_loss(value_clipped, value_target)
else:
loss_v = (value_pred - value_target) ** 2
loss_v_clipped = (value_clipped - value_target) ** 2
loss_c = torch.max(loss_v, loss_v_clipped) * filled_all
loss_c = loss_c.sum() / filled_all.sum()
else:
if self.use_value_norm:
self.value_normalizer.update(value_target)
value_pred = self.value_normalizer.normalize(value_pred)
if self.use_huber_loss:
loss_v = self.huber_loss(value_pred, value_target)
else:
loss_v = (value_pred - value_target) ** 2
loss_c = (loss_v * filled_all).sum() / filled_all.sum()

loss = loss_a + self.vf_coef * loss_c - self.ent_coef * loss_e
self.optimizer.zero_grad()
loss.backward()
if self.use_grad_norm:
grad_norm = torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
info["gradient_norm"] = grad_norm.item()
self.optimizer.step()
if self.scheduler is not None:
self.scheduler.step()

# Logger
lr = self.optimizer.state_dict()['param_groups'][0]['lr']

info.update({
"learning_rate": lr,
"actor_loss": loss_a.item(),
"critic_loss": loss_c.item(),
"entropy": loss_e.item(),
"loss": loss.item(),
"predict_value": value_pred.mean().item()
})

return info

0 comments on commit de9315d

Please sign in to comment.