Skip to content

Commit

Permalink
paralle smac
Browse files Browse the repository at this point in the history
  • Loading branch information
wenzhangliu committed Nov 2, 2023
1 parent 0ee9baa commit cd563a7
Show file tree
Hide file tree
Showing 101 changed files with 343 additions and 340 deletions.
78 changes: 44 additions & 34 deletions xuance/common/memory_tools_marl.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,24 @@ def __init__(self, n_agents, state_space, obs_space, act_space, rew_space, done_
super(MARL_OffPolicyBuffer_RNN, self).__init__(n_agents, state_space, obs_space, act_space, rew_space,
done_space, n_envs, buffer_size, batch_size)

self.episode_data = {}
self.clear_episodes()

def clear(self):
self.data = {
'obs': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len + 1) + self.obs_space, np.float),
'actions': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.act_space, np.float),
'rewards': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.rew_space, np.float),
'terminals': np.zeros((self.buffer_size, self.max_eps_len) + self.done_space, np.bool),
'avail_actions': np.ones((self.buffer_size, self.n_agents, self.max_eps_len + 1, self.dim_act), np.bool),
'filled': np.zeros((self.buffer_size, self.max_eps_len, 1)).astype(np.bool)
}
if self.state_space is not None:
self.data.update({'state': np.zeros(
(self.buffer_size, self.max_eps_len + 1) + self.state_space).astype(np.float32)})
self.ptr, self.size = 0, 0

def clear_episodes(self):
self.episode_data = {
'obs': np.zeros((self.n_envs, self.n_agents, self.max_eps_len + 1) + self.obs_space, dtype=np.float32),
'actions': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.act_space, dtype=np.float32),
Expand All @@ -126,20 +144,6 @@ def __init__(self, n_agents, state_space, obs_space, act_space, rew_space, done_
'state': np.zeros((self.n_envs, self.max_eps_len + 1) + self.state_space, dtype=np.float32),
})

def clear(self):
self.data = {
'obs': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len + 1) + self.obs_space, np.float),
'actions': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.act_space, np.float),
'rewards': np.zeros((self.buffer_size, self.n_agents, self.max_eps_len) + self.rew_space, np.float),
'terminals': np.zeros((self.buffer_size, self.max_eps_len) + self.done_space, np.bool),
'avail_actions': np.ones((self.buffer_size, self.n_agents, self.max_eps_len + 1, self.dim_act), np.bool),
'filled': np.zeros((self.buffer_size, self.max_eps_len, 1)).astype(np.bool)
}
if self.state_space is not None:
self.data.update({'state': np.zeros(
(self.buffer_size, self.max_eps_len + 1) + self.state_space).astype(np.float32)})
self.ptr, self.size = 0, 0

def store_transitions(self, t_envs, *transition_data):
obs_n, actions_dict, state, rewards, terminated, avail_actions = transition_data
self.episode_data['obs'][:, :, t_envs] = obs_n
Expand All @@ -158,6 +162,7 @@ def store_episodes(self):
self.data[k][self.ptr] = self.episode_data[k][i_env]
self.ptr = (self.ptr + 1) % self.buffer_size
self.size = np.min([self.size + 1, self.buffer_size])
self.clear_episodes()

def finish_path(self, i_env, next_t, *terminal_data):
obs_next, state_next, available_actions, filled = terminal_data
Expand Down Expand Up @@ -335,26 +340,8 @@ def __init__(self, n_agents, state_space, obs_space, act_space, rew_space, done_
done_space, n_envs, buffer_size,
use_gae, use_advnorm, gamma, gae_lam,
**kwargs)
self.episode_data = {
'obs': np.zeros((self.n_envs, self.n_agents, self.max_eps_len + 1) + self.obs_space, dtype=np.float32),
'actions': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.act_space, dtype=np.float32),
'rewards': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, dtype=np.float32),
'returns': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, np.float32),
'values': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, np.float32),
'advantages': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, np.float32),
'log_pi_old': np.zeros((self.n_envs, self.n_agents, self.max_eps_len,), np.float32),
'terminals': np.zeros((self.n_envs, self.max_eps_len) + self.done_space, dtype=np.bool),
'avail_actions': np.ones((self.n_envs, self.n_agents, self.max_eps_len + 1, self.dim_act), dtype=np.bool),
'filled': np.zeros((self.n_envs, self.max_eps_len, 1), dtype=np.bool),
}
if self.state_space is not None:
self.episode_data.update({
'state': np.zeros((self.n_envs, self.max_eps_len + 1) + self.state_space, dtype=np.float32),
})
# if self.args.agent == "COMA":
# self.episode_data.update({
# 'actions_onehot': np.zeros((self.n_envs, self.n_agents, self.max_eps_len, self.dim_act),
# dtype=np.float32)})
self.episode_data = {}
self.clear_episodes()

