diff --git a/docs/source/documents/api/learners/marl/matd3.rst b/docs/source/documents/api/learners/marl/matd3.rst
index f1b4979a8..b9f9bf8d5 100644
--- a/docs/source/documents/api/learners/marl/matd3.rst
+++ b/docs/source/documents/api/learners/marl/matd3.rst
@@ -1,6 +1,185 @@
MATD3_Learner
=====================================
+xxxxxx.
+
+.. raw:: html
+
+
+
+**PyTorch:**
+
+.. py:class::
+ xuance.torch.learners.multi_agent_rl.matd3_learner.MATD3_Learner(config, policy, optimizer, scheduler, device, model_dir, gamma, sync_frequency, delay)
+
+ :param config: xxxxxx.
+ :type config: xxxxxx
+ :param policy: xxxxxx.
+ :type policy: xxxxxx
+ :param optimizer: xxxxxx.
+ :type optimizer: xxxxxx
+ :param scheduler: xxxxxx.
+ :type scheduler: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+ :param model_dir: xxxxxx.
+ :type model_dir: xxxxxx
+ :param gamma: xxxxxx.
+ :type gamma: xxxxxx
+ :param sync_frequency: xxxxxx.
+ :type sync_frequency: xxxxxx
+ :param delay: xxxxxx.
+ :type delay: xxxxxx
+
+.. py:function::
+ xuance.torch.learners.multi_agent_rl.matd3_learner.MATD3_Learner.update(sample)
+
+ xxxxxx.
+
+ :param sample: xxxxxx.
+ :type sample: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. raw:: html
+
+
+
+**TensorFlow:**
+
+.. raw:: html
+
+
+
+**MindSpore:**
+
+.. raw:: html
+
+
+
+Source Code
+-----------------
+
+.. tabs::
+
+ .. group-tab:: PyTorch
+
+ .. code-block:: python
+
+ """
+ Multi-Agent TD3
+ """
+ from xuance.torch.learners import *
+
+
+ class MATD3_Learner(LearnerMAS):
+ def __init__(self,
+ config: Namespace,
+ policy: nn.Module,
+ optimizer: Sequence[torch.optim.Optimizer],
+ scheduler: Sequence[torch.optim.lr_scheduler._LRScheduler] = None,
+ device: Optional[Union[int, str, torch.device]] = None,
+ model_dir: str = "./",
+ gamma: float = 0.99,
+ sync_frequency: int = 100,
+ delay: int = 3
+ ):
+ self.gamma = gamma
+ self.tau = config.tau
+ self.delay = delay
+ self.sync_frequency = sync_frequency
+ self.mse_loss = nn.MSELoss()
+ super(MATD3_Learner, self).__init__(config, policy, optimizer, scheduler, device, model_dir)
+ self.optimizer = {
+ 'actor': optimizer[0],
+ 'critic_A': optimizer[1],
+ 'critic_B': optimizer[2]
+ }
+ self.scheduler = {
+ 'actor': scheduler[0],
+ 'critic_A': scheduler[1],
+ 'critic_B': scheduler[2]
+ }
+
+ def update(self, sample):
+ self.iterations += 1
+ obs = torch.Tensor(sample['obs']).to(self.device)
+ actions = torch.Tensor(sample['actions']).to(self.device)
+ obs_next = torch.Tensor(sample['obs_next']).to(self.device)
+ rewards = torch.Tensor(sample['rewards']).to(self.device)
+ terminals = torch.Tensor(sample['terminals']).float().reshape(-1, self.n_agents, 1).to(self.device)
+ agent_mask = torch.Tensor(sample['agent_mask']).float().reshape(-1, self.n_agents, 1).to(self.device)
+ IDs = torch.eye(self.n_agents).unsqueeze(0).expand(self.args.batch_size, -1, -1).to(self.device)
+
+ # train critic
+ _, action_q = self.policy.Qaction(obs, actions, IDs)
+ actions_next = self.policy.target_actor(obs_next, IDs)
+ _, target_q = self.policy.Qtarget(obs_next, actions_next, IDs)
+ q_target = rewards + (1 - terminals) * self.args.gamma * target_q
+ td_error = (action_q - q_target.detach()) * agent_mask
+ loss_c = (td_error ** 2).sum() / agent_mask.sum()
+ # loss_c = F.mse_loss(torch.tile(q_target.detach(), (1, 2)), action_q)
+ self.optimizer['critic_B'].zero_grad()
+ self.optimizer['critic_A'].zero_grad()
+ loss_c.backward()
+ torch.nn.utils.clip_grad_norm_(self.policy.parameters_critic, self.args.grad_clip_norm)
+ self.optimizer['critic_A'].step()
+ self.optimizer['critic_B'].step()
+ if self.scheduler['critic_A'] is not None:
+ self.scheduler['critic_A'].step()
+ self.scheduler['critic_B'].step()
+
+ # actor update
+ if self.iterations % self.delay == 0:
+ _, actions_eval = self.policy(obs, IDs)
+ _, policy_q = self.policy.Qpolicy(obs, actions_eval, IDs)
+ p_loss = -policy_q.mean()
+ self.optimizer['actor'].zero_grad()
+ p_loss.backward()
+ self.optimizer['actor'].step()
+ if self.scheduler is not None:
+ self.scheduler['actor'].step()
+ self.policy.soft_update(self.tau)
+
+ lr_a = self.optimizer['actor'].state_dict()['param_groups'][0]['lr']
+ lr_c_A = self.optimizer['critic_A'].state_dict()['param_groups'][0]['lr']
+ lr_c_B = self.optimizer['critic_B'].state_dict()['param_groups'][0]['lr']
+
+ info = {
+ "learning_rate_actor": lr_a,
+ "learning_rate_critic_A": lr_c_A,
+ "learning_rate_critic_B": lr_c_B,
+ "loss_critic_A": loss_c.item(),
+ "loss_critic_B": loss_c.item()
+ }
+ if self.iterations % self.delay == 0:
+ info["loss_actor"] = p_loss.item()
+
+ return info
+
+
+
+
+
+
+
+
+
+
+
+
+
+ .. group-tab:: TensorFlow
+
+ .. code-block:: python
+
+
+ .. group-tab:: MindSpore
+
+ .. code-block:: python
+
+
+
.. raw:: html
@@ -29,19 +208,18 @@ Source Code
-----------------
.. tabs::
-
+
.. group-tab:: PyTorch
-
- .. code-block:: python3
+ .. code-block:: python3
.. group-tab:: TensorFlow
-
- .. code-block:: python3
+ .. code-block:: python3
.. group-tab:: MindSpore
.. code-block:: python3
+
diff --git a/docs/source/documents/api/learners/marl/mfac.rst b/docs/source/documents/api/learners/marl/mfac.rst
index b35500777..e02c2a7bd 100644
--- a/docs/source/documents/api/learners/marl/mfac.rst
+++ b/docs/source/documents/api/learners/marl/mfac.rst
@@ -1,6 +1,194 @@
MFAC_Learner
======================
+xxxxxx.
+
+.. raw:: html
+
+
+
+**PyTorch:**
+
+.. py:class::
+ xuance.torch.learners.multi_agent_rl.mfac_learner.MFAC_Learner(config, policy, optimizer, scheduler, device, model_dir, gamma)
+
+ :param config: xxxxxx.
+ :type config: xxxxxx
+ :param policy: xxxxxx.
+ :type policy: xxxxxx
+ :param optimizer: xxxxxx.
+ :type optimizer: xxxxxx
+ :param scheduler: xxxxxx.
+ :type scheduler: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+ :param model_dir: xxxxxx.
+ :type model_dir: xxxxxx
+ :param gamma: xxxxxx.
+ :type gamma: xxxxxx
+
+.. py:function::
+ xuance.torch.learners.multi_agent_rl.mfac_learner.MFAC_Learner.update(sample)
+
+ xxxxxx.
+
+ :param sample: xxxxxx.
+ :type sample: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. raw:: html
+
+
+
+**TensorFlow:**
+
+.. raw:: html
+
+
+
+**MindSpore:**
+
+.. raw:: html
+
+
+
+Source Code
+-----------------
+
+.. tabs::
+
+ .. group-tab:: PyTorch
+
+ .. code-block:: python
+
+ """
+ MFAC: Mean Field Actor-Critic
+ Paper link:
+ http://proceedings.mlr.press/v80/yang18d/yang18d.pdf
+ Implementation: Pytorch
+ """
+ import torch
+
+ from xuance.torch.learners import *
+
+
+ class MFAC_Learner(LearnerMAS):
+ def __init__(self,
+ config: Namespace,
+ policy: nn.Module,
+ optimizer: Sequence[torch.optim.Optimizer],
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+ device: Optional[Union[int, str, torch.device]] = None,
+ model_dir: str = "./",
+ gamma: float = 0.99,
+ ):
+ self.gamma = gamma
+ self.tau = config.tau
+ self.mse_loss = nn.MSELoss()
+ super(MFAC_Learner, self).__init__(config, policy, optimizer, scheduler, device, model_dir)
+ self.optimizer = {
+ 'actor': optimizer[0],
+ 'critic': optimizer[1]
+ }
+ self.scheduler = {
+ 'actor': scheduler[0],
+ 'critic': scheduler[1]
+ }
+
+ def update(self, sample):
+ self.iterations += 1
+ obs = torch.Tensor(sample['obs']).to(self.device)
+ actions = torch.Tensor(sample['actions']).to(self.device)
+ obs_next = torch.Tensor(sample['obs_next']).to(self.device)
+ act_mean = torch.Tensor(sample['act_mean']).to(self.device)
+ # act_mean_next = torch.Tensor(sample['act_mean_next']).to(self.device)
+ rewards = torch.Tensor(sample['rewards']).to(self.device)
+ terminals = torch.Tensor(sample['terminals']).float().reshape(-1, self.n_agents, 1).to(self.device)
+ agent_mask = torch.Tensor(sample['agent_mask']).float().reshape(-1, self.n_agents, 1).to(self.device)
+ batch_size = obs.shape[0]
+ IDs = torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)
+
+ act_mean_n = act_mean.unsqueeze(1).repeat([1, self.n_agents, 1])
+
+ # train critic network
+ target_pi_dist_next = self.policy.target_actor(obs_next, IDs)
+ target_pi_next = target_pi_dist_next.logits.softmax(dim=-1)
+ actions_next = target_pi_dist_next.stochastic_sample()
+ actions_next_onehot = self.onehot_action(actions_next, self.dim_act).type(torch.float)
+ act_mean_next = actions_next_onehot.mean(dim=-2, keepdim=False)
+ act_mean_n_next = act_mean_next.unsqueeze(1).repeat([1, self.n_agents, 1])
+
+ q_eval = self.policy.critic(obs, act_mean_n, IDs)
+ q_eval_a = q_eval.gather(-1, actions.long().reshape([batch_size, self.n_agents, 1]))
+
+ q_eval_next = self.policy.target_critic(obs_next, act_mean_n_next, IDs)
+ shape = q_eval_next.shape
+ v_mf = torch.bmm(q_eval_next.reshape(-1, 1, shape[-1]), target_pi_next.reshape(-1, shape[-1], 1))
+ v_mf = v_mf.reshape(*(list(shape[0:-1]) + [1]))
+ q_target = rewards + (1 - terminals) * self.args.gamma * v_mf
+ td_error = (q_eval_a - q_target.detach()) * agent_mask
+ loss_c = (td_error ** 2).sum() / agent_mask.sum()
+ self.optimizer["critic"].zero_grad()
+ loss_c.backward()
+ self.optimizer["critic"].step()
+ if self.scheduler['critic'] is not None:
+ self.scheduler['critic'].step()
+
+ # train actor network
+ _, pi_dist = self.policy(obs, IDs)
+ actions_ = pi_dist.stochastic_sample()
+ advantages = self.policy.target_critic(obs, act_mean_n, IDs)
+ advantages = advantages.gather(-1, actions_.long().reshape([batch_size, self.n_agents, 1]))
+ log_pi_prob = pi_dist.log_prob(actions_).unsqueeze(-1)
+ advantages = log_pi_prob * advantages.detach()
+ loss_a = -(advantages.sum() / agent_mask.sum())
+ self.optimizer["actor"].zero_grad()
+ loss_a.backward()
+ grad_norm_actor = torch.nn.utils.clip_grad_norm_(self.policy.parameters_actor, self.args.clip_grad)
+ self.optimizer["actor"].step()
+ if self.scheduler['actor'] is not None:
+ self.scheduler['actor'].step()
+
+ self.policy.soft_update(self.tau)
+ # Logger
+ lr_a = self.optimizer['actor'].state_dict()['param_groups'][0]['lr']
+ lr_c = self.optimizer['critic'].state_dict()['param_groups'][0]['lr']
+
+ info = {
+ "learning_rate_actor": lr_a,
+ "learning_rate_critic": lr_c,
+ "actor_loss": loss_a.item(),
+ "critic_loss": loss_c.item(),
+ "actor_gradient_norm": grad_norm_actor.item()
+ }
+
+ return info
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ .. group-tab:: TensorFlow
+
+ .. code-block:: python
+
+
+ .. group-tab:: MindSpore
+
+ .. code-block:: python
+
+
+
.. raw:: html
@@ -29,17 +217,15 @@ Source Code
-----------------
.. tabs::
-
+
.. group-tab:: PyTorch
-
- .. code-block:: python3
+ .. code-block:: python3
.. group-tab:: TensorFlow
-
- .. code-block:: python3
+ .. code-block:: python3
.. group-tab:: MindSpore
diff --git a/docs/source/documents/api/learners/marl/mfq.rst b/docs/source/documents/api/learners/marl/mfq.rst
index 9db36da1d..90357dfce 100644
--- a/docs/source/documents/api/learners/marl/mfq.rst
+++ b/docs/source/documents/api/learners/marl/mfq.rst
@@ -1,6 +1,177 @@
MFQ_Learner
=====================================
+xxxxxx.
+
+.. raw:: html
+
+
+
+**PyTorch:**
+
+.. py:class::
+ xuance.torch.learners.multi_agent_rl.mfq_learner.MFQ_Learner(config, policy, optimizer, scheduler, device, model_dir, gamma, sync_frequency)
+
+ :param config: xxxxxx.
+ :type config: xxxxxx
+ :param policy: xxxxxx.
+ :type policy: xxxxxx
+ :param optimizer: xxxxxx.
+ :type optimizer: xxxxxx
+ :param scheduler: xxxxxx.
+ :type scheduler: xxxxxx
+ :param device: xxxxxx.
+ :type device: xxxxxx
+ :param model_dir: xxxxxx.
+ :type model_dir: xxxxxx
+ :param gamma: xxxxxx.
+ :type gamma: xxxxxx
+ :param sync_frequency: xxxxxx.
+ :type sync_frequency: xxxxxx
+
+.. py:function::
+ xuance.torch.learners.multi_agent_rl.mfq_learner.MFQ_Learner.get_boltzmann_policy(q)
+
+ xxxxxx.
+
+ :param sample: xxxxxx.
+ :type sample: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. py:function::
+ xuance.torch.learners.multi_agent_rl.mfq_learner.MFQ_Learner.update(sample)
+
+ xxxxxx.
+
+ :param sample: xxxxxx.
+ :type sample: xxxxxx
+ :return: xxxxxx.
+ :rtype: xxxxxx
+
+.. raw:: html
+
+
+
+**TensorFlow:**
+
+.. raw:: html
+
+
+
+**MindSpore:**
+
+.. raw:: html
+
+
+
+Source Code
+-----------------
+
+.. tabs::
+
+ .. group-tab:: PyTorch
+
+ .. code-block:: python
+
+ """
+ MFQ: Mean Field Q-Learning
+ Paper link:
+ http://proceedings.mlr.press/v80/yang18d/yang18d.pdf
+ Implementation: Pytorch
+ """
+ from xuance.torch.learners import *
+
+
+ class MFQ_Learner(LearnerMAS):
+ def __init__(self,
+ config: Namespace,
+ policy: nn.Module,
+ optimizer: torch.optim.Optimizer,
+ scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
+ device: Optional[Union[int, str, torch.device]] = None,
+ model_dir: str = "./",
+ gamma: float = 0.99,
+ sync_frequency: int = 100
+ ):
+ self.gamma = gamma
+ self.temperature = config.temperature
+ self.sync_frequency = sync_frequency
+ self.mse_loss = nn.MSELoss()
+ self.softmax = torch.nn.Softmax(dim=-1)
+ super(MFQ_Learner, self).__init__(config, policy, optimizer, scheduler, device, model_dir)
+
+ def get_boltzmann_policy(self, q):
+ return self.softmax(q / self.temperature)
+
+ def update(self, sample):
+ self.iterations += 1
+ obs = torch.Tensor(sample['obs']).to(self.device)
+ actions = torch.Tensor(sample['actions']).to(self.device)
+ obs_next = torch.Tensor(sample['obs_next']).to(self.device)
+ act_mean = torch.Tensor(sample['act_mean']).to(self.device)
+ act_mean_next = torch.Tensor(sample['act_mean_next']).to(self.device)
+ rewards = torch.Tensor(sample['rewards']).to(self.device)
+ terminals = torch.Tensor(sample['terminals']).float().reshape(-1, self.n_agents, 1).to(self.device)
+ agent_mask = torch.Tensor(sample['agent_mask']).float().reshape(-1, self.n_agents, 1).to(self.device)
+ IDs = torch.eye(self.n_agents).unsqueeze(0).expand(self.args.batch_size, -1, -1).to(self.device)
+
+ act_mean = act_mean.unsqueeze(1).repeat([1, self.n_agents, 1])
+ act_mean_next = act_mean_next.unsqueeze(1).repeat([1, self.n_agents, 1])
+ _, _, q_eval = self.policy(obs, act_mean, IDs)
+ q_eval_a = q_eval.gather(-1, actions.long().reshape([self.args.batch_size, self.n_agents, 1]))
+ q_next = self.policy.target_Q(obs_next, act_mean_next, IDs)
+ shape = q_next.shape
+ pi = self.get_boltzmann_policy(q_next)
+ v_mf = torch.bmm(q_next.reshape(-1, 1, shape[-1]), pi.unsqueeze(-1).reshape(-1, shape[-1], 1))
+ v_mf = v_mf.reshape(*(list(shape[0:-1]) + [1]))
+ q_target = rewards + (1 - terminals) * self.args.gamma * v_mf
+
+ # calculate the loss function
+ td_error = (q_eval_a - q_target.detach()) * agent_mask
+ loss = (td_error ** 2).sum() / agent_mask.sum()
+ self.optimizer.zero_grad()
+ loss.backward()
+ self.optimizer.step()
+ if self.scheduler is not None:
+ self.scheduler.step()
+
+ if self.iterations % self.sync_frequency == 0:
+ self.policy.copy_target()
+
+ lr = self.optimizer.state_dict()['param_groups'][0]['lr']
+
+ info = {
+ "learning_rate": lr,
+ "loss_Q": loss.item(),
+ "predictQ": q_eval_a.mean().item()
+ }
+
+ return info
+
+
+
+
+
+
+
+
+
+
+
+
+
+ .. group-tab:: TensorFlow
+
+ .. code-block:: python
+
+
+ .. group-tab:: MindSpore
+
+ .. code-block:: python
+
+
+
.. raw:: html
@@ -31,19 +202,16 @@ Source Code
.. tabs::
.. group-tab:: PyTorch
-
- .. code-block:: python3
+ .. code-block:: python3
.. group-tab:: TensorFlow
-
- .. code-block:: python3
+ .. code-block:: python3
.. group-tab:: MindSpore
.. code-block:: python3
-