Skip to content

Commit

Permalink
Merge branch 'fix-lz-ddp' of https://github.com/opendilab/LightZero i…
Browse files Browse the repository at this point in the history
…nto fix-lz-ddp
  • Loading branch information
puyuan1996 committed Jan 27, 2025
2 parents 214f29b + 11ab64d commit f32f118
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 44 deletions.
7 changes: 5 additions & 2 deletions lzero/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def lz_to_ddp_config(cfg: EasyDict) -> EasyDict:
- cfg (:obj:`EasyDict`): The converted config
"""
w = get_world_size()
cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w))
cfg.policy.n_episode = int(np.ceil(cfg.policy.n_episode) / w)
# Generalized handling for multiple keys
keys_to_scale = ['batch_size', 'n_episode', 'num_segments']
for key in keys_to_scale:
if key in cfg.policy:
cfg.policy[key] = int(np.ceil(cfg.policy[key] / w))
return cfg
11 changes: 4 additions & 7 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect
from .utils import random_collect, calculate_update_per_collect


def train_muzero(
Expand Down Expand Up @@ -186,12 +186,9 @@ def train_muzero(

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)

# Determine updates per collection
update_per_collect = calculate_update_per_collect(cfg, new_data)

# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
Expand Down
11 changes: 4 additions & 7 deletions lzero/entry/train_muzero_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from .utils import random_collect
from .utils import random_collect, calculate_update_per_collect

timer = EasyTimer()

Expand Down Expand Up @@ -180,13 +180,10 @@ def train_muzero_segment(

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)

# Determine updates per collection
update_per_collect = calculate_update_per_collect(cfg, new_data)

# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
Expand Down
10 changes: 3 additions & 7 deletions lzero/entry/train_rezero.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect
from .utils import random_collect, calculate_update_per_collect


def train_rezero(
Expand Down Expand Up @@ -152,12 +152,8 @@ def train_rezero(
collect_with_pure_policy=cfg.policy.collect_with_pure_policy
)

if update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
# Determine updates per collection
update_per_collect = calculate_update_per_collect(cfg, new_data)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
Expand Down
17 changes: 8 additions & 9 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,15 +522,14 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
m_output = model.initial_inference(batch_obs, batch_action)
# ======================================================================

if not model.training:
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)
# if not in training, obtain the scalars of the value/reward
[m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
[
m_output.latent_state,
inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
m_output.policy_logits
]
)

network_output.append(m_output)

Expand Down
9 changes: 4 additions & 5 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,11 +728,10 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1
network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data)
latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output)

if not self._eval_model.training:
# if not in training, obtain the scalars of the value/reward
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
latent_state_roots = latent_state_roots.detach().cpu().numpy()
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)
# if not in training, obtain the scalars of the value/reward
pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1)
latent_state_roots = latent_state_roots.detach().cpu().numpy()
policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A)

legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)]
if self._cfg.mcts_ctree:
Expand Down
2 changes: 1 addition & 1 deletion zoo/atari/config/atari_efficientzero_ddp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_efficientzero_ddp_config.py
torchrun --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_efficientzero_ddp_config.py
"""
from ding.utils import DDPContext
from lzero.entry import train_muzero
Expand Down
2 changes: 1 addition & 1 deletion zoo/atari/config/atari_muzero_ddp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_muzero_ddp_config.py
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_muzero_ddp_config.py
"""
from ding.utils import DDPContext
from lzero.entry import train_muzero
Expand Down
1 change: 0 additions & 1 deletion zoo/atari/config/atari_unizero_ddp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py
"""
Expand Down
7 changes: 3 additions & 4 deletions zoo/atari/config/atari_unizero_segment_ddp_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main(env_id, seed):
# evaluator_env_num = 2
# num_simulations = 10
# batch_size = 5
# gpu_num = 4
# gpu_num = 2
# collector_env_num = 2
# num_segments = int(collector_env_num)
# ==============================================================
Expand Down Expand Up @@ -90,7 +90,7 @@ def main(env_id, seed):
num_simulations=num_simulations,
num_segments=num_segments,
td_steps=5,
train_start_after_envsteps=2000,
train_start_after_envsteps=0,
game_segment_length=game_segment_length,
grad_clip_value=5,
replay_buffer_size=int(1e6),
Expand Down Expand Up @@ -138,8 +138,7 @@ def main(env_id, seed):
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_segment_ddp_config.py
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_segment_ddp_config.py
"""
main('PongNoFrameskip-v4', 0)

Expand Down

0 comments on commit f32f118

Please sign in to comment.