diff --git a/clean_pufferl.py b/clean_pufferl.py index 247a8127..e0f8b34f 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -53,7 +53,10 @@ def create(config, vecenv, policy, optimizer=None, wandb=None): if config.compile: policy = torch.compile(policy, mode=config.compile_mode) - optimizer = torch.optim.Adam(policy.parameters(), + if optimizer is None: + optimizer = torch.optim.Adam + + optimizer = optimizer(policy.parameters(), lr=config.learning_rate, eps=1e-5) return pufferlib.namespace(