Skip to content

Commit

Permalink
fix(pu): fix ddp config when uptate_per_collect is None in config
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Jan 23, 2025
1 parent 5143f08 commit 056b22f
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 23 deletions.
10 changes: 2 additions & 8 deletions lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroCollector as Collector
from .utils import random_collect
from .utils import random_collect, calculate_update_per_collect


def train_unizero(
Expand Down Expand Up @@ -154,13 +154,7 @@ def train_unizero(
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = cfg.policy.update_per_collect
if update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
update_per_collect = calculate_update_per_collect(cfg, new_data)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
Expand Down
10 changes: 2 additions & 8 deletions lzero/entry/train_unizero_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from .utils import random_collect
from .utils import random_collect, calculate_update_per_collect

timer = EasyTimer()

Expand Down Expand Up @@ -151,13 +151,7 @@ def train_unizero_segment(
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = cfg.policy.update_per_collect
if update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
update_per_collect = calculate_update_per_collect(cfg, new_data)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
Expand Down
84 changes: 83 additions & 1 deletion lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,93 @@
from typing import Optional, Callable

import psutil
import torch
import torch.distributed as dist
from pympler.asizeof import asizeof
from tensorboardX import SummaryWriter
from typing import Optional, Callable


import torch
import torch.distributed as dist

def is_ddp_enabled():
"""
Check if Distributed Data Parallel (DDP) is enabled by verifying if
PyTorch's distributed package is available and initialized.
"""
return dist.is_available() and dist.is_initialized()

def ddp_synchronize():
"""
Perform a barrier synchronization across all processes in DDP mode.
Ensures all processes reach this point before continuing.
"""
if is_ddp_enabled():
dist.barrier()

def ddp_all_reduce_sum(tensor):
"""
Perform an all-reduce operation (sum) on the given tensor across
all processes in DDP mode. Returns the reduced tensor.
Arguments:
- tensor (:obj:`torch.Tensor`): The input tensor to be reduced.
Returns:
- torch.Tensor: The reduced tensor, summed across all processes.
"""
if is_ddp_enabled():
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
return tensor

def calculate_update_per_collect(cfg, new_data):
"""
Calculate the number of updates to perform per data collection in a
Distributed Data Parallel (DDP) setting. This ensures that all GPUs
compute the same `update_per_collect` value, synchronized across processes.
Arguments:
- cfg: Configuration object containing policy settings.
- new_data (list): The newly collected data segments.
Returns:
- int: The number of updates to perform per collection.
"""
# Retrieve the update_per_collect setting from the configuration
update_per_collect = cfg.policy.update_per_collect

if update_per_collect is None:
# If update_per_collect is not explicitly set, calculate it based on
# the number of collected transitions and the replay ratio.

# The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
# On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
collected_transitions_num = sum(
min(len(game_segment), cfg.policy.game_segment_length)
for game_segment in new_data[0]
)

if torch.cuda.is_available():
# Convert the collected transitions count to a GPU tensor for DDP operations.
collected_transitions_tensor = torch.tensor(
collected_transitions_num, dtype=torch.int64, device='cuda'
)

# Synchronize the collected transitions count across all GPUs using all-reduce.
total_collected_transitions = ddp_all_reduce_sum(
collected_transitions_tensor
).item()

# Calculate update_per_collect based on the total synchronized transitions count.
update_per_collect = int(total_collected_transitions * cfg.policy.replay_ratio)

# Ensure the computed update_per_collect is positive.
assert update_per_collect > 0, "update_per_collect must be positive"
else:
# If not using DDP, calculate update_per_collect directly from the local count.
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)

return update_per_collect

def initialize_zeros_batch(observation_shape, batch_size, device):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py
python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/atari/config/atari_efficientzero_ddp_config.py
"""
from ding.utils import DDPContext
from lzero.entry import train_muzero
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_muzero_multigpu_ddp_config.py
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_muzero_ddp_config.py
"""
from ding.utils import DDPContext
from lzero.entry import train_muzero
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_multigpu_ddp_config.py
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_multigpu_ddp_config.py
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py
"""
from ding.utils import DDPContext
Expand Down
147 changes: 147 additions & 0 deletions zoo/atari/config/atari_unizero_segment_ddp_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from easydict import EasyDict
from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map


def main(env_id, seed):
action_space_size = atari_env_action_space_map[env_id]

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
gpu_num = 2
collector_env_num = 8
num_segments = int(collector_env_num*gpu_num)
game_segment_length = 20
evaluator_env_num = 3
num_simulations = 50
max_env_step = int(5e5)
batch_size = int(64*gpu_num)
num_layers = 2
replay_ratio = 0.25
num_unroll_steps = 10
infer_context_length = 4

# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
buffer_reanalyze_freq = 1/100000
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
reanalyze_batch_size = 160
# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
reanalyze_partition = 0.75

# ====== only for debug =====
# evaluator_env_num = 2
# num_simulations = 10
# batch_size = 5
# gpu_num = 4
# collector_env_num = 2
# num_segments = int(collector_env_num)
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================

atari_unizero_config = dict(
env=dict(
stop_value=int(1e6),
env_id=env_id,
observation_shape=(3, 96, 96),
gray_scale=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
# TODO: only for debug
# collect_max_episode_steps=int(50),
# eval_max_episode_steps=int(50),
),
policy=dict(
learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=1000000, ), ), ), # default is 10000
model=dict(
observation_shape=(3, 96, 96),
action_space_size=action_space_size,
support_scale=300,
world_model_cfg=dict(
support_size=601,
policy_entropy_weight=5e-3,
continuous_action_space=False,
max_blocks=num_unroll_steps,
max_tokens=2 * num_unroll_steps, # NOTE: each timestep has 2 tokens: obs and action
context_length=2 * infer_context_length,
device='cuda',
action_space_size=action_space_size,
num_layers=num_layers,
num_heads=8,
embed_dim=768,
obs_type='image',
env_num=max(collector_env_num, evaluator_env_num),
),
),
multi_gpu=True, # ======== Very important for ddp =============
model_path=None,
use_augmentation=False,
manual_temperature_decay=False,
threshold_training_steps_for_final_temperature=int(2.5e4),
use_priority=False,
num_unroll_steps=num_unroll_steps,
update_per_collect=None,
replay_ratio=replay_ratio,
batch_size=batch_size,
optim_type='AdamW',
learning_rate=0.0001,
num_simulations=num_simulations,
num_segments=num_segments,
td_steps=5,
train_start_after_envsteps=2000,
game_segment_length=game_segment_length,
grad_clip_value=5,
replay_buffer_size=int(1e6),
eval_freq=int(5e3),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
# ============= The key different params for reanalyze =============
# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
buffer_reanalyze_freq=buffer_reanalyze_freq,
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
reanalyze_batch_size=reanalyze_batch_size,
# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
reanalyze_partition=reanalyze_partition,
),
)
atari_unizero_config = EasyDict(atari_unizero_config)
main_config = atari_unizero_config

atari_unizero_create_config = dict(
env=dict(
type='atari_lightzero',
import_names=['zoo.atari.envs.atari_lightzero_env'],
),
env_manager=dict(type='subprocess'),
policy=dict(
type='unizero',
import_names=['lzero.policy.unizero'],
),
)
atari_unizero_create_config = EasyDict(atari_unizero_create_config)
create_config = atari_unizero_create_config

# ============ use muzero_segment_collector instead of muzero_collector =============
from ding.utils import DDPContext
from lzero.config.utils import lz_to_ddp_config
with DDPContext():
main_config = lz_to_ddp_config(main_config)
from lzero.entry import train_unizero_segment
main_config.exp_name = f'data_unizero_ddp/{env_id[:-14]}_{gpu_num}gpu/{env_id[:-14]}_uz_ddp_{gpu_num}gpu_brf{buffer_reanalyze_freq}-rbs{reanalyze_batch_size}-rp{reanalyze_partition}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
train_unizero_segment([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step)


if __name__ == "__main__":
"""
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./zoo/atari/config/atari_unizero_segment_ddp_config.py
torchrun --nproc_per_node=2 ./zoo/atari/config/atari_unizero_ddp_config.py
"""
main('PongNoFrameskip-v4', 0)



Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py
python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_ddp_config.py
"""
from ding.utils import DDPContext
from lzero.entry import train_alphazero
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
Overview:
This script should be executed with <nproc_per_node> GPUs.
Run the following command to launch the script:
python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py
python -m torch.distributed.launch --nproc_per_node=2 ./LightZero/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_ddp_config.py
"""
from ding.utils import DDPContext
from lzero.entry import train_alphazero
Expand Down

0 comments on commit 056b22f

Please sign in to comment.