From 1b83471693ce455c89622f304a8ea4d9d0430cf0 Mon Sep 17 00:00:00 2001 From: jewelltaylor Date: Mon, 28 Oct 2024 20:09:12 -0400 Subject: [PATCH 1/4] Modfiy LR scheduler to have fixed values over defined window size andd added test. --- fl4health/utils/nnunet_utils.py | 40 ++++++++++++++++++++++++-- tests/utils/nnnunet_utils_test.py | 48 +++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 3 deletions(-) create mode 100644 tests/utils/nnnunet_utils_test.py diff --git a/fl4health/utils/nnunet_utils.py b/fl4health/utils/nnunet_utils.py index 84844be6d..3514070de 100644 --- a/fl4health/utils/nnunet_utils.py +++ b/fl4health/utils/nnunet_utils.py @@ -5,7 +5,7 @@ import warnings from enum import Enum from importlib import reload -from logging import DEBUG, INFO, Logger +from logging import DEBUG, INFO, WARN, Logger from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, no_type_check import numpy as np @@ -414,12 +414,31 @@ def flush(self) -> None: class PolyLRSchedulerWrapper(_LRScheduler): def __init__( - self, optimizer: torch.optim.Optimizer, initial_lr: float, max_steps: int, exponent: float = 0.9 + self, + optimizer: torch.optim.Optimizer, + initial_lr: float, + max_steps: int, + exponent: float = 0.9, + steps_per_lr: int = 250, ) -> None: + """ + Learning rate (LR) scheduler with polynomial decay across fixed windows of size steps_per_lr. + + Args: + optimizer (Optimizer): The optimizer to apply LR scheduler to. + initial_lr (float): The initial learning rate of the optimizer. + max_steps (int): The maximum total number of steps across all FL rounds. + exponent (float): Controls how quickly LR descreases over time. Higher values + lead to more rapdid descent. + steps_per_lr (int): The number of steps per LR before decaying. + """ self.optimizer = optimizer self.initial_lr = initial_lr self.max_steps = max_steps self.exponent = exponent + self.steps_per_lr = steps_per_lr + # Number of windows with constant LR across training + self.num_windows = round(max_steps / self.steps_per_lr) self._step_count: int super().__init__(optimizer, -1, False) @@ -427,6 +446,21 @@ def __init__( # Documented issue https://github.com/pytorch/pytorch/issues/100804 @no_type_check def get_lr(self) -> Sequence[float]: + """ + Get the current LR of the scheduler. + + Returns: + Sequence[float]: A uniform sequence of LR for each of the parameter groups in the optimizer. + """ + if self._step_count > self.max_steps: + log(WARN, f"Current LR step of {self._step_count} exceeds Max Steps of {self.max_steps}") + curr_step = min(self._step_count, self.max_steps) - new_lr = self.initial_lr * (1 - curr_step / self.max_steps) ** self.exponent + curr_window = int(curr_step / self.steps_per_lr) + + new_lr = self.initial_lr * (1 - curr_window / self.num_windows) ** self.exponent + + if curr_step % self.steps_per_lr == 0 and curr_step != 0 and curr_step != self.max_steps: + log(INFO, f"Decaying LR of optimizer to {new_lr}") + return [new_lr] * len(self.optimizer.param_groups) diff --git a/tests/utils/nnnunet_utils_test.py b/tests/utils/nnnunet_utils_test.py new file mode 100644 index 000000000..9773970d2 --- /dev/null +++ b/tests/utils/nnnunet_utils_test.py @@ -0,0 +1,48 @@ +import pytest +import re +import logging +from torch.optim import SGD + +from tests.test_utils.models_for_test import MnistNetWithBnAndFrozen +from fl4health.utils.nnunet_utils import PolyLRSchedulerWrapper + + +def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None: + max_steps = 1000 + exponent = 1 + steps_per_lr = 100 + initial_lr = 0.5 + + model = MnistNetWithBnAndFrozen() + opt = SGD(model.parameters(), lr=initial_lr) + lr_scheduler = PolyLRSchedulerWrapper( + optimizer=opt, + max_steps=max_steps, + initial_lr=initial_lr, + exponent=exponent, + steps_per_lr=steps_per_lr + ) + + assert lr_scheduler.num_windows == round(max_steps / steps_per_lr) + assert lr_scheduler.initial_lr == initial_lr + + prev_lr = initial_lr + for step in range(1, max_steps + 1): + curr_lr = lr_scheduler.get_lr()[0] + + if step % steps_per_lr == 0: + assert curr_lr != prev_lr + else: + assert curr_lr == prev_lr + + prev_lr = curr_lr + + if step < max_steps: + lr_scheduler.step() + + caplog.set_level(logging.WARNING) + + lr_scheduler.step() + + pattern = r"Current LR step of \d+ exceeds Max Steps of \d+" + assert re.search(pattern, caplog.text) From 3c8e2eb49e7dd16ae33a732a1cc91228f04bea53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Oct 2024 00:13:10 +0000 Subject: [PATCH 2/4] [pre-commit.ci] Add auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utils/nnnunet_utils_test.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tests/utils/nnnunet_utils_test.py b/tests/utils/nnnunet_utils_test.py index 9773970d2..5311d54e8 100644 --- a/tests/utils/nnnunet_utils_test.py +++ b/tests/utils/nnnunet_utils_test.py @@ -1,10 +1,11 @@ -import pytest -import re import logging +import re + +import pytest from torch.optim import SGD -from tests.test_utils.models_for_test import MnistNetWithBnAndFrozen from fl4health.utils.nnunet_utils import PolyLRSchedulerWrapper +from tests.test_utils.models_for_test import MnistNetWithBnAndFrozen def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None: @@ -16,11 +17,7 @@ def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None: model = MnistNetWithBnAndFrozen() opt = SGD(model.parameters(), lr=initial_lr) lr_scheduler = PolyLRSchedulerWrapper( - optimizer=opt, - max_steps=max_steps, - initial_lr=initial_lr, - exponent=exponent, - steps_per_lr=steps_per_lr + optimizer=opt, max_steps=max_steps, initial_lr=initial_lr, exponent=exponent, steps_per_lr=steps_per_lr ) assert lr_scheduler.num_windows == round(max_steps / steps_per_lr) From bb84076b1214ae947e24ba10bf84e5d0b501c769 Mon Sep 17 00:00:00 2001 From: jewelltaylor Date: Mon, 28 Oct 2024 20:34:56 -0400 Subject: [PATCH 3/4] Round up when calculating number of windows to ensure no divide by zero --- fl4health/utils/nnunet_utils.py | 3 ++- tests/utils/nnnunet_utils_test.py | 13 +++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/fl4health/utils/nnunet_utils.py b/fl4health/utils/nnunet_utils.py index 3514070de..606a5762a 100644 --- a/fl4health/utils/nnunet_utils.py +++ b/fl4health/utils/nnunet_utils.py @@ -6,6 +6,7 @@ from enum import Enum from importlib import reload from logging import DEBUG, INFO, WARN, Logger +from math import ceil from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, no_type_check import numpy as np @@ -438,7 +439,7 @@ def __init__( self.exponent = exponent self.steps_per_lr = steps_per_lr # Number of windows with constant LR across training - self.num_windows = round(max_steps / self.steps_per_lr) + self.num_windows = ceil(max_steps / self.steps_per_lr) self._step_count: int super().__init__(optimizer, -1, False) diff --git a/tests/utils/nnnunet_utils_test.py b/tests/utils/nnnunet_utils_test.py index 9773970d2..5311d54e8 100644 --- a/tests/utils/nnnunet_utils_test.py +++ b/tests/utils/nnnunet_utils_test.py @@ -1,10 +1,11 @@ -import pytest -import re import logging +import re + +import pytest from torch.optim import SGD -from tests.test_utils.models_for_test import MnistNetWithBnAndFrozen from fl4health.utils.nnunet_utils import PolyLRSchedulerWrapper +from tests.test_utils.models_for_test import MnistNetWithBnAndFrozen def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None: @@ -16,11 +17,7 @@ def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None: model = MnistNetWithBnAndFrozen() opt = SGD(model.parameters(), lr=initial_lr) lr_scheduler = PolyLRSchedulerWrapper( - optimizer=opt, - max_steps=max_steps, - initial_lr=initial_lr, - exponent=exponent, - steps_per_lr=steps_per_lr + optimizer=opt, max_steps=max_steps, initial_lr=initial_lr, exponent=exponent, steps_per_lr=steps_per_lr ) assert lr_scheduler.num_windows == round(max_steps / steps_per_lr) From 1e8f31c71cd8ac04c921f01d6af649ca2c5e5e37 Mon Sep 17 00:00:00 2001 From: jewelltaylor Date: Tue, 29 Oct 2024 09:57:18 -0400 Subject: [PATCH 4/4] Address CR by David Only log warning on first occurrence of step count exceeding max steps. Update test parameters. --- fl4health/utils/nnunet_utils.py | 7 +++++-- tests/utils/nnnunet_utils_test.py | 8 ++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/fl4health/utils/nnunet_utils.py b/fl4health/utils/nnunet_utils.py index 606a5762a..db86bedb0 100644 --- a/fl4health/utils/nnunet_utils.py +++ b/fl4health/utils/nnunet_utils.py @@ -453,8 +453,11 @@ def get_lr(self) -> Sequence[float]: Returns: Sequence[float]: A uniform sequence of LR for each of the parameter groups in the optimizer. """ - if self._step_count > self.max_steps: - log(WARN, f"Current LR step of {self._step_count} exceeds Max Steps of {self.max_steps}") + if self._step_count == self.max_steps + 1: + log( + WARN, + f"Current LR step of {self._step_count} reached Max Steps of {self.max_steps}. LR will remain fixed.", + ) curr_step = min(self._step_count, self.max_steps) curr_window = int(curr_step / self.steps_per_lr) diff --git a/tests/utils/nnnunet_utils_test.py b/tests/utils/nnnunet_utils_test.py index 5311d54e8..99c2ddb97 100644 --- a/tests/utils/nnnunet_utils_test.py +++ b/tests/utils/nnnunet_utils_test.py @@ -9,9 +9,9 @@ def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None: - max_steps = 1000 + max_steps = 100 exponent = 1 - steps_per_lr = 100 + steps_per_lr = 10 initial_lr = 0.5 model = MnistNetWithBnAndFrozen() @@ -20,7 +20,7 @@ def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None: optimizer=opt, max_steps=max_steps, initial_lr=initial_lr, exponent=exponent, steps_per_lr=steps_per_lr ) - assert lr_scheduler.num_windows == round(max_steps / steps_per_lr) + assert lr_scheduler.num_windows == 10.0 assert lr_scheduler.initial_lr == initial_lr prev_lr = initial_lr @@ -41,5 +41,5 @@ def test_poly_lr_scheduler(caplog: pytest.LogCaptureFixture) -> None: lr_scheduler.step() - pattern = r"Current LR step of \d+ exceeds Max Steps of \d+" + pattern = r"Current LR step of \d+ reached Max Steps of \d+. LR will remain fixed." assert re.search(pattern, caplog.text)