Skip to content

Commit

Permalink
[example] fix: make wandb optional dependency. allow extra args in ex…
Browse files Browse the repository at this point in the history
…isting scripts (#32)

* [deps] fix: make wandb optional dependency

* allow ppo scripts to take additional args

* fix lint
  • Loading branch information
eric-haibin-lin authored Dec 3, 2024
1 parent 292b60b commit 7f8de22
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion examples/ppo_trainer/run_deepseek7b_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ python3 -m verl.trainer.main_ppo \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_epochs=15
trainer.total_epochs=15 $@
2 changes: 1 addition & 1 deletion examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=100
trainer.total_epochs=100 $@
2 changes: 1 addition & 1 deletion examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=100
trainer.total_epochs=100 $@
2 changes: 1 addition & 1 deletion examples/ppo_trainer/run_deepseek_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.total_epochs=15
trainer.total_epochs=15 $@
2 changes: 1 addition & 1 deletion examples/ppo_trainer/run_gemma.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ python3 -m verl.trainer.main_ppo \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=10 \
trainer.total_epochs=15
trainer.total_epochs=15 $@
2 changes: 1 addition & 1 deletion examples/ppo_trainer/run_qwen2-7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ python3 -m verl.trainer.main_ppo \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=10 \
trainer.total_epochs=15
trainer.total_epochs=15 $@
2 changes: 1 addition & 1 deletion examples/ppo_trainer/run_qwen2-7b_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ python3 -m verl.trainer.main_ppo \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=15
trainer.total_epochs=15 $@
2 changes: 1 addition & 1 deletion examples/ppo_trainer/run_qwen2.5-32b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ python3 -m verl.trainer.main_ppo \
trainer.nnodes=4 \
trainer.save_freq=-1 \
trainer.test_freq=10 \
trainer.total_epochs=15
trainer.total_epochs=15 $@
3 changes: 1 addition & 2 deletions verl/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
A unified tracking interface that supports logging data to different backend
"""

import wandb

from typing import List, Union


Expand All @@ -32,6 +30,7 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li
self.logger = {}

if 'tracking' in default_backend:
import wandb
wandb.init(project=project_name, name=experiment_name, config=config)
self.logger['tracking'] = wandb

Expand Down

0 comments on commit 7f8de22

Please sign in to comment.