diff --git a/fl4health/utils/nnunet_utils.py b/fl4health/utils/nnunet_utils.py index db86bedb0..0dd4447c8 100644 --- a/fl4health/utils/nnunet_utils.py +++ b/fl4health/utils/nnunet_utils.py @@ -430,8 +430,10 @@ def __init__( initial_lr (float): The initial learning rate of the optimizer. max_steps (int): The maximum total number of steps across all FL rounds. exponent (float): Controls how quickly LR descreases over time. Higher values - lead to more rapdid descent. + lead to more rapdid descent. Defaults to 0.9. steps_per_lr (int): The number of steps per LR before decaying. + (ie 10 means the LR will be constant for 10 steps prior to being decreased to the subsequent value). + Defaults to 250 as that is the default for nnunet (decay LR once an epoch and epoch is 250 steps). """ self.optimizer = optimizer self.initial_lr = initial_lr