Skip to content

Commit

Permalink
Merge pull request #268 from VectorInstitute/update-nnunet-lr-scheduler
Browse files Browse the repository at this point in the history
Modify LR Scheduler + Test
  • Loading branch information
scarere authored Oct 29, 2024
2 parents 1507e4c + 1e8f31c commit c0e177c
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
44 changes: 41 additions & 3 deletions fl4health/utils/nnunet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import warnings
from enum import Enum
from importlib import reload
from logging import DEBUG, INFO, Logger
from logging import DEBUG, INFO, WARN, Logger
from math import ceil
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, no_type_check

import numpy as np
Expand Down Expand Up @@ -414,19 +415,56 @@ def flush(self) -> None:

class PolyLRSchedulerWrapper(_LRScheduler):
def __init__(
self, optimizer: torch.optim.Optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9
self,
optimizer: torch.optim.Optimizer,
initial_lr: float,
max_steps: int,
exponent: float = 0.9,
steps_per_lr: int = 250,
) -> None:
"""
Learning rate (LR) scheduler with polynomial decay across fixed windows of size steps_per_lr.
Args:
optimizer (Optimizer): The optimizer to apply LR scheduler to.
initial_lr (float): The initial learning rate of the optimizer.
max_steps (int): The maximum total number of steps across all FL rounds.
exponent (float): Controls how quickly LR descreases over time. Higher values
lead to more rapdid descent.
steps_per_lr (int): The number of steps per LR before decaying.
"""
self.optimizer = optimizer
self.initial_lr = initial_lr
self.max_steps = max_steps
self.exponent = exponent
self.steps_per_lr = steps_per_lr
# Number of windows with constant LR across training
self.num_windows = ceil(max_steps / self.steps_per_lr)
self._step_count: int
super().__init__(optimizer, -1, False)

# mypy incorrectly infers get_lr returns a float
# Documented issue https://github.com/pytorch/pytorch/issues/100804
@no_type_check
def get_lr(self) -> Sequence[float]:
"""
Get the current LR of the scheduler.
Returns:
Sequence[float]: A uniform sequence of LR for each of the parameter groups in the optimizer.
"""
if self._step_count == self.max_steps + 1:
log(
WARN,
f"Current LR step of {self._step_count} reached Max Steps of {self.max_steps}. LR will remain fixed.",
)

curr_step = min(self._step_count, self.max_steps)
new_lr = self.initial_lr * (1 - curr_step / self.max_steps) ** self.exponent
curr_window = int(curr_step / self.steps_per_lr)

new_lr = self.initial_lr * (1 - curr_window / self.num_windows) ** self.exponent

if curr_step % self.steps_per_lr == 0 and curr_step != 0 and curr_step != self.max_steps:
log(INFO, f"Decaying LR of optimizer to {new_lr}")

return [new_lr] * len(self.optimizer.param_groups)
45 changes: 45 additions & 0 deletions tests/utils/nnnunet_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import logging
import re

import pytest
from torch.optim import SGD

from fl4health.utils.nnunet_utils import PolyLRSchedulerWrapper
from tests.test_utils.models_for_test import MnistNetWithBnAndFrozen


def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None:
max_steps = 100
exponent = 1
steps_per_lr = 10
initial_lr = 0.5

model = MnistNetWithBnAndFrozen()
opt = SGD(model.parameters(), lr=initial_lr)
lr_scheduler = PolyLRSchedulerWrapper(
optimizer=opt, max_steps=max_steps, initial_lr=initial_lr, exponent=exponent, steps_per_lr=steps_per_lr
)

assert lr_scheduler.num_windows == 10.0
assert lr_scheduler.initial_lr == initial_lr

prev_lr = initial_lr
for step in range(1, max_steps + 1):
curr_lr = lr_scheduler.get_lr()[0]

if step % steps_per_lr == 0:
assert curr_lr != prev_lr
else:
assert curr_lr == prev_lr

prev_lr = curr_lr

if step < max_steps:
lr_scheduler.step()

caplog.set_level(logging.WARNING)

lr_scheduler.step()

pattern = r"Current LR step of \d+ reached Max Steps of \d+. LR will remain fixed."
assert re.search(pattern, caplog.text)

0 comments on commit c0e177c

Please sign in to comment.