From 08278e526a95b48522e569c29e9cad048e70ca8e Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 29 Dec 2024 09:30:48 +0100 Subject: [PATCH] fix(clean_pufferl): don't ignore `optimizer` arg --- clean_pufferl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/clean_pufferl.py b/clean_pufferl.py index 247a8127..e0f8b34f 100644 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -53,7 +53,10 @@ def create(config, vecenv, policy, optimizer=None, wandb=None): if config.compile: policy = torch.compile(policy, mode=config.compile_mode) - optimizer = torch.optim.Adam(policy.parameters(), + if optimizer is None: + optimizer = torch.optim.Adam + + optimizer = optimizer(policy.parameters(), lr=config.learning_rate, eps=1e-5) return pufferlib.namespace(