@property
def full(self):
Expand All @@ -379,6 +366,28 @@ def clear(self):
})
self.ptr, self.size = 0, 0

def clear_episodes(self):
self.episode_data = {
'obs': np.zeros((self.n_envs, self.n_agents, self.max_eps_len + 1) + self.obs_space, dtype=np.float32),
'actions': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.act_space, dtype=np.float32),
'rewards': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, dtype=np.float32),
'returns': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, np.float32),
'values': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, np.float32),
'advantages': np.zeros((self.n_envs, self.n_agents, self.max_eps_len) + self.rew_space, np.float32),
'log_pi_old': np.zeros((self.n_envs, self.n_agents, self.max_eps_len,), np.float32),
'terminals': np.zeros((self.n_envs, self.max_eps_len) + self.done_space, dtype=np.bool),
'avail_actions': np.ones((self.n_envs, self.n_agents, self.max_eps_len + 1, self.dim_act), dtype=np.bool),
'filled': np.zeros((self.n_envs, self.max_eps_len, 1), dtype=np.bool),
}
if self.state_space is not None:
self.episode_data.update({
'state': np.zeros((self.n_envs, self.max_eps_len + 1) + self.state_space, dtype=np.float32),
})
# if self.args.agent == "COMA":
# self.episode_data.update({
# 'actions_onehot': np.zeros((self.n_envs, self.n_agents, self.max_eps_len, self.dim_act),
# dtype=np.float32)})

def store_transitions(self, t_envs, *transition_data):
obs_n, actions_dict, state, rewards, terminated, avail_actions = transition_data
self.episode_data['obs'][:, :, t_envs] = obs_n
Expand All @@ -399,6 +408,7 @@ def store_episodes(self):
self.data[k][self.ptr] = self.episode_data[k][i_env].copy()
self.ptr = (self.ptr + 1) % self.buffer_size
self.size = min(self.size + 1, self.buffer_size)
self.clear_episodes()

def finish_path(self, i_env, next_t, *terminal_data, value_next=None, value_normalizer=None):
obs_next, state_next, available_actions, filled = terminal_data
Expand Down
4 changes: 2 additions & 2 deletions xuance/configs/coma/sc2/1c3s5z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
parallels: 8
n_size: 128
n_epoch: 15
n_minibatch: 1
Expand Down Expand Up @@ -52,6 +52,6 @@ training_frequency: 1

test_steps: 10000
eval_interval: 10000
test_episode: 10
test_episode: 16
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
4 changes: 2 additions & 2 deletions xuance/configs/coma/sc2/25m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
parallels: 8
n_size: 128
n_epoch: 15
n_minibatch: 1
Expand Down Expand Up @@ -52,6 +52,6 @@ training_frequency: 1

test_steps: 10000
eval_interval: 25000
test_episode: 10
test_episode: 16
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
4 changes: 2 additions & 2 deletions xuance/configs/coma/sc2/2m_vs_1z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
parallels: 8
n_size: 128
n_epoch: 15
n_minibatch: 1
Expand Down Expand Up @@ -52,6 +52,6 @@ training_frequency: 1

test_steps: 10000
eval_interval: 5000
test_episode: 10
test_episode: 16
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
4 changes: 2 additions & 2 deletions xuance/configs/coma/sc2/2s3z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
parallels: 8
n_size: 128
n_epoch: 15
n_minibatch: 1
Expand Down Expand Up @@ -52,6 +52,6 @@ training_frequency: 1

test_steps: 10000
eval_interval: 10000
test_episode: 10
test_episode: 16
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
4 changes: 2 additions & 2 deletions xuance/configs/coma/sc2/3m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
parallels: 8
n_size: 128
n_epoch: 15
n_minibatch: 1
Expand Down Expand Up @@ -52,6 +52,6 @@ training_frequency: 1

test_steps: 10000
eval_interval: 5000
test_episode: 10
test_episode: 16
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
4 changes: 2 additions & 2 deletions xuance/configs/coma/sc2/5m_vs_6m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
parallels: 8
n_size: 128
n_epoch: 15
n_minibatch: 1
Expand Down Expand Up @@ -52,6 +52,6 @@ training_frequency: 1

test_steps: 10000
eval_interval: 50000
test_episode: 10
test_episode: 16
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
4 changes: 2 additions & 2 deletions xuance/configs/coma/sc2/8m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
parallels: 8
n_size: 128
n_epoch: 15
n_minibatch: 1
Expand Down Expand Up @@ -52,6 +52,6 @@ training_frequency: 1

test_steps: 10000
eval_interval: 5000
test_episode: 10
test_episode: 16
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
4 changes: 2 additions & 2 deletions xuance/configs/coma/sc2/8m_vs_9m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
parallels: 8
n_size: 128
n_epoch: 15
n_minibatch: 1
Expand Down Expand Up @@ -52,6 +52,6 @@ training_frequency: 1

