Skip to content

Commit

Permalink
Merge pull request #274 from VectorInstitute/lr-scheduler-update
Browse files Browse the repository at this point in the history
Modify curr step of LR Scheduler to start at 0 instead of 1
  • Loading branch information
jewelltaylor authored Oct 31, 2024
2 parents 8e2baed + 5a6ac7a commit a2a21bc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
8 changes: 5 additions & 3 deletions fl4health/utils/nnunet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,20 @@ def get_lr(self) -> Sequence[float]:
Returns:
Sequence[float]: A uniform sequence of LR for each of the parameter groups in the optimizer.
"""
if self._step_count == self.max_steps + 1:

if self._step_count - 1 == self.max_steps + 1:
log(
WARN,
f"Current LR step of {self._step_count} reached Max Steps of {self.max_steps}. LR will remain fixed.",
)

curr_step = min(self._step_count, self.max_steps)
# Subtract 1 from step count since it starts at 1 (imposed by PyTorch)
curr_step = min(self._step_count - 1, self.max_steps)
curr_window = int(curr_step / self.steps_per_lr)

new_lr = self.initial_lr * (1 - curr_window / self.num_windows) ** self.exponent

if curr_step % self.steps_per_lr == 0 and curr_step != 0 and curr_step != self.max_steps:
log(INFO, f"Decaying LR of optimizer to {new_lr}")
log(INFO, f"Decaying LR of optimizer to {new_lr} at step {curr_step}")

return [new_lr] * len(self.optimizer.param_groups)
13 changes: 7 additions & 6 deletions tests/utils/nnnunet_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@


def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None:
caplog.set_level(logging.WARNING)
pattern = r"Current LR step of \d+ reached Max Steps of \d+. LR will remain fixed."

max_steps = 100
exponent = 1
steps_per_lr = 10
Expand All @@ -23,8 +26,8 @@ def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None:
assert lr_scheduler.num_windows == 10.0
assert lr_scheduler.initial_lr == initial_lr

prev_lr = initial_lr
for step in range(1, max_steps + 1):
prev_lr = None
for step in range(max_steps):
curr_lr = lr_scheduler.get_lr()[0]

if step % steps_per_lr == 0:
Expand All @@ -34,12 +37,10 @@ def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None:

prev_lr = curr_lr

if step < max_steps:
lr_scheduler.step()
lr_scheduler.step()

caplog.set_level(logging.WARNING)
assert not re.search(pattern, caplog.text)

lr_scheduler.step()

pattern = r"Current LR step of \d+ reached Max Steps of \d+. LR will remain fixed."
assert re.search(pattern, caplog.text)

0 comments on commit a2a21bc

Please sign in to comment.