diff --git a/examples/ppo_trainer/run_deepseek7b_llm.sh b/examples/ppo_trainer/run_deepseek7b_llm.sh index f5beea65..1df0da02 100644 --- a/examples/ppo_trainer/run_deepseek7b_llm.sh +++ b/examples/ppo_trainer/run_deepseek7b_llm.sh @@ -35,4 +35,4 @@ python3 -m verl.trainer.main_ppo \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ - trainer.total_epochs=15 \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh index 2990c516..f4e25879 100644 --- a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh +++ b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh @@ -38,4 +38,4 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=100 \ No newline at end of file + trainer.total_epochs=100 $@ diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh index 42c65851..ed113b22 100644 --- a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh +++ b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh @@ -37,4 +37,4 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=100 \ No newline at end of file + trainer.total_epochs=100 $@ diff --git a/examples/ppo_trainer/run_deepseek_megatron.sh b/examples/ppo_trainer/run_deepseek_megatron.sh index 9e8a60aa..2d1cab20 100644 --- a/examples/ppo_trainer/run_deepseek_megatron.sh +++ b/examples/ppo_trainer/run_deepseek_megatron.sh @@ -28,4 +28,4 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ - trainer.total_epochs=15 \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_gemma.sh b/examples/ppo_trainer/run_gemma.sh index a95e7fa0..bcd54521 100644 --- a/examples/ppo_trainer/run_gemma.sh +++ b/examples/ppo_trainer/run_gemma.sh @@ -36,4 +36,4 @@ python3 -m verl.trainer.main_ppo \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=10 \ - trainer.total_epochs=15 \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2-7b.sh b/examples/ppo_trainer/run_qwen2-7b.sh index c7754948..396eb639 100644 --- a/examples/ppo_trainer/run_qwen2-7b.sh +++ b/examples/ppo_trainer/run_qwen2-7b.sh @@ -44,4 +44,4 @@ python3 -m verl.trainer.main_ppo \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=10 \ - trainer.total_epochs=15 \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm.sh b/examples/ppo_trainer/run_qwen2-7b_rm.sh index 6132fa79..2f77e87f 100644 --- a/examples/ppo_trainer/run_qwen2-7b_rm.sh +++ b/examples/ppo_trainer/run_qwen2-7b_rm.sh @@ -51,4 +51,4 @@ python3 -m verl.trainer.main_ppo \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=15 \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2.5-32b.sh b/examples/ppo_trainer/run_qwen2.5-32b.sh index 358ea366..e7f93cc5 100644 --- a/examples/ppo_trainer/run_qwen2.5-32b.sh +++ b/examples/ppo_trainer/run_qwen2.5-32b.sh @@ -45,4 +45,4 @@ python3 -m verl.trainer.main_ppo \ trainer.nnodes=4 \ trainer.save_freq=-1 \ trainer.test_freq=10 \ - trainer.total_epochs=15 \ No newline at end of file + trainer.total_epochs=15 $@ diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index 1c8284e2..19aab117 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -15,8 +15,6 @@ A unified tracking interface that supports logging data to different backend """ -import wandb - from typing import List, Union @@ -32,6 +30,7 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li self.logger = {} if 'tracking' in default_backend: + import wandb wandb.init(project=project_name, name=experiment_name, config=config) self.logger['tracking'] = wandb