test_steps: 10000
eval_interval: 50000
test_episode: 10
test_episode: 16
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
4 changes: 2 additions & 2 deletions xuance/configs/coma/sc2/MMM2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
parallels: 8
n_size: 128
n_epoch: 15
n_minibatch: 1
Expand Down Expand Up @@ -52,6 +52,6 @@ training_frequency: 1

test_steps: 10000
eval_interval: 50000
test_episode: 10
test_episode: 16
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
4 changes: 2 additions & 2 deletions xuance/configs/coma/sc2/corridor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ critic_hidden_size: [128, 128]
activation: "ReLU"

seed: 1
parallels: 1
parallels: 8
n_size: 128
n_epoch: 15
n_minibatch: 1
Expand Down Expand Up @@ -52,6 +52,6 @@ training_frequency: 1

test_steps: 10000
eval_interval: 50000
test_episode: 10
test_episode: 16
log_dir: "./logs/coma/"
model_dir: "./models/coma/"
6 changes: 3 additions & 3 deletions xuance/configs/dcg/sc2/1c3s5z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ env_id: "1c3s5z"
fps: 15
policy: "DCG_policy"
representation: "Basic_RNN"
vectorize: "Dummy_StarCraft2"
vectorize: "Subproc_StarCraft2
runner: "StarCraft2_Runner"
on_policy: False

Expand Down Expand Up @@ -32,7 +32,7 @@ n_msg_iterations: 8 # number of iterations for message passing during belief pr
msg_normalized: True # Message normalization during greedy action selection (Kok and Vlassis, 2006)

seed: 1
parallels: 1
parallels: 8
buffer_size: 5000
batch_size: 32
learning_rate: 0.0007
Expand All @@ -52,6 +52,6 @@ use_grad_clip: False
grad_clip_norm: 0.5

eval_interval: 20000
test_episode: 10
test_episode: 16
log_dir: "./logs/dcg/"
model_dir: "./models/dcg/"
6 changes: 3 additions & 3 deletions xuance/configs/dcg/sc2/25m.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ env_id: "25m"
fps: 15
policy: "DCG_policy"
representation: "Basic_RNN"
vectorize: "Dummy_StarCraft2"
vectorize: "Subproc_StarCraft2
runner: "StarCraft2_Runner"
on_policy: False

Expand Down Expand Up @@ -32,7 +32,7 @@ n_msg_iterations: 8 # number of iterations for message passing during belief pr
msg_normalized: True # Message normalization during greedy action selection (Kok and Vlassis, 2006)

seed: 1
parallels: 1
parallels: 8
buffer_size: 5000
batch_size: 32
learning_rate: 0.0007
Expand All @@ -52,6 +52,6 @@ use_grad_clip: False
grad_clip_norm: 0.5

eval_interval: 50000
test_episode: 10
test_episode: 16
log_dir: "./logs/dcg/"
model_dir: "./models/dcg/"
6 changes: 3 additions & 3 deletions xuance/configs/dcg/sc2/2m_vs_1z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ env_id: "2m_vs_1z"
fps: 15
policy: "DCG_policy"
representation: "Basic_RNN"
vectorize: "Dummy_StarCraft2"
vectorize: "Subproc_StarCraft2
runner: "StarCraft2_Runner"
on_policy: False

Expand Down Expand Up @@ -32,7 +32,7 @@ n_msg_iterations: 8 # number of iterations for message passing during belief pr
msg_normalized: True # Message normalization during greedy action selection (Kok and Vlassis, 2006)

seed: 1
parallels: 1
parallels: 8
buffer_size: 5000
batch_size: 32
learning_rate: 0.0007
Expand All @@ -52,6 +52,6 @@ use_grad_clip: False
grad_clip_norm: 0.5

eval_interval: 10000
test_episode: 10
test_episode: 16
log_dir: "./logs/dcg/"
model_dir: "./models/dcg/"
6 changes: 3 additions & 3 deletions xuance/configs/dcg/sc2/2s3z.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ env_id: "2s3z"
fps: 15
policy: "DCG_policy"
representation: "Basic_RNN"
vectorize: "Dummy_StarCraft2"
vectorize: "Subproc_StarCraft2
runner: "StarCraft2_Runner"
on_policy: False

Expand Down Expand Up @@ -32,7 +32,7 @@ n_msg_iterations: 8 # number of iterations for message passing during belief pr
msg_normalized: True # Message normalization during greedy action selection (Kok and Vlassis, 2006)

seed: 1
parallels: 1
parallels: 8
buffer_size: 5000
batch_size: 32
learning_rate: 0.0007
Expand All @@ -52,6 +52,6 @@ use_grad_clip: False
grad_clip_norm: 0.5

eval_interval: 20000
test_episode: 10
test_episode: 16
log_dir: "./logs/dcg/"
model_dir: "./models/dcg/"
Loading

0 comments on commit cd563a7

Please sign in to comment.