Skip to content

Commit

Permalink
Merge pull request #301 from VectorInstitute/sa_early_stop
Browse files Browse the repository at this point in the history
Add early stop module
  • Loading branch information
sanaAyrml authored Jan 25, 2025
2 parents 8dcf29b + 171fcc1 commit 1866aab
Show file tree
Hide file tree
Showing 5 changed files with 647 additions and 5 deletions.
34 changes: 30 additions & 4 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from flwr.common.typing import Config, NDArrays, Scalar
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader

from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule
Expand All @@ -27,6 +27,7 @@
set_pack_losses_with_val_metrics,
)
from fl4health.utils.config import narrow_dict_type, narrow_dict_type_and_set_attribute
from fl4health.utils.early_stopper import EarlyStopper
from fl4health.utils.logging import LoggingMode
from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType, TrainingLosses
from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, Metric, MetricManager
Expand Down Expand Up @@ -117,6 +118,11 @@ def __init__(
self.num_val_samples: int
self.num_test_samples: int | None
self.learning_rate: float | None

# User can set the early stopper for the client by instantiating the EarlyStopper class
# and setting the patience and interval_steps attributes. The early stopper will be used to
# stop training if the validation loss does not improve for a certain number of steps.
self.early_stopper: EarlyStopper | None = None
# Config can contain max_num_validation_steps key, which determines an upper bound
# for the validation steps taken. If not specified, no upper bound will be enforced.
# By specifying this in the config we cannot guarantee the validation set is the same
Expand Down Expand Up @@ -160,8 +166,16 @@ def get_parameters(self, config: Config) -> NDArrays:
return FullParameterExchanger().push_parameters(self.model, config=config)
else:
assert self.model is not None and self.parameter_exchanger is not None
# If the client has early stopping module and the patience is None, we load the best saved state
# to send the best checkpointed local model's parameters to the server
self._maybe_load_saved_best_local_model_state()
return self.parameter_exchanger.push_parameters(self.model, config=config)

def _maybe_load_saved_best_local_model_state(self) -> None:
if self.early_stopper is not None and self.early_stopper.patience is None:
log(INFO, "Loading saved best model's state before sending model to server.")
self.early_stopper.load_snapshot(["model"])

def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None:
"""
Sets the local model parameters transferred from the server using a parameter exchanger to coordinate how
Expand Down Expand Up @@ -612,6 +626,7 @@ def train_by_epochs(
self.model.train()
steps_this_round = 0 # Reset number of steps this round
report_data: dict[str, Any] = {"round": current_round}
continue_training = True
for local_epoch in range(epochs):
self.train_metric_manager.clear()
self.train_loss_meter.clear()
Expand Down Expand Up @@ -641,6 +656,11 @@ def train_by_epochs(
self.reports_manager.report(report_data, current_round, self.total_epochs, self.total_steps)
self.total_steps += 1
steps_this_round += 1
if self.early_stopper is not None and self.early_stopper.should_stop(steps_this_round):
log(INFO, "Early stopping criterion met. Stopping training.")
self.early_stopper.load_snapshot()
continue_training = False
break

# Log and report results
metrics = self.train_metric_manager.compute()
Expand All @@ -653,6 +673,9 @@ def train_by_epochs(
# Update internal epoch counter
self.total_epochs += 1

if not continue_training:
break

# Return final training metrics
return loss_dict, metrics

Expand Down Expand Up @@ -709,6 +732,10 @@ def train_by_steps(
report_data.update(self.get_client_specific_reports())
self.reports_manager.report(report_data, current_round, None, self.total_steps)
self.total_steps += 1
if self.early_stopper is not None and self.early_stopper.should_stop(step):
log(INFO, "Early stopping criterion met. Stopping training.")
self.early_stopper.load_snapshot()
break

loss_dict = self.train_loss_meter.compute().as_dict()
metrics = self.train_metric_manager.compute()
Expand Down Expand Up @@ -879,7 +906,6 @@ def setup_client(self, config: Config) -> None:
self.parameter_exchanger = self.get_parameter_exchanger(config)

self.reports_manager.report({"host_type": "client", "initialized": str(datetime.datetime.now())})

self.initialized = True

def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
Expand Down Expand Up @@ -1113,7 +1139,7 @@ def get_model(self, config: Config) -> nn.Module:
"""
raise NotImplementedError

def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler | None:
def get_lr_scheduler(self, optimizer_key: str, config: Config) -> LRScheduler | None:
"""
Optional user defined method that returns learning rate scheduler
to be used throughout training for the given optimizer. Defaults to None.
Expand All @@ -1125,7 +1151,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler |
config (Config): The config from the server.
Returns:
_LRScheduler | None: Client learning rate schedulers.
LRScheduler | None: Client learning rate schedulers.
"""
return None

Expand Down
188 changes: 188 additions & 0 deletions fl4health/utils/early_stopper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
from __future__ import annotations

import copy
from collections.abc import Callable
from logging import INFO, WARNING
from pathlib import Path
from typing import TYPE_CHECKING, Any

import torch.nn as nn
from flwr.common.logger import log
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler

from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer
from fl4health.reporting.reports_manager import ReportsManager
from fl4health.utils.logging import LoggingMode
from fl4health.utils.losses import TrainingLosses
from fl4health.utils.metrics import MetricManager
from fl4health.utils.snapshotter import (
AbstractSnapshotter,
LRSchedulerSnapshotter,
NumberSnapshotter,
OptimizerSnapshotter,
SerializableObjectSnapshotter,
T,
TorchModuleSnapshotter,
)

if TYPE_CHECKING:
from fl4health.clients.basic_client import BasicClient


class EarlyStopper:
def __init__(
self,
client: BasicClient,
patience: int | None = 1,
interval_steps: int = 5,
snapshot_dir: Path | None = None,
) -> None:
"""
Early stopping class is a plugin for the client that allows to stop local training based on the validation
loss. At each training step this class saves the best state of the client and restores it if the client is
stopped. If the client starts to overfit, the early stopper will stop the training process and restore the best
state of the client before sending the model to the server.
Args:
client (BasicClient): The client to be monitored.
patience (int, optional): Number of validation cycles to wait before stopping the training. If it is equal
to None client never stops, but still loads the best state before sending the model to the server.
Defaults to 1.
interval_steps (int): Specifies the frequency, in terms of training intervals, at which the early
stopping mechanism should evaluate the validation loss. Defaults to 5.
snapshot_dir (Path | None, optional): Rather than keeping best state in the memory we can checkpoint it to
the given directory. If it is not given, the best state is kept in the memory. Defaults to None.
"""

self.client = client

self.patience = patience
self.count_down = patience
self.interval_steps = interval_steps

self.best_score: float | None = None
self.snapshot_ckpt: dict[str, tuple[AbstractSnapshotter, Any]] = {}

self.snapshot_attrs: dict = {
"model": (TorchModuleSnapshotter(self.client), nn.Module),
"optimizers": (OptimizerSnapshotter(self.client), Optimizer),
"lr_schedulers": (
LRSchedulerSnapshotter(self.client),
LRScheduler,
),
"learning_rate": (NumberSnapshotter(self.client), float),
"total_steps": (NumberSnapshotter(self.client), int),
"total_epochs": (NumberSnapshotter(self.client), int),
"reports_manager": (
SerializableObjectSnapshotter(self.client),
ReportsManager,
),
"train_loss_meter": (
SerializableObjectSnapshotter(self.client),
TrainingLosses,
),
"train_metric_manager": (
SerializableObjectSnapshotter(self.client),
MetricManager,
),
}

if snapshot_dir is not None:
# TODO: Move to generic checkpointer
self.checkpointer = PerRoundStateCheckpointer(snapshot_dir)
self.checkpoint_name = f"temp_{self.client.client_name}.pt"
else:
log(INFO, "Snapshot is being persisted in memory")

def add_default_snapshot_attr(
self, name: str, snapshot_class: Callable[[BasicClient], AbstractSnapshotter], input_type: type[T]
) -> None:
self.snapshot_attrs.update({name: (snapshot_class(self.client), input_type)})

def delete_default_snapshot_attr(self, name: str) -> None:
del self.snapshot_attrs[name]

def save_snapshot(self) -> None:
"""
Creates a snapshot of the client state and if snapshot_ckpt is given, saves it to the checkpoint.
"""
for attr, (snapshotter_function, expected_type) in self.snapshot_attrs.items():
self.snapshot_ckpt.update(snapshotter_function.save(attr, expected_type))

if self.checkpointer is not None:
log(
INFO,
f"Saving client best state to checkpoint at {self.checkpointer.checkpoint_dir} "
f"with name {self.checkpoint_name}.",
)
self.checkpointer.save_checkpoint(self.checkpoint_name, self.snapshot_ckpt)
self.snapshot_ckpt.clear()

else:
log(
WARNING,
"Checkpointing directory is not provided. Client best state will be kept in the memory.",
)
self.snapshot_ckpt = copy.deepcopy(self.snapshot_ckpt)

def load_snapshot(self, attributes: list[str] | None = None) -> None:
"""
Load checkpointed snapshot dict consisting to the respective model attributes.
Args:
attributes (list[str] | None): List of attributes to load from the checkpoint.
If None, all attributes are loaded. Defaults to None.
"""
assert (
self.checkpointer.checkpoint_exists(self.checkpoint_name) or self.snapshot_ckpt != {}
), "No checkpoint to load"

if attributes is None:
attributes = list(self.snapshot_attrs.keys())

log(INFO, f"Loading client best state {attributes} from checkpoint at {self.checkpointer.checkpoint_dir}")

if self.checkpointer.checkpoint_exists(self.checkpoint_name):
self.snapshot_ckpt = self.checkpointer.load_checkpoint(self.checkpoint_name)

for attr in attributes:
snapshotter, expected_type = self.snapshot_attrs[attr]
snapshotter.load(self.snapshot_ckpt, attr, expected_type)

def should_stop(self, steps: int) -> bool:
"""
Determine if the client should stop training based on early stopping criteria.
Args:
steps (int): Number of steps since the start of the training.
Returns:
bool: True if training should stop, otherwise False.
"""
if steps % self.interval_steps != 0:
return False

val_loss, _ = self.client._validate_or_test(
loader=self.client.val_loader,
loss_meter=self.client.val_loss_meter,
metric_manager=self.client.val_metric_manager,
logging_mode=LoggingMode.EARLY_STOP_VALIDATION,
include_losses_in_metrics=False,
)

if val_loss is None:
return False

if self.best_score is None or val_loss < self.best_score:
self.best_score = val_loss
self.count_down = self.patience
self.save_snapshot()
return False

if self.count_down is not None:
self.count_down -= 1
if self.count_down <= 0:
return True

return False
3 changes: 2 additions & 1 deletion fl4health/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import Enum


class LoggingMode(Enum):
class LoggingMode(str, Enum):
TRAIN = "Training"
EARLY_STOP_VALIDATION = "Early_Stop_Validation"
VALIDATION = "Validation"
TEST = "Testing"
Loading

0 comments on commit 1866aab

Please sign in to comment.