-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
48 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,49 @@ | ||
from mava.configs.system.ppo.ff_ippo import FFIPPOConfig as BaseConfig | ||
import jax | ||
import jax.numpy as jnp | ||
import optax | ||
from mava.algorithms.base import Algorithm | ||
from mava.utils import update_policy | ||
from mava.networks import FeedForwardActor as Actor | ||
from mava.networks import FeedForwardQNet as QNetwork | ||
from mava.networks.distributions import TanhTransformedDistribution | ||
from mava.utils.network_utils import get_action_head | ||
from typing import Any, Dict | ||
|
||
class HAPPOConfig(BaseConfig): | ||
def __init__(self): | ||
super().__init__() | ||
self.algorithm = 'HAPPO' | ||
self.clip_param = 0.2 | ||
self.num_agents = 4 | ||
self.lr = 3e-4 | ||
self.network = { | ||
'pre_torso': { | ||
'_target_': 'mava.networks.torsos.MLPTorso', | ||
'layer_sizes': [128, 128], | ||
'use_layer_norm': False, | ||
'activation': 'relu' | ||
}, | ||
'post_torso': { | ||
'_target_': 'mava.networks.torsos.MLPTorso', | ||
'layer_sizes': [128, 128], | ||
'use_layer_norm': False, | ||
'activation': 'relu' | ||
} | ||
} | ||
class HAPPO(Algorithm): | ||
def __init__(self, config: Any) -> None: | ||
super().__init__(config) | ||
self.clip_param = config.clip_param | ||
self.num_agents = config.num_agents | ||
self.actor_networks = [Actor(config.network) for _ in range(self.num_agents)] | ||
self.critic_network = QNetwork(config.network) | ||
self.optimizer = optax.adam(config.lr) | ||
self.opt_state = self.optimizer.init(self.critic_network.params) | ||
|
||
def update(self, data: Dict[str, Any]) -> None: | ||
advantages = data['advantages'] | ||
old_log_probs = data['old_log_probs'] | ||
observations = data['observations'] | ||
actions = data['actions'] | ||
|
||
for agent_id in range(self.num_agents): | ||
agent_advantages = advantages[:, agent_id] | ||
agent_old_log_probs = old_log_probs[:, agent_id] | ||
agent_observations = observations[:, agent_id] | ||
agent_actions = actions[:, agent_id] | ||
|
||
def loss_fn(params: Any, agent_id: int, agent_observations: Any, agent_actions: Any, | ||
agent_old_log_probs: Any, agent_advantages: Any) -> jnp.ndarray: | ||
log_probs = self.actor_networks[agent_id](params, agent_observations, agent_actions) | ||
ratio = jnp.exp(log_probs - agent_old_log_probs) | ||
surr1 = ratio * agent_advantages | ||
surr2 = jnp.clip(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * agent_advantages | ||
return -jnp.mean(jnp.minimum(surr1, surr2)) | ||
|
||
new_params, opt_state = update_policy(self.actor_networks[agent_id].params, | ||
loss_fn, self.optimizer, self.opt_state, | ||
agent_id, agent_observations, agent_actions, | ||
agent_old_log_probs, agent_advantages) | ||
self.actor_networks[agent_id].params = new_params | ||
self.opt_state = opt_state | ||
|
||
return self.actor_networks, self.critic_network |