diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index cd7ff7605..e9270b537 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -20,7 +20,7 @@ from lzero.policy.random_policy import LightZeroRandomPolicy from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import MuZeroCollector as Collector -from .utils import random_collect +from .utils import random_collect, calculate_update_per_collect def train_unizero( @@ -154,13 +154,7 @@ def train_unizero( new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) # Determine updates per collection - update_per_collect = cfg.policy.update_per_collect - if update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. - # The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game. - # On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps. - collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0]) - update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + update_per_collect = calculate_update_per_collect(cfg, new_data) # Update replay buffer replay_buffer.push_game_segments(new_data) diff --git a/lzero/entry/train_unizero_segment.py b/lzero/entry/train_unizero_segment.py index aee8df9ef..7ff466685 100644 --- a/lzero/entry/train_unizero_segment.py +++ b/lzero/entry/train_unizero_segment.py @@ -20,7 +20,7 @@ from lzero.policy.random_policy import LightZeroRandomPolicy from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import MuZeroSegmentCollector as Collector -from .utils import random_collect +from .utils import random_collect, calculate_update_per_collect timer = EasyTimer() @@ -151,13 +151,7 @@ def train_unizero_segment( new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) # Determine updates per collection - update_per_collect = cfg.policy.update_per_collect - if update_per_collect is None: - # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. - # The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game. - # On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps. - collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0]) - update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + update_per_collect = calculate_update_per_collect(cfg, new_data) # Update replay buffer replay_buffer.push_game_segments(new_data) diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index d2c23f930..3255b1fef 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -2,11 +2,93 @@ from typing import Optional, Callable import psutil +import torch +import torch.distributed as dist from pympler.asizeof import asizeof from tensorboardX import SummaryWriter -from typing import Optional, Callable + + import torch +import torch.distributed as dist + +def is_ddp_enabled(): + """ + Check if Distributed Data Parallel (DDP) is enabled by verifying if + PyTorch's distributed package is available and initialized. + """ + return dist.is_available() and dist.is_initialized() + +def ddp_synchronize(): + """ + Perform a barrier synchronization across all processes in DDP mode. + Ensures all processes reach this point before continuing. + """ + if is_ddp_enabled(): + dist.barrier() + +def ddp_all_reduce_sum(tensor): + """ + Perform an all-reduce operation (sum) on the given tensor across + all processes in DDP mode. Returns the reduced tensor. + Arguments: + - tensor (:obj:`torch.Tensor`): The input tensor to be reduced. + + Returns: + - torch.Tensor: The reduced tensor, summed across all processes. + """ + if is_ddp_enabled(): + dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + return tensor + +def calculate_update_per_collect(cfg, new_data): + """ + Calculate the number of updates to perform per data collection in a + Distributed Data Parallel (DDP) setting. This ensures that all GPUs + compute the same `update_per_collect` value, synchronized across processes. + + Arguments: + - cfg: Configuration object containing policy settings. + - new_data (list): The newly collected data segments. + + Returns: + - int: The number of updates to perform per collection. + """ + # Retrieve the update_per_collect setting from the configuration + update_per_collect = cfg.policy.update_per_collect + + if update_per_collect is None: + # If update_per_collect is not explicitly set, calculate it based on + # the number of collected transitions and the replay ratio. + + # The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game. + # On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps. + collected_transitions_num = sum( + min(len(game_segment), cfg.policy.game_segment_length) + for game_segment in new_data[0] + ) + + if torch.cuda.is_available(): + # Convert the collected transitions count to a GPU tensor for DDP operations. + collected_transitions_tensor = torch.tensor( + collected_transitions_num, dtype=torch.int64, device='cuda' + ) + + # Synchronize the collected transitions count across all GPUs using all-reduce. + total_collected_transitions = ddp_all_reduce_sum( + collected_transitions_tensor + ).item() + + # Calculate update_per_collect based on the total synchronized transitions count. + update_per_collect = int(total_collected_transitions * cfg.policy.replay_ratio) + + # Ensure the computed update_per_collect is positive. + assert update_per_collect > 0, "update_per_collect must be positive" + else: + # If not using DDP, calculate update_per_collect directly from the local count. + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + + return update_per_collect def initialize_zeros_batch(observation_shape, batch_size, device): """ diff --git a/zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py b/zoo/atari/config/atari_efficientzero_ddp_config.py similarity index 98% rename from zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py rename to zoo/atari/config/atari_efficientzero_ddp_config.py index 5fe2c25e8..5cfddd8ed 100644 --- a/zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py +++ b/zoo/atari/config/atari_efficientzero_ddp_config.py @@ -84,7 +84,7 @@ Overview: This script should be executed with GPUs. Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py + python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_efficientzero_ddp_config.py """ from ding.utils import DDPContext from lzero.entry import train_muzero diff --git a/zoo/atari/config/atari_muzero_multigpu_ddp_config.py b/zoo/atari/config/atari_muzero_ddp_config.py similarity index 98% rename from zoo/atari/config/atari_muzero_multigpu_ddp_config.py rename to zoo/atari/config/atari_muzero_ddp_config.py index 4ea1809fb..a6bcd6877 100644 --- a/zoo/atari/config/atari_muzero_multigpu_ddp_config.py +++ b/zoo/atari/config/atari_muzero_ddp_config.py @@ -100,7 +100,7 @@ Overview: This script should be executed with GPUs. Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_muzero_multigpu_ddp_config.py + python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_muzero_ddp_config.py """ from ding.utils import DDPContext from lzero.entry import train_muzero diff --git a/zoo/atari/config/atari_unizero_multigpu_ddp_config.py b/zoo/atari/config/atari_unizero_ddp_config.py similarity index 98% rename from zoo/atari/config/atari_unizero_multigpu_ddp_config.py rename to zoo/atari/config/atari_unizero_ddp_config.py index 82f64f141..12650d18b 100644 --- a/zoo/atari/config/atari_unizero_multigpu_ddp_config.py +++ b/zoo/atari/config/atari_unizero_ddp_config.py @@ -103,8 +103,8 @@ Overview: This script should be executed with GPUs. Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_multigpu_ddp_config.py - torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_multigpu_ddp_config.py + python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py + torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py """ from ding.utils import DDPContext diff --git a/zoo/atari/config/atari_unizero_segment_ddp_config.py b/zoo/atari/config/atari_unizero_segment_ddp_config.py new file mode 100644 index 000000000..fb46a0433 --- /dev/null +++ b/zoo/atari/config/atari_unizero_segment_ddp_config.py @@ -0,0 +1,147 @@ +from easydict import EasyDict +from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + + +def main(env_id, seed): + action_space_size = atari_env_action_space_map[env_id] + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + gpu_num = 2 + collector_env_num = 8 + num_segments = int(collector_env_num*gpu_num) + game_segment_length = 20 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(5e5) + batch_size = int(64*gpu_num) + num_layers = 2 + replay_ratio = 0.25 + num_unroll_steps = 10 + infer_context_length = 4 + + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq = 1/100000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + + # ====== only for debug ===== + # evaluator_env_num = 2 + # num_simulations = 10 + # batch_size = 5 + # gpu_num = 4 + # collector_env_num = 2 + # num_segments = int(collector_env_num) + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + + atari_unizero_config = dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 96, 96), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # TODO: only for debug + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000 + model=dict( + observation_shape=(3, 96, 96), + action_space_size=action_space_size, + support_scale=300, + world_model_cfg=dict( + support_size=601, + policy_entropy_weight=5e-3, + continuous_action_space=False, + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + multi_gpu=True, # ======== Very important for ddp ============= + model_path=None, + use_augmentation=False, + manual_temperature_decay=False, + threshold_training_steps_for_final_temperature=int(2.5e4), + use_priority=False, + num_unroll_steps=num_unroll_steps, + update_per_collect=None, + replay_ratio=replay_ratio, + batch_size=batch_size, + optim_type='AdamW', + learning_rate=0.0001, + num_simulations=num_simulations, + num_segments=num_segments, + td_steps=5, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + grad_clip_value=5, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + atari_unizero_config = EasyDict(atari_unizero_config) + main_config = atari_unizero_config + + atari_unizero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + atari_unizero_create_config = EasyDict(atari_unizero_create_config) + create_config = atari_unizero_create_config + + # ============ use muzero_segment_collector instead of muzero_collector ============= + from ding.utils import DDPContext + from lzero.config.utils import lz_to_ddp_config + with DDPContext(): + main_config = lz_to_ddp_config(main_config) + from lzero.entry import train_unizero_segment + main_config.exp_name = f'data_unizero_ddp/{env_id[:-14]}_{gpu_num}gpu/{env_id[:-14]}_uz_ddp_{gpu_num}gpu_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + """ + Overview: + This script should be executed with GPUs. + Run the following command to launch the script: + python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_segment_ddp_config.py + torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py + """ + main('PongNoFrameskip-v4', 0) + + + diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_ddp_config.py similarity index 99% rename from zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py rename to zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_ddp_config.py index 34ece48ba..c673aeb3b 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_ddp_config.py @@ -98,7 +98,7 @@ Overview: This script should be executed with GPUs. Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py + python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_ddp_config.py """ from ding.utils import DDPContext from lzero.entry import train_alphazero diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_ddp_config.py similarity index 99% rename from zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py rename to zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_ddp_config.py index a9e011a9d..e96b6592d 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_ddp_config.py @@ -97,7 +97,7 @@ Overview: This script should be executed with GPUs. Run the following command to launch the script: - python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py + python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_ddp_config.py """ from ding.utils import DDPContext from lzero.entry import train_alphazero