diff --git a/examples/feddg_ga_example/README.md b/examples/feddg_ga_example/README.md index adf3956f7..7f677ad7e 100644 --- a/examples/feddg_ga_example/README.md +++ b/examples/feddg_ga_example/README.md @@ -29,6 +29,7 @@ from the FL4Health directory. The following arguments must be present in the spe * `batch_size`: size of the batches each client will train on * `n_server_rounds`: The number of rounds to run FL * `evaluate_after_fit`: Should be set to `True`. Performs an evaluation at the end of each client's fit round. +* `pack_losses_with_val_metrics`: Should be set to `True`. Includes validation losses with metrics calculations ## Starting Clients diff --git a/examples/feddg_ga_example/config.yaml b/examples/feddg_ga_example/config.yaml index 50c7aa8c0..71fb1e2cb 100644 --- a/examples/feddg_ga_example/config.yaml +++ b/examples/feddg_ga_example/config.yaml @@ -5,4 +5,7 @@ n_server_rounds: 3 # The number of rounds to run FL n_clients: 2 # The number of clients in the FL experiment local_steps: 5 # The number of local steps (one per batch) to complete for client batch_size: 128 # The batch size for client training -evaluate_after_fit: True # Evaluates model immediately after local training on the validation set (in addition to the training set) +# Evaluates model immediately after local training on the validation set (in addition to the training set) +evaluate_after_fit: True +# Packs the measured validation losses with the metrics (required for fed-dgga) +pack_losses_with_val_metrics: True diff --git a/examples/feddg_ga_example/server.py b/examples/feddg_ga_example/server.py index 64c3429c1..1964666e9 100644 --- a/examples/feddg_ga_example/server.py +++ b/examples/feddg_ga_example/server.py @@ -11,7 +11,7 @@ from fl4health.model_bases.apfl_base import ApflModule from fl4health.reporting import JsonReporter from fl4health.server.base_server import FlServer -from fl4health.strategies.feddg_ga_strategy import FedDgGaStrategy +from fl4health.strategies.feddg_ga import FedDgGa from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -25,6 +25,7 @@ def fit_config( local_epochs: Optional[int] = None, local_steps: Optional[int] = None, evaluate_after_fit: bool = False, + pack_losses_with_val_metrics: bool = False, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -32,6 +33,7 @@ def fit_config( "batch_size": batch_size, "n_server_rounds": n_server_rounds, "evaluate_after_fit": evaluate_after_fit, + "pack_losses_with_val_metrics": pack_losses_with_val_metrics, } @@ -44,12 +46,13 @@ def main(config: Dict[str, Any]) -> None: local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), evaluate_after_fit=config.get("evaluate_after_fit", False), + pack_losses_with_val_metrics=config.get("pack_losses_with_val_metrics", False), ) initial_model = ApflModule(MnistNetWithBnAndFrozen()) # Implementation of FedDG-GA as a server side strategy - strategy = FedDgGaStrategy( + strategy = FedDgGa( min_fit_clients=config["n_clients"], min_evaluate_clients=config["n_clients"], # Server waits for min_available_clients before starting FL rounds diff --git a/examples/fenda_ditto_example/client.py b/examples/fenda_ditto_example/client.py index 1ae7d2a1b..5d555282b 100644 --- a/examples/fenda_ditto_example/client.py +++ b/examples/fenda_ditto_example/client.py @@ -21,7 +21,7 @@ from fl4health.clients.fenda_ditto_client import FendaDittoClient from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.parallel_split_models import ParallelFeatureJoinMode -from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel +from fl4health.model_bases.sequential_split_models import SequentiallySplitModel from fl4health.reporting import JsonReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_mnist_data @@ -37,8 +37,8 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) return train_loader, val_loader - def get_global_model(self, config: Config) -> SequentiallySplitExchangeBaseModel: - return SequentiallySplitExchangeBaseModel( + def get_global_model(self, config: Config) -> SequentiallySplitModel: + return SequentiallySplitModel( base_module=SequentialGlobalFeatureExtractorMnist(), head_module=SequentialLocalPredictionHeadMnist(), ).to(self.device) diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index bfed7f1a7..f66f4ec97 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -1,23 +1,18 @@ -import copy import datetime -import os -from collections.abc import Iterable, Sequence -from enum import Enum -from inspect import currentframe, getframeinfo -from logging import INFO, WARNING, LogRecord +from collections.abc import Sequence +from logging import INFO, WARNING from pathlib import Path -from typing import Any, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn from flwr.client import NumPyClient -from flwr.common.logger import LOGGER_NAME, console_handler, log +from flwr.common.logger import log from flwr.common.typing import Config, NDArrays, Scalar from torch.nn.modules.loss import _Loss from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from tqdm import tqdm from fl4health.checkpointing.checkpointer import PerRoundCheckpointer from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule @@ -25,19 +20,21 @@ from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.reporting.base_reporter import BaseReporter from fl4health.reporting.reports_manager import ReportsManager +from fl4health.utils.client import ( + check_if_batch_is_empty_and_verify_input, + fold_loss_dict_into_metrics, + maybe_progress_bar, + move_data_to_device, + set_pack_losses_with_val_metrics, +) from fl4health.utils.config import narrow_dict_type, narrow_dict_type_and_set_attribute +from fl4health.utils.logging import LoggingMode from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType, TrainingLosses from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, Metric, MetricManager from fl4health.utils.random import generate_hash from fl4health.utils.typing import LogLevel, TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType -class LoggingMode(Enum): - TRAIN = "Training" - VALIDATION = "Validation" - TEST = "Testing" - - class BasicClient(NumPyClient): def __init__( self, @@ -216,7 +213,7 @@ def shutdown(self) -> None: self.reports_manager.report({"shutdown": str(datetime.datetime.now())}) self.reports_manager.shutdown() - def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool]: + def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool, bool]: """ Method to ensure the required keys are present in config and extracts values to be returned. @@ -250,8 +247,10 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N except ValueError: evaluate_after_fit = False + pack_losses_with_val_metrics = set_pack_losses_with_val_metrics(config) + # Either local epochs or local steps is none based on what key is passed in the config - return local_epochs, local_steps, current_server_round, evaluate_after_fit + return local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict[str, Scalar]]: """ @@ -271,7 +270,9 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict ValueError: If local_steps or local_epochs is not specified in config. """ round_start_time = datetime.datetime.now() - local_epochs, local_steps, current_server_round, evaluate_after_fit = self.process_config(config) + local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics = ( + self.process_config(config) + ) if not self.initialized: self.setup_client(config) @@ -301,7 +302,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict # Check if we should run an evaluation with validation data after fit # (for example, this is used by FedDGGA) if self._should_evaluate_after_fit(evaluate_after_fit): - validation_loss, validation_metrics = self.evaluate_after_fit() + validation_loss, validation_metrics = self.validate(pack_losses_with_val_metrics) metrics.update(validation_metrics) # We perform a pre-aggregation checkpoint if applicable self._maybe_checkpoint(validation_loss, validation_metrics, CheckpointMode.PRE_AGGREGATION) @@ -332,22 +333,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict metrics, ) - def evaluate_after_fit(self) -> Tuple[float, dict[str, Scalar]]: - """ - Run self.validate right after fit to collect metrics on the local model against validation data. - - Returns: (dict[str, Scalar]) a dictionary with the metrics. - - """ - loss, metric_values = self.validate() - # The computed loss value is packed into the metrics dictionary, perhaps for use on the server-side - metrics_after_fit = { - **metric_values, # type: ignore - "val - loss": loss, - } - return loss, metrics_after_fit - - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: """ Evaluates the model on the validation set, and test set (if defined). @@ -365,8 +351,10 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, di start_time = datetime.datetime.now() current_server_round = narrow_dict_type(config, "current_server_round", int) + pack_losses_with_val_metrics = set_pack_losses_with_val_metrics(config) + self.set_parameters(parameters, config, fitting_round=False) - loss, metrics = self.validate() + loss, metrics = self.validate(pack_losses_with_val_metrics) end_time = datetime.datetime.now() elapsed = end_time - start_time @@ -519,69 +507,6 @@ def get_client_specific_reports(self) -> dict[str, Any]: """ return {} - def _move_data_to_device( - self, data: Union[TorchInputType, TorchTargetType] - ) -> Union[TorchTargetType, TorchInputType]: - """ - Moving data to self.device where data is intended to be either input to - the model or the targets that the model is trying to achieve - - Args: - data (TorchInputType | TorchTargetType): The data to move to - self.device. Can be a TorchInputType or a TorchTargetType - - Raises: - TypeError: Raised if data is not one of the types specified by - TorchInputType or TorchTargetType - - Returns: - Union[TorchTargetType, TorchInputType]: The data argument except now it's been moved to self.device - """ - # Currently we expect both inputs and targets to be either tensors - # or dictionaries of tensors - if isinstance(data, torch.Tensor): - return data.to(self.device) - elif isinstance(data, dict): - return {key: value.to(self.device) for key, value in data.items()} - else: - raise TypeError( - "data must be of type torch.Tensor or dict[str, torch.Tensor]. \ - If definition of TorchInputType or TorchTargetType has \ - changed this method might need to be updated or split into \ - two" - ) - - def is_empty_batch(self, input: Union[torch.Tensor, dict[str, torch.Tensor]]) -> bool: - """ - Check whether input, which represents a batch of inputs to a model, is empty. - - Args: - input (Union[torch.Tensor, dict[str, torch.Tensor]]): input batch. - input can be of type torch.Tensor or dict[str, torch.Tensor], and in the - latter case, the batch is considered to be empty if all tensors in the dictionary - have length zero. - - Raises: - TypeError: raised if input is not of type torch.Tensor or dict[str, torch.Tensor]. - ValueError: raised if input has type dict[str, torch.Tensor] and not all tensors - within the dictionary have the same size. - - Returns: - bool: True if input is an empty batch. - """ - if isinstance(input, torch.Tensor): - return len(input) == 0 - elif isinstance(input, dict): - input_iter = iter(input.items()) - _, first_val = next(input_iter) - first_val_len = len(first_val) - if not all(len(val) == first_val_len for _, val in input_iter): - raise ValueError("Not all tensors in the dictionary have the same size.") - else: - return first_val_len == 0 - else: - raise TypeError("Input must be of type torch.Tensor or dict[str, torch.Tensor].") - def update_metric_manager( self, preds: TorchPredType, @@ -683,16 +608,16 @@ def train_by_epochs( self.update_before_epoch(epoch=local_epoch) # Update report data dict report_data.update({"fit_epoch": local_epoch}) - for input, target in self.maybe_progress_bar(self.train_loader): + for input, target in maybe_progress_bar(self.train_loader, self.progress_bar): self.update_before_step(steps_this_round, current_round) # Assume first dimension is batch size. Sampling iterators (such as Poisson batch sampling), can # construct empty batches. We skip the iteration if this occurs. - if self.is_empty_batch(input): + if check_if_batch_is_empty_and_verify_input(input): log(INFO, "Empty batch generated by data loader. Skipping step.") continue - input = self._move_data_to_device(input) - target = self._move_data_to_device(target) + input = move_data_to_device(input, self.device) + target = move_data_to_device(target, self.device) losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) self.update_metric_manager(preds, target, self.train_metric_manager) @@ -741,7 +666,7 @@ def train_by_steps( self.train_metric_manager.clear() self._log_header_str(current_round) report_data: dict[str, Any] = {"round": current_round} - for step in self.maybe_progress_bar(range(steps)): + for step in maybe_progress_bar(range(steps), self.progress_bar): self.update_before_step(step, current_round) try: @@ -754,12 +679,12 @@ def train_by_steps( # Assume first dimension is batch size. Sampling iterators (such as Poisson batch sampling), can # construct empty batches. We skip the iteration if this occurs. - if self.is_empty_batch(input): + if check_if_batch_is_empty_and_verify_input(input): log(INFO, "Empty batch generated by data loader. Skipping step.") continue - input = self._move_data_to_device(input) - target = self._move_data_to_device(target) + input = move_data_to_device(input, self.device) + target = move_data_to_device(target, self.device) losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) self.update_metric_manager(preds, target, self.train_metric_manager) @@ -784,7 +709,8 @@ def _validate_or_test( loss_meter: LossMeter, metric_manager: MetricManager, logging_mode: LoggingMode = LoggingMode.VALIDATION, - ) -> Tuple[float, dict[str, Scalar]]: + include_losses_in_metrics: bool = False, + ) -> Tuple[float, Dict[str, Scalar]]: """ Evaluate the model on the given validation or test dataset. @@ -792,8 +718,10 @@ def _validate_or_test( loader (DataLoader): The data loader for the dataset (validation or test). loss_meter (LossMeter): The meter to track the losses. metric_manager (MetricManager): The manager to track the metrics. - logging_mode (LoggingMode): The LoggingMode for whether this evaluation is for validation or test. - Default is for validation. + logging_mode (LoggingMode, optional): The LoggingMode for whether this evaluation is for validation or + test. Defaults to LoggingMode.VALIDATION. + include_losses_in_metrics (bool, optional): Whether or not to pack the additional losses into the metrics + dictionary. Defaults to False. Returns: Tuple[float, dict[str, Scalar]]: The loss and a dictionary of metrics from evaluation. @@ -806,9 +734,9 @@ def _validate_or_test( metric_manager.clear() loss_meter.clear() with torch.no_grad(): - for input, target in self.maybe_progress_bar(loader): - input = self._move_data_to_device(input) - target = self._move_data_to_device(target) + for input, target in maybe_progress_bar(loader, self.progress_bar): + input = move_data_to_device(input, self.device) + target = move_data_to_device(target, self.device) losses, preds = self.val_step(input, target) loss_meter.update(losses) self.update_metric_manager(preds, target, metric_manager) @@ -818,9 +746,12 @@ def _validate_or_test( metrics = metric_manager.compute() self._log_results(loss_dict, metrics, logging_mode=logging_mode) + if include_losses_in_metrics: + fold_loss_dict_into_metrics(metrics, loss_dict, logging_mode) + return loss_dict["checkpoint"], metrics - def validate(self) -> Tuple[float, dict[str, Scalar]]: + def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: """ Validate the current model on the entire validation and potentially an entire test dataset if it has been defined. @@ -829,13 +760,19 @@ def validate(self) -> Tuple[float, dict[str, Scalar]]: Tuple[float, dict[str, Scalar]]: The validation loss and a dictionary of metrics from validation (and test if present). """ - val_loss, val_metrics = self._validate_or_test(self.val_loader, self.val_loss_meter, self.val_metric_manager) + val_loss, val_metrics = self._validate_or_test( + self.val_loader, + self.val_loss_meter, + self.val_metric_manager, + include_losses_in_metrics=include_losses_in_metrics, + ) if self.test_loader: test_loss, test_metrics = self._validate_or_test( self.test_loader, self.test_loss_meter, self.test_metric_manager, LoggingMode.TEST, + include_losses_in_metrics=include_losses_in_metrics, ) # There will be no clashes due to the naming convention associated with the metric managers if self.num_test_samples is not None: @@ -1040,24 +977,6 @@ def set_optimizer(self, config: Config) -> None: assert not isinstance(optimizer, dict) self.optimizers = {"global": optimizer} - def clone_and_freeze_model(self, model: nn.Module) -> nn.Module: - """ - Creates a clone of the model with frozen weights to be used in loss calculations so the original model is - preserved in its current state. - - Args: - model (nn.Module): Model to clone and freeze - Returns: - nn.Module: Cloned and frozen model - """ - - cloned_model = copy.deepcopy(model) - for param in cloned_model.parameters(): - param.requires_grad = False - cloned_model.eval() - - return cloned_model - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, ...]: """ User defined method that returns a PyTorch Train DataLoader @@ -1248,48 +1167,6 @@ def update_before_epoch(self, epoch: int) -> None: """ pass - def maybe_progress_bar(self, iterable: Iterable) -> Iterable: - """ - Used to print progress bars during client training and validation. If - self.progress_bar is false, just returns the original input iterable - without modifying it. - - Args: - iterable (Iterable): The iterable to wrap - - Returns: - Iterable: an iterator which acts exactly like the original - iterable, but prints a dynamically updating progress bar every - time a value is requested. Or the original iterable if - self.progress_bar is False - """ - if not self.progress_bar: - return iterable - else: - # We can use the flwr console handler to format progress bar - frame = currentframe() - lineno = 0 if frame is None else getframeinfo(frame).lineno - record = LogRecord( - name=LOGGER_NAME, - pathname=os.path.abspath(os.getcwd()), - lineno=lineno, # - args={}, - exc_info=None, - level=INFO, - msg="{l_bar}{bar}{r_bar}", - ) - format = console_handler.format(record) - # Create a clean looking tqdm instance that matches the flwr logging - kwargs: Any = { - "leave": True, - "ascii": " >=", - # "desc": f"{LOG_COLORS['INFO']}INFO{LOG_COLORS['RESET']} ", - "unit": "steps", - "dynamic_ncols": True, - "bar_format": format, - } - return tqdm(iterable, **kwargs) - def transform_gradients(self, losses: TrainingLosses) -> None: """ Hook function for model training only called after backwards pass but before diff --git a/fl4health/clients/constrained_fenda_client.py b/fl4health/clients/constrained_fenda_client.py index 3d9febeff..61cd3d87b 100644 --- a/fl4health/clients/constrained_fenda_client.py +++ b/fl4health/clients/constrained_fenda_client.py @@ -12,6 +12,7 @@ from fl4health.losses.fenda_loss_config import ConstrainedFendaLossContainer from fl4health.model_bases.fenda_base import FendaModelWithFeatureState from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger +from fl4health.utils.client import clone_and_freeze_model from fl4health.utils.losses import EvaluationLosses, LossMeterType from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType @@ -140,8 +141,8 @@ def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], conf assert isinstance(self.model, FendaModelWithFeatureState) if self.loss_container.has_contrastive_loss() or self.loss_container.has_perfcl_loss(): - self.old_local_module = self.clone_and_freeze_model(self.model.first_feature_extractor) - self.old_global_module = self.clone_and_freeze_model(self.model.second_feature_extractor) + self.old_local_module = clone_and_freeze_model(self.model.first_feature_extractor) + self.old_global_module = clone_and_freeze_model(self.model.second_feature_extractor) super().update_after_train(local_steps, loss_dict, config) @@ -159,7 +160,7 @@ def update_before_train(self, current_server_round: int) -> None: assert isinstance(self.model, FendaModelWithFeatureState) if self.loss_container.has_perfcl_loss(): - self.initial_global_module = self.clone_and_freeze_model(self.model.second_feature_extractor) + self.initial_global_module = clone_and_freeze_model(self.model.second_feature_extractor) super().update_before_train(current_server_round) diff --git a/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py b/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py index 474456b04..55c3a9fc9 100644 --- a/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py +++ b/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py @@ -11,6 +11,7 @@ from fl4health.clients.ditto_client import DittoClient from fl4health.losses.deep_mmd_loss import DeepMmdLoss from fl4health.model_bases.feature_extractor_buffer import FeatureExtractorBuffer +from fl4health.utils.client import clone_and_freeze_model from fl4health.utils.losses import EvaluationLosses, LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType @@ -89,7 +90,7 @@ def update_before_train(self, current_server_round: int) -> None: assert isinstance(self.global_model, nn.Module) # Clone and freeze the initial weights GLOBAL MODEL. These are used to form the Ditto local # update penalty term. - self.initial_global_model = self.clone_and_freeze_model(self.global_model) + self.initial_global_model = clone_and_freeze_model(self.global_model) self.initial_global_feature_extractor = FeatureExtractorBuffer( model=self.initial_global_model, flatten_feature_extraction_layers=self.flatten_feature_extraction_layers, @@ -149,7 +150,7 @@ def _maybe_checkpoint(self, loss: float, metrics: Dict[str, Scalar], checkpoint_ # each time. self.local_feature_extractor._maybe_register_hooks() - def validate(self) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: """ Validate the current model on the entire validation dataset. @@ -158,7 +159,7 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]: """ for layer in self.flatten_feature_extraction_layers.keys(): self.deep_mmd_losses[layer].training = False - return super().validate() + return super().validate(include_losses_in_metrics) def compute_training_loss( self, diff --git a/fl4health/clients/ditto_client.py b/fl4health/clients/ditto_client.py index 27f443ced..12e894163 100644 --- a/fl4health/clients/ditto_client.py +++ b/fl4health/clients/ditto_client.py @@ -358,7 +358,7 @@ def compute_training_loss( return TrainingLosses(backward=loss + penalty_loss, additional_losses=additional_losses) - def validate(self) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: """ Validate the current model on the entire validation dataset. @@ -367,7 +367,7 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]: """ # Set the global model to evaluate mode self.global_model.eval() - return super().validate() + return super().validate(include_losses_in_metrics=include_losses_in_metrics) def compute_evaluation_loss( self, diff --git a/fl4health/clients/evaluate_client.py b/fl4health/clients/evaluate_client.py index c29c2b98b..39ac6f4c7 100644 --- a/fl4health/clients/evaluate_client.py +++ b/fl4health/clients/evaluate_client.py @@ -175,7 +175,7 @@ def validate_on_model( self._handle_logging(losses, metrics, is_global) return losses, metrics - def validate(self) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_loss_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: local_loss: Optional[EvaluationLosses] = None local_metrics: Optional[Dict[str, Scalar]] = None diff --git a/fl4health/clients/fedrep_client.py b/fl4health/clients/fedrep_client.py index 9181442b1..ed2aafa74 100644 --- a/fl4health/clients/fedrep_client.py +++ b/fl4health/clients/fedrep_client.py @@ -242,7 +242,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict # Check if we should run an evaluation with validation data after fit # (for example, this is used by FedDGGA) if self._should_evaluate_after_fit(evaluate_after_fit): - validation_loss, validation_metrics = self.evaluate_after_fit() + validation_loss, validation_metrics = self.validate() metrics.update(validation_metrics) # We perform a pre-aggregation checkpoint if applicable self._maybe_checkpoint(validation_loss, validation_metrics, CheckpointMode.PRE_AGGREGATION) diff --git a/fl4health/clients/fenda_ditto_client.py b/fl4health/clients/fenda_ditto_client.py index 9cad768fb..187e086a7 100644 --- a/fl4health/clients/fenda_ditto_client.py +++ b/fl4health/clients/fenda_ditto_client.py @@ -9,7 +9,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointModule from fl4health.clients.ditto_client import DittoClient from fl4health.model_bases.fenda_base import FendaModel -from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel +from fl4health.model_bases.sequential_split_models import SequentiallySplitModel from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType, TrainingLosses @@ -88,7 +88,7 @@ def __init__( reporters=reporters, progress_bar=progress_bar, ) - self.global_model: SequentiallySplitExchangeBaseModel + self.global_model: SequentiallySplitModel self.model: FendaModel self.freeze_global_feature_extractor = freeze_global_feature_extractor @@ -107,7 +107,7 @@ def get_model(self, config: Config) -> FendaModel: """ raise NotImplementedError("This function must be defined in the inheriting class to use this client") - def get_global_model(self, config: Config) -> SequentiallySplitExchangeBaseModel: + def get_global_model(self, config: Config) -> SequentiallySplitModel: """ User defined method that returns a Global Sequential Model that is compatible with the local FENDA model. @@ -115,7 +115,7 @@ def get_global_model(self, config: Config) -> SequentiallySplitExchangeBaseModel config (Config): The config from the server. Returns: - SequentiallySplitExchangeBaseModel: The global (Ditto) model. + SequentiallySplitModel: The global (Ditto) model. Raises: NotImplementedError: To be defined in child class. diff --git a/fl4health/clients/flash_client.py b/fl4health/clients/flash_client.py index da8150015..d3f8e1aa4 100644 --- a/fl4health/clients/flash_client.py +++ b/fl4health/clients/flash_client.py @@ -8,6 +8,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointModule from fl4health.clients.basic_client import BasicClient +from fl4health.utils.client import check_if_batch_is_empty_and_verify_input, move_data_to_device from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric @@ -47,14 +48,16 @@ def __init__( # gamma: Threshold for early stopping based on the change in validation loss. self.gamma: Optional[float] = None - def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool]: - local_epochs, local_steps, current_server_round, evaluate_after_fit = super().process_config(config) + def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool, bool]: + local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics = ( + super().process_config(config) + ) if local_steps is not None: raise ValueError( "Training by steps is not applicable for FLASH clients.\ Please define 'local_epochs' in your config instead" ) - return local_epochs, local_steps, current_server_round, evaluate_after_fit + return local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics def train_by_epochs( self, epochs: int, current_round: Optional[int] = None @@ -69,12 +72,12 @@ def train_by_epochs( self._log_header_str(current_round, local_epoch) report_data.update({"fit_epoch": local_epoch}) for input, target in self.train_loader: - if self.is_empty_batch(input): + if check_if_batch_is_empty_and_verify_input(input): log(INFO, "Empty batch generated by data loader. Skipping step.") continue - input = self._move_data_to_device(input) - target = self._move_data_to_device(target) + input = move_data_to_device(input, self.device) + target = move_data_to_device(target, self.device) losses, preds = self.train_step(input, target) self.train_loss_meter.update(losses) self.train_metric_manager.update(preds, target) diff --git a/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py b/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py index 6b64808f1..4f3482973 100644 --- a/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py +++ b/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py @@ -11,6 +11,7 @@ from fl4health.clients.ditto_client import DittoClient from fl4health.losses.mkmmd_loss import MkMmdLoss from fl4health.model_bases.feature_extractor_buffer import FeatureExtractorBuffer +from fl4health.utils.client import clone_and_freeze_model from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType @@ -113,7 +114,7 @@ def update_before_train(self, current_server_round: int) -> None: assert isinstance(self.global_model, nn.Module) # Clone and freeze the initial weights GLOBAL MODEL. These are used to form the Ditto local # update penalty term. - self.initial_global_model = self.clone_and_freeze_model(self.global_model) + self.initial_global_model = clone_and_freeze_model(self.global_model) self.initial_global_feature_extractor = FeatureExtractorBuffer( model=self.initial_global_model, flatten_feature_extraction_layers=self.flatten_feature_extraction_layers, diff --git a/fl4health/clients/model_merge_client.py b/fl4health/clients/model_merge_client.py index d953c9a7e..f6778ff99 100644 --- a/fl4health/clients/model_merge_client.py +++ b/fl4health/clients/model_merge_client.py @@ -13,6 +13,7 @@ from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.reporting.base_reporter import BaseReporter from fl4health.reporting.reports_manager import ReportsManager +from fl4health.utils.client import move_data_to_device from fl4health.utils.metrics import Metric, MetricManager from fl4health.utils.random import generate_hash from fl4health.utils.typing import TorchInputType, TorchTargetType @@ -192,8 +193,8 @@ def validate(self) -> Dict[str, Scalar]: self.test_metric_manager.clear() with torch.no_grad(): for input, target in self.test_loader: - input = self._move_data_to_device(input) - target = self._move_data_to_device(target) + input = move_data_to_device(input, self.device) + target = move_data_to_device(target, self.device) preds = {"predictions": self.model(input)} self.test_metric_manager.update(preds, target) diff --git a/fl4health/clients/moon_client.py b/fl4health/clients/moon_client.py index c9501ccd0..38a4f7e04 100644 --- a/fl4health/clients/moon_client.py +++ b/fl4health/clients/moon_client.py @@ -9,6 +9,7 @@ from fl4health.clients.basic_client import BasicClient, Config from fl4health.losses.contrastive_loss import MoonContrastiveLoss from fl4health.model_bases.sequential_split_models import SequentiallySplitModel +from fl4health.utils.client import clone_and_freeze_model from fl4health.utils.losses import EvaluationLosses, LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType @@ -112,7 +113,7 @@ def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], conf """ assert isinstance(self.model, SequentiallySplitModel) # Save the parameters of the old LOCAL model - old_model = self.clone_and_freeze_model(self.model) + old_model = clone_and_freeze_model(self.model) # Current model is appended to the back of the list self.old_models_list.append(old_model) # If the list is longer than desired, the element at the front of the list is removed. @@ -131,7 +132,7 @@ def update_before_train(self, current_server_round: int) -> None: current_server_round (int): Current federated training round being executed. """ # Save the parameters of the global model - self.global_model = self.clone_and_freeze_model(self.model) + self.global_model = clone_and_freeze_model(self.model) super().update_before_train(current_server_round) diff --git a/fl4health/clients/mr_mtl_client.py b/fl4health/clients/mr_mtl_client.py index 80f4e231c..6ffa6fe28 100644 --- a/fl4health/clients/mr_mtl_client.py +++ b/fl4health/clients/mr_mtl_client.py @@ -149,7 +149,7 @@ def compute_training_loss( # Use the rest of the training loss computation from the AdaptiveDriftConstraintClient parent return super().compute_training_loss(preds, features, target) - def validate(self) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: """ Validate the current model on the entire validation dataset. @@ -158,4 +158,4 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]: """ # ensure that the initial global model is in eval mode assert not self.initial_global_model.training - return super().validate() + return super().validate(include_losses_in_metrics=include_losses_in_metrics) diff --git a/fl4health/clients/nnunet_client.py b/fl4health/clients/nnunet_client.py index 853b2805e..f393e8174 100644 --- a/fl4health/clients/nnunet_client.py +++ b/fl4health/clients/nnunet_client.py @@ -21,9 +21,10 @@ from torch.utils.data import DataLoader from fl4health.checkpointing.client_module import ClientCheckpointModule -from fl4health.clients.basic_client import BasicClient, LoggingMode +from fl4health.clients.basic_client import BasicClient from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type +from fl4health.utils.logging import LoggingMode from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric, MetricManager from fl4health.utils.nnunet_utils import ( @@ -275,7 +276,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler: ) # Determine total number of steps throughout all FL rounds - local_epochs, local_steps, _, _ = self.process_config(config) + local_epochs, local_steps, _, _, _ = self.process_config(config) if local_steps is not None: steps_per_round = local_steps elif local_epochs is not None: diff --git a/fl4health/clients/perfcl_client.py b/fl4health/clients/perfcl_client.py index 4d2fe7017..fa08abe47 100644 --- a/fl4health/clients/perfcl_client.py +++ b/fl4health/clients/perfcl_client.py @@ -10,6 +10,7 @@ from fl4health.model_bases.perfcl_base import PerFclModel from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger +from fl4health.utils.client import clone_and_freeze_model from fl4health.utils.losses import EvaluationLosses, LossMeterType from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType @@ -160,9 +161,9 @@ def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], conf """ assert isinstance(self.model, PerFclModel) # First module is the local feature extractor for PerFcl Models - self.old_local_module = self.clone_and_freeze_model(self.model.first_feature_extractor) + self.old_local_module = clone_and_freeze_model(self.model.first_feature_extractor) # Second module is the global feature extractor for PerFcl Models - self.old_global_module = self.clone_and_freeze_model(self.model.second_feature_extractor) + self.old_global_module = clone_and_freeze_model(self.model.second_feature_extractor) super().update_after_train(local_steps, loss_dict, config) @@ -178,7 +179,7 @@ def update_before_train(self, current_server_round: int) -> None: """ # Save the parameters of the aggregated global model assert isinstance(self.model, PerFclModel) - self.initial_global_module = self.clone_and_freeze_model(self.model.second_feature_extractor) + self.initial_global_module = clone_and_freeze_model(self.model.second_feature_extractor) super().update_before_train(current_server_round) diff --git a/fl4health/server/base_server.py b/fl4health/server/base_server.py index afcfac1b6..a9574e780 100644 --- a/fl4health/server/base_server.py +++ b/fl4health/server/base_server.py @@ -21,7 +21,7 @@ from fl4health.server.polling import poll_clients from fl4health.strategies.strategy_with_poll import StrategyWithPolling from fl4health.utils.config import narrow_dict_type_and_set_attribute -from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, TestMetricPrefix +from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, MetricPrefix from fl4health.utils.parameter_extraction import get_all_model_parameters from fl4health.utils.random import generate_hash @@ -220,11 +220,9 @@ def _unpack_metrics( for client_proxy, eval_res in results: val_metrics = { - k: v for k, v in eval_res.metrics.items() if not k.startswith(TestMetricPrefix.TEST_PREFIX.value) - } - test_metrics = { - k: v for k, v in eval_res.metrics.items() if k.startswith(TestMetricPrefix.TEST_PREFIX.value) + k: v for k, v in eval_res.metrics.items() if not k.startswith(MetricPrefix.TEST_PREFIX.value) } + test_metrics = {k: v for k, v in eval_res.metrics.items() if k.startswith(MetricPrefix.TEST_PREFIX.value)} if len(test_metrics) > 0: assert TEST_LOSS_KEY in test_metrics and TEST_NUM_EXAMPLES_KEY in test_metrics, ( @@ -268,9 +266,7 @@ def _handle_result_aggregation( for key, value in test_metrics_aggregated.items(): val_metrics_aggregated[key] = value if test_loss_aggregated is not None: - val_metrics_aggregated[f"{TestMetricPrefix.TEST_PREFIX.value} loss - aggregated"] = ( - test_loss_aggregated - ) + val_metrics_aggregated[f"{MetricPrefix.TEST_PREFIX.value} loss - aggregated"] = test_loss_aggregated return val_loss_aggregated, val_metrics_aggregated diff --git a/fl4health/strategies/fedavg_with_adaptive_constraint.py b/fl4health/strategies/fedavg_with_adaptive_constraint.py index 9b1e59030..735fcdd62 100644 --- a/fl4health/strategies/fedavg_with_adaptive_constraint.py +++ b/fl4health/strategies/fedavg_with_adaptive_constraint.py @@ -55,9 +55,6 @@ def __init__( Implementation based on https://arxiv.org/abs/1602.05629. Args: - initial_parameters (Parameters): Initial global model parameters. - init_loss_weight (float): Initial loss weight (mu in FedProx). If adaptivity is false, then this is the - constant weight used for all clients. fraction_fit (float, optional): Fraction of clients used during training. Defaults to 1.0. fraction_evaluate (float, optional): Fraction of clients used during validation. Defaults to 1.0. min_fit_clients (int, optional): _description_. Defaults to 2. @@ -74,10 +71,13 @@ def __init__( Function used to configure server-side central validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. + initial_parameters (Parameters): Initial global model parameters. fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. Defaults to None. evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. Defaults to None. + initial_loss_weight (float): Initial loss weight (mu in FedProx). If adaptivity is false, then this is the + constant weight used for all clients. adapt_loss_weight (bool, optional): Determines whether the value of mu is adaptively modified by the server based on aggregated train loss. Defaults to False. loss_weight_delta (float, optional): This is the amount by which the server changes the value of mu diff --git a/fl4health/strategies/feddg_ga_strategy.py b/fl4health/strategies/feddg_ga.py similarity index 77% rename from fl4health/strategies/feddg_ga_strategy.py rename to fl4health/strategies/feddg_ga.py index 3d47b34be..97e90f16a 100644 --- a/fl4health/strategies/feddg_ga_strategy.py +++ b/fl4health/strategies/feddg_ga.py @@ -3,7 +3,14 @@ from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np -from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common import ( + EvaluateIns, + MetricsAggregationFn, + NDArrays, + Parameters, + ndarrays_to_parameters, + parameters_to_ndarrays, +) from flwr.common.logger import log from flwr.common.typing import EvaluateRes, FitIns, FitRes, Scalar from flwr.server.client_manager import ClientManager @@ -23,7 +30,7 @@ class FairnessMetricType(Enum): """Defines the basic types for fairness metrics, their default names and their default signals""" ACCURACY = "val - prediction - accuracy" - LOSS = "val - loss" + LOSS = "val - checkpoint" CUSTOM = "custom" @classmethod @@ -84,12 +91,10 @@ def __init__( self.signal = FairnessMetricType.signal_for_type(metric_type) -class FedDgGaStrategy(FedAvg): +class FedDgGa(FedAvg): def __init__( self, *, - fraction_fit: float = 1.0, - fraction_evaluate: float = 1.0, min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, @@ -106,20 +111,16 @@ def __init__( fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, fairness_metric: Optional[FairnessMetric] = None, - weight_step_size: float = 0.2, + adjustment_weight_step_size: float = 0.2, ): - """Strategy for the FedDG-GA algorithm (Federated Domain Generalization with - Generalization Adjustment, Zhang et al. 2023). + """ + Strategy for the FedDG-GA algorithm (Federated Domain Generalization with Generalization Adjustment, Zhang et + al. 2023). This strategy assumes (and checks) that the configuration sent by the server to the clients has the + key "evaluate_after_fit" and it is set to True. It also ensures that the key "pack_losses_with_val_metrics" is + present and its value is set to True. These are to facilitate the exchange of evaluation information needed + for the strategy to work correctly. Args: - fraction_fit : float, optional - Fraction of clients used during training. In case `min_fit_clients` - is larger than `fraction_fit * available_clients`, `min_fit_clients` - will still be sampled. Defaults to 1.0. - fraction_evaluate : float, optional - Fraction of clients used during validation. In case `min_evaluate_clients` - is larger than `fraction_evaluate * available_clients`, `min_evaluate_clients` - will still be sampled. Defaults to 1.0. min_fit_clients : int, optional Minimum number of clients used during training. Defaults to 2. min_evaluate_clients : int, optional @@ -133,9 +134,9 @@ def __init__( ] Optional function used for validation. Defaults to None. on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure training. Defaults to None. + Function used to configure training. Must be specified for this strategy. Defaults to None. on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional - Function used to configure validation. Defaults to None. + Function used to configure validation. Must be specified for this strategy. Defaults to None. accept_failures : bool, optional Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Parameters, optional @@ -149,19 +150,18 @@ def __init__( determine their adjustment weight for aggregation. Can be set to any default metric in FairnessMetricType or set to use a custom metric. Optional, default is FairnessMetric(FairnessMetricType.LOSS). - weight_step_size : float - The step size to determine the magnitude of change for the adjustment weight. It has to be - 0 < weight_step_size < 1. Optional, default is 0.2. + adjustment_weight_step_size : float + The step size to determine the magnitude of change for the generalization adjustment weights. It has + to be 0 < adjustment_weight_step_size < 1. Optional, default is 0.2. """ - if fraction_fit != 1.0 or fraction_evaluate != 1.0: - log( - WARNING, - "fraction_fit or fraction_evaluate are not 1.0. The behaviour of FedDG-GA is unknown in those cases.", - ) + + # NOTE: For FedDG-GA, we require that fraction_fit and fraction_evaluate are 1.0, as behavior of the FedDG-GA + # algorithm is not well-defined when participation in each round of training and evaluation is partial. Thus, + # we force these values to be 1.0 in super and do not allow them to be set by the user. super().__init__( - fraction_fit=fraction_fit, - fraction_evaluate=fraction_evaluate, + fraction_fit=1.0, + fraction_evaluate=1.0, min_fit_clients=min_fit_clients, min_evaluate_clients=min_evaluate_clients, min_available_clients=min_available_clients, @@ -179,8 +179,10 @@ def __init__( else: self.fairness_metric = fairness_metric - self.weight_step_size = weight_step_size - assert 0 < self.weight_step_size < 1, f"weight_step_size has to be between 0 and 1 ({self.weight_step_size})" + self.adjustment_weight_step_size = adjustment_weight_step_size + assert ( + 0 < self.adjustment_weight_step_size < 1 + ), f"adjustment_weight_step_size has to be between 0 and 1 ({self.adjustment_weight_step_size})" self.train_metrics: Dict[str, Dict[str, Scalar]] = {} self.evaluation_metrics: Dict[str, Dict[str, Scalar]] = {} @@ -220,18 +222,44 @@ def configure_fit( self.initial_adjustment_weight = 1.0 / len(client_fit_ins) - # Setting self.num_rounds + # Setting self.num_rounds once and doing some sanity checks + assert self.on_fit_config_fn is not None, "on_fit_config_fn must be specified" + config = self.on_fit_config_fn(server_round) + assert "evaluate_after_fit" in config, "evaluate_after_fit must be present in config" + assert config["evaluate_after_fit"] is True, "evaluate_after_fit must be set to True" + + assert "pack_losses_with_val_metrics" in config, "pack_losses_with_val_metrics must be present in config" + assert config["pack_losses_with_val_metrics"] is True, "pack_losses_with_val_metrics must be set to True" + + assert "n_server_rounds" in config, "n_server_rounds must be specified" + assert isinstance(config["n_server_rounds"], int), "n_server_rounds is not an integer" + n_server_rounds = config["n_server_rounds"] + if self.num_rounds is None: - assert self.on_fit_config_fn is not None, "on_fit_config_fn must be specified" - config = self.on_fit_config_fn(server_round) - assert "evaluate_after_fit" in config, "evaluate_after_fit must be present in config and set to True" - assert config["evaluate_after_fit"] is True, "evaluate_after_fit must be set to True" - assert "n_server_rounds" in config, "n_server_rounds must be specified" - assert isinstance(config["n_server_rounds"], int), "n_server_rounds is not an integer" - self.num_rounds = config["n_server_rounds"] + self.num_rounds = n_server_rounds + else: + assert ( + n_server_rounds == self.num_rounds + ), f"n_server_rounds has changed from the original value of {self.num_rounds} and is now {n_server_rounds}" return client_fit_ins + def configure_evaluate( + self, server_round: int, parameters: Parameters, client_manager: ClientManager + ) -> List[Tuple[ClientProxy, EvaluateIns]]: + assert isinstance( + client_manager, FixedSamplingClientManager + ), f"Client manager is not of type FixedSamplingClientManager: {type(client_manager)}" + + client_evaluate_ins = super().configure_evaluate(server_round, parameters, client_manager) + + assert self.on_evaluate_config_fn is not None, "on_fit_config_fn must be specified" + config = self.on_evaluate_config_fn(server_round) + assert "pack_losses_with_val_metrics" in config, "pack_losses_with_val_metrics must be present in config" + assert config["pack_losses_with_val_metrics"] is True, "pack_losses_with_val_metrics must be set to True" + + return client_evaluate_ins + def aggregate_fit( self, server_round: int, @@ -252,9 +280,19 @@ def aggregate_fit( (Tuple[Optional[Parameters], Dict[str, Scalar]]) A tuple containing the aggregated parameters and the aggregated fit metrics. """ - # The original aggregated parameters is done by the super class (which we want to - # override its behaviour here), so we are discarding it to recalculate them in the lines below - _, metrics_aggregated = super().aggregate_fit(server_round, results, failures) + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") self.train_metrics = {} for client_proxy, fit_res in results: @@ -290,10 +328,9 @@ def aggregate_evaluate( self.evaluation_metrics = {} for client_proxy, eval_res in results: cid = client_proxy.cid + # make sure that the metrics has the desired loss key + assert FairnessMetricType.LOSS.value in eval_res.metrics self.evaluation_metrics[cid] = eval_res.metrics - # adding the loss to the metrics - val_loss_key = FairnessMetricType.LOSS.value - self.evaluation_metrics[cid][val_loss_key] = eval_res.loss # Updating the weights at the end of the training round cids = [client_proxy.cid for client_proxy, _ in results] @@ -412,8 +449,8 @@ def get_current_weight_step_size(self, server_round: int) -> float: # The implementation of d^r here differs from the definition in the paper # because our server round starts at 1 instead of 0. assert self.num_rounds is not None - weight_step_size_decay = self.weight_step_size / self.num_rounds - weight_step_size_for_round = self.weight_step_size - ((server_round - 1) * weight_step_size_decay) + weight_step_size_decay = self.adjustment_weight_step_size / self.num_rounds + weight_step_size_for_round = self.adjustment_weight_step_size - ((server_round - 1) * weight_step_size_decay) # Omitting an additional scaler here that is present in the reference # implementation but not in the paper: diff --git a/fl4health/strategies/feddg_ga_with_adaptive_constraint.py b/fl4health/strategies/feddg_ga_with_adaptive_constraint.py new file mode 100644 index 000000000..8ab069492 --- /dev/null +++ b/fl4health/strategies/feddg_ga_with_adaptive_constraint.py @@ -0,0 +1,237 @@ +from logging import INFO, WARNING +from typing import Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common.logger import log +from flwr.common.typing import FitRes, Scalar +from flwr.server.client_proxy import ClientProxy + +from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint +from fl4health.strategies.aggregate_utils import aggregate_losses +from fl4health.strategies.feddg_ga import FairnessMetric, FedDgGa + + +class FedDgGaAdaptiveConstraint(FedDgGa): + def __init__( + self, + *, + min_fit_clients: int = 2, + min_evaluate_clients: int = 2, + min_available_clients: int = 2, + evaluate_fn: Optional[ + Callable[ + [int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]], + ] + ] = None, + on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + accept_failures: bool = True, + initial_parameters: Parameters, + fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + initial_loss_weight: float = 1.0, + adapt_loss_weight: bool = False, + loss_weight_delta: float = 0.1, + loss_weight_patience: int = 5, + weighted_train_losses: bool = False, + fairness_metric: Optional[FairnessMetric] = None, + adjustment_weight_step_size: float = 0.2, + ): + """ + Strategy for the FedDG-GA algorithm (Federated Domain Generalization with Generalization Adjustment, + Zhang et al. 2023) combined with the Adaptive Strategy for Auxiliary constraints like FedProx. See + documentation on FedAvgWithAdaptiveConstraint for more information. + + NOTE: Initial parameters are NOT optional. They must be passed for this strategy. + + Args: + min_fit_clients (int, optional): Minimum number of clients used during training. Defaults to 2. + min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. + min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 2. + evaluate_fn : + Optional[ + Callable[[int, NDArrays, Dict[str, Scalar]], + Optional[Tuple[float, Dict[str, Scalar]]]] + ] + Optional function used for validation. Defaults to None. + on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): Function used to configure + training. Defaults to None. + on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): Function used to configure + validation. Defaults to None + initial_parameters (Parameters): Initial global model parameters. + accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. + fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): + Metrics aggregation function, Defaults to None. + evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): + Metrics aggregation function. Defaults to None. + initial_loss_weight (float, optional): Initial penalty loss weight (mu in FedProx). If adaptivity is false, + then this is the constant weight used for all clients. Defaults to 1.0. + adapt_loss_weight (bool, optional): Determines whether the value of the penalty loss weight is adaptively + modified by the server based on aggregated train loss. Defaults to False. + loss_weight_delta (float, optional): This is the amount by which the server changes the value of the + penalty loss weight based on the modification criteria. Only applicable if adaptivity is on. + Defaults to 0.1. + loss_weight_patience (int, optional): This is the number of rounds a server must see decreasing + aggregated train loss before reducing the value of the penalty loss weight. Only applicable if + adaptivity is on. Defaults to 5. + weighted_train_losses (bool, optional): Determines whether the training losses from the clients should be + aggregated using a weighted or unweighted average. These aggregated losses are used to adjust the + proximal weight in the adaptive setting. Defaults to False. + fairness_metric (Optional[FairnessMetric], optional): he metric to evaluate the local model of each + client against the global model in order to determine their adjustment weight for aggregation. + Can be set to any default metric in FairnessMetricType or set to use a custom metric. + Optional, default is FairnessMetric(FairnessMetricType.LOSS) when specified as None. + adjustment_weight_step_size (float, optional): The step size to determine the magnitude of change for + the generalization adjustment weight. It has to be 0 < adjustment_weight_step_size < 1. + Optional, default is 0.2. + """ + + self.loss_weight = initial_loss_weight + self.adapt_loss_weight = adapt_loss_weight + + if self.adapt_loss_weight: + self.loss_weight_delta = loss_weight_delta + self.loss_weight_patience = loss_weight_patience + self.loss_weight_patience_counter: int = 0 + + self.previous_loss = float("inf") + + self.server_model_weights = parameters_to_ndarrays(initial_parameters) + initial_parameters.tensors.extend(ndarrays_to_parameters([np.array(initial_loss_weight)]).tensors) + + super().__init__( + min_fit_clients=min_fit_clients, + min_evaluate_clients=min_evaluate_clients, + min_available_clients=min_available_clients, + evaluate_fn=evaluate_fn, + on_fit_config_fn=on_fit_config_fn, + on_evaluate_config_fn=on_evaluate_config_fn, + accept_failures=accept_failures, + initial_parameters=initial_parameters, + fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, + evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, + fairness_metric=fairness_metric, + adjustment_weight_step_size=adjustment_weight_step_size, + ) + + self.parameter_packer = ParameterPackerAdaptiveConstraint() + self.weighted_train_losses = weighted_train_losses + + def aggregate_fit( + self, + server_round: int, + results: List[Tuple[ClientProxy, FitRes]], + failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], + ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + """ + Aggregate fit results by weighing them against the adjustment weights and then summing them. + + Collects the fit metrics that will be used to change the adjustment weights for the next round. + + If applicable, determine whether the constraint weight should be updated based on the aggregated loss + seen on the clients. + + Args: + server_round: (int) the current server round. + results: (List[Tuple[ClientProxy, FitRes]]) The clients' fit results. + failures: (List[Union[Tuple[ClientProxy, FitRes], BaseException]]) the clients' fit failures. + + Returns: + (Tuple[Optional[Parameters], Dict[str, Scalar]]) A tuple containing the aggregated parameters + and the aggregated fit metrics. For adaptive constraints, the server also packs a constraint weight + to be sent to the clients. This is sent even if adaptive constraint weights are turned off and + the value simply remains constant. + """ + if not results: + return None, {} + # Do not aggregate if there are failures and failures are not accepted + if not self.accept_failures and failures: + return None, {} + + # Convert results with packed params of model weights and training loss. The results list is modified in-place + # to only contain model parameters for use in the Fed-DGGA calculations and aggregation + train_losses_and_counts = self._unpack_weights_and_losses(results) + + # Aggregate train loss + train_losses_aggregated = aggregate_losses(train_losses_and_counts, self.weighted_train_losses) + self._maybe_update_constraint_weight_param(train_losses_aggregated) + + # Aggregate custom metrics if aggregation fn was provided + metrics_aggregated = {} + if self.fit_metrics_aggregation_fn: + fit_metrics = [(res.num_examples, res.metrics) for _, res in results] + metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) + elif server_round == 1: # Only log this warning once + log(WARNING, "No fit_metrics_aggregation_fn provided") + + self.train_metrics = {} + for client_proxy, fit_res in results: + self.train_metrics[client_proxy.cid] = fit_res.metrics + + weights_aggregated = self.weight_and_aggregate_results(results) + + parameters = self.parameter_packer.pack_parameters(weights_aggregated, self.loss_weight) + return ndarrays_to_parameters(parameters), metrics_aggregated + + def _unpack_weights_and_losses(self, results: List[Tuple[ClientProxy, FitRes]]) -> List[Tuple[int, float]]: + """ + This function takes results returned from a fit round from each of the participating clients and unpacks the + information into the appropriate objects. The parameters contained in the FitRes object are unpacked to + separate the model weights from the training losses. The model weights are reinserted into the parameters + of the FitRes objects and the losses (along with sample counts) are placed in a list and returned + + NOTE: The results that are passed to this function are MODIFIED IN-PLACE + + Args: + results (List[Tuple[ClientProxy, FitRes]]): The results produced in a fitting round by each of the clients + these the FitRes object contains both model weights and training losses which need to be processed. + + Returns: + List[Tuple[int, float]]: A list of the training losses produced by client training + """ + train_losses_and_counts: List[Tuple[int, float]] = [] + for _, fit_res in results: + sample_count = fit_res.num_examples + updated_weights, train_loss = self.parameter_packer.unpack_parameters( + parameters_to_ndarrays(fit_res.parameters) + ) + # Modify the parameters in-place to just be the model weights. + fit_res.parameters = ndarrays_to_parameters(updated_weights) + train_losses_and_counts.append((sample_count, train_loss)) + + return train_losses_and_counts + + def _maybe_update_constraint_weight_param(self, loss: float) -> None: + """ + Update constraint weight parameter if adaptive_loss_weight is set to True. Regardless of whether adaptivity + is turned on at this time, the previous loss seen by the server is updated. + + Args: + loss (float): This is the loss to which we compare the previous loss seen by the server. For Adaptive + Constraint clients this should be the aggregated training loss seen by each client participating in + training. + NOTE: For adaptive constraint losses, including FedProx, this loss is exchanged (along with the weights) + by each client and is the VANILLA loss that does not include the additional penalty losses. + """ + + if self.adapt_loss_weight: + if loss <= self.previous_loss: + self.loss_weight_patience_counter += 1 + if self.loss_weight_patience_counter == self.loss_weight_patience: + self.loss_weight -= self.loss_weight_delta + self.loss_weight = max(0.0, self.loss_weight) + self.loss_weight_patience_counter = 0 + log(INFO, f"Aggregate training loss has dropped {self.loss_weight_patience} rounds in a row") + log(INFO, f"Constraint weight is decreased to {self.loss_weight}") + else: + self.loss_weight += self.loss_weight_delta + self.loss_weight_patience_counter = 0 + log( + INFO, + f"Aggregate training loss increased this round: Current loss {loss}, " + f"Previous loss: {self.previous_loss}", + ) + log(INFO, f"Constraint weight is increased by {self.loss_weight_delta} to {self.loss_weight}") + self.previous_loss = loss diff --git a/fl4health/utils/client.py b/fl4health/utils/client.py new file mode 100644 index 000000000..3272415ff --- /dev/null +++ b/fl4health/utils/client.py @@ -0,0 +1,157 @@ +import copy +import os +from inspect import currentframe, getframeinfo +from logging import INFO, LogRecord +from typing import Any, Dict, Iterable, TypeVar + +import torch +import torch.nn as nn +from flwr.common.logger import LOGGER_NAME, console_handler, log +from flwr.common.typing import Config, Scalar +from tqdm import tqdm + +from fl4health.utils.config import narrow_dict_type +from fl4health.utils.logging import LoggingMode +from fl4health.utils.metrics import MetricPrefix +from fl4health.utils.typing import TorchInputType, TorchTargetType + +T = TypeVar("T", TorchInputType, TorchTargetType) + + +def fold_loss_dict_into_metrics( + metrics: Dict[str, Scalar], loss_dict: Dict[str, float], logging_mode: LoggingMode +) -> None: + # Prefixing the loss value keys with the mode from which they are generated + if logging_mode is LoggingMode.VALIDATION: + metrics.update({f"{MetricPrefix.VAL_PREFIX.value} {key}": loss_val for key, loss_val in loss_dict.items()}) + else: + metrics.update({f"{MetricPrefix.TEST_PREFIX.value} {key}": loss_val for key, loss_val in loss_dict.items()}) + + +def set_pack_losses_with_val_metrics(config: Config) -> bool: + try: + pack_losses_with_val_metrics = narrow_dict_type(config, "pack_losses_with_val_metrics", bool) + except ValueError: + pack_losses_with_val_metrics = False + if pack_losses_with_val_metrics: + log(INFO, "As specified in the config, all validation losses will be packed into validation metrics") + return pack_losses_with_val_metrics + + +def move_data_to_device(data: T, device: torch.device) -> T: + """ + _summary_ + + Args: + data (T): The data to move to self.device. Can be a TorchInputType or a TorchTargetType + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + + Raises: + TypeError: Raised if data is not one of the types specified by TorchInputType or TorchTargetType + + Returns: + T: The data argument except now it's been moved to self.device + """ + # Currently we expect both inputs and targets to be either tensors + # or dictionaries of tensors + if isinstance(data, torch.Tensor): + return data.to(device) + elif isinstance(data, dict): + return {key: value.to(device) for key, value in data.items()} + else: + raise TypeError( + "data must be of type torch.Tensor or Dict[str, torch.Tensor]. If definition of TorchInputType or " + "TorchTargetType has changed this method might need to be updated or split into two." + ) + + +def check_if_batch_is_empty_and_verify_input(input: TorchInputType) -> bool: + """ + This function checks whether the provided batch (input) is empty. If the input is a dictionary of inputs, it + first verifies that the length of all inputs is the same, then checks if they are non-empty. + NOTE: This function assumes the input is BATCH FIRST + + Args: + input (TorchInputType): Input batch. input can be of type torch.Tensor or Dict[str, torch.Tensor], and in the + latter case, the batch is considered to be empty if all tensors in the dictionary have length zero. + + Raises: + TypeError: Raised if input is not of type torch.Tensor or Dict[str, torch.Tensor]. + ValueError: Raised if input has type Dict[str, torch.Tensor] and not all tensors within the dictionary have + the same size. + + Returns: + bool: True if input is an empty batch. + """ + if isinstance(input, torch.Tensor): + return len(input) == 0 + elif isinstance(input, dict): + input_iter = iter(input.items()) + _, first_val = next(input_iter) + first_val_len = len(first_val) + if not all(len(val) == first_val_len for _, val in input_iter): + raise ValueError("Not all tensors in the dictionary have the same size.") + else: + return first_val_len == 0 + else: + raise TypeError("Input must be of type torch.Tensor or Dict[str, torch.Tensor].") + + +def clone_and_freeze_model(model: nn.Module) -> nn.Module: + """ + Creates a clone of the model with frozen weights to be used in loss calculations so the original model is + preserved in its current state. + + Args: + model (nn.Module): Model to clone and freeze + Returns: + nn.Module: Cloned and frozen model + """ + + cloned_model = copy.deepcopy(model) + for param in cloned_model.parameters(): + param.requires_grad = False + cloned_model.eval() + + return cloned_model + + +def maybe_progress_bar(iterable: Iterable, display_progress_bar: bool) -> Iterable: + """ + Used to print progress bars during client training and validation. If + self.progress_bar is false, just returns the original input iterable + without modifying it. + Args: + iterable (Iterable): The iterable to wrap + Returns: + Iterable: an iterator which acts exactly like the original + iterable, but prints a dynamically updating progress bar every + time a value is requested. Or the original iterable if + self.progress_bar is False + """ + if not display_progress_bar: + return iterable + else: + # We can use the flwr console handler to format progress bar + frame = currentframe() + lineno = 0 if frame is None else getframeinfo(frame).lineno + record = LogRecord( + name=LOGGER_NAME, + pathname=os.path.abspath(os.getcwd()), + lineno=lineno, # + args={}, + exc_info=None, + level=INFO, + msg="{l_bar}{bar}{r_bar}", + ) + format = console_handler.format(record) + # Create a clean looking tqdm instance that matches the flwr logging + kwargs: Any = { + "leave": True, + "ascii": " >=", + "unit": "steps", + "dynamic_ncols": True, + "bar_format": format, + } + return tqdm(iterable, **kwargs) diff --git a/fl4health/utils/logging.py b/fl4health/utils/logging.py new file mode 100644 index 000000000..e2eab7d81 --- /dev/null +++ b/fl4health/utils/logging.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class LoggingMode(Enum): + TRAIN = "Training" + VALIDATION = "Validation" + TEST = "Testing" diff --git a/fl4health/utils/metrics.py b/fl4health/utils/metrics.py index 7a33773ce..74af911b6 100644 --- a/fl4health/utils/metrics.py +++ b/fl4health/utils/metrics.py @@ -12,13 +12,13 @@ from fl4health.utils.typing import TorchPredType, TorchTargetType, TorchTransformFunction -class TestMetricPrefix(Enum): - __test__ = False +class MetricPrefix(Enum): TEST_PREFIX = "test -" + VAL_PREFIX = "val -" -TEST_NUM_EXAMPLES_KEY = f"{TestMetricPrefix.TEST_PREFIX.value} num_examples" -TEST_LOSS_KEY = f"{TestMetricPrefix.TEST_PREFIX.value} loss" +TEST_NUM_EXAMPLES_KEY = f"{MetricPrefix.TEST_PREFIX.value} num_examples" +TEST_LOSS_KEY = f"{MetricPrefix.TEST_PREFIX.value} checkpoint" class Metric(ABC): @@ -336,7 +336,7 @@ def __init__( average: Optional[str] = "weighted", ): """ - Computes the F1 score using the sklearn f1_score function. As such, the values of average are correspond to + Computes the F1 score using the sklearn f1_score function. As such, the values of average correspond to those of that function. Args: diff --git a/tests/clients/test_basic_client.py b/tests/clients/test_basic_client.py index 178bfe248..a19253885 100644 --- a/tests/clients/test_basic_client.py +++ b/tests/clients/test_basic_client.py @@ -9,9 +9,11 @@ from flwr.common import Scalar from freezegun import freeze_time -from fl4health.clients.basic_client import BasicClient, LoggingMode +from fl4health.clients.basic_client import BasicClient from fl4health.reporting import JsonReporter from fl4health.reporting.base_reporter import BaseReporter +from fl4health.utils.client import fold_loss_dict_into_metrics +from fl4health.utils.logging import LoggingMode from tests.test_utils.assert_metrics_dict import assert_metrics_dict freezegun.configure(extend_ignore_list=["transformers"]) # type: ignore @@ -79,17 +81,22 @@ def test_metrics_reporter_evaluate() -> None: test_metrics_final = { "test_metric": 1234, "testing_metric": 1234, - "test - loss": 123.123, + "val - checkpoint": 123.123, + "test - checkpoint": 123.123, "test - num_examples": 0, } reporter = JsonReporter() fl_client = MockBasicClient( loss=test_loss, + loss_dict={"checkpoint": test_loss}, metrics=test_metrics, test_set_metrics=test_metrics_testing, reporters=[reporter], ) - fl_client.evaluate([], {"current_server_round": test_current_server_round, "local_epochs": 0}) + fl_client.evaluate( + [], + {"current_server_round": test_current_server_round, "local_epochs": 0, "pack_losses_with_val_metrics": True}, + ) metric_dict = { "host_type": "client", @@ -103,6 +110,7 @@ def test_metrics_reporter_evaluate() -> None: }, }, } + errors = assert_metrics_dict(metric_dict, reporter.metrics) assert len(errors) == 0, f"Metrics check failed. Errors: {errors}" @@ -179,12 +187,11 @@ def __init__( self._validate_or_test.side_effect = self.mock_validate_or_test def mock_validate_or_test( # type: ignore - self, - loader, - loss_meter, - metric_manager, - logging_mode=LoggingMode.VALIDATION, + self, loader, loss_meter, metric_manager, logging_mode=LoggingMode.VALIDATION, include_losses_in_metrics=False ): + if include_losses_in_metrics: + assert self.mock_loss_dict is not None and self.mock_metrics is not None + fold_loss_dict_into_metrics(self.mock_metrics, self.mock_loss_dict, logging_mode) if logging_mode == LoggingMode.VALIDATION: return self.mock_loss, self.mock_metrics else: diff --git a/tests/server/test_base_server.py b/tests/server/test_base_server.py index 576dabc9d..0a06d0919 100644 --- a/tests/server/test_base_server.py +++ b/tests/server/test_base_server.py @@ -20,7 +20,7 @@ from fl4health.server.base_server import FlServer, FlServerWithCheckpointing from fl4health.strategies.basic_fedavg import BasicFedAvg from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn -from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, TestMetricPrefix +from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, MetricPrefix from tests.test_utils.assert_metrics_dict import assert_metrics_dict from tests.test_utils.custom_client_proxy import CustomClientProxy from tests.test_utils.models_for_test import LinearTransform @@ -165,7 +165,7 @@ def test_unpack_metrics() -> None: "val - prediction - accuracy": 0.9, TEST_LOSS_KEY: 0.8, TEST_NUM_EXAMPLES_KEY: 5, - f"{TestMetricPrefix.TEST_PREFIX.value} accuracy": 0.85, + f"{MetricPrefix.TEST_PREFIX.value} accuracy": 0.85, }, ) @@ -180,7 +180,7 @@ def test_unpack_metrics() -> None: # Check the test results assert len(test_results) == 1 - assert test_results[0][1].metrics[f"{TestMetricPrefix.TEST_PREFIX.value} accuracy"] == 0.85 + assert test_results[0][1].metrics[f"{MetricPrefix.TEST_PREFIX.value} accuracy"] == 0.85 assert test_results[0][1].loss == 0.8 @@ -198,7 +198,7 @@ def test_handle_result_aggregation() -> None: "val - prediction - accuracy": 0.9, TEST_LOSS_KEY: 0.8, TEST_NUM_EXAMPLES_KEY: 5, - f"{TestMetricPrefix.TEST_PREFIX.value} accuracy": 0.85, + f"{MetricPrefix.TEST_PREFIX.value} accuracy": 0.85, }, ) client_proxy2 = CustomClientProxy("2") @@ -210,7 +210,7 @@ def test_handle_result_aggregation() -> None: "val - prediction - accuracy": 0.8, TEST_LOSS_KEY: 1.6, TEST_NUM_EXAMPLES_KEY: 10, - f"{TestMetricPrefix.TEST_PREFIX.value} accuracy": 0.75, + f"{MetricPrefix.TEST_PREFIX.value} accuracy": 0.75, }, ) @@ -221,17 +221,17 @@ def test_handle_result_aggregation() -> None: failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] server_round = 1 - val_loss_aggregated, val_metrics_aggregated = fl_server._handle_result_aggregation(server_round, results, failures) + _, val_metrics_aggregated = fl_server._handle_result_aggregation(server_round, results, failures) # Check the aggregated validation metrics assert "val - prediction - accuracy" in val_metrics_aggregated assert val_metrics_aggregated["val - prediction - accuracy"] == pytest.approx(0.8333, rel=1e-3) # Check the aggregated test metrics - assert f"{TestMetricPrefix.TEST_PREFIX.value} accuracy" in val_metrics_aggregated - assert val_metrics_aggregated[f"{TestMetricPrefix.TEST_PREFIX.value} accuracy"] == pytest.approx(0.7833, rel=1e-3) - assert f"{TestMetricPrefix.TEST_PREFIX.value} loss - aggregated" in val_metrics_aggregated - assert val_metrics_aggregated[f"{TestMetricPrefix.TEST_PREFIX.value} loss - aggregated"] == pytest.approx( + assert f"{MetricPrefix.TEST_PREFIX.value} accuracy" in val_metrics_aggregated + assert val_metrics_aggregated[f"{MetricPrefix.TEST_PREFIX.value} accuracy"] == pytest.approx(0.7833, rel=1e-3) + assert f"{MetricPrefix.TEST_PREFIX.value} loss - aggregated" in val_metrics_aggregated + assert val_metrics_aggregated[f"{MetricPrefix.TEST_PREFIX.value} loss - aggregated"] == pytest.approx( 1.333, rel=1e-3 ) diff --git a/tests/smoke_tests/basic_client_metrics.json b/tests/smoke_tests/basic_client_metrics.json index 6fbceeae0..acd823c42 100644 --- a/tests/smoke_tests/basic_client_metrics.json +++ b/tests/smoke_tests/basic_client_metrics.json @@ -18,7 +18,7 @@ "eval_metrics": { "val - prediction - accuracy": 0.0942, "test - num_examples": 10000, - "test - loss": 2.30616, + "test - checkpoint": 2.30616, "test - prediction - accuracy": 0.0966 }, "eval_loss": 2.3042 @@ -33,7 +33,7 @@ "eval_metrics": { "val - prediction - accuracy": 0.0936, "test - num_examples": 10000, - "test - loss": 2.30109, + "test - checkpoint": 2.30109, "test - prediction - accuracy": 0.0972 }, "eval_loss": 2.2999 diff --git a/tests/smoke_tests/feddg_ga_client_metrics.json b/tests/smoke_tests/feddg_ga_client_metrics.json index bfd329d12..fee31d578 100644 --- a/tests/smoke_tests/feddg_ga_client_metrics.json +++ b/tests/smoke_tests/feddg_ga_client_metrics.json @@ -8,7 +8,7 @@ "val - personal - accuracy": 0.5269, "val - global - accuracy": 0.5331, "val - local - accuracy": 0.4984, - "val - loss": 1.5862 + "val - checkpoint": 1.5862 }, "fit_losses": { "backward": 1.4428, @@ -30,7 +30,7 @@ "val - personal - accuracy": {"target_value": 0.6584, "custom_tolerance": 0.005}, "val - global - accuracy": 0.6031, "val - local - accuracy": {"target_value": 0.614, "custom_tolerance": 0.005}, - "val - loss": {"target_value": 1.0008, "custom_tolerance": 0.005} + "val - checkpoint": {"target_value": 1.0008, "custom_tolerance": 0.005} }, "fit_losses": { "global": 0.6785, @@ -51,7 +51,7 @@ "val - personal - accuracy": {"target_value": 0.8218, "custom_tolerance": 0.005}, "val - global - accuracy": 0.8497, "val - local - accuracy": {"target_value": 0.7508, "custom_tolerance": 0.005}, - "val - loss": {"target_value": 0.6042, "custom_tolerance": 0.005} + "val - checkpoint": {"target_value": 0.6042, "custom_tolerance": 0.005} }, "fit_losses": { "global": 0.5084, diff --git a/tests/smoke_tests/feddg_ga_config.yaml b/tests/smoke_tests/feddg_ga_config.yaml index 50c7aa8c0..71fb1e2cb 100644 --- a/tests/smoke_tests/feddg_ga_config.yaml +++ b/tests/smoke_tests/feddg_ga_config.yaml @@ -5,4 +5,7 @@ n_server_rounds: 3 # The number of rounds to run FL n_clients: 2 # The number of clients in the FL experiment local_steps: 5 # The number of local steps (one per batch) to complete for client batch_size: 128 # The batch size for client training -evaluate_after_fit: True # Evaluates model immediately after local training on the validation set (in addition to the training set) +# Evaluates model immediately after local training on the validation set (in addition to the training set) +evaluate_after_fit: True +# Packs the measured validation losses with the metrics (required for fed-dgga) +pack_losses_with_val_metrics: True diff --git a/tests/smoke_tests/load_from_checkpoint_example/README.md b/tests/smoke_tests/load_from_checkpoint_example/README.md deleted file mode 100644 index bb2705159..000000000 --- a/tests/smoke_tests/load_from_checkpoint_example/README.md +++ /dev/null @@ -1,34 +0,0 @@ -# Basic Federated Learning Example -This example provides an very simple implementation of a federated learning training setup on the CIFAR dataset. The -FL server expects two clients to be spun up (i.e. it will wait until two clients report in before starting training). -Each client has the same "local" dataset. I.e. they each load the complete CIFAR dataset and therefore have the same -training and validation sets. The server has some custom metrics aggregation, but is otherwise a vanilla FL -implementation using FedAvg as the server side optimization. - -## Running the Example -In order to run the example, first ensure you have [installed the dependencies in your virtual environment according to the main README](/README.md#development-requirements) and it has been activated. - -## Starting Server - -The next step is to start the server by running -``` -python -m examples.basic_example.server --config_path /path/to/config.yaml -``` -from the FL4Health directory. The following arguments must be present in the specified config file: -* `n_clients`: number of clients the server waits for in order to run the FL training -* `local_epochs`: number of epochs each client will train for locally -* `batch_size`: size of the batches each client will train on -* `n_server_rounds`: The number of rounds to run FL - -## Starting Clients - -Once the server has started and logged "FL starting," the next step, in separate terminals, is to start the two -clients. This is done by simply running (remembering to activate your environment) -``` -python -m examples.basic_example.client --dataset_path /path/to/data -``` -**NOTE**: The argument `dataset_path` has two functions, depending on whether the dataset exists locally or not. If -the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be -automatically downloaded to the path specified and used in the run. - -After both clients have been started federated learning should commence. diff --git a/tests/strategies/test_feddg_ga_strategy.py b/tests/strategies/test_feddg_ga.py similarity index 70% rename from tests/strategies/test_feddg_ga_strategy.py rename to tests/strategies/test_feddg_ga.py index 3416caf60..a562a2707 100644 --- a/tests/strategies/test_feddg_ga_strategy.py +++ b/tests/strategies/test_feddg_ga.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple from unittest.mock import Mock import numpy as np @@ -9,11 +9,11 @@ from pytest import approx, raises from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager -from fl4health.strategies.feddg_ga_strategy import FairnessMetricType, FedDgGaStrategy +from fl4health.strategies.feddg_ga import FairnessMetricType, FedDgGa from tests.test_utils.custom_client_proxy import CustomClientProxy -def test_configure_fit_success() -> None: +def test_configure_fit_and_evaluate_success() -> None: fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) test_n_server_rounds = 3 @@ -21,9 +21,16 @@ def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: return { "n_server_rounds": test_n_server_rounds, "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn) + def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": test_n_server_rounds, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn) assert strategy.num_rounds is None try: @@ -41,7 +48,7 @@ def test_configure_fit_fail() -> None: simple_client_manager = _apply_mocks_to_client_manager(SimpleClientManager()) # Fails with no configure fit - strategy = FedDgGaStrategy() + strategy = FedDgGa() with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) @@ -50,9 +57,10 @@ def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), simple_client_manager) @@ -61,9 +69,10 @@ def on_fit_config_fn_1(server_round: int) -> Dict[str, Scalar]: return { "foo": 123, "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_1) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_1) assert strategy.num_rounds is None with raises(AssertionError): @@ -74,9 +83,10 @@ def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: return { "n_server_rounds": 1.1, "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_2) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_2) assert strategy.num_rounds is None with raises(AssertionError): @@ -86,9 +96,10 @@ def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: def on_fit_config_fn_3(server_round: int) -> Dict[str, Scalar]: return { "n_server_rounds": 2, + "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_3) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_3) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) @@ -97,9 +108,75 @@ def on_fit_config_fn_4(server_round: int) -> Dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": False, + "pack_losses_with_val_metrics": True, } - strategy = FedDgGaStrategy(on_fit_config_fn=on_fit_config_fn_4) + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_4) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being there + def on_fit_config_fn_5(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": True, + } + + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_5) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being True + def on_fit_config_fn_6(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": False, + } + + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_6) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + +def test_configure_evaluate_fail() -> None: + fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) + simple_client_manager = _apply_mocks_to_client_manager(SimpleClientManager()) + + # Fails with no evaluate fit + strategy = FedDgGa() + with raises(AssertionError): + strategy.configure_evaluate(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with bad client manager type + def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGa(on_evaluate_config_fn=on_evaluate_config_fn) + with raises(AssertionError): + strategy.configure_evaluate(1, Parameters([], ""), simple_client_manager) + + # Fail with no pack_losses_with_val_metrics + def on_evaluate_config_fn_1(server_round: int) -> Dict[str, Scalar]: + return { + "foo": 123, + } + + strategy = FedDgGa(on_evaluate_config_fn=on_evaluate_config_fn_1) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being True + def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 1.1, + "pack_losses_with_val_metrics": False, + } + + strategy = FedDgGa(on_fit_config_fn=on_fit_config_fn_2) with raises(AssertionError): strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) @@ -112,10 +189,9 @@ def test_aggregate_fit_and_aggregate_evaluate() -> None: test_fit_metrics_2 = test_fit_results[1][1].metrics test_eval_metrics_1 = test_eval_results[0][1].metrics test_eval_metrics_2 = test_eval_results[1][1].metrics - test_val_loss_key = FairnessMetricType.LOSS.value test_initial_adjustment_weight = 1.0 / 3.0 - strategy = FedDgGaStrategy() + strategy = FedDgGa() strategy.num_rounds = 3 strategy.initial_adjustment_weight = test_initial_adjustment_weight @@ -135,16 +211,17 @@ def test_aggregate_fit_and_aggregate_evaluate() -> None: assert parameters_array == [approx(1.0, abs=0.0005), approx(1.0666, abs=0.0005)] # test evaluate fit - _, _ = strategy.aggregate_evaluate(2, deepcopy(test_eval_results), []) + loss_aggregated, _ = strategy.aggregate_evaluate(2, deepcopy(test_eval_results), []) assert strategy.evaluation_metrics == { - test_cid_1: {**test_eval_metrics_1, test_val_loss_key: test_eval_results[0][1].loss}, - test_cid_2: {**test_eval_metrics_2, test_val_loss_key: test_eval_results[1][1].loss}, + test_cid_1: {**test_eval_metrics_1}, + test_cid_2: {**test_eval_metrics_2}, } assert strategy.adjustment_weights == { test_cid_1: approx(0.2999, abs=0.0005), test_cid_2: approx(0.7000, abs=0.0005), } + assert approx(loss_aggregated, abs=1e-6) == 1.7 def test_weight_and_aggregate_results_with_default_weights() -> None: @@ -153,7 +230,7 @@ def test_weight_and_aggregate_results_with_default_weights() -> None: test_cid_2 = test_fit_results[1][0].cid test_initial_adjustment_weight = 1.0 / 3.0 - strategy = FedDgGaStrategy() + strategy = FedDgGa() strategy.initial_adjustment_weight = test_initial_adjustment_weight aggregated_results = strategy.weight_and_aggregate_results(test_fit_results) @@ -170,7 +247,7 @@ def test_weight_and_aggregate_results_with_existing_weights() -> None: test_cid_2 = test_fit_results[1][0].cid test_adjustment_weights = {test_cid_1: 0.21, test_cid_2: 0.76} - strategy = FedDgGaStrategy() + strategy = FedDgGa() strategy.adjustment_weights = deepcopy(test_adjustment_weights) aggregated_results = strategy.weight_and_aggregate_results(test_fit_results) @@ -183,7 +260,7 @@ def test_update_weights_by_ga() -> None: test_val_loss_key = FairnessMetricType.LOSS.value test_initial_adjustment_weight = 1.0 / 3.0 - strategy = FedDgGaStrategy() + strategy = FedDgGa() strategy.num_rounds = 3 strategy.initial_adjustment_weight = test_initial_adjustment_weight strategy.train_metrics = { @@ -212,7 +289,7 @@ def test_update_weights_by_ga_with_same_metrics() -> None: test_val_loss_key = FairnessMetricType.LOSS.value test_initial_adjustment_weight = 1.0 / 3.0 - strategy = FedDgGaStrategy() + strategy = FedDgGa() strategy.num_rounds = 3 strategy.initial_adjustment_weight = test_initial_adjustment_weight strategy.train_metrics = { @@ -234,7 +311,7 @@ def test_update_weights_by_ga_with_same_metrics() -> None: def test_get_current_weight_step_size() -> None: - strategy = FedDgGaStrategy() + strategy = FedDgGa() with raises(AssertionError): strategy.get_current_weight_step_size(2) @@ -252,7 +329,7 @@ def test_get_current_weight_step_size() -> None: assert result_step_size == approx(0.1000, abs=0.0005) strategy.num_rounds = 10 - strategy.weight_step_size = 0.5 + strategy.adjustment_weight_step_size = 0.5 result_step_size = strategy.get_current_weight_step_size(6) assert result_step_size == approx(0.2500, abs=0.0005) @@ -270,10 +347,10 @@ def _apply_mocks_to_client_manager(client_manager: ClientManager) -> ClientManag def _make_test_data() -> Tuple[List[Tuple[ClientProxy, FitRes]], List[Tuple[ClientProxy, EvaluateRes]]]: test_val_loss_key = FairnessMetricType.LOSS.value - test_fit_metrics_1: Dict[str, Union[bool, bytes, float, int, str]] = {test_val_loss_key: 1.0} - test_fit_metrics_2: Dict[str, Union[bool, bytes, float, int, str]] = {test_val_loss_key: 2.0} - test_eval_metrics_1: Dict[str, Union[bool, bytes, float, int, str]] = {"metric-1": 1.0} - test_eval_metrics_2: Dict[str, Union[bool, bytes, float, int, str]] = {"metric-2": 2.0} + test_fit_metrics_1: Dict[str, Scalar] = {test_val_loss_key: 1.0} + test_fit_metrics_2: Dict[str, Scalar] = {test_val_loss_key: 2.0} + test_eval_metrics_1: Dict[str, Scalar] = {"metric-1": 1.0, test_val_loss_key: 1.2} + test_eval_metrics_2: Dict[str, Scalar] = {"metric-2": 2.0, test_val_loss_key: 2.2} test_parameters_1 = ndarrays_to_parameters([np.array([1.0, 1.1])]) test_parameters_2 = ndarrays_to_parameters([np.array([2.0, 2.1])]) test_fit_results = [ diff --git a/tests/strategies/test_feddg_ga_with_adapt_constraint.py b/tests/strategies/test_feddg_ga_with_adapt_constraint.py new file mode 100644 index 000000000..6797a4c49 --- /dev/null +++ b/tests/strategies/test_feddg_ga_with_adapt_constraint.py @@ -0,0 +1,411 @@ +from copy import deepcopy +from typing import Dict, List, Tuple +from unittest.mock import Mock + +import numpy as np +from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays +from flwr.common.typing import Code, EvaluateRes, FitRes, Parameters, Scalar, Status +from flwr.server.client_manager import ClientManager, ClientProxy, SimpleClientManager +from pytest import approx, raises + +from fl4health.client_managers.fixed_sampling_client_manager import FixedSamplingClientManager +from fl4health.strategies.feddg_ga import FairnessMetricType +from fl4health.strategies.feddg_ga_with_adaptive_constraint import FedDgGaAdaptiveConstraint +from tests.test_utils.custom_client_proxy import CustomClientProxy + +INITIAL_PARAMETERS = ndarrays_to_parameters([np.array([0.0, 0.0])]) + + +def test_configure_fit_and_evaluate_success() -> None: + fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) + test_n_server_rounds = 3 + + def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": test_n_server_rounds, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, + } + + def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": test_n_server_rounds, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint( + initial_parameters=INITIAL_PARAMETERS, + on_fit_config_fn=on_fit_config_fn, + on_evaluate_config_fn=on_evaluate_config_fn, + ) + assert strategy.num_rounds is None + + try: + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + except Exception as e: + assert False, f"initialize_parameters threw an exception: {e}" + + assert strategy.num_rounds == test_n_server_rounds + assert strategy.initial_adjustment_weight == 1.0 / fixed_sampling_client_manager.num_available() + fixed_sampling_client_manager.reset_sample.assert_called_once() # type: ignore + + +def test_configure_fit_fail() -> None: + fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) + simple_client_manager = _apply_mocks_to_client_manager(SimpleClientManager()) + + # Fails with no configure fit + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with bad client manager type + def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), simple_client_manager) + + # Fail with no n_server_rounds + def on_fit_config_fn_1(server_round: int) -> Dict[str, Scalar]: + return { + "foo": 123, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_1) + assert strategy.num_rounds is None + + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with n_server_rounds not being an integer + def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 1.1, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_2) + assert strategy.num_rounds is None + + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with evaluate_after_fit not being set + def on_fit_config_fn_3(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_3) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with evaluate_after_fit not being True + def on_fit_config_fn_4(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": False, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_4) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being there + def on_fit_config_fn_5(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": True, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_5) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being True + def on_fit_config_fn_6(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "evaluate_after_fit": True, + "pack_losses_with_val_metrics": False, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_6) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + +def test_configure_evaluate_fail() -> None: + fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) + simple_client_manager = _apply_mocks_to_client_manager(SimpleClientManager()) + + # Fails with no evaluate fit + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + with raises(AssertionError): + strategy.configure_evaluate(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with bad client manager type + def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 2, + "pack_losses_with_val_metrics": True, + } + + strategy = FedDgGaAdaptiveConstraint( + initial_parameters=INITIAL_PARAMETERS, on_evaluate_config_fn=on_evaluate_config_fn + ) + with raises(AssertionError): + strategy.configure_evaluate(1, Parameters([], ""), simple_client_manager) + + # Fail with no pack_losses_with_val_metrics + def on_evaluate_config_fn_1(server_round: int) -> Dict[str, Scalar]: + return { + "foo": 123, + } + + strategy = FedDgGaAdaptiveConstraint( + initial_parameters=INITIAL_PARAMETERS, on_evaluate_config_fn=on_evaluate_config_fn_1 + ) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + # Fails with pack_losses_with_val_metrics not being True + def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: + return { + "n_server_rounds": 1.1, + "pack_losses_with_val_metrics": False, + } + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS, on_fit_config_fn=on_fit_config_fn_2) + with raises(AssertionError): + strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) + + +def test_aggregate_fit_and_aggregate_evaluate() -> None: + test_fit_results, test_eval_results = _make_test_data() + test_cid_1 = test_fit_results[0][0].cid + test_cid_2 = test_fit_results[1][0].cid + test_fit_metrics_1 = test_fit_results[0][1].metrics + test_fit_metrics_2 = test_fit_results[1][1].metrics + test_eval_metrics_1 = test_eval_results[0][1].metrics + test_eval_metrics_2 = test_eval_results[1][1].metrics + test_initial_adjustment_weight = 1.0 / 3.0 + + strategy = FedDgGaAdaptiveConstraint( + initial_parameters=INITIAL_PARAMETERS, initial_loss_weight=1.0, adapt_loss_weight=True, loss_weight_patience=1 + ) + strategy.num_rounds = 3 + strategy.initial_adjustment_weight = test_initial_adjustment_weight + + # test aggregate fit + parameters_aggregated, _ = strategy.aggregate_fit(2, deepcopy(test_fit_results), []) + + # make sure the the loss has been aggregated and stored and the loss weight adjusted + assert strategy.previous_loss == 2.0 + assert strategy.loss_weight == 0.9 + + assert strategy.train_metrics == { + test_cid_1: test_fit_metrics_1, + test_cid_2: test_fit_metrics_2, + } + assert strategy.adjustment_weights == { + test_cid_1: test_initial_adjustment_weight, + test_cid_2: test_initial_adjustment_weight, + } + assert parameters_aggregated is not None + parameters_array = parameters_to_ndarrays(parameters_aggregated)[0].tolist() + assert parameters_array == [approx(1.0, abs=0.0005), approx(1.0666, abs=0.0005)] + + # test evaluate fit + loss_aggregated, _ = strategy.aggregate_evaluate(2, deepcopy(test_eval_results), []) + + assert strategy.evaluation_metrics == { + test_cid_1: {**test_eval_metrics_1}, + test_cid_2: {**test_eval_metrics_2}, + } + assert strategy.adjustment_weights == { + test_cid_1: approx(0.2999, abs=0.0005), + test_cid_2: approx(0.7000, abs=0.0005), + } + assert approx(loss_aggregated, abs=1e-6) == 1.7 + + +def test_weight_and_aggregate_results_with_default_weights() -> None: + test_fit_results, _ = _make_test_data() + test_cid_1 = test_fit_results[0][0].cid + test_cid_2 = test_fit_results[1][0].cid + test_initial_adjustment_weight = 1.0 / 3.0 + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + strategy.initial_adjustment_weight = test_initial_adjustment_weight + strategy._unpack_weights_and_losses(test_fit_results) + aggregated_results = strategy.weight_and_aggregate_results(test_fit_results) + + assert strategy.adjustment_weights == { + test_cid_1: test_initial_adjustment_weight, + test_cid_2: test_initial_adjustment_weight, + } + assert aggregated_results[0].tolist() == [approx(1.0, abs=0.0005), approx(1.0666, abs=0.0005)] + + +def test_weight_and_aggregate_results_with_existing_weights() -> None: + test_fit_results, _ = _make_test_data() + test_cid_1 = test_fit_results[0][0].cid + test_cid_2 = test_fit_results[1][0].cid + test_adjustment_weights = {test_cid_1: 0.21, test_cid_2: 0.76} + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + strategy.adjustment_weights = deepcopy(test_adjustment_weights) + strategy._unpack_weights_and_losses(test_fit_results) + aggregated_results = strategy.weight_and_aggregate_results(test_fit_results) + + assert strategy.adjustment_weights == test_adjustment_weights + assert aggregated_results[0].tolist() == [approx(1.73, abs=0.0005), approx(1.8270, abs=0.0005)] + + +def test_update_weights_by_ga() -> None: + test_cids = ["1", "2"] + test_val_loss_key = FairnessMetricType.LOSS.value + test_initial_adjustment_weight = 1.0 / 3.0 + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + strategy.num_rounds = 3 + strategy.initial_adjustment_weight = test_initial_adjustment_weight + strategy.train_metrics = { + test_cids[0]: {test_val_loss_key: 0.5467}, + test_cids[1]: {test_val_loss_key: 0.5432}, + } + strategy.evaluation_metrics = { + test_cids[0]: {test_val_loss_key: 0.3556}, + test_cids[1]: {test_val_loss_key: 0.7654}, + } + strategy.adjustment_weights = { + test_cids[0]: test_initial_adjustment_weight, + test_cids[1]: test_initial_adjustment_weight, + } + + strategy.update_weights_by_ga(2, test_cids) + + assert strategy.adjustment_weights == { + test_cids[0]: approx(0.2999, abs=0.0005), + test_cids[1]: approx(0.7000, abs=0.0005), + } + + +def test_update_weights_by_ga_with_same_metrics() -> None: + test_cids = ["1", "2"] + test_val_loss_key = FairnessMetricType.LOSS.value + test_initial_adjustment_weight = 1.0 / 3.0 + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + strategy.num_rounds = 3 + strategy.initial_adjustment_weight = test_initial_adjustment_weight + strategy.train_metrics = { + test_cids[0]: {test_val_loss_key: 0.5467}, + test_cids[1]: {test_val_loss_key: 0.5432}, + } + strategy.evaluation_metrics = { + test_cids[0]: {test_val_loss_key: 0.5467}, + test_cids[1]: {test_val_loss_key: 0.5432}, + } + strategy.adjustment_weights = { + test_cids[0]: test_initial_adjustment_weight, + test_cids[1]: test_initial_adjustment_weight, + } + + strategy.update_weights_by_ga(2, test_cids) + + assert strategy.adjustment_weights == {test_cids[0]: 0.5, test_cids[1]: 0.5} + + +def test_get_current_weight_step_size() -> None: + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + + with raises(AssertionError): + strategy.get_current_weight_step_size(2) + + strategy.num_rounds = 3 + result_step_size = strategy.get_current_weight_step_size(1) + assert result_step_size == approx(0.2000, abs=0.0005) + result_step_size = strategy.get_current_weight_step_size(2) + assert result_step_size == approx(0.1333, abs=0.0005) + result_step_size = strategy.get_current_weight_step_size(3) + assert result_step_size == approx(0.0666, abs=0.0005) + + strategy.num_rounds = 10 + result_step_size = strategy.get_current_weight_step_size(6) + assert result_step_size == approx(0.1000, abs=0.0005) + + strategy.num_rounds = 10 + strategy.adjustment_weight_step_size = 0.5 + result_step_size = strategy.get_current_weight_step_size(6) + assert result_step_size == approx(0.2500, abs=0.0005) + + +def test_unpack_weights_and_losses() -> None: + test_fit_results, _ = _make_test_data() + # make sure the results are of length 2 (one for the weights, one for the loss) + assert len(test_fit_results[0][1].parameters.tensors) == 2 + assert len(test_fit_results[1][1].parameters.tensors) == 2 + + strategy = FedDgGaAdaptiveConstraint(initial_parameters=INITIAL_PARAMETERS) + train_losses_and_counts = strategy._unpack_weights_and_losses(test_fit_results) + + # Assert that the fit results have been modified in place and properly + assert len(test_fit_results) == 2 + test_ndarrays_1 = parameters_to_ndarrays(test_fit_results[0][1].parameters) + test_ndarrays_2 = parameters_to_ndarrays(test_fit_results[1][1].parameters) + target_ndarray_1 = np.array([1.0, 1.1]) + target_ndarray_2 = np.array([2.0, 2.1]) + # length should be 1, since we've unpacked the loss arrays + assert len(test_ndarrays_1) == 1 + assert len(test_ndarrays_2) == 1 + + assert np.allclose(test_ndarrays_1[0], target_ndarray_1, rtol=0.0, atol=1e-6) + assert np.allclose(test_ndarrays_2[0], target_ndarray_2, rtol=0.0, atol=1e-6) + + # Make sure that the losses have properly been extracted and stored. + assert train_losses_and_counts[0][1] == 1.5 + assert train_losses_and_counts[1][1] == 2.5 + + +def _apply_mocks_to_client_manager(client_manager: ClientManager) -> ClientManager: + client_proxy_1 = CustomClientProxy("1") + client_proxy_2 = CustomClientProxy("2") + client_manager.register(client_proxy_1) + client_manager.register(client_proxy_2) + client_manager.sample = Mock() # type: ignore + client_manager.sample.return_value = [client_proxy_1, client_proxy_2] + client_manager.reset_sample = Mock() # type: ignore + return client_manager + + +def _make_test_data() -> Tuple[List[Tuple[ClientProxy, FitRes]], List[Tuple[ClientProxy, EvaluateRes]]]: + test_val_loss_key = FairnessMetricType.LOSS.value + test_fit_metrics_1: Dict[str, Scalar] = {test_val_loss_key: 1.0} + test_fit_metrics_2: Dict[str, Scalar] = {test_val_loss_key: 2.0} + test_eval_metrics_1: Dict[str, Scalar] = {"metric-1": 1.0, test_val_loss_key: 1.2} + test_eval_metrics_2: Dict[str, Scalar] = {"metric-2": 2.0, test_val_loss_key: 2.2} + test_parameters_1 = ndarrays_to_parameters([np.array([1.0, 1.1]), np.array(1.5)]) + test_parameters_2 = ndarrays_to_parameters([np.array([2.0, 2.1]), np.array(2.5)]) + test_fit_results = [ + (CustomClientProxy("1"), FitRes(Status(Code.OK, ""), test_parameters_1, 1, test_fit_metrics_1)), + (CustomClientProxy("2"), FitRes(Status(Code.OK, ""), test_parameters_2, 1, test_fit_metrics_2)), + ] + test_evaluate_results = [ + (CustomClientProxy("1"), EvaluateRes(Status(Code.OK, ""), 1.2, 1, test_eval_metrics_1)), + (CustomClientProxy("2"), EvaluateRes(Status(Code.OK, ""), 2.2, 1, test_eval_metrics_2)), + ] + + return test_fit_results, test_evaluate_results # type: ignore