From 4c15c9399c729644508c0b51b848abbd9d9219ed Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:25:29 -0500 Subject: [PATCH 01/13] WIP --- fl4health/checkpointing/server_module.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 fl4health/checkpointing/server_module.py diff --git a/fl4health/checkpointing/server_module.py b/fl4health/checkpointing/server_module.py new file mode 100644 index 000000000..e69de29bb From 7cba9386f84446ddfc6fd6d7fa63a35557c5d336 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Mon, 25 Nov 2024 08:51:28 -0500 Subject: [PATCH 02/13] WIP checkin to preserve work, not running checks --- .../ae_examples/cvae_dim_example/server.py | 12 +- .../cvae_examples/conv_cvae_example/server.py | 12 +- .../cvae_examples/mlp_cvae_example/server.py | 12 +- .../ae_examples/fedprox_vae_example/server.py | 4 +- examples/basic_example/server.py | 6 +- .../instance_level_dp/client.py | 4 +- examples/fedopt_example/client.py | 4 +- .../fedpca_examples/dim_reduction/server.py | 4 +- .../fedsimclr_finetuning_example/server.py | 4 +- .../fedsimclr_pretraining_example/server.py | 4 +- examples/fenda_ditto_example/client.py | 10 +- examples/model_merge_example/server.py | 4 +- .../warm_up_example/fedavg_warm_up/client.py | 8 +- fl4health/checkpointing/checkpointer.py | 172 ++++++++--- fl4health/checkpointing/client_module.py | 122 ++++++-- .../checkpointing/opacus_checkpointer.py | 10 +- fl4health/checkpointing/server_module.py | 287 ++++++++++++++++++ .../adaptive_drift_constraint_client.py | 4 +- fl4health/clients/apfl_client.py | 4 +- fl4health/clients/basic_client.py | 113 ++++--- fl4health/clients/clipping_client.py | 4 +- fl4health/clients/constrained_fenda_client.py | 4 +- .../deep_mmd_clients/ditto_deep_mmd_client.py | 4 +- fl4health/clients/ditto_client.py | 4 +- fl4health/clients/ensemble_client.py | 4 +- fl4health/clients/fedpm_client.py | 4 +- fl4health/clients/fedrep_client.py | 4 +- fl4health/clients/fenda_client.py | 4 +- fl4health/clients/fenda_ditto_client.py | 4 +- fl4health/clients/flash_client.py | 4 +- fl4health/clients/instance_level_dp_client.py | 4 +- .../mkmmd_clients/ditto_mkmmd_client.py | 4 +- .../mkmmd_clients/mr_mtl_mkmmd_client.py | 4 +- fl4health/clients/moon_client.py | 4 +- fl4health/clients/mr_mtl_client.py | 4 +- fl4health/clients/nnunet_client.py | 4 +- .../clients/partial_weight_exchange_client.py | 4 +- fl4health/clients/perfcl_client.py | 4 +- fl4health/clients/scaffold_client.py | 6 +- .../ditto_server.py | 20 +- .../fedprox_server.py | 4 +- .../mrmtl_server.py | 19 +- fl4health/servers/base_server.py | 163 ++++------ .../servers/client_level_dp_fed_avg_server.py | 14 +- fl4health/servers/fedpm_server.py | 14 +- fl4health/servers/model_merge_server.py | 4 +- fl4health/servers/nnunet_server.py | 4 +- fl4health/servers/scaffold_server.py | 26 +- .../tabular_feature_alignment_server.py | 4 +- fl4health/utils/typing.py | 6 +- .../ag_news/dynamic_layer_exchange/client.py | 10 +- .../ag_news/sparse_tensor_exchange/client.py | 10 +- research/cifar10/adaptive_pfl/ditto/client.py | 16 +- .../cifar10/adaptive_pfl/fedprox/client.py | 16 +- .../cifar10/adaptive_pfl/fedprox/server.py | 8 +- .../adaptive_pfl/fenda_ditto/client.py | 16 +- research/cifar10/adaptive_pfl/mrmtl/client.py | 16 +- research/cifar10/ditto/client.py | 16 +- research/cifar10/ditto_deep_mmd/client.py | 16 +- research/cifar10/ditto_mkmmd/client.py | 16 +- research/cifar10/fed_dgga_pfl/ditto/client.py | 16 +- research/cifar10/fed_dgga_pfl/fenda/client.py | 16 +- .../fed_dgga_pfl/fenda_ditto/client.py | 16 +- research/cifar10/fedavg/client.py | 16 +- research/cifar10/fedavg/server.py | 8 +- .../flamby/fed_heart_disease/apfl/client.py | 12 +- .../flamby/fed_heart_disease/ditto/client.py | 12 +- .../fed_heart_disease/fedadam/client.py | 10 +- .../fed_heart_disease/fedadam/server.py | 8 +- .../flamby/fed_heart_disease/fedavg/client.py | 10 +- .../flamby/fed_heart_disease/fedavg/server.py | 8 +- .../flamby/fed_heart_disease/fedper/client.py | 12 +- .../fed_heart_disease/fedprox/client.py | 10 +- .../fed_heart_disease/fedprox/server.py | 8 +- .../flamby/fed_heart_disease/fenda/client.py | 12 +- .../flamby/fed_heart_disease/moon/client.py | 10 +- .../flamby/fed_heart_disease/moon/server.py | 8 +- .../flamby/fed_heart_disease/perfcl/client.py | 12 +- .../fed_heart_disease/scaffold/client.py | 10 +- .../fed_heart_disease/scaffold/server.py | 25 +- research/flamby/fed_isic2019/apfl/client.py | 10 +- research/flamby/fed_isic2019/ditto/client.py | 10 +- .../fed_isic2019/ditto_deep_mmd/client.py | 10 +- .../flamby/fed_isic2019/ditto_mkmmd/client.py | 10 +- .../flamby/fed_isic2019/fedadam/client.py | 10 +- .../flamby/fed_isic2019/fedadam/server.py | 4 +- research/flamby/fed_isic2019/fedavg/client.py | 10 +- research/flamby/fed_isic2019/fedavg/server.py | 4 +- research/flamby/fed_isic2019/fedper/client.py | 10 +- .../flamby/fed_isic2019/fedprox/client.py | 10 +- .../flamby/fed_isic2019/fedprox/server.py | 4 +- research/flamby/fed_isic2019/fenda/client.py | 12 +- research/flamby/fed_isic2019/moon/client.py | 10 +- research/flamby/fed_isic2019/moon/server.py | 4 +- .../fed_isic2019/mr_mtl_mkmmd/client.py | 10 +- research/flamby/fed_isic2019/perfcl/client.py | 12 +- .../flamby/fed_isic2019/scaffold/client.py | 10 +- .../flamby/fed_isic2019/scaffold/server.py | 21 +- research/flamby/fed_ixi/apfl/client.py | 12 +- research/flamby/fed_ixi/ditto/client.py | 12 +- research/flamby/fed_ixi/fedadam/client.py | 10 +- research/flamby/fed_ixi/fedadam/server.py | 8 +- research/flamby/fed_ixi/fedavg/client.py | 10 +- research/flamby/fed_ixi/fedavg/server.py | 8 +- research/flamby/fed_ixi/fedper/client.py | 12 +- research/flamby/fed_ixi/fedprox/client.py | 10 +- research/flamby/fed_ixi/fedprox/server.py | 8 +- research/flamby/fed_ixi/fenda/client.py | 12 +- research/flamby/fed_ixi/moon/client.py | 10 +- research/flamby/fed_ixi/moon/server.py | 8 +- research/flamby/fed_ixi/perfcl/client.py | 12 +- research/flamby/fed_ixi/scaffold/client.py | 10 +- research/flamby/fed_ixi/scaffold/server.py | 25 +- .../flamby_servers/full_exchange_server.py | 4 +- .../flamby/flamby_servers/scaffold_server.py | 52 ---- research/flamby/single_node_trainer.py | 4 +- research/gemini/ditto/client.py | 4 +- research/gemini/fedper/client.py | 4 +- research/gemini/moon/client.py | 4 +- research/gemini/moon/server.py | 4 +- research/gemini/perfcl/client.py | 4 +- .../gemini/servers/full_exchange_server.py | 4 +- research/picai/fedavg/client.py | 4 +- research/picai/reporting/server.py | 6 +- research/picai/single_node_trainer.py | 6 +- tests/checkpointing/test_best_checkpointer.py | 4 +- tests/checkpointing/test_client_module.py | 62 ++-- .../test_function_checkpointer.py | 4 +- .../test_opacus_checkpointers.py | 10 +- .../test_per_round_checkpointer.py | 4 +- tests/checkpointing/test_save_load.py | 18 +- tests/preprocessing/test_warm_up_module.py | 4 +- tests/servers/test_base_server.py | 14 +- .../load_from_checkpoint_example/client.py | 4 +- .../load_from_checkpoint_example/server.py | 6 +- 135 files changed, 1266 insertions(+), 837 deletions(-) delete mode 100644 research/flamby/flamby_servers/scaffold_server.py diff --git a/examples/ae_examples/cvae_dim_example/server.py b/examples/ae_examples/cvae_dim_example/server.py index 7f5357ed4..5a5dd9e78 100644 --- a/examples/ae_examples/cvae_dim_example/server.py +++ b/examples/ae_examples/cvae_dim_example/server.py @@ -8,7 +8,8 @@ from flwr.server.strategy import FedAvg from examples.models.mnist_model import MnistNet -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer from fl4health.utils.config import load_config @@ -47,7 +48,10 @@ def main(config: Dict[str, Any]) -> None: model = MnistNet(int(config["latent_dim"]) * 2) # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() - checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl") + checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl") + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, checkpointer=checkpointer + ) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -66,10 +70,8 @@ def main(config: Dict[str, Any]) -> None: server = FlServer( client_manager=SimpleClientManager(), fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/examples/ae_examples/cvae_examples/conv_cvae_example/server.py b/examples/ae_examples/cvae_examples/conv_cvae_example/server.py index 15a5e7589..f5ab44d9d 100644 --- a/examples/ae_examples/cvae_examples/conv_cvae_example/server.py +++ b/examples/ae_examples/cvae_examples/conv_cvae_example/server.py @@ -8,7 +8,8 @@ from flwr.server.strategy import FedAvg from examples.ae_examples.cvae_examples.conv_cvae_example.models import ConvConditionalDecoder, ConvConditionalEncoder -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.model_bases.autoencoders_base import ConditionalVae from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer @@ -48,7 +49,10 @@ def main(config: Dict[str, Any]) -> None: # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() - checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name) + checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, checkpointer=checkpointer + ) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -67,10 +71,8 @@ def main(config: Dict[str, Any]) -> None: server = FlServer( client_manager=SimpleClientManager(), fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py b/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py index efa550404..0f3261531 100644 --- a/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py +++ b/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py @@ -8,7 +8,8 @@ from flwr.server.strategy import FedAvg from examples.ae_examples.cvae_examples.mlp_cvae_example.models import MnistConditionalDecoder, MnistConditionalEncoder -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.model_bases.autoencoders_base import ConditionalVae from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer @@ -48,7 +49,10 @@ def main(config: Dict[str, Any]) -> None: # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() - checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name) + checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, checkpointer=checkpointer + ) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -67,10 +71,8 @@ def main(config: Dict[str, Any]) -> None: server = FlServer( client_manager=SimpleClientManager(), fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/examples/ae_examples/fedprox_vae_example/server.py b/examples/ae_examples/fedprox_vae_example/server.py index 17d36f3b0..31b3d559d 100644 --- a/examples/ae_examples/fedprox_vae_example/server.py +++ b/examples/ae_examples/fedprox_vae_example/server.py @@ -7,7 +7,7 @@ from flwr.server.client_manager import SimpleClientManager from examples.ae_examples.fedprox_vae_example.models import MnistVariationalDecoder, MnistVariationalEncoder -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.model_bases.autoencoders_base import VariationalAe from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer @@ -48,7 +48,7 @@ def main(config: Dict[str, Any]) -> None: # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() - checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name) + checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name) # Server performs simple FedAveraging as its server-side optimization strategy and potentially adapts the # FedProx proximal weight mu diff --git a/examples/basic_example/server.py b/examples/basic_example/server.py index c4937759e..5395e06be 100644 --- a/examples/basic_example/server.py +++ b/examples/basic_example/server.py @@ -9,7 +9,7 @@ from examples.models.cnn_model import Net from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer from fl4health.utils.config import load_config @@ -44,8 +44,8 @@ def main(config: Dict[str, Any]) -> None: # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() checkpointers = [ - BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl"), - LatestTorchCheckpointer(config["checkpoint_path"], "latest_model.pkl"), + BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl"), + LatestTorchModuleCheckpointer(config["checkpoint_path"], "latest_model.pkl"), ] # Server performs simple FedAveraging as its server-side optimization strategy diff --git a/examples/dp_fed_examples/instance_level_dp/client.py b/examples/dp_fed_examples/instance_level_dp/client.py index f66450125..05b754125 100644 --- a/examples/dp_fed_examples/instance_level_dp/client.py +++ b/examples/dp_fed_examples/instance_level_dp/client.py @@ -13,7 +13,7 @@ from torch.utils.data import DataLoader from examples.models.cnn_model import Net -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer from fl4health.clients.instance_level_dp_client import InstanceLevelDpClient from fl4health.utils.config import narrow_dict_type @@ -48,7 +48,7 @@ def get_criterion(self, config: Config) -> _Loss: post_aggregation_checkpointer = BestLossOpacusCheckpointer( checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) # Load model and data data_path = Path(args.dataset_path) diff --git a/examples/fedopt_example/client.py b/examples/fedopt_example/client.py index a96b173c8..f29b0fd3a 100644 --- a/examples/fedopt_example/client.py +++ b/examples/fedopt_example/client.py @@ -13,7 +13,7 @@ from examples.fedopt_example.client_data import LabelEncoder, Vocabulary, construct_dataloaders from examples.fedopt_example.metrics import CompoundMetric from examples.models.lstm_model import LSTM -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient, TorchInputType from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType @@ -27,7 +27,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__(data_path, metrics, device, loss_meter_type, checkpointer) self.weight_matrix: torch.Tensor diff --git a/examples/fedpca_examples/dim_reduction/server.py b/examples/fedpca_examples/dim_reduction/server.py index 031ba7fc5..b11959e6d 100644 --- a/examples/fedpca_examples/dim_reduction/server.py +++ b/examples/fedpca_examples/dim_reduction/server.py @@ -8,7 +8,7 @@ from flwr.server.strategy import FedAvg from examples.models.mnist_model import MnistNet -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer from fl4health.utils.config import load_config @@ -47,7 +47,7 @@ def main(config: Dict[str, Any]) -> None: parameter_exchanger = FullParameterExchanger() # To facilitate checkpointing - checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl") + checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl") # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( diff --git a/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py b/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py index b12c169c1..29b773a20 100644 --- a/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py +++ b/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py @@ -10,7 +10,7 @@ from flwr.server.strategy import FedAvg from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.model_bases.fedsimclr_base import FedSimClrModel from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer @@ -51,7 +51,7 @@ def main(config: Dict[str, Any]) -> None: model = load_model() # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() - checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl") + checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl") # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( diff --git a/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py b/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py index 9c13486f2..2f63891b6 100644 --- a/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py +++ b/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py @@ -10,7 +10,7 @@ from examples.models.ssl_models import CifarSslEncoder, CifarSslPredictionHead, CifarSslProjectionHead from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.model_bases.fedsimclr_base import FedSimClrModel from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer @@ -50,7 +50,7 @@ def main(config: Dict[str, Any]) -> None: ) # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() - checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl") + checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl") # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( diff --git a/examples/fenda_ditto_example/client.py b/examples/fenda_ditto_example/client.py index 8c7da0221..475ed256d 100644 --- a/examples/fenda_ditto_example/client.py +++ b/examples/fenda_ditto_example/client.py @@ -16,8 +16,8 @@ SequentialGlobalFeatureExtractorMnist, SequentialLocalPredictionHeadMnist, ) -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_ditto_client import FendaDittoClient from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.parallel_split_models import ParallelFeatureJoinMode @@ -106,15 +106,15 @@ def get_criterion(self, config: Config) -> _Loss: post_aggregation_checkpointer = None if args.checkpointer_type in ["pre", "both"]: - pre_aggregation_checkpointer = BestLossTorchCheckpointer( + pre_aggregation_checkpointer = BestLossTorchModuleCheckpointer( args.checkpoint_path, "fenda_ditto_client_pre_agg.pkl" ) if args.checkpointer_type in ["post", "both"]: - post_aggregation_checkpointer = BestLossTorchCheckpointer( + post_aggregation_checkpointer = BestLossTorchModuleCheckpointer( args.checkpoint_path, "fenda_ditto_client_post_agg.pkl" ) - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=pre_aggregation_checkpointer, post_aggregation=post_aggregation_checkpointer, ) diff --git a/examples/model_merge_example/server.py b/examples/model_merge_example/server.py index 0e25505dd..6bf00cc70 100644 --- a/examples/model_merge_example/server.py +++ b/examples/model_merge_example/server.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from examples.models.cnn_model import MnistNet -from fl4health.checkpointing.checkpointer import LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import LatestTorchModuleCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.model_merge_server import ModelMergeServer from fl4health.strategies.model_merge_strategy import ModelMergeStrategy @@ -77,7 +77,7 @@ def main(config: Dict[str, Any], data_path: Path) -> None: evaluate_fn=server_side_evaluate_fn_partial, ) - checkpointer = LatestTorchCheckpointer(checkpoint_dir=config["ckpt_path"], checkpoint_name="model_merge.pt") + checkpointer = LatestTorchModuleCheckpointer(checkpoint_dir=config["ckpt_path"], checkpoint_name="model_merge.pt") server = ModelMergeServer( client_manager=SimpleClientManager(), diff --git a/examples/warm_up_example/fedavg_warm_up/client.py b/examples/warm_up_example/fedavg_warm_up/client.py index 243ef2fb1..85cf733ae 100644 --- a/examples/warm_up_example/fedavg_warm_up/client.py +++ b/examples/warm_up_example/fedavg_warm_up/client.py @@ -13,8 +13,8 @@ from torch.utils.data import DataLoader from examples.models.cnn_model import MnistNet -from fl4health.checkpointing.checkpointer import LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_mnist_data @@ -39,8 +39,8 @@ def __init__( # Checkpointing is crucial for the warm up process checkpoint_name = f"client_{self.client_name}_latest_model.pkl" - post_aggregation_checkpointer = LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) - self.checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + post_aggregation_checkpointer = LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + self.checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) diff --git a/fl4health/checkpointing/checkpointer.py b/fl4health/checkpointing/checkpointer.py index 329223a1f..79af6d7bf 100644 --- a/fl4health/checkpointing/checkpointer.py +++ b/fl4health/checkpointing/checkpointer.py @@ -1,8 +1,8 @@ import os from abc import ABC, abstractmethod -from logging import ERROR, INFO +from logging import ERROR, INFO, WARNING from pathlib import Path -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, overload import torch import torch.nn as nn @@ -12,7 +12,7 @@ CheckpointScoreFunctionType = Callable[[float, Dict[str, Scalar]], float] -class TorchCheckpointer(ABC): +class TorchModuleCheckpointer(ABC): def __init__(self, checkpoint_dir: str, checkpoint_name: str) -> None: """ Basic abstract base class to handle checkpointing pytorch models. Models are saved with torch.save by default @@ -22,7 +22,7 @@ def __init__(self, checkpoint_dir: str, checkpoint_name: str) -> None: checkpointer will not create it if it does not. checkpoint_name (str): Name of the checkpoint to be saved. """ - self.best_checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) @abstractmethod def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Scalar]) -> None: @@ -38,13 +38,18 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca Raises: NotImplementedError: Must be implemented by the checkpointer """ - raise NotImplementedError + raise NotImplementedError("maybe_checkpoint must be implemented by inheriting classes") + + @overload + def load_checkpoint(self) -> nn.Module: + return torch.load(self.checkpoint_path) - def load_best_checkpoint(self) -> nn.Module: - return torch.load(self.best_checkpoint_path) + @overload + def load_checkpoint(self, path_to_checkpoint: str) -> nn.Module: + return torch.load(path_to_checkpoint) -class FunctionTorchCheckpointer(TorchCheckpointer): +class FunctionTorchModuleCheckpointer(TorchModuleCheckpointer): def __init__( self, checkpoint_dir: str, @@ -61,7 +66,7 @@ def __init__( checkpoint_dir (str): Directory to which the model is saved. This directory should already exist. The checkpointer will not create it if it does not. checkpoint_name (str): Name of the checkpoint to be saved. - checkpoint_score_function (CheckpointFunctionType): Function taking in a loss value and dictionary of + checkpoint_score_function (CheckpointScoreFunctionType): Function taking in a loss value and dictionary of metrics and produces a score based on these. maximize (bool, optional): Specifies whether we're trying to minimize or maximize the score produced by the scoring function. Defaults to False. @@ -74,18 +79,42 @@ def __init__( self.comparison_str = ">=" if self.maximize else "<=" def _should_checkpoint(self, comparison_score: float) -> bool: - # Compares the current score to the best previously recorded, returns true if should checkpoint and false - # otherwise + """ + Compares the current score to the best previously recorded, returns true if should checkpoint and false + otherwise. If the previous best score is None, then we always checkpoint. + + Args: + comparison_score (float): Score that is being maximized or minimized. Will be compared against the previous + best score seen by this checkpointer. + + Returns: + bool: Whether or not to checkpoint the model based on the provided score + """ + if self.best_score: if self.maximize: return self.best_score <= comparison_score - else: - return self.best_score >= comparison_score + return self.best_score >= comparison_score # If best score is none, then this is the first checkpoint return True def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Scalar]) -> None: + """ + Given the loss/metrics associated with the provided model, the checkpointer uses the scoring function to + produce a score. This score will then be used to determine whether the model should be checkpointed or not. + + Args: + model (nn.Module): Model that might be persisted if the scoring function determines it should be + loss (float): Loss associated with the provided model. Will potentially contribute to checkpointing + decision, based on the score function. + metrics (Dict[str, Scalar]): Metrics associated with the provided model. Will potentially contribute to + the checkpointing decision, based on the score function. + + Raises: + e: Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this + context, so we explicitly surface the error with a try/except. + """ # First we use the provided scoring function to produce a score comparison_score = self.checkpoint_score_function(loss, metrics) if self._should_checkpoint(comparison_score): @@ -96,8 +125,8 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca ) self.best_score = comparison_score try: - log(INFO, f"Saving checkpoint as {str(self.best_checkpoint_path)}") - torch.save(model, self.best_checkpoint_path) + log(INFO, f"Saving checkpoint as {str(self.checkpoint_path)}") + torch.save(model, self.checkpoint_path) except Exception as e: log(ERROR, f"Encountered the following error while saving the checkpoint: {e}") raise e @@ -109,26 +138,50 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca ) -class LatestTorchCheckpointer(FunctionTorchCheckpointer): +class LatestTorchModuleCheckpointer(FunctionTorchModuleCheckpointer): def __init__(self, checkpoint_dir: str, checkpoint_name: str) -> None: + """ + A checkpointer that always checkpoints the model, regardless of the loss/metrics provided. As such, the score + function is essentially a dummy. + + Args: + checkpoint_dir (str): Directory to which the model is saved. This directory should already exist. The + checkpointer will not create it if it does not. + checkpoint_name (str): Name of the checkpoint to be saved. + """ + # This function is required by the parent class, but not used in the LatestTorchCheckpointer def null_score_function(loss: float, _: Dict[str, Scalar]) -> float: return 0.0 super().__init__(checkpoint_dir, checkpoint_name, null_score_function, False) - def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Scalar]) -> None: + def maybe_checkpoint(self, model: nn.Module, loss: float, _: Dict[str, Scalar]) -> None: + """ + This function is essentially a pass through, as this class always checkpoints the provided model + + Args: + model (nn.Module): Model to be checkpointed whenever this function is called + loss (float): Loss associated with the provided model. Will potentially contribute to checkpointing + decision, based on the score function. NOT USED. + metrics (Dict[str, Scalar]): Metrics associated with the provided model. Will potentially contribute to + the checkpointing decision, based on the score function. NOT USED. + + Raises: + e: Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this + context, so we explicitly surface the error with a try/except. + """ # Always checkpoint the latest model log(INFO, "Saving latest checkpoint with LatestTorchCheckpointer") try: - log(INFO, f"Saving checkpoint as {str(self.best_checkpoint_path)}") - torch.save(model, self.best_checkpoint_path) + log(INFO, f"Saving checkpoint as {str(self.checkpoint_path)}") + torch.save(model, self.checkpoint_path) except Exception as e: log(ERROR, f"Encountered the following error while saving the checkpoint: {e}") raise e -class BestLossTorchCheckpointer(FunctionTorchCheckpointer): +class BestLossTorchModuleCheckpointer(FunctionTorchModuleCheckpointer): def __init__(self, checkpoint_dir: str, checkpoint_name: str) -> None: """ This checkpointer only uses the loss value provided to the maybe_checkpoint function to determine whether a @@ -150,6 +203,21 @@ def loss_score_function(loss: float, _: Dict[str, Scalar]) -> float: ) def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Scalar]) -> None: + """ + This function will decide whether to checkpoint the provided model based on the loss argument. If the provided + loss is better than any previous losses seen by this checkpointer, the model will be saved. + + Args: + model (nn.Module): Model that might be persisted if the scoring function determines it should be + loss (float): Loss associated with the provided model. This value is used to determine whether to save the + model or not. + metrics (Dict[str, Scalar]): Metrics associated with the provided model. Will not be used by this + checkpointer. + + Raises: + e: Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this + context, so we explicitly surface the error with a try/except. + """ # First we use the provided scoring function to produce a score comparison_score = self.checkpoint_score_function(loss, metrics) if self._should_checkpoint(comparison_score): @@ -160,8 +228,8 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca ) self.best_score = comparison_score try: - log(INFO, f"Saving checkpoint as {str(self.best_checkpoint_path)}") - torch.save(model, self.best_checkpoint_path) + log(INFO, f"Saving checkpoint as {str(self.checkpoint_path)}") + torch.save(model, self.checkpoint_path) except Exception as e: log(ERROR, f"Encountered the following error while saving the checkpoint: {e}") raise e @@ -173,49 +241,69 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca ) -class PerRoundCheckpointer(ABC): - def __init__(self, checkpoint_dir: Path, checkpoint_name: Path) -> None: +class PerRoundStateCheckpointer: + def __init__(self, checkpoint_dir: Path) -> None: """ - Abstract Base Class that provides a uniform interface for loading, saving and checking - if checkpoints exists. + Base class that provides a uniform interface for loading, saving and checking if checkpoints exists. Args: - checkpoint_dir (Path): Base directory to store checkpoints. - checkpoint_name (Path): The file name in which to save the checkpoint. + checkpoint_dir (Path): Base directory to store checkpoints. This checkpoint directory MUST already exist. + It will not be created by this state checkpointer. """ - self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) + log( + WARNING, + "Creating PerRoundCheckpointer. Currently, this functionality is still experimental and only supported " + "for BasicClient and NnunetClient, along with their associated servers.", + ) + self.checkpoint_dir = checkpoint_dir - def save_checkpoint(self, checkpoint_dict: Dict[str, Any]) -> None: + def save_checkpoint(self, checkpoint_name: str, checkpoint_dict: Dict[str, Any]) -> None: """ - Saves checkpoint_dict to checkpoint path. + Saves checkpoint_dict to checkpoint path form from this classes checkpointer dir and the provided checkpoint + name. + Args: - checkpoint_dict (Dict[str, Any]): A dictionary with string keys and values of type - Any representing the state to checkpoint. + checkpoint_name (str): Name of the state checkpoint file. + checkpoint_dict (Dict[str, Any]): A dictionary with string keys and values of type Any representing the + state to checkpoint. + + Raises: + e: Will throw an error if there is an issue saving the model. Torch.save seems to swallow errors in this + context, so we explicitly surface the error with a try/except. """ + + checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name) try: - log(INFO, f"Saving checkpoint as {self.checkpoint_path}") - torch.save(checkpoint_dict, self.checkpoint_path) + log(INFO, f"Saving state as {checkpoint_path}") + torch.save(checkpoint_dict, checkpoint_path) except Exception as e: log(ERROR, f"Encountered the following error while saving the checkpoint: {e}") raise e - def load_checkpoint(self) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_name: str) -> Dict[str, Any]: """ - Loads and returns the most recent checkpoint if it exists. + Loads and returns the checkpoint stored in checkpoint_dir under the provided name if it exists. + If it doesn't exist, an assertion error will be thrown. + + Args: + checkpoint_name (str): Name of the state checkpoint to be loaded. Returns: - Dict[str, Any] A dictionary representing the checkpointed state. + Dict[str, Any]: A dictionary representing the checkpointed state, as loaded by torch.load. """ - assert self.checkpoint_exists() - return torch.load(self.checkpoint_path) + checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name) + log(INFO, f"Loading state from checkpoint at {checkpoint_path}") + assert self.checkpoint_exists(checkpoint_path) + + return torch.load(checkpoint_path) - def checkpoint_exists(self) -> bool: + def checkpoint_exists(self, checkpoint_path: str) -> bool: """ Checks if a checkpoint exists at the checkpoint_path constructed at initialization. Returns: bool: Whether or not a checkpoint exists. """ - return os.path.exists(self.checkpoint_path) + return os.path.exists(checkpoint_path) diff --git a/fl4health/checkpointing/client_module.py b/fl4health/checkpointing/client_module.py index 82a88dee3..aa22872dd 100644 --- a/fl4health/checkpointing/client_module.py +++ b/fl4health/checkpointing/client_module.py @@ -1,14 +1,14 @@ from enum import Enum from logging import INFO -from typing import Dict, Optional, Sequence, Union +from typing import Any, Dict, Sequence, Union import torch.nn as nn from flwr.common.logger import log from flwr.common.typing import Scalar -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer, TorchModuleCheckpointer -CheckpointModuleInput = Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] +CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None class CheckpointMode(Enum): @@ -16,16 +16,22 @@ class CheckpointMode(Enum): POST_AGGREGATION = "post_aggregation" -class ClientCheckpointModule: +class ClientCheckpointAndStateModule: def __init__( - self, pre_aggregation: CheckpointModuleInput = None, post_aggregation: CheckpointModuleInput = None + self, + pre_aggregation: CheckpointModuleInput = None, + post_aggregation: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ - This module is meant to hold up to two distinct client-side checkpointers. - The first checkpointer, if defined, is used to checkpoint local models BEFORE server-side aggregation. - **NOTE**: This is akin to "further fine-tuning" approaches for global models. - The second checkpointer, if defined, is used to checkpoint local models AFTER server-side aggregation - **NOTE**: This is the "traditional" mechanism for global models. + This module is meant to hold up three major components that determine how clients handle model and state + checkpointing, where state checkpointing is meant to allow clients to restart if FL training is interrupted. + For model checkpointing, there are two distinct types. + The first type, if defined, is used to checkpoint local models BEFORE server-side aggregation, but + after local training. **NOTE**: This is akin to "further fine-tuning" approaches for global models. + + The second type, if defined, is used to checkpoint local models AFTER server-side aggregation, but + before local training **NOTE**: This is the "traditional" mechanism for global models. As a final note, for some methods, such as Ditto or MR-MTL, these checkpoints will actually be identical. That's because the target model for these methods is never globally aggregated. That is, they remain local @@ -34,31 +40,37 @@ def __init__( pre_aggregation (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their validation metrics/losses **BEFORE** server-side aggregation. Defaults to None. - post_aggregation (CheckpointModuleInput, optional], optional): If defined, this checkpointer (or sequence + post_aggregation (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their validation metrics/losses **AFTER** server-side aggregation. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer is used to + preserve client state (not just models), in the event one wants to restart federated training. + Defaults to None. """ - self.pre_aggregation = [pre_aggregation] if isinstance(pre_aggregation, TorchCheckpointer) else pre_aggregation + self.pre_aggregation = ( + [pre_aggregation] if isinstance(pre_aggregation, TorchModuleCheckpointer) else pre_aggregation + ) self.post_aggregation = ( - [post_aggregation] if isinstance(post_aggregation, TorchCheckpointer) else post_aggregation + [post_aggregation] if isinstance(post_aggregation, TorchModuleCheckpointer) else post_aggregation ) self._check_if_shared_checkpoint_names() + self.state_checkpointer = state_checkpointer def _check_if_shared_checkpoint_names(self) -> None: """ - This function is meant to throw an exception if there is an overlap in the paths to which their checkpointers - will save checkpoints to avoid accidental overwriting. + This function checks whether there is overlap in the paths to which the checkpointers of this module are + supposed to write. This is to ensure that there isn't any accidental overwriting of checkpoints that is + unintended. + + Raises: + ValueError: If any of the pre- or post-aggregation model checkpointer paths are not unique. """ pre_aggregation_paths = ( - [checkpointer.best_checkpoint_path for checkpointer in self.pre_aggregation] - if self.pre_aggregation - else [] + [checkpointer.checkpoint_path for checkpointer in self.pre_aggregation] if self.pre_aggregation else [] ) post_aggregation_paths = ( - [checkpointer.best_checkpoint_path for checkpointer in self.post_aggregation] - if self.post_aggregation - else [] + [checkpointer.checkpoint_path for checkpointer in self.post_aggregation] if self.post_aggregation else [] ) all_paths = pre_aggregation_paths + post_aggregation_paths @@ -75,24 +87,21 @@ def maybe_checkpoint( self, model: nn.Module, loss: float, metrics: Dict[str, Scalar], mode: CheckpointMode ) -> None: """ - If checkpointer or checkpoints indicated by the checkpoint mode exists, maybe checkpoint model based on the - model metrics or loss - - Args: - loss (float): The metric value obtained by the current model. - Used by checkpointer to decide whether to checkpoint the model. - mode (CheckpointMode): Determines which of the checkpointers to use. + Performs model checkpointing for a particular mode (either pre- or post-aggregation) if any checkpointers are + provided for that particular mode in this module. If present, the various checkpointers will decide whether + or not to checkpoint based on their internal criterion and the loss/metrics provided. Args: model (nn.Module): The model that might be checkpointed by the checkpointers. - loss (float): The loss value obtained by the current model. Potentially used by checkpointer to decide + loss (float): The metric value obtained by the provided model. Used by the checkpointer(s) to decide whether to checkpoint the model. - metrics (Dict[str, Scalar]): The metrics obtained by the current model. Potentially used by checkpointer + metrics (Dict[str, Scalar]): The metrics obtained by the provided model. Potentially used by checkpointer to decide whether to checkpoint the model. - mode (CheckpointMode): Determines which of the checkpointers to use. + mode (CheckpointMode): Determines which of the types of checkpointers to use. Currently, the only modes + available are pre- and post-aggregation. Raises: - ValueError: Thrown if the model provided is not recognized. + ValueError: Thrown if the model checkpointing mode is not recognized. """ if mode == CheckpointMode.PRE_AGGREGATION: if self.pre_aggregation is not None: @@ -108,3 +117,52 @@ def maybe_checkpoint( log(INFO, "No Post-aggregation checkpoint specified. Skipping.") else: raise ValueError(f"Unrecognized mode for checkpointing: {str(mode)}") + + def save_state(self, state_checkpoint_name: str, state: Dict[str, Any]) -> None: + """ + This function is meant to facilitate saving state required to restart an FL process on the client side. This + function will simply save whatever information is passed in the state variable using the file name in + state_checkpoint_name. This function should only be called if a state_checkpointer exists in this module + + Args: + state_checkpoint_name (str): Name of the state checkpoint file. The checkpointer itself will have a + directory to which state will be saved. + state (Dict[str, Any]): State to be saved so that training might be resumed on the client if federated + training is interrupted. For example, this might contain things like optimizer states, learning rate + scheduler states, etc. + + Raises: + ValueError: Throws an error if this function is called, but no state checkpointer has been provided + """ + + if self.state_checkpointer is not None: + self.state_checkpointer.save_checkpoint(state_checkpoint_name) + else: + raise ValueError("Attempting to save state but no state checkpointer is specified") + + def maybe_load_state(self, state_checkpoint_name: str) -> Dict[str, Any] | None: + """ + This function facilitates loading of any pre-existing state (with the name state_checkpoint_name) in the + directory of the state_checkpointer. If the state already exists at the proper path, the state is loaded + and returned. If it doesn't exist, we return None. + + Args: + state_checkpoint_name (str): Name of the state checkpoint file. The checkpointer itself will have a + directory from which state will be loaded (if it exists). + + Raises: + ValueError: Throws an error if this function is called, but no state checkpointer has been provided + + Returns: + Dict[str, Any] | None: If the state checkpoint properly exists and is loaded correctly, this dictionary + carries that state. Otherwise, we return a None (or throw an exception). + """ + + if self.state_checkpointer is not None: + if self.state_checkpointer.checkpoint_exists(state_checkpoint_name): + return self.state_checkpointer.load_checkpoint(state_checkpoint_name) + else: + log(INFO, "State checkpointer is defined but no state checkpoint exists.") + return None + else: + raise ValueError("Attempting to load state, but no state checkpointer is specified") diff --git a/fl4health/checkpointing/opacus_checkpointer.py b/fl4health/checkpointing/opacus_checkpointer.py index 7669f4bab..907fe9650 100644 --- a/fl4health/checkpointing/opacus_checkpointer.py +++ b/fl4health/checkpointing/opacus_checkpointer.py @@ -7,10 +7,10 @@ from flwr.common.typing import Scalar from opacus import GradSampleModule -from fl4health.checkpointing.checkpointer import FunctionTorchCheckpointer +from fl4health.checkpointing.checkpointer import FunctionTorchModuleCheckpointer -class OpacusCheckpointer(FunctionTorchCheckpointer): +class OpacusCheckpointer(FunctionTorchModuleCheckpointer): """ This is a specific type of checkpointer to be used in saving models trained using Opacus for differential privacy. Certain layers within Opacus wrapped models do not interact well with torch.save functionality. This checkpointer @@ -71,10 +71,10 @@ def _extract_and_save_state(self, model: nn.Module) -> None: model (nn.Module): Model to be checkpointed via the state dictionary. """ model_state_dict = model.state_dict() - with open(self.best_checkpoint_path, "wb") as handle: + with open(self.checkpoint_path, "wb") as handle: pickle.dump(model_state_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) - def load_best_checkpoint(self) -> nn.Module: + def load_checkpoint(self) -> nn.Module: raise NotImplementedError( "When loading from Opacus checkpointers, you need to provide a model into which state is loaded. " "Please use load_best_checkpoint_into_model instead and provide model architecture to load state into." @@ -92,7 +92,7 @@ def load_best_checkpoint_into_model( target_is_grad_sample_module (bool, optional): Whether the target_model that the state_dict is being loaded into is an Opacus module or just a vanilla Pytorch module. Defaults to False. """ - with open(self.best_checkpoint_path, "rb") as handle: + with open(self.checkpoint_path, "rb") as handle: model_state_dict = pickle.load(handle) # If the target is just a plain PyTorch module, we remove the _module key prefix that Opacus inserts into # its GradSampleModules. diff --git a/fl4health/checkpointing/server_module.py b/fl4health/checkpointing/server_module.py index e69de29bb..90a169065 100644 --- a/fl4health/checkpointing/server_module.py +++ b/fl4health/checkpointing/server_module.py @@ -0,0 +1,287 @@ +from logging import INFO +from typing import Any, Dict, Sequence, Union + +import torch.nn as nn +from flwr.common import Parameters +from flwr.common.logger import log +from flwr.common.parameter import parameters_to_ndarrays +from flwr.common.typing import Scalar + +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer, TorchModuleCheckpointer +from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking +from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint +from fl4health.utils.typing import ExchangerType + +CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None + + +class BaseServerCheckpointAndStateModule: + def __init__( + self, + model: nn.Module | None = None, + parameter_exchanger: ExchangerType | None = None, + model_checkpointers: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, + ) -> None: + """ + This module is meant to handle basic model and state checkpointing on the server-side of an FL process. Unlike + the module on the client side, this module has no concept of pre- or post-aggregation checkpointing. It only + considers checkpointing the global server model after aggregation, perhaps based on validation statistics + retrieved on the client side by running a federated evaluation step. Multiple model checkpointers may be + used. For state checkpointing, which saves the state of the entire server-side FL process to help with + FL restarts, we allow only a single checkpointer responsible for saving the state after each fit and eval + round of FL. + + Args: + model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture + to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. + Recall that servers only have parameters rather than torch models. So we need to know where to route + these parameters to allow for real models to be saved. Defaults to None. + parameter_exchanger (FullParameterExchangerWithPacking | None, optional): This will facilitate routing the + server parameters into the right components of the provided model architecture. Note that this + exchanger and the model must match the one used for training and exchange with the servers to ensure + parameters go to the right places. Defaults to None. + model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be + used to preserve FL training state to facilitate restarting training if interrupted. Generally, this + checkpointer will save much more than just the model being trained. Defaults to None. + """ + self.model = model + self.parameter_exchanger = parameter_exchanger + self.model_checkpointers = ( + [model_checkpointers] if isinstance(model_checkpointers, TorchModuleCheckpointer) else model_checkpointers + ) + self.state_checkpointer = state_checkpointer + if self.model_checkpointers is not None and len(self.model_checkpointers): + # If there are model checkpointers, make sure the the model and parameter exchanger are defined. + self._validate_model_checkpointer_components() + self._check_if_shared_checkpoint_names() + + def _validate_model_checkpointer_components(self) -> None: + assert self.model is not None, ( + "Checkpointer(s) is (are) defined but no model is defined to hydrate. The functionality of " + "this class can be overridden in a child class if checkpointing without a parameter exchanger is " + "possible and desired" + ) + assert self.parameter_exchanger is not None, ( + "Checkpointer(s) is (are) defined but no parameter_exchanger is defined to hydrate. The functionality of " + "this class can be overridden in a child class if checkpointing without a parameter exchanger is " + "possible and desired" + ) + + def _check_if_shared_checkpoint_names(self) -> None: + """ + This function is meant to throw an exception if there is an overlap in the paths to which model checkpointers + will save model checkpoints to avoid accidental overwriting. + """ + + checkpointer_paths = ( + [checkpointer.checkpoint_path for checkpointer in self.model_checkpointers] + if self.model_checkpointers + else [] + ) + unique_paths = set(checkpointer_paths) + + if len(unique_paths) != len(checkpointer_paths): + formatted_all_paths = "\n".join(checkpointer_paths) + raise ValueError( + "The paths of all of your checkpointers should be unique otherwise overwrites are possible and data " + f"will be lost. The current paths are:\n{formatted_all_paths}" + ) + + def maybe_checkpoint(self, server_parameters: Parameters, loss: float, metrics: Dict[str, Scalar]) -> None: + """ + If there are model checkpointers defined in this class, we hydrate the model for checkpointing with the server + parameters and call maybe checkpoint model on each of the checkpointers to decide whether to checkpoint based + on the model metrics or loss and the checkpointer definitions. + + Args: + server_parameters (Parameters): Parameters held by the server that should be injected into the model + loss (float): The aggregated loss value obtained by the current aggregated server model. + Potentially used by checkpointer to decide whether to checkpoint the model. + metrics (Dict[str, Scalar]): The aggregated metrics obtained by the aggregated server model. Potentially + used by checkpointer to decide whether to checkpoint the model. + """ + if self.model_checkpointers is not None and len(self.model_checkpointers) > 0: + self._hydrate_model_for_checkpointing(server_parameters) + for checkpointer in self.model_checkpointers: + assert self.model is not None + checkpointer.maybe_checkpoint(self.model, loss, metrics) + else: + log(INFO, "No model checkpointers specified. Skipping any checkpointing.") + + def _hydrate_model_for_checkpointing(self, server_parameters: Parameters) -> None: + """ + This function is used as a means of saving the server-side model after aggregation in the FL training + trajectory. Presently, the server only holds Flower Parameters, which are essentially just ndarrays. Without + knowledge of a model architecture to which the arrays correspond. Thus, in the default implementation, we + require that a torch architecture and a parameter exchanger be provided which handles mapping these numpy + arrays into the architecture properly. + + This function may be overridden in a child class if different behavior is desired. + + NOTE: This function stores the weights directly in the self.model attribute + Args: + server_parameters (Parameters): Parameters to be injected into the torch model architecture and + checkpointed. + """ + assert self.model is not None, "Hydrate model for checkpoint called but self.model is None" + assert ( + self.parameter_exchanger is not None + ), "Hydrate model for checkpoint called but self.parameter_exchanger is None" + model_ndarrays = parameters_to_ndarrays(server_parameters) + self.parameter_exchanger.pull_parameters(model_ndarrays, self.model) + + def save_state( + self, state_checkpoint_name: str, server_parameters: Parameters, other_state: Dict[str, Any] + ) -> None: + """ + This function is meant to facilitate saving state required to restart on FL process on the server side. By + default, this function will always at least preserve the model being trained. However, it may be desirable to + save additional information, like the current server round etc. So the other_state dictionary may be provided + to preserve this additional state. + + NOTE: This function will throw an error if you attempt to save the model under the 'model' key in other_state + + Args: + state_checkpoint_name (str): Name of the state checkpoint file. The checkpointer itself will have a + directory to which state will be saved. + server_parameters (Parameters): Like model checkpointers, these are the aggregated Parameters stored by + the server representing model state. They are mapped to a torch model architecture via the + _hydrate_model_for_checkpointing function. + other_state (Dict[str, Any]): Any additional state (such as current server round) to be checkpointed in + order to allow FL to restart from where it left off. + + Raises: + ValueError: Throws an error if other_state already has a key called 'model' + ValueError: Throws an error if this function is called, but no state checkpointer has been provided + """ + if self.state_checkpointer is not None: + self._hydrate_model_for_checkpointing(server_parameters) + if "model" not in other_state: + other_state["model"] = self.model + else: + raise ValueError("Key 'model' already exists in the other_state dictionary.") + self.state_checkpointer.save_checkpoint(state_checkpoint_name) + else: + raise ValueError("Attempting to save state but no state checkpointer is specified") + + def maybe_load_state(self, state_checkpoint_name: str) -> Dict[str, Any] | None: + """ + This function facilitates loading of any pre-existing state (with the name state_checkpoint_name) in the + directory of the state_checkpointer. If the state already exists at the proper path, the state is loaded + and returned. If it doesn't exist, we return None. + + Args: + state_checkpoint_name (str): Name of the state checkpoint file. The checkpointer itself will have a + directory from which state will be loaded (if it exists). + + Raises: + ValueError: Throws an error if this function is called, but no state checkpointer has been provided + + Returns: + Dict[str, Any] | None: If the state checkpoint properly exists and is loaded correctly, this dictionary + carries that state. Otherwise, we return a None (or throw an exception). + """ + if self.state_checkpointer is not None: + if self.state_checkpointer.checkpoint_exists(state_checkpoint_name): + return self.state_checkpointer.load_checkpoint(state_checkpoint_name) + else: + log(INFO, "State checkpointer is defined but no state checkpoint exists.") + return None + else: + raise ValueError("Attempting to load state, but no state checkpointer is specified") + + +class ScaffoldServerCheckpointAndStateModule(BaseServerCheckpointAndStateModule): + def __init__( + self, + model: nn.Module | None = None, + parameter_exchanger: FullParameterExchangerWithPacking | None = None, + model_checkpointers: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, + ) -> None: + """ + This module is meant to handle SCAFFOLD model and state checkpointing on the server-side of an FL process. + Unlike the module on the client side, this module has no concept of pre- or post-aggregation checkpointing. + It only considers checkpointing the global server model after aggregation, perhaps based on validation + statistics retrieved on the client side by running a federated evaluation step. Multiple model checkpointers + may be used. For state checkpointing, which saves the state of the entire server-side FL process to help with + FL restarts, we allow only a single checkpointer responsible for saving the state after each fit and eval + round of FL. + + Args: + model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture + to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. + Recall that servers only have parameters rather than torch models. So we need to know where to route + these parameters to allow for real models to be saved. Defaults to None. + parameter_exchanger (FullParameterExchangerWithPacking | None, optional): This will facilitate routing the + server parameters into the right components of the provided model architecture. Note that this + exchanger and the model must match the one used for training and exchange with the servers to ensure + parameters go to the right places. Defaults to None. + model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be + used to preserve FL training state to facilitate restarting training if interrupted. Generally, this + checkpointer will save much more than just the model being trained. Defaults to None. + """ + super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) + + def _hydrate_model_for_checkpointing(self, server_parameters: Parameters): + assert self.model is not None, "Hydrate model for checkpoint called but self.model is None" + assert ( + self.parameter_exchanger is not None + ), "Hydrate model for checkpoint called but self.parameter_exchanger is None" + packed_parameters = parameters_to_ndarrays(server_parameters) + # Don't need the control variates for checkpointing. + assert isinstance(self.parameter_exchanger, FullParameterExchangerWithPacking) + model_ndarrays, _ = self.parameter_exchanger.unpack_parameters(packed_parameters) + self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model) + + +class AdaptiveConstraintServerCheckpointAndStateModule(BaseServerCheckpointAndStateModule): + def __init__( + self, + model: nn.Module | None = None, + parameter_exchanger: FullParameterExchangerWithPacking | None = None, + model_checkpointers: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, + ) -> None: + """ + This module is meant to handle FL flows with adaptive constraints, where the server and client communicate + a loss weight parameter in addition to the model weights. Unlike the module on the client side, this module + has no concept of pre- or post-aggregation checkpointing. It only considers checkpointing the global server + model after aggregation, perhaps based on validation statistics retrieved on the client side by running a + federated evaluation step. Multiple model checkpointers may be used. For state checkpointing, which saves the + state of the entire server-side FL process to help with FL restarts, we allow only a single checkpointer + responsible for saving the state after each fit and eval round of FL. + + Args: + model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture + to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. + Recall that servers only have parameters rather than torch models. So we need to know where to route + these parameters to allow for real models to be saved. Defaults to None. + parameter_exchanger (FullParameterExchangerWithPacking | None, optional): This will facilitate routing the + server parameters into the right components of the provided model architecture. Note that this + exchanger and the model must match the one used for training and exchange with the servers to ensure + parameters go to the right places. Defaults to None. + model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be + used to preserve FL training state to facilitate restarting training if interrupted. Generally, this + checkpointer will save much more than just the model being trained. Defaults to None. + """ + + super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) + + def _hydrate_model_for_checkpointing(self, server_parameters: Parameters): + assert self.model is not None, "Hydrate model for checkpoint called but self.model is None" + assert ( + self.parameter_exchanger is not None + ), "Hydrate model for checkpoint called but self.parameter_exchanger is None" + packed_parameters = parameters_to_ndarrays(server_parameters) + # Don't need the extra loss weight variable for checkpointing. + assert isinstance(self.parameter_exchanger, FullParameterExchangerWithPacking) + model_ndarrays, _ = self.parameter_exchanger.unpack_parameters(packed_parameters) + self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model) diff --git a/fl4health/clients/adaptive_drift_constraint_client.py b/fl4health/clients/adaptive_drift_constraint_client.py index f3b02bdd2..aa3b11e3a 100644 --- a/fl4health/clients/adaptive_drift_constraint_client.py +++ b/fl4health/clients/adaptive_drift_constraint_client.py @@ -6,7 +6,7 @@ from flwr.common.logger import log from flwr.common.typing import Config, NDArrays -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.losses.weight_drift_loss import WeightDriftLoss from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger @@ -26,7 +26,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, ) -> None: diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index 458d6d0a2..9df872e67 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -5,7 +5,7 @@ from flwr.common.typing import Config from torch.optim import Optimizer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.apfl_base import ApflModule from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger @@ -22,7 +22,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, ) -> None: super().__init__(data_path, metrics, device, loss_meter_type, checkpointer, reporters) diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 92b278e96..fc6044417 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -14,8 +14,8 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import PerRoundCheckpointer -from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer +from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.reporting.base_reporter import BaseReporter @@ -42,10 +42,9 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - intermediate_client_state_dir: Optional[Path] = None, client_name: Optional[str] = None, ) -> None: """ @@ -60,40 +59,33 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the client should send data to. - progress_bar (bool): Whether or not to display a progress bar - during client training and validation. Uses tqdm. Defaults to - False - intermediate_client_state_dir (Optional[Path]): An optional path to store per round - checkpoints. - client_name (str): An optional client name that uniquely identifies a client. - If not passed, a hash is randomly generated. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. """ self.data_path = data_path self.device = device self.metrics = metrics - self.checkpointer = checkpointer self.progress_bar = progress_bar self.client_name = client_name if client_name is not None else generate_hash() + self.state_checkpoint_name = f"client_{self.client_name}_state.pt" - self.per_round_checkpointer: Union[None, PerRoundCheckpointer] - - if intermediate_client_state_dir is not None: - log( - WARNING, - "intermediate_client_state_dir is not None. Creating PerRoundCheckpointer. This functionality is " - "still experimental and only supported with the base FlServer and NnunetServers at the moment", - ) - self.per_round_checkpointer = PerRoundCheckpointer( - intermediate_client_state_dir, Path(f"client_{self.client_name}.pt") - ) + if checkpoint_and_state_module is not None: + self.checkpoint_and_state_module = checkpoint_and_state_module else: - self.per_round_checkpointer = None + # Define a default module that does nothing. + self.checkpoint_and_state_module = ClientCheckpointAndStateModule( + model=None, parameter_exchanger=None, model_checkpointers=None, state_checkpointer=None + ) # Initialize reporters with client information. self.reports_manager = ReportsManager(reporters) @@ -135,8 +127,7 @@ def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar], checkpoint_ loss (float): validation loss to potentially be used for checkpointing metrics (dict[str, float]): validation metrics to potentially be used for checkpointing """ - if self.checkpointer: - self.checkpointer.maybe_checkpoint(self.model, loss, metrics, checkpoint_mode) + self.checkpoint_and_state_module.maybe_checkpoint(self.model, loss, metrics, checkpoint_mode) def get_parameters(self, config: Config) -> NDArrays: """ @@ -277,10 +268,15 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict if not self.initialized: self.setup_client(config) - # If per_round_checkpointer not None and checkpoint exists load it and set attributes. - # Model not updated because FL restarted from most recent FL round (redo preempted round) - if self.per_round_checkpointer is not None and self.per_round_checkpointer.checkpoint_exists(): - self.load_client_state() + if self.checkpoint_and_state_module.state_checkpointer is not None: + # If this is the first time the client is being setup, we also attempt to load any existing state + # If no state exists, we assume this is a fresh run. State is useful, for example, in restarting FL + # training that was interrupted or failed part way through. + state_load_success = self._load_client_state() + if state_load_success: + log(INFO, "Successfully loaded client state.") + else: + log(INFO, "Client state was not loaded.") self.set_parameters(parameters, config, fitting_round=True) @@ -329,10 +325,9 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict current_server_round, ) - # After local client training has finished, checkpoint client state - # if per_round_checkpointer is not None - if self.per_round_checkpointer is not None: - self.save_client_state() + # After local client training has finished, checkpoint client state if a state checkpointer is defined + if self.checkpoint_and_state_module.state_checkpointer is not None: + self._save_client_state() # FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics # calculation results. @@ -1192,15 +1187,13 @@ def transform_gradients(self, losses: TrainingLosses) -> None: """ pass - def save_client_state(self) -> None: + def _save_client_state(self) -> None: """ Saves checkpoint dict consisting of client name, total steps, lr schedulers, - metrics reporter and optimizers state. Method can be overridden to augment saved checkpointed state. + metrics reporter and optimizers state. Method can be overridden to augment saved checkpointed state. """ - assert self.per_round_checkpointer is not None - - ckpt = { + state = { "lr_schedulers_state": {key: scheduler.state_dict() for key, scheduler in self.lr_schedulers.items()}, "total_steps": self.total_steps, "client_name": self.client_name, @@ -1208,34 +1201,30 @@ def save_client_state(self) -> None: "optimizers_state": {key: optimizer.state_dict()["state"] for key, optimizer in self.optimizers.items()}, } - self.per_round_checkpointer.save_checkpoint(ckpt) - - log( - INFO, - f"Saving client state to checkpoint at {self.per_round_checkpointer.checkpoint_path}", - ) + self.checkpoint_and_state_module.save_state(self.state_checkpoint_name, state) - def load_client_state(self) -> None: + def _load_client_state(self) -> bool: """ Load checkpoint dict consisting of client name, total steps, lr schedulers, metrics - reporter and optimizers state. Method can be overridden to augment loaded checkpointed state. + reporter and optimizers state. Method can be overridden to augment loaded checkpointed state. """ - assert self.per_round_checkpointer is not None and self.per_round_checkpointer.checkpoint_exists() + client_state = self.checkpoint_and_state_module.maybe_load_state(self.state_checkpoint_name) - ckpt = self.per_round_checkpointer.load_checkpoint() + if client_state is None: + return False - narrow_dict_type_and_set_attribute(self, ckpt, "client_name", "client_name", str) - narrow_dict_type_and_set_attribute(self, ckpt, "total_steps", "total_steps", int) - narrow_dict_type_and_set_attribute(self, ckpt, "reports_manager", "reports_manager", ReportsManager) + narrow_dict_type_and_set_attribute(self, client_state, "client_name", "client_name", str) + narrow_dict_type_and_set_attribute(self, client_state, "total_steps", "total_steps", int) + narrow_dict_type_and_set_attribute(self, client_state, "reports_manager", "reports_manager", ReportsManager) - assert "lr_schedulers_state" in ckpt and isinstance(ckpt["lr_schedulers_state"], dict) - assert "optimizers_state" in ckpt and isinstance(ckpt["optimizers_state"], dict) + assert "lr_schedulers_state" in client_state and isinstance(client_state["lr_schedulers_state"], dict) + assert "optimizers_state" in client_state and isinstance(client_state["optimizers_state"], dict) # Optimizer is updated in setup_client to reference model weights from server # Thus, only optimizer state (per parameter values such as momentum) # should be loaded for key, optimizer in self.optimizers.items(): - optimizer_state = ckpt["optimizers_state"][key] + optimizer_state = client_state["optimizers_state"][key] optimizer_state_dict = optimizer.state_dict() optimizer_state_dict["state"] = optimizer_state optimizer.load_state_dict(optimizer_state_dict) @@ -1243,4 +1232,6 @@ def load_client_state(self) -> None: # Schedulers initialized in setup_client to reference correct optimizers # Here we load in all other aspects of the scheduler state for key in self.lr_schedulers: - self.lr_schedulers[key].load_state_dict(ckpt["lr_schedulers_state"][key]) + self.lr_schedulers[key].load_state_dict(client_state["lr_schedulers_state"][key]) + + return True diff --git a/fl4health/clients/clipping_client.py b/fl4health/clients/clipping_client.py index 032f5754b..9c557350e 100644 --- a/fl4health/clients/clipping_client.py +++ b/fl4health/clients/clipping_client.py @@ -8,7 +8,7 @@ from flwr.common.typing import Config from numpy import linalg -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking @@ -31,7 +31,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, diff --git a/fl4health/clients/constrained_fenda_client.py b/fl4health/clients/constrained_fenda_client.py index 61cd3d87b..25e5e30a9 100644 --- a/fl4health/clients/constrained_fenda_client.py +++ b/fl4health/clients/constrained_fenda_client.py @@ -7,7 +7,7 @@ from flwr.common.logger import log from flwr.common.typing import Config -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_client import FendaClient from fl4health.losses.fenda_loss_config import ConstrainedFendaLossContainer from fl4health.model_bases.fenda_base import FendaModelWithFeatureState @@ -25,7 +25,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, loss_container: Optional[ConstrainedFendaLossContainer] = None, ) -> None: """ diff --git a/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py b/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py index 55c3a9fc9..c7438c88c 100644 --- a/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py +++ b/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py @@ -7,7 +7,7 @@ from flwr.common.logger import log from flwr.common.typing import Config, Scalar -from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule +from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient from fl4health.losses.deep_mmd_loss import DeepMmdLoss from fl4health.model_bases.feature_extractor_buffer import FeatureExtractorBuffer @@ -24,7 +24,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, deep_mmd_loss_weight: float = 10.0, feature_extraction_layers_with_size: Optional[Dict[str, int]] = None, ) -> None: diff --git a/fl4health/clients/ditto_client.py b/fl4health/clients/ditto_client.py index 12e894163..e46839a01 100644 --- a/fl4health/clients/ditto_client.py +++ b/fl4health/clients/ditto_client.py @@ -8,7 +8,7 @@ from flwr.common.typing import Config, NDArrays, Scalar from torch.optim import Optimizer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.adaptive_drift_constraint_client import AdaptiveDriftConstraintClient from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.reporting.base_reporter import BaseReporter @@ -25,7 +25,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, ) -> None: diff --git a/fl4health/clients/ensemble_client.py b/fl4health/clients/ensemble_client.py index c254a2fcb..d06bec5d2 100644 --- a/fl4health/clients/ensemble_client.py +++ b/fl4health/clients/ensemble_client.py @@ -5,7 +5,7 @@ from flwr.common.typing import Config from torch.optim import Optimizer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.ensemble_base import EnsembleModel from fl4health.utils.losses import EvaluationLosses, LossMeterType, TrainingLosses @@ -20,7 +20,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: """ This client enables the training of ensemble models in a federated manner. diff --git a/fl4health/clients/fedpm_client.py b/fl4health/clients/fedpm_client.py index be9653c53..c46220fef 100644 --- a/fl4health/clients/fedpm_client.py +++ b/fl4health/clients/fedpm_client.py @@ -4,7 +4,7 @@ import torch from flwr.common.typing import Config -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.masked_layers.masked_layers_utils import convert_to_masked_model from fl4health.parameter_exchange.fedpm_exchanger import FedPmExchanger @@ -22,7 +22,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, ) -> None: super().__init__( diff --git a/fl4health/clients/fedrep_client.py b/fl4health/clients/fedrep_client.py index ed2aafa74..338314ed7 100644 --- a/fl4health/clients/fedrep_client.py +++ b/fl4health/clients/fedrep_client.py @@ -9,7 +9,7 @@ from flwr.common.typing import Config, NDArrays, Scalar from torch.optim import Optimizer -from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule +from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.fedrep_base import FedRepModel from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel @@ -36,7 +36,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, ) -> None: super().__init__(data_path, metrics, device, loss_meter_type, checkpointer, reporters) diff --git a/fl4health/clients/fenda_client.py b/fl4health/clients/fenda_client.py index 3999c0c30..644aae513 100644 --- a/fl4health/clients/fenda_client.py +++ b/fl4health/clients/fenda_client.py @@ -4,7 +4,7 @@ import torch from flwr.common.typing import Config -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.fenda_base import FendaModel from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger @@ -20,7 +20,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, diff --git a/fl4health/clients/fenda_ditto_client.py b/fl4health/clients/fenda_ditto_client.py index 187e086a7..8775cc2ab 100644 --- a/fl4health/clients/fenda_ditto_client.py +++ b/fl4health/clients/fenda_ditto_client.py @@ -6,7 +6,7 @@ from flwr.common.logger import log from flwr.common.typing import Config, NDArrays -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.sequential_split_models import SequentiallySplitModel @@ -25,7 +25,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, freeze_global_feature_extractor: bool = False, diff --git a/fl4health/clients/flash_client.py b/fl4health/clients/flash_client.py index d3f8e1aa4..a5dda30f6 100644 --- a/fl4health/clients/flash_client.py +++ b/fl4health/clients/flash_client.py @@ -6,7 +6,7 @@ from flwr.common.logger import log from flwr.common.typing import Config, Scalar -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.utils.client import check_if_batch_is_empty_and_verify_input, move_data_to_device from fl4health.utils.config import narrow_dict_type @@ -21,7 +21,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: """ This client is used to perform client-side training associated with the Flash method described in diff --git a/fl4health/clients/instance_level_dp_client.py b/fl4health/clients/instance_level_dp_client.py index 9b093300d..56fccd156 100644 --- a/fl4health/clients/instance_level_dp_client.py +++ b/fl4health/clients/instance_level_dp_client.py @@ -5,7 +5,7 @@ from flwr.common.typing import Config from opacus import PrivacyEngine -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type @@ -25,7 +25,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, ) -> None: super().__init__( diff --git a/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py b/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py index 4f3482973..327e46482 100644 --- a/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py +++ b/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py @@ -7,7 +7,7 @@ from flwr.common.logger import log from flwr.common.typing import Config, Scalar -from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule +from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient from fl4health.losses.mkmmd_loss import MkMmdLoss from fl4health.model_bases.feature_extractor_buffer import FeatureExtractorBuffer @@ -24,7 +24,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, mkmmd_loss_weight: float = 10.0, feature_extraction_layers: Optional[Sequence[str]] = None, feature_l2_norm_weight: float = 0.0, diff --git a/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py b/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py index 8d95d0543..ca3e691f3 100644 --- a/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py +++ b/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py @@ -6,7 +6,7 @@ from flwr.common.logger import log from flwr.common.typing import Config, Scalar -from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule +from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule from fl4health.clients.mr_mtl_client import MrMtlClient from fl4health.losses.mkmmd_loss import MkMmdLoss from fl4health.model_bases.feature_extractor_buffer import FeatureExtractorBuffer @@ -22,7 +22,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, mkmmd_loss_weight: float = 10.0, feature_extraction_layers: Optional[Sequence[str]] = None, feature_l2_norm_weight: float = 0.0, diff --git a/fl4health/clients/moon_client.py b/fl4health/clients/moon_client.py index 38a4f7e04..e6da07ec8 100644 --- a/fl4health/clients/moon_client.py +++ b/fl4health/clients/moon_client.py @@ -5,7 +5,7 @@ import torch from flwr.common.logger import log -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient, Config from fl4health.losses.contrastive_loss import MoonContrastiveLoss from fl4health.model_bases.sequential_split_models import SequentiallySplitModel @@ -22,7 +22,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, temperature: float = 0.5, contrastive_weight: float = 1.0, len_old_models_buffer: int = 1, diff --git a/fl4health/clients/mr_mtl_client.py b/fl4health/clients/mr_mtl_client.py index 6ffa6fe28..3b2894d19 100644 --- a/fl4health/clients/mr_mtl_client.py +++ b/fl4health/clients/mr_mtl_client.py @@ -7,7 +7,7 @@ from flwr.common.logger import log from flwr.common.typing import Config, NDArrays, Scalar -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.adaptive_drift_constraint_client import AdaptiveDriftConstraintClient from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType, TrainingLosses @@ -22,7 +22,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, ) -> None: diff --git a/fl4health/clients/nnunet_client.py b/fl4health/clients/nnunet_client.py index f877940cc..efcd41789 100644 --- a/fl4health/clients/nnunet_client.py +++ b/fl4health/clients/nnunet_client.py @@ -20,7 +20,7 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type @@ -75,7 +75,7 @@ def __init__( progress_bar: bool = False, intermediate_client_state_dir: Optional[Path] = None, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, client_name: Optional[str] = None, ) -> None: diff --git a/fl4health/clients/partial_weight_exchange_client.py b/fl4health/clients/partial_weight_exchange_client.py index 7ae7037a2..f52233726 100644 --- a/fl4health/clients/partial_weight_exchange_client.py +++ b/fl4health/clients/partial_weight_exchange_client.py @@ -8,7 +8,7 @@ from flwr.common.logger import log from flwr.common.typing import Config, NDArrays -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger @@ -25,7 +25,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, store_initial_model: bool = False, ) -> None: diff --git a/fl4health/clients/perfcl_client.py b/fl4health/clients/perfcl_client.py index fa08abe47..e6ffdd950 100644 --- a/fl4health/clients/perfcl_client.py +++ b/fl4health/clients/perfcl_client.py @@ -4,7 +4,7 @@ import torch from flwr.common.typing import Config -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.losses.perfcl_loss import PerFclLoss from fl4health.model_bases.perfcl_base import PerFclModel @@ -23,7 +23,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, global_feature_loss_temperature: float = 0.5, local_feature_loss_temperature: float = 0.5, global_feature_contrastive_loss_weight: float = 1.0, diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index 7bdfad00b..b76d67a15 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -8,7 +8,7 @@ from flwr.common.typing import Config, NDArrays from opacus.optimizers.optimizer import DPOptimizer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.clients.instance_level_dp_client import InstanceLevelDpClient from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger @@ -35,7 +35,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, ) -> None: super().__init__( @@ -257,7 +257,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: ScaffoldClient.__init__( self, diff --git a/fl4health/servers/adaptive_constraint_servers/ditto_server.py b/fl4health/servers/adaptive_constraint_servers/ditto_server.py index f35a16303..21b568b18 100644 --- a/fl4health/servers/adaptive_constraint_servers/ditto_server.py +++ b/fl4health/servers/adaptive_constraint_servers/ditto_server.py @@ -1,9 +1,9 @@ -from typing import Optional, Sequence, Union +from typing import Sequence from flwr.common.typing import Config from flwr.server.client_manager import ClientManager -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.reporting.base_reporter import BaseReporter from fl4health.servers.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint @@ -15,7 +15,7 @@ def __init__( client_manager: ClientManager, fl_config: Config, strategy: FedAvgWithAdaptiveConstraint, - checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None, + checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, ) -> None: """ @@ -32,11 +32,13 @@ def __init__( strategy (FedAvgWithAdaptiveConstraint): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. For Ditto, the strategy must be a derivative of the FedAvgWithAdaptiveConstraint class. - checkpointer (Optional[Union[TorchCheckpointer, Sequence [TorchCheckpointer]]], optional): To be provided - if the server should perform server side checkpointing based on some criteria. If none, then no - server-side checkpointing is performed. Multiple checkpointers can also be passed in a sequence to - checkpointer based on multiple criteria. Ensure checkpoint names are different for each checkpoint - or they will overwrite on another. Defaults to None. + checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. + NOTE: For Ditto, the model shared with the server is the GLOBAL MODEL, which isn't the target of FL + training for this algorithm. However, one may still want to save this model for other purposes. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should send data to before and after each round. """ @@ -47,6 +49,6 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, ) diff --git a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py index 278c56538..08bf62cc3 100644 --- a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py +++ b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py @@ -5,7 +5,7 @@ from flwr.common.typing import Config from flwr.server.client_manager import ClientManager -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint from fl4health.reporting.base_reporter import BaseReporter @@ -20,7 +20,7 @@ def __init__( fl_config: Config, strategy: FedAvgWithAdaptiveConstraint, model: Optional[nn.Module] = None, - checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None, + checkpointer: Optional[Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]]] = None, reporters: Sequence[BaseReporter] | None = None, ) -> None: """ diff --git a/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py b/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py index b8da665c5..0ad2fd573 100644 --- a/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py +++ b/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py @@ -3,7 +3,8 @@ from flwr.common.typing import Config from flwr.server.client_manager import ClientManager -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.reporting.base_reporter import BaseReporter from fl4health.servers.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint @@ -15,7 +16,7 @@ def __init__( client_manager: ClientManager, fl_config: Config, strategy: FedAvgWithAdaptiveConstraint, - checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None, + checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, ) -> None: """ @@ -32,11 +33,13 @@ def __init__( strategy (FedAvgWithAdaptiveConstraint): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. For MR-MTL, the strategy must be a derivative of the FedAvgWithAdaptiveConstraint class. - checkpointer (Optional[Union[TorchCheckpointer, Sequence [TorchCheckpointer]]], optional): To be provided - if the server should perform server side checkpointing based on some criteria. If none, then no - server-side checkpointing is performed. Multiple checkpointers can also be passed in a sequence to - checkpointer based on multiple criteria. Ensure checkpoint names are different for each checkpoint - or they will overwrite on another. Defaults to None. + checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. + NOTE: For MR-MTL, the server model is an aggregation of the personal models, which isn't the target of + FL training for this algorithm. However, one may still want to save this model for other purposes. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should send data to before and after each round. """ @@ -47,6 +50,6 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, ) diff --git a/fl4health/servers/base_server.py b/fl4health/servers/base_server.py index cb651eb46..0e071e14a 100644 --- a/fl4health/servers/base_server.py +++ b/fl4health/servers/base_server.py @@ -1,12 +1,11 @@ import datetime from logging import DEBUG, ERROR, INFO, WARNING from pathlib import Path -from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import torch.nn as nn from flwr.common import EvaluateRes, Parameters from flwr.common.logger import log -from flwr.common.parameter import parameters_to_ndarrays from flwr.common.typing import Code, Config, GetParametersIns, Scalar from flwr.server.client_manager import ClientManager from flwr.server.client_proxy import ClientProxy @@ -14,8 +13,7 @@ from flwr.server.server import EvaluateResultsAndFailures, FitResultsAndFailures, Server, evaluate_clients from flwr.server.strategy import Strategy -from fl4health.checkpointing.checkpointer import PerRoundCheckpointer, TorchCheckpointer -from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.reporting.base_reporter import BaseReporter from fl4health.reporting.reports_manager import ReportsManager from fl4health.servers.polling import poll_clients @@ -26,8 +24,6 @@ from fl4health.utils.random import generate_hash from fl4health.utils.typing import EvaluateFailures, FitFailures -ExchangerType = TypeVar("ExchangerType", bound=ParameterExchanger) - class FlServer(Server): def __init__( @@ -36,10 +32,7 @@ def __init__( fl_config: Config, strategy: Optional[Strategy] = None, reporters: Sequence[BaseReporter] | None = None, - model: nn.Module | None = None, - checkpointer: Union[TorchCheckpointer, Sequence[TorchCheckpointer]] | None = None, - parameter_exchanger: ExchangerType | None = None, - intermediate_server_state_dir: Path | None = None, + checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, @@ -57,31 +50,19 @@ def __init__( strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. If None the strategy is FedAvg as set by the flwr Server. Defaults to None. - reporters (Sequence[BaseReporter] | None, optional): sequence of FL4Health reporters which the server + reporters (Sequence[BaseReporter] | None, optional): sequence of FL4Health reporters which the server should send data to before and after each round. Defaults to None. - model (Optional[nn.Module]): This is the torch model to be checkpointed. It will be hydrated by the - _hydrate_model_for_checkpointing function so that it has the proper weights to be saved. If no model - is defined and checkpointing is attempted an error will throw. Defaults to None. - checkpointer (Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]], optional): To be provided - if the server should perform server side checkpointing based on some criteria. If none, then no - server-side checkpointing is performed. Multiple checkpointers can also be passed in a sequence to - checkpointer based on multiple criterion. Ensure checkpoint names are different for each checkpoint or - they will overwrite on another. Defaults to None. - parameter_exchanger (Optional[ExchangerType], optional): A parameter exchanger used to facilitate - server-side model checkpointing if a checkpointer has been defined. If not provided then checkpointing - will not be done unless the _hydrate_model_for_checkpointing function is overridden. Because the - server only sees numpy arrays, the parameter exchanger is used to insert the numpy arrays into a - provided model. Defaults to None. - intermediate_server_state_dir (Path): A directory to store and load state from for the server - during an FL experiment. This allows for the saving of server state in case federated training is - interrupted and needs to be restarted from the same point. If none, then no state is saved during each - server round. Defaults to None. - on_init_parameters_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): - Function used to configure how one asks a client to provide parameters from which to initialize all - other clients by providing a Config dictionary. If this is none, then a blank config is sent with the - parameter request (which is default behavior for flower servers). Defaults to None. - server_name (Optional[str], optional): An optional string name to uniquely identify server. - Defaults to None. + checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. @@ -89,28 +70,18 @@ def __init__( super().__init__(client_manager=client_manager, strategy=strategy) self.fl_config = fl_config - self.server_model = model + if checkpoint_and_state_module is not None: + self.checkpoint_and_state_module = checkpoint_and_state_module + else: + # Define a default module that does nothing. + self.checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=None, parameter_exchanger=None, model_checkpointers=None, state_checkpointer=None + ) self.on_init_parameters_config_fn = on_init_parameters_config_fn - self.checkpointer = [checkpointer] if isinstance(checkpointer, TorchCheckpointer) else checkpointer - # To facilitate model rehydration from server-side state for checkpointing - self.parameter_exchanger = parameter_exchanger self.server_name = server_name if server_name is not None else generate_hash() + self.state_checkpoint_name = f"server_{self.server_name}_state.pt" self.accept_failures = accept_failures - self.per_round_checkpointer: PerRoundCheckpointer | None - - if intermediate_server_state_dir is not None: - log( - WARNING, - "intermediate_server_state_dir is not None. Creating PerRoundCheckpointer. This functionality is " - "still experimental and only supported for BasicClient and NnunetClient currently.", - ) - self.per_round_checkpointer = PerRoundCheckpointer( - intermediate_server_state_dir, Path(f"{self.server_name}.ckpt") - ) - else: - self.per_round_checkpointer = None - self.current_round: int self.history: History @@ -164,16 +135,14 @@ def fit_with_per_round_checkpointing(self, num_rounds: int, timeout: Optional[fl metrics computed during training and validation. The second element of the tuple is the elapsed time in seconds. """ - # Initialize parameters - log(INFO, "Initializing global parameters") - assert self.per_round_checkpointer is not None - - # if checkpoint exists, update history, server round and model accordingly - if self.per_round_checkpointer.checkpoint_exists(): - self._load_server_state() + # Attempt to load the server state if it exists. If the state checkpoint exists, update history, server + # round and model accordingly + state_load_success = self._load_server_state() + if state_load_success: + log(INFO, "Server state checkpoint successfully loaded.") else: - log(INFO, "Initializing server state") + log(INFO, "Initializing server state and global parameters") self.parameters = self._get_initial_parameters(server_round=0, timeout=timeout) self.history = History() self.current_round = 1 @@ -231,7 +200,6 @@ def fit_with_per_round_checkpointing(self, num_rounds: int, timeout: Optional[fl self.current_round += 1 # Save checkpoint after training and testing - self._hydrate_model_for_checkpointing() self._save_server_state() # Bookkeeping @@ -266,7 +234,7 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float self.update_before_fit(num_rounds, timeout) - if self.per_round_checkpointer is not None: + if self.checkpoint_and_state_module.state_checkpointer is not None: history, elapsed_time = self.fit_with_per_round_checkpointing(num_rounds, timeout) else: history, elapsed_time = super().fit(num_rounds, timeout) @@ -386,7 +354,7 @@ def evaluate_round( self._terminate_after_unacceptable_failures(timeout) if loss_aggregated: - self._maybe_checkpoint(loss_aggregated, metrics_aggregated, server_round) + self.checkpoint_and_state_module.maybe_checkpoint(self.parameters, loss_aggregated, metrics_aggregated) # Report evaluation results report_data = { "val - loss - aggregated": loss_aggregated, @@ -415,72 +383,49 @@ def _log_fl_config(self) -> None: if not isinstance(config_value, bytes): log(INFO, f"Key: {config_key} Value: {config_value!r}") - def _hydrate_model_for_checkpointing(self) -> None: - """ - This function is used as a means of saving the server-side model after aggregation in the FL training - trajectory. In the current implementation, the server only holds numpy arrays. Without knowledge of a model - architecture to which the arrays correspond. Thus, in the default implementation, we require that a torch - architecture have been provided (self.server_model) and a parameter exchanger (self.parameter_exchanger) be - provided which handles mapping these numpy arrays into the architecture properly. - - This function may be overridden if different behavior is desired. - - NOTE: This function stores the weights directly in the self.server_model attribute - """ - - assert self.server_model is not None, ( - "Model hydration has been called but no server_model is defined to hydrate. The functionality of " - "_hydrate_model_for_checkpointing can be overridden if checkpointing without a torch architecture is " - "possible and desired" - ) - assert self.parameter_exchanger is not None, ( - "Model hydration has been called but no parameter_exchanger is defined to hydrate. The functionality of " - "_hydrate_model_for_checkpointing can be overridden if checkpointing without a parameter exchanger is " - "possible and desired" - ) - model_ndarrays = parameters_to_ndarrays(self.parameters) - self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model) - def _save_server_state(self) -> None: """ Save server checkpoint consisting of model, history, server round, metrics reporter and server name. This - method can be overridden to add any necessary state to the checkpoint. + method can be overridden to add any necessary state to the checkpoint. The model will be injected into the + ckpt state by the checkpoint module """ - - assert self.per_round_checkpointer is not None - - ckpt = { - "model": self.server_model, + other_state_to_save = { "history": self.history, "current_round": self.current_round, "reports_manager": self.reports_manager, "server_name": self.server_name, } - self.per_round_checkpointer.save_checkpoint(ckpt) - - log(INFO, f"Saving server state to checkpoint at {self.per_round_checkpointer.checkpoint_path}") + self.checkpoint_and_state_module.save_state( + state_checkpoint_name=self.state_checkpoint_name, + server_parameters=self.parameters, + other_state=other_state_to_save, + ) - def _load_server_state(self) -> None: + def _load_server_state(self) -> bool: """ Load server checkpoint consisting of model, history, server name, current round and metrics reporter. The method can be overridden to add any necessary state when loading the checkpoint. """ - assert self.per_round_checkpointer is not None and self.per_round_checkpointer.checkpoint_exists() - ckpt = self.per_round_checkpointer.load_checkpoint() + # Attempt to load the server state if it exists. This variable will be None if it does not. + server_state = self.checkpoint_and_state_module.maybe_load_state(self.state_checkpoint_name) - log(INFO, f"Loading server state from checkpoint at {self.per_round_checkpointer.checkpoint_path}") + if server_state is None: + return False - narrow_dict_type_and_set_attribute(self, ckpt, "server_name", "server_name", str) - narrow_dict_type_and_set_attribute(self, ckpt, "current_round", "current_round", int) - narrow_dict_type_and_set_attribute(self, ckpt, "reports_manager", "reports_manager", ReportsManager) - narrow_dict_type_and_set_attribute(self, ckpt, "history", "history", History) - narrow_dict_type_and_set_attribute(self, ckpt, "model", "parameters", nn.Module, func=get_all_model_parameters) + narrow_dict_type_and_set_attribute(self, server_state, "server_name", "server_name", str) + narrow_dict_type_and_set_attribute(self, server_state, "current_round", "current_round", int) + narrow_dict_type_and_set_attribute(self, server_state, "reports_manager", "reports_manager", ReportsManager) + narrow_dict_type_and_set_attribute(self, server_state, "history", "history", History) + narrow_dict_type_and_set_attribute( + self, server_state, "model", "parameters", nn.Module, func=get_all_model_parameters + ) # Needed for when _hydrate_model_for_checkpointing is called - narrow_dict_type_and_set_attribute(self, ckpt, "model", "server_model", nn.Module) + narrow_dict_type_and_set_attribute(self, server_state, "model", "server_model", nn.Module) - self.parameters = get_all_model_parameters(ckpt["model"]) + self.parameters = get_all_model_parameters(server_state["model"]) + return True def _terminate_after_unacceptable_failures(self, timeout: Optional[float]) -> None: assert not self.accept_failures diff --git a/fl4health/servers/client_level_dp_fed_avg_server.py b/fl4health/servers/client_level_dp_fed_avg_server.py index 4430822a3..ba4122d05 100644 --- a/fl4health/servers/client_level_dp_fed_avg_server.py +++ b/fl4health/servers/client_level_dp_fed_avg_server.py @@ -8,7 +8,7 @@ from flwr.server.client_manager import ClientManager from flwr.server.history import History -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.client_managers.fixed_without_replacement_manager import FixedSamplingByFractionClientManager from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager from fl4health.privacy.fl_accountants import ( @@ -29,7 +29,7 @@ def __init__( strategy: ClientLevelDPFedAvgM, server_noise_multiplier: float, num_server_rounds: int, - checkpointer: Optional[TorchCheckpointer] = None, + checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, delta: Optional[int] = None, accept_failures: bool = True, @@ -48,9 +48,11 @@ def __init__( client updates and other information potentially sent by the participating clients. server_noise_multiplier (float): Magnitude of noise added to the weights aggregation process by the server. num_server_rounds (int): Number of rounds of FL training carried out by the server. - checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform - server side checkpointing based on some criteria. If none, then no server-side checkpointing is - performed. Defaults to None. + checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should send data to before and after each round. delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to @@ -63,7 +65,7 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, accept_failures=accept_failures, ) diff --git a/fl4health/servers/fedpm_server.py b/fl4health/servers/fedpm_server.py index fa12fbf4a..8df75ff0e 100644 --- a/fl4health/servers/fedpm_server.py +++ b/fl4health/servers/fedpm_server.py @@ -6,7 +6,7 @@ from flwr.server.client_manager import ClientManager from flwr.server.server import FitResultsAndFailures -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.reporting.base_reporter import BaseReporter from fl4health.servers.base_server import FlServer from fl4health.strategies.fedpm import FedPm @@ -18,7 +18,7 @@ def __init__( client_manager: ClientManager, fl_config: Config, strategy: FedPm, - checkpointer: Optional[TorchCheckpointer] = None, + checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, reset_frequency: int = 1, reporters: Sequence[BaseReporter] | None = None, accept_failures: bool = True, @@ -36,9 +36,11 @@ def __init__( NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. This strategy must be of SCAFFOLD type. - checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform - server side checkpointing based on some criteria. If none, then no server-side checkpointing is - performed. Defaults to None. + checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. reset_frequency (int): Determines the frequency with which the beta priors are reset. Defaults to 1. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should send data to before and after each round. @@ -51,7 +53,7 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, accept_failures=accept_failures, ) diff --git a/fl4health/servers/model_merge_server.py b/fl4health/servers/model_merge_server.py index 2e1987452..e12452bfb 100644 --- a/fl4health/servers/model_merge_server.py +++ b/fl4health/servers/model_merge_server.py @@ -12,7 +12,7 @@ from flwr.server.server import Server from flwr.server.strategy import Strategy -from fl4health.checkpointing.checkpointer import LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import LatestTorchModuleCheckpointer from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.reporting.base_reporter import BaseReporter from fl4health.reporting.reports_manager import ReportsManager @@ -27,7 +27,7 @@ def __init__( client_manager: ClientManager, strategy: Optional[Strategy] = None, server_model: Optional[nn.Module] = None, - checkpointer: Optional[LatestTorchCheckpointer] = None, + checkpointer: Optional[LatestTorchModuleCheckpointer] = None, parameter_exchanger: Optional[ParameterExchanger] = None, reporters: Sequence[BaseReporter] | None = None, server_name: Optional[str] = None, diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index 628fdfd8f..7022c18b9 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -14,7 +14,7 @@ from flwr.server.history import History from flwr.server.strategy import Strategy -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer from fl4health.reporting.base_reporter import BaseReporter from fl4health.reporting.reports_manager import ReportsManager from fl4health.servers.base_server import ExchangerType, FlServer @@ -64,7 +64,7 @@ def __init__( on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]], model: nn.Module | None = None, strategy: Strategy | None = None, - checkpointer: TorchCheckpointer | Sequence[TorchCheckpointer] | None = None, + checkpointer: TorchModuleCheckpointer | Sequence[TorchModuleCheckpointer] | None = None, reporters: Sequence[BaseReporter] | None = None, parameter_exchanger: ExchangerType | None = None, intermediate_server_state_dir: Path | None = None, diff --git a/fl4health/servers/scaffold_server.py b/fl4health/servers/scaffold_server.py index d484875b7..e3988b7f4 100644 --- a/fl4health/servers/scaffold_server.py +++ b/fl4health/servers/scaffold_server.py @@ -9,7 +9,7 @@ from flwr.server.history import History from flwr.server.server import fit_clients -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.server_module import ScaffoldServerCheckpointAndStateModule from fl4health.reporting.base_reporter import BaseReporter from fl4health.servers.base_server import FlServer from fl4health.servers.instance_level_dp_server import InstanceLevelDpServer @@ -22,7 +22,7 @@ def __init__( client_manager: ClientManager, fl_config: Config, strategy: Scaffold, - checkpointer: Optional[TorchCheckpointer] = None, + checkpoint_and_state_module: ScaffoldServerCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, warm_start: bool = False, accept_failures: bool = True, @@ -41,9 +41,11 @@ def __init__( strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. This strategy must be of SCAFFOLD type. - checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform server - side checkpointing based on some criteria. If none, then no server-side checkpointing is performed. - Defaults to None. + checkpoint_and_state_module (ScaffoldServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should send data to before and after each round. warm_start (bool, optional): Whether or not to initialize control variates of each client as local @@ -59,7 +61,7 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, accept_failures=accept_failures, ) @@ -172,7 +174,7 @@ def __init__( local_epochs: Optional[int] = None, local_steps: Optional[int] = None, delta: Optional[float] = None, - checkpointer: Optional[TorchCheckpointer] = None, + checkpoint_and_state_module: ScaffoldServerCheckpointAndStateModule | None = None, warm_start: bool = False, reporters: Sequence[BaseReporter] | None = None, ) -> None: @@ -196,9 +198,11 @@ def __init__( strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. This strategy must be of SCAFFOLD type. - checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform - server side checkpointing based on some criteria. If none, then no server-side checkpointing is - performed. Defaults to None. + checkpoint_and_state_module (ScaffoldServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. warm_start (bool, optional): Whether or not to initialize control variates of each client as local gradients. The clients will perform a training pass (without updating the weights) in order to provide a "warm" estimate of the SCAFFOLD control variates. If false, variates are initialized to 0. @@ -217,7 +221,7 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, warm_start=warm_start, reporters=reporters, ) diff --git a/fl4health/servers/tabular_feature_alignment_server.py b/fl4health/servers/tabular_feature_alignment_server.py index afa98c502..6188a6032 100644 --- a/fl4health/servers/tabular_feature_alignment_server.py +++ b/fl4health/servers/tabular_feature_alignment_server.py @@ -9,7 +9,7 @@ from flwr.server.client_manager import ClientManager from flwr.server.history import History -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer from fl4health.feature_alignment.constants import ( CURRENT_SERVER_ROUND, FEATURE_INFO, @@ -31,7 +31,7 @@ def __init__( config: Config, initialize_parameters: Callable[..., Parameters], strategy: BasicFedAvg, - checkpointer: Optional[TorchCheckpointer] = None, + checkpointer: Optional[TorchModuleCheckpointer] = None, tabular_features_source_of_truth: Optional[TabularFeaturesInfoEncoder] = None, reporters: Sequence[BaseReporter] | None = None, accept_failures: bool = True, diff --git a/fl4health/utils/typing.py b/fl4health/utils/typing.py index a298bd86d..ffb098eb0 100644 --- a/fl4health/utils/typing.py +++ b/fl4health/utils/typing.py @@ -1,7 +1,7 @@ import logging from collections.abc import Callable from enum import Enum -from typing import List, Tuple, Union +from typing import List, Tuple, TypeVar, Union import torch import torch.nn as nn @@ -9,6 +9,8 @@ from flwr.common.typing import NDArrays from flwr.server.client_proxy import ClientProxy +from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger + TorchInputType = torch.Tensor | dict[str, torch.Tensor] TorchTargetType = torch.Tensor | dict[str, torch.Tensor] TorchPredType = dict[str, torch.Tensor] @@ -19,6 +21,8 @@ FitFailures = List[Union[Tuple[ClientProxy, FitRes], BaseException]] EvaluateFailures = List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] +ExchangerType = TypeVar("ExchangerType", bound=ParameterExchanger) + class LogLevel(Enum): NOTSET = logging.NOTSET diff --git a/research/ag_news/dynamic_layer_exchange/client.py b/research/ag_news/dynamic_layer_exchange/client.py index d2b04f0b6..844b8a972 100644 --- a/research/ag_news/dynamic_layer_exchange/client.py +++ b/research/ag_news/dynamic_layer_exchange/client.py @@ -14,8 +14,8 @@ from torch.utils.data import DataLoader from transformers import BertForSequenceClassification -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import TorchInputType from fl4health.clients.partial_weight_exchange_client import PartialWeightExchangeClient from fl4health.parameter_exchange.layer_exchanger import DynamicLayerExchanger @@ -38,7 +38,7 @@ def __init__( exchange_percentage: float, norm_threshold: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, store_initial_model: bool = True, ) -> None: @@ -180,7 +180,9 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ # Checkpointing checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = BertDynamicLayerExchangeClient( data_path, diff --git a/research/ag_news/sparse_tensor_exchange/client.py b/research/ag_news/sparse_tensor_exchange/client.py index 97010afd5..451b999a9 100644 --- a/research/ag_news/sparse_tensor_exchange/client.py +++ b/research/ag_news/sparse_tensor_exchange/client.py @@ -14,8 +14,8 @@ from torch.utils.data import DataLoader from transformers import BertForSequenceClassification -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import TorchInputType from fl4health.clients.partial_weight_exchange_client import PartialWeightExchangeClient from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger @@ -37,7 +37,7 @@ def __init__( learning_rate: float, sparsity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, store_initial_model: bool = True, ) -> None: @@ -149,7 +149,9 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ # Checkpointing checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = BertSparseTensorExchangeClient( data_path, diff --git a/research/cifar10/adaptive_pfl/ditto/client.py b/research/cifar10/adaptive_pfl/ditto/client.py index 7759d93ca..fce7a2236 100644 --- a/research/cifar10/adaptive_pfl/ditto/client.py +++ b/research/cifar10/adaptive_pfl/ditto/client.py @@ -13,8 +13,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType @@ -34,7 +34,7 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -138,14 +138,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/adaptive_pfl/fedprox/client.py b/research/cifar10/adaptive_pfl/fedprox/client.py index 720cc0011..72e3736f6 100644 --- a/research/cifar10/adaptive_pfl/fedprox/client.py +++ b/research/cifar10/adaptive_pfl/fedprox/client.py @@ -13,8 +13,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fed_prox_client import FedProxClient from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType @@ -34,7 +34,7 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -142,14 +142,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/adaptive_pfl/fedprox/server.py b/research/cifar10/adaptive_pfl/fedprox/server.py index 8450e109f..16667615f 100644 --- a/research/cifar10/adaptive_pfl/fedprox/server.py +++ b/research/cifar10/adaptive_pfl/fedprox/server.py @@ -9,7 +9,7 @@ from flwr.common.typing import Config from flwr.server.client_manager import SimpleClientManager -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.servers.adaptive_constraint_servers.fedprox_server import FedProxServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -55,8 +55,8 @@ def main( best_checkpoint_name = "server_best_model.pkl" last_checkpoint_name = "server_last_model.pkl" checkpointer = [ - BestLossTorchCheckpointer(checkpoint_dir, best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, last_checkpoint_name), ] client_manager = SimpleClientManager() @@ -93,7 +93,7 @@ def main( ) log(INFO, "Training Complete") - assert isinstance(checkpointer[0], BestLossTorchCheckpointer) + assert isinstance(checkpointer[0], BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer[0].best_score}") # Shutdown the server gracefully diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/client.py b/research/cifar10/adaptive_pfl/fenda_ditto/client.py index 947fad771..670e78eb8 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/client.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/client.py @@ -12,8 +12,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_ditto_client import FendaDittoClient from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.sequential_split_models import SequentiallySplitModel @@ -35,7 +35,7 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, freeze_global_feature_extractor: bool = False, ) -> None: super().__init__( @@ -152,14 +152,14 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/adaptive_pfl/mrmtl/client.py b/research/cifar10/adaptive_pfl/mrmtl/client.py index f0b38fe0e..6899072c8 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/client.py +++ b/research/cifar10/adaptive_pfl/mrmtl/client.py @@ -13,8 +13,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.mr_mtl_client import MrMtlClient from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType @@ -34,7 +34,7 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -136,14 +136,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/ditto/client.py b/research/cifar10/ditto/client.py index f4b44d12c..2224bfe14 100644 --- a/research/cifar10/ditto/client.py +++ b/research/cifar10/ditto/client.py @@ -13,8 +13,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data @@ -36,7 +36,7 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, use_partitioned_data: bool = False, ) -> None: super().__init__( @@ -199,14 +199,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/ditto_deep_mmd/client.py b/research/cifar10/ditto_deep_mmd/client.py index dca688a68..eac5c8f4d 100644 --- a/research/cifar10/ditto_deep_mmd/client.py +++ b/research/cifar10/ditto_deep_mmd/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.deep_mmd_clients.ditto_deep_mmd_client import DittoDeepMmdClient from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data @@ -44,7 +44,7 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, deep_mmd_loss_weight: float = 10, deep_mmd_loss_depth: int = 1, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, use_partitioned_data: bool = False, ) -> None: feature_extraction_layers_with_size = OrderedDict(list(BASELINE_LAYERS.items())[-1 * deep_mmd_loss_depth :]) @@ -226,14 +226,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/ditto_mkmmd/client.py b/research/cifar10/ditto_mkmmd/client.py index 7f5d5f034..44f3a21bd 100644 --- a/research/cifar10/ditto_mkmmd/client.py +++ b/research/cifar10/ditto_mkmmd/client.py @@ -13,8 +13,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.mkmmd_clients.ditto_mkmmd_client import DittoMkMmdClient from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data @@ -42,7 +42,7 @@ def __init__( feature_l2_norm_weight: float = 1, mkmmd_loss_depth: int = 1, beta_global_update_interval: int = 20, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, use_partitioned_data: bool = False, ) -> None: super().__init__( @@ -244,14 +244,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/fed_dgga_pfl/ditto/client.py b/research/cifar10/fed_dgga_pfl/ditto/client.py index 7759d93ca..fce7a2236 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/ditto/client.py @@ -13,8 +13,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType @@ -34,7 +34,7 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -138,14 +138,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/fed_dgga_pfl/fenda/client.py b/research/cifar10/fed_dgga_pfl/fenda/client.py index 4e5549b69..e75493055 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda/client.py @@ -12,8 +12,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_client import FendaClient from fl4health.model_bases.fenda_base import FendaModel from fl4health.utils.config import narrow_dict_type @@ -34,7 +34,7 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -136,14 +136,14 @@ def get_model(self, config: Config) -> FendaModel: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py index 19f891ba3..5ac46d079 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py @@ -12,8 +12,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_ditto_client import FendaDittoClient from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.sequential_split_models import SequentiallySplitModel @@ -35,7 +35,7 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, freeze_global_feature_extractor: bool = False, ) -> None: super().__init__( @@ -152,14 +152,14 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/fedavg/client.py b/research/cifar10/fedavg/client.py index c206f9c40..168b9d1f8 100644 --- a/research/cifar10/fedavg/client.py +++ b/research/cifar10/fedavg/client.py @@ -13,8 +13,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data @@ -36,7 +36,7 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, use_partitioned_data: bool = False, ) -> None: super().__init__( @@ -198,14 +198,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) diff --git a/research/cifar10/fedavg/server.py b/research/cifar10/fedavg/server.py index d9a78514e..31459f7b5 100644 --- a/research/cifar10/fedavg/server.py +++ b/research/cifar10/fedavg/server.py @@ -10,7 +10,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAvg -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer from fl4health.utils.config import load_config @@ -49,8 +49,8 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ best_checkpoint_name = "server_best_model.pkl" last_checkpoint_name = "server_last_model.pkl" checkpointer = [ - BestLossTorchCheckpointer(checkpoint_dir, best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, last_checkpoint_name), ] client_manager = SimpleClientManager() # Initializing the model on the server side @@ -84,7 +84,7 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), ) - assert isinstance(checkpointer[0], BestLossTorchCheckpointer) + assert isinstance(checkpointer[0], BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer[0].best_score}") # Shutdown the server gracefully diff --git a/research/flamby/fed_heart_disease/apfl/client.py b/research/flamby/fed_heart_disease/apfl/client.py index 6862ef25b..50b6de939 100644 --- a/research/flamby/fed_heart_disease/apfl/client.py +++ b/research/flamby/fed_heart_disease/apfl/client.py @@ -13,8 +13,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import ApflModule from fl4health.utils.losses import LossMeterType @@ -32,7 +32,7 @@ def __init__( learning_rate: float, alpha_learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -128,11 +128,11 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedHeartDiseaseApflClient( data_path=args.dataset_dir, diff --git a/research/flamby/fed_heart_disease/ditto/client.py b/research/flamby/fed_heart_disease/ditto/client.py index 292cfdfdf..fbddce0f6 100644 --- a/research/flamby/fed_heart_disease/ditto/client.py +++ b/research/flamby/fed_heart_disease/ditto/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -32,7 +32,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -136,12 +136,12 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedHeartDiseaseDittoClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_heart_disease/fedadam/client.py b/research/flamby/fed_heart_disease/fedadam/client.py index 825049603..a59222c58 100644 --- a/research/flamby/fed_heart_disease/fedadam/client.py +++ b/research/flamby/fed_heart_disease/fedadam/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -31,7 +31,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -113,7 +113,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedHeartDiseaseFedAdamClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_heart_disease/fedadam/server.py b/research/flamby/fed_heart_disease/fedadam/server.py index cfcf2e012..5fe3e3071 100644 --- a/research/flamby/fed_heart_disease/fedadam/server.py +++ b/research/flamby/fed_heart_disease/fedadam/server.py @@ -10,7 +10,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAdam -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -33,9 +33,9 @@ def main( federated_checkpointing: bool = config.get("federated_checkpointing", True) log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}") checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = SimpleClientManager() @@ -67,7 +67,7 @@ def main( ) if federated_checkpointing: - assert isinstance(checkpointer, BestLossTorchCheckpointer) + assert isinstance(checkpointer, BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_score}") # Shutdown the server gracefully diff --git a/research/flamby/fed_heart_disease/fedavg/client.py b/research/flamby/fed_heart_disease/fedavg/client.py index 894157270..1ca53e5a2 100644 --- a/research/flamby/fed_heart_disease/fedavg/client.py +++ b/research/flamby/fed_heart_disease/fedavg/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -31,7 +31,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -113,7 +113,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedHeartDiseaseFedAvgClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_heart_disease/fedavg/server.py b/research/flamby/fed_heart_disease/fedavg/server.py index c433d3b23..a95b205ed 100644 --- a/research/flamby/fed_heart_disease/fedavg/server.py +++ b/research/flamby/fed_heart_disease/fedavg/server.py @@ -10,7 +10,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAvg -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -31,9 +31,9 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ federated_checkpointing: bool = config.get("federated_checkpointing", True) log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}") checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = SimpleClientManager() @@ -65,7 +65,7 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ ) if federated_checkpointing: - assert isinstance(checkpointer, BestLossTorchCheckpointer) + assert isinstance(checkpointer, BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_score}") # Shutdown the server gracefully diff --git a/research/flamby/fed_heart_disease/fedper/client.py b/research/flamby/fed_heart_disease/fedper/client.py index 60b9b6897..64e81f0ef 100644 --- a/research/flamby/fed_heart_disease/fedper/client.py +++ b/research/flamby/fed_heart_disease/fedper/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.moon_client import MoonClient from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger @@ -35,7 +35,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -140,11 +140,11 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedHeartDiseaseFedPerClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_heart_disease/fedprox/client.py b/research/flamby/fed_heart_disease/fedprox/client.py index ce31f8ab1..288664161 100644 --- a/research/flamby/fed_heart_disease/fedprox/client.py +++ b/research/flamby/fed_heart_disease/fedprox/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fed_prox_client import FedProxClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -31,7 +31,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -113,7 +113,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedHeartDiseaseFedProxClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_heart_disease/fedprox/server.py b/research/flamby/fed_heart_disease/fedprox/server.py index 646b8d179..9eff20859 100644 --- a/research/flamby/fed_heart_disease/fedprox/server.py +++ b/research/flamby/fed_heart_disease/fedprox/server.py @@ -9,7 +9,7 @@ from flwr.common.logger import log from flwr.server.client_manager import SimpleClientManager -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.servers.adaptive_constraint_servers.fedprox_server import FedProxServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -31,9 +31,9 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub federated_checkpointing: bool = config.get("federated_checkpointing", True) log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}") checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = SimpleClientManager() @@ -66,7 +66,7 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub ) if federated_checkpointing: - assert isinstance(checkpointer, BestLossTorchCheckpointer) + assert isinstance(checkpointer, BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_score}") # Shutdown the server gracefully diff --git a/research/flamby/fed_heart_disease/fenda/client.py b/research/flamby/fed_heart_disease/fenda/client.py index 63787db6a..45e4e433b 100644 --- a/research/flamby/fed_heart_disease/fenda/client.py +++ b/research/flamby/fed_heart_disease/fenda/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_client import FendaClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -33,7 +33,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -133,11 +133,11 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedHeartDiseaseFendaClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_heart_disease/moon/client.py b/research/flamby/fed_heart_disease/moon/client.py index 5da0aca18..ae7e4a3d3 100644 --- a/research/flamby/fed_heart_disease/moon/client.py +++ b/research/flamby/fed_heart_disease/moon/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.moon_client import MoonClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -34,7 +34,7 @@ def __init__( learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, contrastive_weight: float = 10, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -134,7 +134,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedHeartDiseaseMoonClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_heart_disease/moon/server.py b/research/flamby/fed_heart_disease/moon/server.py index 44f8c6a53..2748458bb 100644 --- a/research/flamby/fed_heart_disease/moon/server.py +++ b/research/flamby/fed_heart_disease/moon/server.py @@ -9,7 +9,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAvg -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -32,9 +32,9 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ federated_checkpointing: bool = config.get("federated_checkpointing", True) log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}") checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = SimpleClientManager() @@ -66,7 +66,7 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ ) if federated_checkpointing: - assert isinstance(checkpointer, BestLossTorchCheckpointer) + assert isinstance(checkpointer, BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_score}") # Shutdown the server gracefully diff --git a/research/flamby/fed_heart_disease/perfcl/client.py b/research/flamby/fed_heart_disease/perfcl/client.py index 6862d142a..749eb7093 100644 --- a/research/flamby/fed_heart_disease/perfcl/client.py +++ b/research/flamby/fed_heart_disease/perfcl/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.perfcl_client import PerFclClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -33,7 +33,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, mu: float = 10.0, gamma: float = 10.0, ) -> None: @@ -151,11 +151,11 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedHeartDiseasePerFclClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_heart_disease/scaffold/client.py b/research/flamby/fed_heart_disease/scaffold/client.py index a5d7aff20..0dc3d5f1e 100644 --- a/research/flamby/fed_heart_disease/scaffold/client.py +++ b/research/flamby/fed_heart_disease/scaffold/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.scaffold_client import ScaffoldClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -31,7 +31,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -113,7 +113,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedHeartDiseaseScaffoldClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_heart_disease/scaffold/server.py b/research/flamby/fed_heart_disease/scaffold/server.py index e7b0a8654..35b5a2e3c 100644 --- a/research/flamby/fed_heart_disease/scaffold/server.py +++ b/research/flamby/fed_heart_disease/scaffold/server.py @@ -8,12 +8,15 @@ from flamby.datasets.fed_heart_disease import Baseline from flwr.common.logger import log -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import ScaffoldServerCheckpointAndStateModule from fl4health.client_managers.fixed_without_replacement_manager import FixedSamplingByFractionClientManager +from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking +from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates +from fl4health.servers.scaffold_server import ScaffoldServer from fl4health.strategies.scaffold import Scaffold from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn -from research.flamby.flamby_servers.scaffold_server import ScaffoldServer from research.flamby.utils import fit_config, get_initial_model_info_with_control_variates, summarize_model_info @@ -32,15 +35,22 @@ def main( federated_checkpointing: bool = config.get("federated_checkpointing", True) log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}") checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = FixedSamplingByFractionClientManager() model = Baseline() summarize_model_info(model) + model_size = len(model.state_dict()) + checkpoint_and_state_module = ScaffoldServerCheckpointAndStateModule( + model=model, + parameter_exchanger=FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)), + model_checkpointers=checkpointer, + ) + initial_parameters, initial_control_variates = get_initial_model_info_with_control_variates(model) strategy = Scaffold( @@ -59,7 +69,10 @@ def main( ) server = ScaffoldServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( @@ -69,7 +82,7 @@ def main( ) if federated_checkpointing: - assert isinstance(checkpointer, BestLossTorchCheckpointer) + assert isinstance(checkpointer, BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_score}") # Shutdown the server gracefully diff --git a/research/flamby/fed_isic2019/apfl/client.py b/research/flamby/fed_isic2019/apfl/client.py index 6c622ab72..011115d25 100644 --- a/research/flamby/fed_isic2019/apfl/client.py +++ b/research/flamby/fed_isic2019/apfl/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import ApflModule from fl4health.utils.losses import LossMeterType @@ -34,7 +34,7 @@ def __init__( learning_rate: float, alpha_learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -126,7 +126,9 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019ApflClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/ditto/client.py b/research/flamby/fed_isic2019/ditto/client.py index 7b1249eae..dcefd58af 100644 --- a/research/flamby/fed_isic2019/ditto/client.py +++ b/research/flamby/fed_isic2019/ditto/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -32,7 +32,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -128,7 +128,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019DittoClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/ditto_deep_mmd/client.py b/research/flamby/fed_isic2019/ditto_deep_mmd/client.py index de413e0f0..7d5dae85d 100644 --- a/research/flamby/fed_isic2019/ditto_deep_mmd/client.py +++ b/research/flamby/fed_isic2019/ditto_deep_mmd/client.py @@ -15,8 +15,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.deep_mmd_clients.ditto_deep_mmd_client import DittoDeepMmdClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -40,7 +40,7 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, deep_mmd_loss_weight: float = 10, deep_mmd_loss_depth: int = 1, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: feature_extraction_layers_with_size = OrderedDict( list(FED_ISIC2019_BASELINE_LAYERS.items())[-1 * deep_mmd_loss_depth :] @@ -158,7 +158,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019DittoClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/ditto_mkmmd/client.py b/research/flamby/fed_isic2019/ditto_mkmmd/client.py index 401afec2c..0f6c74c98 100644 --- a/research/flamby/fed_isic2019/ditto_mkmmd/client.py +++ b/research/flamby/fed_isic2019/ditto_mkmmd/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.mkmmd_clients.ditto_mkmmd_client import DittoMkMmdClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -41,7 +41,7 @@ def __init__( feature_l2_norm_weight: float = 1, mkmmd_loss_depth: int = 1, beta_global_update_interval: int = 20, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -175,7 +175,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019DittoClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/fedadam/client.py b/research/flamby/fed_isic2019/fedadam/client.py index 51a00470c..5f2880894 100644 --- a/research/flamby/fed_isic2019/fedadam/client.py +++ b/research/flamby/fed_isic2019/fedadam/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -32,7 +32,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -114,7 +114,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019FedAdamClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/fedadam/server.py b/research/flamby/fed_isic2019/fedadam/server.py index 3a9c65351..d72d80411 100644 --- a/research/flamby/fed_isic2019/fedadam/server.py +++ b/research/flamby/fed_isic2019/fedadam/server.py @@ -9,7 +9,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAdam -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -30,7 +30,7 @@ def main( checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" - checkpointer = BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) client_manager = SimpleClientManager() model = FedAdamEfficientNet() diff --git a/research/flamby/fed_isic2019/fedavg/client.py b/research/flamby/fed_isic2019/fedavg/client.py index a2ff5af6f..940ee26da 100644 --- a/research/flamby/fed_isic2019/fedavg/client.py +++ b/research/flamby/fed_isic2019/fedavg/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -31,7 +31,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -113,7 +113,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019FedAvgClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/fedavg/server.py b/research/flamby/fed_isic2019/fedavg/server.py index a682b40d5..232237fc8 100644 --- a/research/flamby/fed_isic2019/fedavg/server.py +++ b/research/flamby/fed_isic2019/fedavg/server.py @@ -10,7 +10,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAvg -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -28,7 +28,7 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" - checkpointer = BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) client_manager = SimpleClientManager() model = Baseline() diff --git a/research/flamby/fed_isic2019/fedper/client.py b/research/flamby/fed_isic2019/fedper/client.py index 1a8f3564d..5dd1a561d 100644 --- a/research/flamby/fed_isic2019/fedper/client.py +++ b/research/flamby/fed_isic2019/fedper/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.moon_client import MoonClient from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger @@ -35,7 +35,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -132,7 +132,9 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019FedPerClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/fedprox/client.py b/research/flamby/fed_isic2019/fedprox/client.py index de53f6a67..54b661bbf 100644 --- a/research/flamby/fed_isic2019/fedprox/client.py +++ b/research/flamby/fed_isic2019/fedprox/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fed_prox_client import FedProxClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -31,7 +31,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -113,7 +113,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019FedProxClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/fedprox/server.py b/research/flamby/fed_isic2019/fedprox/server.py index 26863c194..b3a54eaca 100644 --- a/research/flamby/fed_isic2019/fedprox/server.py +++ b/research/flamby/fed_isic2019/fedprox/server.py @@ -9,7 +9,7 @@ from flwr.common.logger import log from flwr.server.client_manager import SimpleClientManager -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.servers.adaptive_constraint_servers.fedprox_server import FedProxServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -28,7 +28,7 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" - checkpointer = BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) client_manager = SimpleClientManager() model = Baseline() diff --git a/research/flamby/fed_isic2019/fenda/client.py b/research/flamby/fed_isic2019/fenda/client.py index c3c1c2656..a7e8133ca 100644 --- a/research/flamby/fed_isic2019/fenda/client.py +++ b/research/flamby/fed_isic2019/fenda/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_client import FendaClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -33,7 +33,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -133,11 +133,11 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIsic2019FendaClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/moon/client.py b/research/flamby/fed_isic2019/moon/client.py index 6a3e8dfdf..4c52f0229 100644 --- a/research/flamby/fed_isic2019/moon/client.py +++ b/research/flamby/fed_isic2019/moon/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.moon_client import MoonClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -34,7 +34,7 @@ def __init__( learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, contrastive_weight: float = 10, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -134,7 +134,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019MoonClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/moon/server.py b/research/flamby/fed_isic2019/moon/server.py index efc7cb7f6..7e0d1bbdb 100644 --- a/research/flamby/fed_isic2019/moon/server.py +++ b/research/flamby/fed_isic2019/moon/server.py @@ -9,7 +9,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAvg -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -29,7 +29,7 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" - checkpointer = BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) client_manager = SimpleClientManager() model = FedIsic2019MoonModel() diff --git a/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py b/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py index 01368c84f..3e441dd58 100644 --- a/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py +++ b/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.mkmmd_clients.mr_mtl_mkmmd_client import MrMtlMkMmdClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -41,7 +41,7 @@ def __init__( feature_l2_norm_weight: float = 1, mkmmd_loss_depth: int = 1, beta_global_update_interval: int = 20, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -171,7 +171,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019MrMtlClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/perfcl/client.py b/research/flamby/fed_isic2019/perfcl/client.py index a6f7d0e80..84568fd66 100644 --- a/research/flamby/fed_isic2019/perfcl/client.py +++ b/research/flamby/fed_isic2019/perfcl/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.perfcl_client import PerFclClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -33,7 +33,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, mu: float = 10.0, gamma: float = 10.0, ) -> None: @@ -151,11 +151,11 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIsic2019PerFclClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/scaffold/client.py b/research/flamby/fed_isic2019/scaffold/client.py index 94a8ae8f7..66b6d8c79 100644 --- a/research/flamby/fed_isic2019/scaffold/client.py +++ b/research/flamby/fed_isic2019/scaffold/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.scaffold_client import ScaffoldClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric @@ -31,7 +31,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -113,7 +113,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIsic2019ScaffoldClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_isic2019/scaffold/server.py b/research/flamby/fed_isic2019/scaffold/server.py index 90417c906..6fe00bba5 100644 --- a/research/flamby/fed_isic2019/scaffold/server.py +++ b/research/flamby/fed_isic2019/scaffold/server.py @@ -8,12 +8,15 @@ from flamby.datasets.fed_isic2019 import Baseline from flwr.common.logger import log -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import ScaffoldServerCheckpointAndStateModule from fl4health.client_managers.fixed_without_replacement_manager import FixedSamplingByFractionClientManager +from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking +from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates +from fl4health.servers.scaffold_server import ScaffoldServer from fl4health.strategies.scaffold import Scaffold from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn -from research.flamby.flamby_servers.scaffold_server import ScaffoldServer from research.flamby.utils import fit_config, get_initial_model_info_with_control_variates @@ -29,11 +32,18 @@ def main( checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" - checkpointer = BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) client_manager = FixedSamplingByFractionClientManager() model = Baseline() + model_size = len(model.state_dict()) + checkpoint_and_state_module = ScaffoldServerCheckpointAndStateModule( + model=model, + parameter_exchanger=FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)), + model_checkpointers=checkpointer, + ) + initial_parameters, initial_control_variates = get_initial_model_info_with_control_variates(model) strategy = Scaffold( @@ -52,7 +62,10 @@ def main( ) server = ScaffoldServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_ixi/apfl/client.py b/research/flamby/fed_ixi/apfl/client.py index 4f1b49357..400b385ba 100644 --- a/research/flamby/fed_ixi/apfl/client.py +++ b/research/flamby/fed_ixi/apfl/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import ApflModule from fl4health.utils.losses import LossMeterType @@ -34,7 +34,7 @@ def __init__( learning_rate: float, alpha_learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -130,11 +130,11 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIxiApflClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_ixi/ditto/client.py b/research/flamby/fed_ixi/ditto/client.py index 9656d7ef4..263fb1488 100644 --- a/research/flamby/fed_ixi/ditto/client.py +++ b/research/flamby/fed_ixi/ditto/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric @@ -32,7 +32,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -139,11 +139,11 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIxiDittoClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_ixi/fedadam/client.py b/research/flamby/fed_ixi/fedadam/client.py index e6c6f785e..d6ea8a8d1 100644 --- a/research/flamby/fed_ixi/fedadam/client.py +++ b/research/flamby/fed_ixi/fedadam/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric @@ -32,7 +32,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -114,7 +114,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIxiFedAdamClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_ixi/fedadam/server.py b/research/flamby/fed_ixi/fedadam/server.py index 58444b599..ca0e98c64 100644 --- a/research/flamby/fed_ixi/fedadam/server.py +++ b/research/flamby/fed_ixi/fedadam/server.py @@ -9,7 +9,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAdam -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -33,9 +33,9 @@ def main( federated_checkpointing: bool = config.get("federated_checkpointing", True) log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}") checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = SimpleClientManager() @@ -67,7 +67,7 @@ def main( ) if federated_checkpointing: - assert isinstance(checkpointer, BestLossTorchCheckpointer) + assert isinstance(checkpointer, BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_score}") # Shutdown the server gracefully diff --git a/research/flamby/fed_ixi/fedavg/client.py b/research/flamby/fed_ixi/fedavg/client.py index 4fbd24695..5a56508f7 100644 --- a/research/flamby/fed_ixi/fedavg/client.py +++ b/research/flamby/fed_ixi/fedavg/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric @@ -31,7 +31,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -116,7 +116,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIxiFedAvgClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_ixi/fedavg/server.py b/research/flamby/fed_ixi/fedavg/server.py index 424f0c9fc..e9d7c0fbb 100644 --- a/research/flamby/fed_ixi/fedavg/server.py +++ b/research/flamby/fed_ixi/fedavg/server.py @@ -10,7 +10,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAvg -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -31,9 +31,9 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ federated_checkpointing: bool = config.get("federated_checkpointing", True) log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}") checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = SimpleClientManager() @@ -68,7 +68,7 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ ) if federated_checkpointing: - assert isinstance(checkpointer, BestLossTorchCheckpointer) + assert isinstance(checkpointer, BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_score}") # Shutdown the server gracefully diff --git a/research/flamby/fed_ixi/fedper/client.py b/research/flamby/fed_ixi/fedper/client.py index fb2c3aa78..4a953ccac 100644 --- a/research/flamby/fed_ixi/fedper/client.py +++ b/research/flamby/fed_ixi/fedper/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.moon_client import MoonClient from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger @@ -35,7 +35,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -140,11 +140,11 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIxiFedPerClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_ixi/fedprox/client.py b/research/flamby/fed_ixi/fedprox/client.py index 26921012e..58d3e18f0 100644 --- a/research/flamby/fed_ixi/fedprox/client.py +++ b/research/flamby/fed_ixi/fedprox/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fed_prox_client import FedProxClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric @@ -31,7 +31,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -116,7 +116,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIxiFedProxClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_ixi/fedprox/server.py b/research/flamby/fed_ixi/fedprox/server.py index bbe7771ff..016a665ce 100644 --- a/research/flamby/fed_ixi/fedprox/server.py +++ b/research/flamby/fed_ixi/fedprox/server.py @@ -9,7 +9,7 @@ from flwr.common.logger import log from flwr.server.client_manager import SimpleClientManager -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.servers.adaptive_constraint_servers.fedprox_server import FedProxServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -31,9 +31,9 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub federated_checkpointing: bool = config.get("federated_checkpointing", True) log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}") checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = SimpleClientManager() @@ -70,7 +70,7 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub ) if federated_checkpointing: - assert isinstance(checkpointer, BestLossTorchCheckpointer) + assert isinstance(checkpointer, BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_score}") # Shutdown the server gracefully diff --git a/research/flamby/fed_ixi/fenda/client.py b/research/flamby/fed_ixi/fenda/client.py index 876fc600a..446bb5803 100644 --- a/research/flamby/fed_ixi/fenda/client.py +++ b/research/flamby/fed_ixi/fenda/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_client import FendaClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric @@ -33,7 +33,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -133,11 +133,11 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIxiFendaClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_ixi/moon/client.py b/research/flamby/fed_ixi/moon/client.py index a3b5ac5e9..d758633cd 100644 --- a/research/flamby/fed_ixi/moon/client.py +++ b/research/flamby/fed_ixi/moon/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.moon_client import MoonClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric @@ -34,7 +34,7 @@ def __init__( learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, contrastive_weight: float = 10, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -134,7 +134,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIxiMoonClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_ixi/moon/server.py b/research/flamby/fed_ixi/moon/server.py index cd381a2f6..e96830aaa 100644 --- a/research/flamby/fed_ixi/moon/server.py +++ b/research/flamby/fed_ixi/moon/server.py @@ -9,7 +9,7 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAvg -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -32,9 +32,9 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ federated_checkpointing: bool = config.get("federated_checkpointing", True) log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}") checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = SimpleClientManager() @@ -66,7 +66,7 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ ) if federated_checkpointing: - assert isinstance(checkpointer, BestLossTorchCheckpointer) + assert isinstance(checkpointer, BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_score}") # Shutdown the server gracefully diff --git a/research/flamby/fed_ixi/perfcl/client.py b/research/flamby/fed_ixi/perfcl/client.py index e93098aa7..15ee18b40 100644 --- a/research/flamby/fed_ixi/perfcl/client.py +++ b/research/flamby/fed_ixi/perfcl/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.perfcl_client import PerFclClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric @@ -33,7 +33,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, mu: float = 10.0, gamma: float = 10.0, ) -> None: @@ -151,11 +151,11 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" post_aggregation_checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer) + checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIxiPerFclClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_ixi/scaffold/client.py b/research/flamby/fed_ixi/scaffold/client.py index 4cb8a6aa5..d8c74ea62 100644 --- a/research/flamby/fed_ixi/scaffold/client.py +++ b/research/flamby/fed_ixi/scaffold/client.py @@ -14,8 +14,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.scaffold_client import ScaffoldClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric @@ -31,7 +31,7 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, @@ -116,7 +116,9 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointModule(post_aggregation=BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name)) + checkpointer = ClientCheckpointAndStateModule( + post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + ) client = FedIxiScaffoldClient( data_path=Path(args.dataset_dir), diff --git a/research/flamby/fed_ixi/scaffold/server.py b/research/flamby/fed_ixi/scaffold/server.py index b38f0e287..d5e1edf12 100644 --- a/research/flamby/fed_ixi/scaffold/server.py +++ b/research/flamby/fed_ixi/scaffold/server.py @@ -8,12 +8,15 @@ from flamby.datasets.fed_ixi import Baseline from flwr.common.logger import log -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import ScaffoldServerCheckpointAndStateModule from fl4health.client_managers.fixed_without_replacement_manager import FixedSamplingByFractionClientManager +from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking +from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates +from fl4health.servers.scaffold_server import ScaffoldServer from fl4health.strategies.scaffold import Scaffold from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn -from research.flamby.flamby_servers.scaffold_server import ScaffoldServer from research.flamby.utils import fit_config, get_initial_model_info_with_control_variates, summarize_model_info @@ -32,9 +35,9 @@ def main( federated_checkpointing: bool = config.get("federated_checkpointing", True) log(INFO, f"Performing Federated Checkpointing: {federated_checkpointing}") checkpointer = ( - BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = FixedSamplingByFractionClientManager() @@ -45,6 +48,13 @@ def main( model = Baseline(out_channels_first_layer=12) summarize_model_info(model) + model_size = len(model.state_dict()) + checkpoint_and_state_module = ScaffoldServerCheckpointAndStateModule( + model=model, + parameter_exchanger=FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)), + model_checkpointers=checkpointer, + ) + initial_parameters, initial_control_variates = get_initial_model_info_with_control_variates(model) strategy = Scaffold( @@ -63,7 +73,10 @@ def main( ) server = ScaffoldServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( @@ -73,7 +86,7 @@ def main( ) if federated_checkpointing: - assert isinstance(checkpointer, BestLossTorchCheckpointer) + assert isinstance(checkpointer, BestLossTorchModuleCheckpointer) log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer.best_score}") # Shutdown the server gracefully diff --git a/research/flamby/flamby_servers/full_exchange_server.py b/research/flamby/flamby_servers/full_exchange_server.py index dd0dc6985..b204542cd 100644 --- a/research/flamby/flamby_servers/full_exchange_server.py +++ b/research/flamby/flamby_servers/full_exchange_server.py @@ -5,7 +5,7 @@ from flwr.server.client_manager import ClientManager from flwr.server.strategy import Strategy -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer @@ -17,7 +17,7 @@ def __init__( fl_config: Config, model: Optional[nn.Module] = None, strategy: Optional[Strategy] = None, - checkpointer: Optional[TorchCheckpointer] = None, + checkpointer: Optional[TorchModuleCheckpointer] = None, ) -> None: # To help with model rehydration parameter_exchanger = FullParameterExchanger() diff --git a/research/flamby/flamby_servers/scaffold_server.py b/research/flamby/flamby_servers/scaffold_server.py deleted file mode 100644 index 3521caea4..000000000 --- a/research/flamby/flamby_servers/scaffold_server.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from flwr.common.parameter import parameters_to_ndarrays -from flwr.common.typing import Config -from flwr.server.client_manager import ClientManager -from flwr.server.strategy import Strategy - -from fl4health.checkpointing.checkpointer import TorchCheckpointer -from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking -from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates -from fl4health.servers.base_server import FlServer - - -class ScaffoldServer(FlServer): - def __init__( - self, - client_manager: ClientManager, - fl_config: Config, - model: Optional[nn.Module] = None, - strategy: Optional[Strategy] = None, - checkpointer: Optional[TorchCheckpointer] = None, - ) -> None: - assert model is not None - # To help with model rehydration - model_size = len(model.state_dict()) - parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)) - super().__init__( - client_manager=client_manager, - fl_config=fl_config, - parameter_exchanger=parameter_exchanger, - model=model, - strategy=strategy, - checkpointer=checkpointer, - ) - - def _hydrate_model_for_checkpointing(self) -> None: - assert self.server_model is not None, ( - "Model hydration has been called but no server_model is defined to hydrate. The functionality of " - "_hydrate_model_for_checkpointing can be overridden if checkpointing without a torch architecture is " - "possible and desired" - ) - assert self.parameter_exchanger is not None, ( - "Model hydration has been called but no parameter_exchanger is defined to hydrate. The functionality of " - "_hydrate_model_for_checkpointing can be overridden if checkpointing without a parameter exchanger is " - "possible and desired" - ) - packed_parameters = parameters_to_ndarrays(self.parameters) - # Don't need the control variates for checkpointing. - assert isinstance(self.parameter_exchanger, FullParameterExchangerWithPacking) - model_ndarrays, _ = self.parameter_exchanger.unpack_parameters(packed_parameters) - self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model) diff --git a/research/flamby/single_node_trainer.py b/research/flamby/single_node_trainer.py index 0b5bee501..e49e9108b 100644 --- a/research/flamby/single_node_trainer.py +++ b/research/flamby/single_node_trainer.py @@ -9,7 +9,7 @@ from torch.nn.modules.loss import _Loss from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.utils.metrics import MetricManager @@ -25,7 +25,7 @@ def __init__( checkpoint_dir = os.path.join(checkpoint_stub, run_name) # This is called the "server model" so that it can be found by the evaluate_on_holdout.py script checkpoint_name = "server_best_model.pkl" - self.checkpointer = BestLossTorchCheckpointer(checkpoint_dir, checkpoint_name) + self.checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) self.dataset_dir = dataset_dir self.model: nn.Module self.criterion: _Loss diff --git a/research/gemini/ditto/client.py b/research/gemini/ditto/client.py index 4f7cb305b..75b81f3f1 100644 --- a/research/gemini/ditto/client.py +++ b/research/gemini/ditto/client.py @@ -17,7 +17,7 @@ from torch.utils.data import DataLoader from utils.random import set_all_random_seeds -from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer +from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchModuleCheckpointer from fl4health.clients.ditto_client import DittoClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric @@ -41,7 +41,7 @@ def __init__( checkpoint_stub: str, run_name: str = "", loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[TorchCheckpointer] = None, + checkpointer: Optional[TorchModuleCheckpointer] = None, ) -> None: # Checkpointing: create a string of the names of the hospitals self.hospitals = hospital_id diff --git a/research/gemini/fedper/client.py b/research/gemini/fedper/client.py index 3d1b82117..f3877ca81 100644 --- a/research/gemini/fedper/client.py +++ b/research/gemini/fedper/client.py @@ -21,7 +21,7 @@ from torch.utils.data import DataLoader from utils.random import set_all_random_seeds -from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer +from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchModuleCheckpointer from fl4health.clients.moon_client import MoonClient from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger @@ -42,7 +42,7 @@ def __init__( checkpoint_stub: str, run_name: str = "", loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[TorchCheckpointer] = None, + checkpointer: Optional[TorchModuleCheckpointer] = None, ) -> None: # Checkpointing: create a string of the names of the hospitals self.hospitals = hospital_id diff --git a/research/gemini/moon/client.py b/research/gemini/moon/client.py index f6cb82d92..576f8ddf7 100644 --- a/research/gemini/moon/client.py +++ b/research/gemini/moon/client.py @@ -22,7 +22,7 @@ from torch.utils.data import DataLoader from utils.random import set_all_random_seeds -from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer +from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchModuleCheckpointer from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric from research.gemini.metrics.metrics import Accuracy, Binary_F1, Binary_ROC_AUC @@ -41,7 +41,7 @@ def __init__( run_name: str = "", loss_meter_type: LossMeterType = LossMeterType.AVERAGE, contrastive_weight: float = 10, - checkpointer: Optional[TorchCheckpointer] = None, + checkpointer: Optional[TorchModuleCheckpointer] = None, ) -> None: # Checkpointing: create a string of the names of the hospitals self.hospitals = hospital_id diff --git a/research/gemini/moon/server.py b/research/gemini/moon/server.py index 0bf3b1762..8133da667 100644 --- a/research/gemini/moon/server.py +++ b/research/gemini/moon/server.py @@ -18,7 +18,7 @@ from servers.full_exchange_server import FullExchangeServer from utils.random import set_all_random_seeds -from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, LatestTorchModuleCheckpointer from fl4health.utils.config import load_config from research.gemini.simple_metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn @@ -58,7 +58,7 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ checkpointer = ( BestMetricTorchCheckpointer(checkpoint_dir, checkpoint_name) if federated_checkpointing - else LatestTorchCheckpointer(checkpoint_dir, checkpoint_name) + else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) client_manager = SimpleClientManager() diff --git a/research/gemini/perfcl/client.py b/research/gemini/perfcl/client.py index b69e0bd9b..841baf4a8 100644 --- a/research/gemini/perfcl/client.py +++ b/research/gemini/perfcl/client.py @@ -21,7 +21,7 @@ from torch.utils.data import DataLoader from utils.random import set_all_random_seeds -from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchCheckpointer +from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer, TorchModuleCheckpointer from fl4health.clients.fenda_client import FendaClient from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric @@ -40,7 +40,7 @@ def __init__( checkpoint_stub: str, run_name: str = "", loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[TorchCheckpointer] = None, + checkpointer: Optional[TorchModuleCheckpointer] = None, extra_loss_weights: Tuple[float, float] = (10, 10), ) -> None: # Checkpointing: create a string of the names of the hospitals diff --git a/research/gemini/servers/full_exchange_server.py b/research/gemini/servers/full_exchange_server.py index 374f30fd4..47ab4d2a5 100644 --- a/research/gemini/servers/full_exchange_server.py +++ b/research/gemini/servers/full_exchange_server.py @@ -4,7 +4,7 @@ from flwr.server.client_manager import ClientManager from flwr.server.strategy import Strategy -from fl4health.checkpointing.checkpointer import TorchCheckpointer +from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServerWithCheckpointing @@ -15,7 +15,7 @@ def __init__( client_manager: ClientManager, model: nn.Module, strategy: Optional[Strategy] = None, - checkpointer: Optional[TorchCheckpointer] = None, + checkpointer: Optional[TorchModuleCheckpointer] = None, ) -> None: # To help with model rehydration parameter_exchanger = FullParameterExchanger() diff --git a/research/picai/fedavg/client.py b/research/picai/fedavg/client.py index 11a88c21b..39758d83b 100644 --- a/research/picai/fedavg/client.py +++ b/research/picai/fedavg/client.py @@ -13,7 +13,7 @@ from torch.optim import Optimizer from torchmetrics.classification import MultilabelAveragePrecision -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType @@ -38,7 +38,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, intermediate_client_state_dir: Optional[Path] = None, diff --git a/research/picai/reporting/server.py b/research/picai/reporting/server.py index bc87a11b0..987db9a5f 100644 --- a/research/picai/reporting/server.py +++ b/research/picai/reporting/server.py @@ -9,7 +9,7 @@ from examples.models.cnn_model import Net from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.reporting import WandBReporter from fl4health.servers.base_server import FlServer @@ -45,8 +45,8 @@ def main(config: Dict[str, Any]) -> None: # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() checkpointers = [ - BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl"), - LatestTorchCheckpointer(config["checkpoint_path"], "latest_model.pkl"), + BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl"), + LatestTorchModuleCheckpointer(config["checkpoint_path"], "latest_model.pkl"), ] # Server performs simple FedAveraging as its server-side optimization strategy diff --git a/research/picai/single_node_trainer.py b/research/picai/single_node_trainer.py index e884a5224..a75b81a58 100644 --- a/research/picai/single_node_trainer.py +++ b/research/picai/single_node_trainer.py @@ -11,7 +11,7 @@ from torch.nn.modules.loss import _Loss from torch.optim import Optimizer -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, PerRoundCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, PerRoundStateCheckpointer from fl4health.utils.metrics import MetricManager @@ -34,9 +34,9 @@ def __init__( os.mkdir(checkpoint_dir) per_round_checkpoint_name = "ckpt.pkl" - self.per_epoch_checkpointer = PerRoundCheckpointer(Path(checkpoint_dir), Path(per_round_checkpoint_name)) + self.per_epoch_checkpointer = PerRoundStateCheckpointer(Path(checkpoint_dir), Path(per_round_checkpoint_name)) best_metric_checkpoint_name = "best_ckpt.pkl" - self.checkpointer = BestLossTorchCheckpointer(checkpoint_dir, best_metric_checkpoint_name) + self.checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, best_metric_checkpoint_name) self.train_loader = train_loader self.val_loader = val_loader diff --git a/tests/checkpointing/test_best_checkpointer.py b/tests/checkpointing/test_best_checkpointer.py index c98dccd71..fe1088ed7 100644 --- a/tests/checkpointing/test_best_checkpointer.py +++ b/tests/checkpointing/test_best_checkpointer.py @@ -1,8 +1,8 @@ -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer def test_best_metric_checkpointer() -> None: - best_loss_checkpointer = BestLossTorchCheckpointer("", "") + best_loss_checkpointer = BestLossTorchModuleCheckpointer("", "") # First checkpoint should happen since the best metric is None none_checkpoint = best_loss_checkpointer._should_checkpoint(0.95) assert none_checkpoint diff --git a/tests/checkpointing/test_client_module.py b/tests/checkpointing/test_client_module.py index 22f029144..4cf1a7439 100644 --- a/tests/checkpointing/test_client_module.py +++ b/tests/checkpointing/test_client_module.py @@ -4,8 +4,12 @@ import pytest import torch -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer, TorchCheckpointer -from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointModule +from fl4health.checkpointing.checkpointer import ( + BestLossTorchModuleCheckpointer, + LatestTorchModuleCheckpointer, + TorchModuleCheckpointer, +) +from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer from fl4health.utils.privacy_utilities import convert_model_to_opacus_model from tests.test_utils.models_for_test import LinearTransform @@ -16,7 +20,7 @@ def test_client_checkpointer_module_opacus(tmp_path: Path) -> None: checkpoint_dir.mkdir() pre_aggregation_checkpointer = BestLossOpacusCheckpointer(str(checkpoint_dir), "pre_agg.pkl") post_aggregation_checkpointer = BestLossOpacusCheckpointer(str(checkpoint_dir), "post_agg.pkl") - checkpointer = ClientCheckpointModule( + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=pre_aggregation_checkpointer, post_aggregation=post_aggregation_checkpointer ) @@ -59,9 +63,9 @@ def test_client_checkpointer_module_opacus(tmp_path: Path) -> None: def test_client_checkpointer_module(tmp_path: Path) -> None: checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() - pre_aggregation_checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "pre_agg.pkl") - post_aggregation_checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "post_agg.pkl") - checkpointer = ClientCheckpointModule( + pre_aggregation_checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "pre_agg.pkl") + post_aggregation_checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "post_agg.pkl") + checkpointer = ClientCheckpointAndStateModule( pre_aggregation=pre_aggregation_checkpointer, post_aggregation=post_aggregation_checkpointer ) @@ -73,8 +77,8 @@ def test_client_checkpointer_module(tmp_path: Path) -> None: assert checkpointer.pre_aggregation is not None assert checkpointer.post_aggregation is not None - loaded_pre_model = pre_aggregation_checkpointer.load_best_checkpoint() - loaded_post_model = post_aggregation_checkpointer.load_best_checkpoint() + loaded_pre_model = pre_aggregation_checkpointer.load_checkpoint() + loaded_post_model = post_aggregation_checkpointer.load_checkpoint() assert isinstance(loaded_pre_model, LinearTransform) # pre aggregation model should be the same as model_1 @@ -86,8 +90,8 @@ def test_client_checkpointer_module(tmp_path: Path) -> None: checkpointer.maybe_checkpoint(model_2, 0.68, {"test_1": 1.0}, CheckpointMode.PRE_AGGREGATION) checkpointer.maybe_checkpoint(model_1, 0.68, {"test_1": 1.0}, CheckpointMode.POST_AGGREGATION) - loaded_pre_model = pre_aggregation_checkpointer.load_best_checkpoint() - loaded_post_model = post_aggregation_checkpointer.load_best_checkpoint() + loaded_pre_model = pre_aggregation_checkpointer.load_checkpoint() + loaded_post_model = post_aggregation_checkpointer.load_checkpoint() assert isinstance(loaded_pre_model, LinearTransform) # pre aggregation model should be the same as model_1 @@ -101,12 +105,12 @@ def test_client_checkpointer_module(tmp_path: Path) -> None: def test_client_checkpointer_module_with_sequence_of_checkpointers(tmp_path: Path) -> None: checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() - pre_aggregation_checkpointer: List[TorchCheckpointer] = [ - BestLossTorchCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl"), - LatestTorchCheckpointer(str(checkpoint_dir), "pre_agg_latest.pkl"), + pre_aggregation_checkpointer: List[TorchModuleCheckpointer] = [ + BestLossTorchModuleCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl"), + LatestTorchModuleCheckpointer(str(checkpoint_dir), "pre_agg_latest.pkl"), ] - post_aggregation_checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "post_agg.pkl") - checkpoint_module = ClientCheckpointModule( + post_aggregation_checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "post_agg.pkl") + checkpoint_module = ClientCheckpointAndStateModule( pre_aggregation=pre_aggregation_checkpointer, post_aggregation=post_aggregation_checkpointer ) @@ -119,9 +123,9 @@ def test_client_checkpointer_module_with_sequence_of_checkpointers(tmp_path: Pat assert checkpoint_module.pre_aggregation is not None assert checkpoint_module.post_aggregation is not None - loaded_pre_model_best = pre_aggregation_checkpointer[0].load_best_checkpoint() - loaded_pre_model_latest = pre_aggregation_checkpointer[1].load_best_checkpoint() - loaded_post_model = post_aggregation_checkpointer.load_best_checkpoint() + loaded_pre_model_best = pre_aggregation_checkpointer[0].load_checkpoint() + loaded_pre_model_latest = pre_aggregation_checkpointer[1].load_checkpoint() + loaded_post_model = post_aggregation_checkpointer.load_checkpoint() assert isinstance(loaded_pre_model_best, LinearTransform) # pre aggregation model should be the same as model_1 @@ -137,9 +141,9 @@ def test_client_checkpointer_module_with_sequence_of_checkpointers(tmp_path: Pat checkpoint_module.maybe_checkpoint(model_2, 0.88, {"test_1": 1.0}, CheckpointMode.PRE_AGGREGATION) checkpoint_module.maybe_checkpoint(model_1, 0.68, {"test_1": 1.0}, CheckpointMode.POST_AGGREGATION) - loaded_pre_model_best = pre_aggregation_checkpointer[0].load_best_checkpoint() - loaded_pre_model_latest = pre_aggregation_checkpointer[1].load_best_checkpoint() - loaded_post_model = post_aggregation_checkpointer.load_best_checkpoint() + loaded_pre_model_best = pre_aggregation_checkpointer[0].load_checkpoint() + loaded_pre_model_latest = pre_aggregation_checkpointer[1].load_checkpoint() + loaded_post_model = post_aggregation_checkpointer.load_checkpoint() assert isinstance(loaded_pre_model_best, LinearTransform) # pre aggregation model should be the same as model_1 since the metric isn't better than the previous one @@ -158,23 +162,23 @@ def test_path_duplication_check(tmp_path: Path) -> None: checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() pre_aggregation_checkpointer = [ - BestLossTorchCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl"), - LatestTorchCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl"), + BestLossTorchModuleCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl"), + LatestTorchModuleCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl"), ] - post_aggregation_checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "post_agg.pkl") + post_aggregation_checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "post_agg.pkl") # We have duplicate names, so we want to raise an error to prevent data loss/checkpoint overwrites with pytest.raises(ValueError): - ClientCheckpointModule( + ClientCheckpointAndStateModule( pre_aggregation=pre_aggregation_checkpointer, post_aggregation=post_aggregation_checkpointer ) pre_aggregation_checkpointer = [ - BestLossTorchCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl"), - LatestTorchCheckpointer(str(checkpoint_dir), "pre_agg_latest.pkl"), + BestLossTorchModuleCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl"), + LatestTorchModuleCheckpointer(str(checkpoint_dir), "pre_agg_latest.pkl"), ] - post_aggregation_checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl") + post_aggregation_checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl") # We have duplicate names, so we want to raise an error to prevent data loss/checkpoint overwrites with pytest.raises(ValueError): - ClientCheckpointModule( + ClientCheckpointAndStateModule( pre_aggregation=pre_aggregation_checkpointer, post_aggregation=post_aggregation_checkpointer ) diff --git a/tests/checkpointing/test_function_checkpointer.py b/tests/checkpointing/test_function_checkpointer.py index 9e348d007..cf3cdb28f 100644 --- a/tests/checkpointing/test_function_checkpointer.py +++ b/tests/checkpointing/test_function_checkpointer.py @@ -2,7 +2,7 @@ from flwr.common.typing import Scalar -from fl4health.checkpointing.checkpointer import FunctionTorchCheckpointer +from fl4health.checkpointing.checkpointer import FunctionTorchModuleCheckpointer def score_function(_: float, metrics: Dict[str, Scalar]) -> float: @@ -15,7 +15,7 @@ def score_function(_: float, metrics: Dict[str, Scalar]) -> float: def test_function_checkpointer() -> None: - function_checkpointer = FunctionTorchCheckpointer("", "", score_function, maximize=True) + function_checkpointer = FunctionTorchModuleCheckpointer("", "", score_function, maximize=True) loss_1, loss_2 = 1.0, 0.9 metrics_1: Dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.67, "f1": 0.76} metrics_2: Dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.9, "f1": 0.6} diff --git a/tests/checkpointing/test_opacus_checkpointers.py b/tests/checkpointing/test_opacus_checkpointers.py index e58104e08..ae40f92ca 100644 --- a/tests/checkpointing/test_opacus_checkpointers.py +++ b/tests/checkpointing/test_opacus_checkpointers.py @@ -5,7 +5,7 @@ import torch from flwr.common.typing import Scalar -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.opacus_checkpointer import ( BestLossOpacusCheckpointer, LatestOpacusCheckpointer, @@ -37,7 +37,7 @@ def test_save_and_load_best_loss_checkpoint(tmp_path: Path) -> None: # Should throw a not implemented error with pytest.raises(NotImplementedError): - _ = checkpointer.load_best_checkpoint() + _ = checkpointer.load_checkpoint() checkpointer.load_best_checkpoint_into_model(target_model) @@ -64,7 +64,7 @@ def test_save_and_load_latest_checkpoint(tmp_path: Path) -> None: # Should throw a not implemented error with pytest.raises(NotImplementedError): - _ = checkpointer.load_best_checkpoint() + _ = checkpointer.load_checkpoint() checkpointer.load_best_checkpoint_into_model(target_model) assert isinstance(target_model, LinearTransform) @@ -109,7 +109,7 @@ def test_save_and_load_function_checkpoint(tmp_path: Path) -> None: # Should throw a not implemented error with pytest.raises(NotImplementedError): - _ = opacus_checkpointer.load_best_checkpoint() + _ = opacus_checkpointer.load_checkpoint() opacus_checkpointer.load_best_checkpoint_into_model(target_model) assert isinstance(target_model, LinearTransform) @@ -143,7 +143,7 @@ def test_fix_of_loss_stateless_model_exception(tmp_path: Path) -> None: model = create_opacus_model_via_functorch(model) opacus_target_model = convert_model_to_opacus_model(opacus_target_model) - torch_checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), checkpoint_name) + torch_checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), checkpoint_name) # This should throw an error along the lines of # AttributeError: Can't pickle local object 'vmap..wrapped' with pytest.raises(AttributeError) as attribute_exception: diff --git a/tests/checkpointing/test_per_round_checkpointer.py b/tests/checkpointing/test_per_round_checkpointer.py index 5f81adbcb..77ab6d733 100644 --- a/tests/checkpointing/test_per_round_checkpointer.py +++ b/tests/checkpointing/test_per_round_checkpointer.py @@ -4,7 +4,7 @@ import torch from torch.optim import Optimizer -from fl4health.checkpointing.checkpointer import PerRoundCheckpointer +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer from tests.test_utils.models_for_test import LinearModel @@ -12,7 +12,7 @@ def test_per_round_checkpointer() -> None: model: torch.nn.Module = LinearModel() optimizer: Optimizer = torch.optim.SGD(model.parameters(), lr=0.01) with tempfile.TemporaryDirectory() as results_dir: - checkpointer = PerRoundCheckpointer(checkpoint_dir=Path(results_dir), checkpoint_name=Path("ckpt.pt")) + checkpointer = PerRoundStateCheckpointer(checkpoint_dir=Path(results_dir), checkpoint_name=Path("ckpt.pt")) assert not checkpointer.checkpoint_exists() diff --git a/tests/checkpointing/test_save_load.py b/tests/checkpointing/test_save_load.py index c50f60622..d31043abb 100644 --- a/tests/checkpointing/test_save_load.py +++ b/tests/checkpointing/test_save_load.py @@ -5,9 +5,9 @@ from flwr.common.typing import Scalar from fl4health.checkpointing.checkpointer import ( - BestLossTorchCheckpointer, - FunctionTorchCheckpointer, - LatestTorchCheckpointer, + BestLossTorchModuleCheckpointer, + FunctionTorchModuleCheckpointer, + LatestTorchModuleCheckpointer, ) from tests.test_utils.models_for_test import LinearTransform @@ -19,14 +19,14 @@ def test_save_and_load_best_loss_checkpoint(tmp_path: Path) -> None: model_1 = LinearTransform() model_2 = LinearTransform() - checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "best_model.pkl") + checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") checkpointer.maybe_checkpoint(model_1, 1.23, {"test": 1.2}) checkpointer.maybe_checkpoint(model_2, 0.98, {"test": 1.2}) # Correct metric saved. assert checkpointer.best_score == 0.98 - loaded_model = checkpointer.load_best_checkpoint() + loaded_model = checkpointer.load_checkpoint() assert isinstance(loaded_model, LinearTransform) # Correct loading tensors of the second model with better loss value assert torch.equal(model_2.linear.weight, loaded_model.linear.weight) @@ -39,11 +39,11 @@ def test_save_and_load_latest_checkpoint(tmp_path: Path) -> None: model_1 = LinearTransform() model_2 = LinearTransform() - checkpointer = LatestTorchCheckpointer(str(checkpoint_dir), "best_model.pkl") + checkpointer = LatestTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") checkpointer.maybe_checkpoint(model_2, 0.7, {"test": 1.2}) checkpointer.maybe_checkpoint(model_1, 0.6, {"test": 1.2}) - loaded_model = checkpointer.load_best_checkpoint() + loaded_model = checkpointer.load_checkpoint() assert isinstance(loaded_model, LinearTransform) # Correct loading tensors of the first model since each should be saved and model_1 is the "latest" assert torch.equal(model_1.linear.weight, loaded_model.linear.weight) @@ -66,7 +66,7 @@ def test_save_and_load_function_checkpoint(tmp_path: Path) -> None: model_1 = LinearTransform() model_2 = LinearTransform() - function_checkpointer = FunctionTorchCheckpointer( + function_checkpointer = FunctionTorchModuleCheckpointer( str(checkpoint_dir), checkpoint_name, score_function, maximize=True ) loss_1, loss_2 = 1.0, 0.9 @@ -81,7 +81,7 @@ def test_save_and_load_function_checkpoint(tmp_path: Path) -> None: # Should be true since the average of accuracy and precision provided in the dictionary is larger than 0.85 function_checkpointer.maybe_checkpoint(model_2, loss_2, metrics_2) - loaded_model = function_checkpointer.load_best_checkpoint() + loaded_model = function_checkpointer.load_checkpoint() assert isinstance(loaded_model, LinearTransform) # Correct loading tensors of the first model since each should be saved and model_1 is the "latest" assert torch.equal(model_2.linear.weight, loaded_model.linear.weight) diff --git a/tests/preprocessing/test_warm_up_module.py b/tests/preprocessing/test_warm_up_module.py index 45406c870..e90369c13 100644 --- a/tests/preprocessing/test_warm_up_module.py +++ b/tests/preprocessing/test_warm_up_module.py @@ -3,7 +3,7 @@ import os from pathlib import Path -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.model_bases.apfl_base import ApflModule from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.moon_base import MoonModel @@ -18,7 +18,7 @@ def test_initializing_warm_up_module(tmp_path: Path) -> None: # Save a temporary model using checkpointer saved_model = SmallCnn() - checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "best_model.pkl") + checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") checkpointer.maybe_checkpoint(saved_model, 0.7, {}) # Save a temporary weights mapping dict diff --git a/tests/servers/test_base_server.py b/tests/servers/test_base_server.py index e0de78b5c..4538a0bc1 100644 --- a/tests/servers/test_base_server.py +++ b/tests/servers/test_base_server.py @@ -12,7 +12,7 @@ from flwr.server.strategy import FedAvg from freezegun import freeze_time -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.client_managers.base_sampling_manager import SimpleClientManager from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger @@ -37,7 +37,7 @@ def test_hydration_no_model_with_checkpointer(tmp_path: Path) -> None: # Temporary path to write pkl to, will be cleaned up at the end of the test. checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() - checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "best_model.pkl") + checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") # Checkpointer is defined but there is no server-side model defined to produce a model from the server state. # An assertion error should be throw stating this @@ -53,7 +53,7 @@ def test_hydration_no_exchanger_with_checkpointer(tmp_path: Path) -> None: # Temporary path to write pkl to, will be cleaned up at the end of the test. checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() - checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "best_model.pkl") + checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") # Checkpointer is defined but there is no parameter exchanger defined to produce a model from the server state. # An assertion error should be throw stating this @@ -79,7 +79,7 @@ def test_hydration_and_checkpointer(tmp_path: Path) -> None: # Temporary path to write pkl to, will be cleaned up at the end of the test. checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() - checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "best_model.pkl") + checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") # Server-side hydration to convert server state to model and checkpointing behavior are both defined, a model # should be saved and be loaded successfully. @@ -87,7 +87,7 @@ def test_hydration_and_checkpointer(tmp_path: Path) -> None: client_manager=PoissonSamplingClientManager(), fl_config={}, checkpointer=checkpointer ) fl_server_both._maybe_checkpoint(1.0, {}, server_round=5) - loaded_model = checkpointer.load_best_checkpoint() + loaded_model = checkpointer.load_checkpoint() assert isinstance(loaded_model, LinearTransform) # Correct loading tensors of the saved model assert torch.equal(model.linear.weight, loaded_model.linear.weight) @@ -97,7 +97,7 @@ def test_fl_server_with_checkpointing(tmp_path: Path) -> None: # Temporary path to write pkl to, will be cleaned up at the end of the test. checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() - checkpointer = BestLossTorchCheckpointer(str(checkpoint_dir), "best_model.pkl") + checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") # Initial model held by server initial_model = LinearTransform() # represents the model computed by the clients aggregation @@ -116,7 +116,7 @@ def test_fl_server_with_checkpointing(tmp_path: Path) -> None: server.parameters = ndarrays_to_parameters(parameter_exchanger.push_parameters(updated_model)) server._maybe_checkpoint(1.0, {}, server_round=5) - loaded_model = checkpointer.load_best_checkpoint() + loaded_model = checkpointer.load_checkpoint() assert isinstance(loaded_model, LinearTransform) # Correct loading tensors of the saved model assert torch.equal(updated_model.linear.weight, loaded_model.linear.weight) diff --git a/tests/smoke_tests/load_from_checkpoint_example/client.py b/tests/smoke_tests/load_from_checkpoint_example/client.py index 132e0e000..662195ba5 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/client.py +++ b/tests/smoke_tests/load_from_checkpoint_example/client.py @@ -11,7 +11,7 @@ from torch.utils.data import DataLoader from examples.models.cnn_model import Net -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.reporting import JsonReporter from fl4health.reporting.base_reporter import BaseReporter @@ -29,7 +29,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpointer: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, intermediate_client_state_dir: Optional[Path] = None, diff --git a/tests/smoke_tests/load_from_checkpoint_example/server.py b/tests/smoke_tests/load_from_checkpoint_example/server.py index ab5615c2e..c47ab58bc 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/server.py +++ b/tests/smoke_tests/load_from_checkpoint_example/server.py @@ -10,7 +10,7 @@ from examples.models.cnn_model import Net from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.reporting import JsonReporter from fl4health.servers.base_server import FlServer @@ -47,8 +47,8 @@ def main(config: Dict[str, Any], intermediate_server_state_dir: str, server_name # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() checkpointers = [ - BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl"), - LatestTorchCheckpointer(config["checkpoint_path"], "latest_model.pkl"), + BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl"), + LatestTorchModuleCheckpointer(config["checkpoint_path"], "latest_model.pkl"), ] # Server performs simple FedAveraging as its server-side optimization strategy From 12ff1bb7b1394ab60cc283cde674b1a6f1766b10 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Tue, 26 Nov 2024 12:39:42 -0500 Subject: [PATCH 03/13] Committing an initial full migration to the state checkpointing modules --- .../ae_examples/fedprox_vae_example/server.py | 13 +- examples/basic_example/server.py | 8 +- .../docker_basic_example/fl_client/client.py | 24 +- examples/fedopt_example/client.py | 4 +- .../fedpca_examples/dim_reduction/server.py | 8 +- .../fedsimclr_finetuning_example/server.py | 8 +- examples/nnunet_example/client.py | 14 +- .../server.py | 4 +- .../warm_up_example/fedavg_warm_up/client.py | 11 +- fl4health/checkpointing/server_module.py | 311 ++++++++++++++++-- .../adaptive_drift_constraint_client.py | 26 +- fl4health/clients/apfl_client.py | 39 ++- fl4health/clients/basic_client.py | 3 +- fl4health/clients/clipping_client.py | 39 ++- fl4health/clients/constrained_fenda_client.py | 28 +- .../deep_mmd_clients/ditto_deep_mmd_client.py | 33 +- fl4health/clients/ditto_client.py | 38 ++- fl4health/clients/ensemble_client.py | 25 +- fl4health/clients/evaluate_client.py | 29 +- fl4health/clients/fed_pca_client.py | 6 + fl4health/clients/fedper_client.py | 8 + fl4health/clients/fedpm_client.py | 35 +- fl4health/clients/fedrep_client.py | 41 ++- fl4health/clients/fenda_client.py | 51 ++- fl4health/clients/fenda_ditto_client.py | 25 +- fl4health/clients/flash_client.py | 35 +- fl4health/clients/instance_level_dp_client.py | 34 +- .../mkmmd_clients/ditto_mkmmd_client.py | 25 +- .../mkmmd_clients/mr_mtl_mkmmd_client.py | 25 +- fl4health/clients/model_merge_client.py | 4 +- fl4health/clients/moon_client.py | 36 +- fl4health/clients/mr_mtl_client.py | 24 +- fl4health/clients/nnunet_client.py | 116 +++---- .../clients/partial_weight_exchange_client.py | 24 +- fl4health/clients/perfcl_client.py | 36 +- fl4health/clients/scaffold_client.py | 38 ++- fl4health/clients/tabular_data_client.py | 48 ++- .../parameter_exchanger_base.py | 5 +- .../ditto_server.py | 31 +- .../fedprox_server.py | 82 ++--- .../mrmtl_server.py | 39 ++- fl4health/servers/base_server.py | 21 +- .../servers/client_level_dp_fed_avg_server.py | 24 +- fl4health/servers/fedpm_server.py | 30 +- fl4health/servers/instance_level_dp_server.py | 52 +-- fl4health/servers/nnunet_server.py | 117 +++---- fl4health/servers/scaffold_server.py | 61 +++- fl4health/servers/sparse_coo_server.py | 68 ++++ .../tabular_feature_alignment_server.py | 31 +- fl4health/utils/typing.py | 2 - research/cifar10/fedavg/server.py | 23 +- research/cifar10/personal_server.py | 4 +- .../flamby_servers/full_exchange_server.py | 14 +- .../flamby/flamby_servers/personal_server.py | 4 +- research/picai/fedavg/client.py | 18 +- research/picai/fedavg/server.py | 21 +- research/picai/fl_nnunet/start_client.py | 14 +- research/picai/reporting/server.py | 8 +- tests/servers/test_base_server.py | 46 ++- .../load_from_checkpoint_example/client.py | 30 +- .../load_from_checkpoint_example/server.py | 16 +- 61 files changed, 1458 insertions(+), 579 deletions(-) create mode 100644 fl4health/servers/sparse_coo_server.py diff --git a/examples/ae_examples/fedprox_vae_example/server.py b/examples/ae_examples/fedprox_vae_example/server.py index 31b3d559d..055933b56 100644 --- a/examples/ae_examples/fedprox_vae_example/server.py +++ b/examples/ae_examples/fedprox_vae_example/server.py @@ -8,8 +8,10 @@ from examples.ae_examples.fedprox_vae_example.models import MnistVariationalDecoder, MnistVariationalEncoder from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule from fl4health.model_bases.autoencoders_base import VariationalAe -from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger +from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking +from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint from fl4health.servers.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -47,8 +49,11 @@ def main(config: Dict[str, Any]) -> None: model_checkpoint_name = "best_VAE_model.pkl" # To facilitate checkpointing - parameter_exchanger = FullParameterExchanger() + parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint()) checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name) + checkpoint_and_state_module = AdaptiveConstraintServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer + ) # Server performs simple FedAveraging as its server-side optimization strategy and potentially adapts the # FedProx proximal weight mu @@ -70,10 +75,8 @@ def main(config: Dict[str, Any]) -> None: server = FlServer( client_manager=SimpleClientManager(), fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/examples/basic_example/server.py b/examples/basic_example/server.py index 5395e06be..040c35655 100644 --- a/examples/basic_example/server.py +++ b/examples/basic_example/server.py @@ -10,6 +10,7 @@ from examples.models.cnn_model import Net from examples.utils.functions import make_dict_with_epochs_or_steps from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer from fl4health.utils.config import load_config @@ -47,6 +48,9 @@ def main(config: Dict[str, Any]) -> None: BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl"), LatestTorchModuleCheckpointer(config["checkpoint_path"], "latest_model.pkl"), ] + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointers + ) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -65,10 +69,8 @@ def main(config: Dict[str, Any]) -> None: server = FlServer( client_manager=SimpleClientManager(), fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointers, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/examples/docker_basic_example/fl_client/client.py b/examples/docker_basic_example/fl_client/client.py index 5bf59948c..678e138f6 100644 --- a/examples/docker_basic_example/fl_client/client.py +++ b/examples/docker_basic_example/fl_client/client.py @@ -1,10 +1,14 @@ import argparse from pathlib import Path -from typing import Sequence +from typing import Sequence, Tuple import flwr as fl import torch +import torch.nn as nn from flwr.common.typing import Config +from torch.nn.modules.loss import _Loss +from torch.optim import Optimizer +from torch.utils.data import DataLoader from examples.models.cnn_model import Net from fl4health.clients.basic_client import BasicClient @@ -20,17 +24,19 @@ def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.dev self.model = Net() self.parameter_exchanger = FullParameterExchanger() - def setup_client(self, config: Config) -> None: - super().setup_client(config) + def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) - train_loader, validation_loader, num_examples = load_cifar10_data(self.data_path, batch_size) + train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) + return train_loader, val_loader - self.train_loader = train_loader - self.val_loader = validation_loader - self.num_examples = num_examples + def get_criterion(self, config: Config) -> _Loss: + return torch.nn.CrossEntropyLoss() - self.criterion = torch.nn.CrossEntropyLoss() - self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) + def get_optimizer(self, config: Config) -> Optimizer: + return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) + + def get_model(self, config: Config) -> nn.Module: + return Net().to(self.device) if __name__ == "__main__": diff --git a/examples/fedopt_example/client.py b/examples/fedopt_example/client.py index f29b0fd3a..81cd90886 100644 --- a/examples/fedopt_example/client.py +++ b/examples/fedopt_example/client.py @@ -27,9 +27,9 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, ) -> None: - super().__init__(data_path, metrics, device, loss_meter_type, checkpointer) + super().__init__(data_path, metrics, device, loss_meter_type, checkpoint_and_state_module) self.weight_matrix: torch.Tensor self.vocabulary: Vocabulary self.label_encoder: LabelEncoder diff --git a/examples/fedpca_examples/dim_reduction/server.py b/examples/fedpca_examples/dim_reduction/server.py index b11959e6d..57ef8f0bf 100644 --- a/examples/fedpca_examples/dim_reduction/server.py +++ b/examples/fedpca_examples/dim_reduction/server.py @@ -9,6 +9,7 @@ from examples.models.mnist_model import MnistNet from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer from fl4health.utils.config import load_config @@ -48,6 +49,9 @@ def main(config: Dict[str, Any]) -> None: # To facilitate checkpointing checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl") + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer + ) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -66,10 +70,8 @@ def main(config: Dict[str, Any]) -> None: server = FlServer( client_manager=SimpleClientManager(), fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py b/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py index 29b773a20..1a263997b 100644 --- a/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py +++ b/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py @@ -11,6 +11,7 @@ from examples.utils.functions import make_dict_with_epochs_or_steps from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.model_bases.fedsimclr_base import FedSimClrModel from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer @@ -52,6 +53,9 @@ def main(config: Dict[str, Any]) -> None: # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl") + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer + ) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -70,10 +74,8 @@ def main(config: Dict[str, Any]) -> None: server = FlServer( client_manager=SimpleClientManager(), fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/examples/nnunet_example/client.py b/examples/nnunet_example/client.py index 9cdd4d513..cd7974a2e 100644 --- a/examples/nnunet_example/client.py +++ b/examples/nnunet_example/client.py @@ -6,6 +6,9 @@ from pathlib import Path from typing import Optional, Union +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule + with warnings.catch_warnings(): # Silence deprecation warnings from sentry sdk due to flwr and wandb # https://github.com/adap/flower/issues/4086 @@ -68,6 +71,13 @@ def main( pred_transforms=[torch.sigmoid, get_segs_from_probs], ) + if intermediate_client_state_dir is not None: + checkpoint_and_state_module = ClientCheckpointAndStateModule( + state_checkpointer=PerRoundStateCheckpointer(Path(intermediate_client_state_dir)) + ) + else: + checkpoint_and_state_module = None + # Create client client = NnunetClient( # Args specific to nnUNetClient @@ -80,9 +90,7 @@ def main( device=device, metrics=[dice], progress_bar=verbose, - intermediate_client_state_dir=( - Path(intermediate_client_state_dir) if intermediate_client_state_dir is not None else None - ), + checkpoint_and_state_module=checkpoint_and_state_module, client_name=client_name, ) diff --git a/examples/sparse_tensor_partial_exchange_example/server.py b/examples/sparse_tensor_partial_exchange_example/server.py index 780b33a28..f2d18cb82 100644 --- a/examples/sparse_tensor_partial_exchange_example/server.py +++ b/examples/sparse_tensor_partial_exchange_example/server.py @@ -8,7 +8,7 @@ from examples.models.cnn_model import Net from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.servers.base_server import FlServer +from fl4health.servers.sparse_coo_server import SparseCooServer from fl4health.strategies.fedavg_sparse_coo_tensor import FedAvgSparseCooTensor from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn @@ -59,7 +59,7 @@ def main(config: Dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = SparseCooServer(client_manager=client_manager, fl_config=config, strategy=strategy) fl.server.start_server( server=server, diff --git a/examples/warm_up_example/fedavg_warm_up/client.py b/examples/warm_up_example/fedavg_warm_up/client.py index 85cf733ae..5b9aa463f 100644 --- a/examples/warm_up_example/fedavg_warm_up/client.py +++ b/examples/warm_up_example/fedavg_warm_up/client.py @@ -31,17 +31,18 @@ def __init__( device: torch.device, checkpoint_dir: str, ) -> None: + # Checkpointing is crucial for the warm up process + checkpoint_name = f"client_{self.client_name}_latest_model.pkl" + post_aggregation_checkpointer = LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + super().__init__( data_path=data_path, metrics=metrics, device=device, + checkpoint_and_state_module=checkpoint_and_state_module, ) - # Checkpointing is crucial for the warm up process - checkpoint_name = f"client_{self.client_name}_latest_model.pkl" - post_aggregation_checkpointer = LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) - self.checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) batch_size = narrow_dict_type(config, "batch_size", int) diff --git a/fl4health/checkpointing/server_module.py b/fl4health/checkpointing/server_module.py index 90a169065..55af4484d 100644 --- a/fl4health/checkpointing/server_module.py +++ b/fl4health/checkpointing/server_module.py @@ -5,12 +5,19 @@ from flwr.common import Parameters from flwr.common.logger import log from flwr.common.parameter import parameters_to_ndarrays -from flwr.common.typing import Scalar +from flwr.common.typing import NDArrays, Scalar from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer, TorchModuleCheckpointer +from fl4health.checkpointing.opacus_checkpointer import OpacusCheckpointer from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking -from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint -from fl4health.utils.typing import ExchangerType +from fl4health.parameter_exchange.parameter_exchanger_base import ExchangerType +from fl4health.parameter_exchange.parameter_packer import ( + ParameterPackerAdaptiveConstraint, + ParameterPackerWithClippingBit, + ParameterPackerWithControlVariates, + ParameterPackerWithLayerNames, + SparseCooParameterPacker, +) CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None @@ -37,7 +44,7 @@ def __init__( to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. Recall that servers only have parameters rather than torch models. So we need to know where to route these parameters to allow for real models to be saved. Defaults to None. - parameter_exchanger (FullParameterExchangerWithPacking | None, optional): This will facilitate routing the + parameter_exchanger (ExchangerType | None, optional): This will facilitate routing the server parameters into the right components of the provided model architecture. Note that this exchanger and the model must match the one used for training and exchange with the servers to ensure parameters go to the right places. Defaults to None. @@ -104,9 +111,9 @@ def maybe_checkpoint(self, server_parameters: Parameters, loss: float, metrics: used by checkpointer to decide whether to checkpoint the model. """ if self.model_checkpointers is not None and len(self.model_checkpointers) > 0: + assert self.model is not None self._hydrate_model_for_checkpointing(server_parameters) for checkpointer in self.model_checkpointers: - assert self.model is not None checkpointer.maybe_checkpoint(self.model, loss, metrics) else: log(INFO, "No model checkpointers specified. Skipping any checkpointing.") @@ -194,7 +201,7 @@ def maybe_load_state(self, state_checkpoint_name: str) -> Dict[str, Any] | None: raise ValueError("Attempting to load state, but no state checkpointer is specified") -class ScaffoldServerCheckpointAndStateModule(BaseServerCheckpointAndStateModule): +class PackingServerCheckpointAndAndStateModule(BaseServerCheckpointAndStateModule): def __init__( self, model: nn.Module | None = None, @@ -203,13 +210,12 @@ def __init__( state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ - This module is meant to handle SCAFFOLD model and state checkpointing on the server-side of an FL process. - Unlike the module on the client side, this module has no concept of pre- or post-aggregation checkpointing. - It only considers checkpointing the global server model after aggregation, perhaps based on validation - statistics retrieved on the client side by running a federated evaluation step. Multiple model checkpointers - may be used. For state checkpointing, which saves the state of the entire server-side FL process to help with - FL restarts, we allow only a single checkpointer responsible for saving the state after each fit and eval - round of FL. + + This module is meant to be a base class for any server-side checkpointing module that relies on unpacking + of parameters to hydrate models for checkpointing. The specifics of the unpacking will be handled by the + child classes of the packer within the parameter exchange. + NOTE: This function ASSUMES full parameter exchange unpacking. If more complex unpacking/parameter exchange + is used, this is not the right parent class. Args: model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture @@ -217,34 +223,226 @@ def __init__( Recall that servers only have parameters rather than torch models. So we need to know where to route these parameters to allow for real models to be saved. Defaults to None. parameter_exchanger (FullParameterExchangerWithPacking | None, optional): This will facilitate routing the - server parameters into the right components of the provided model architecture. Note that this - exchanger and the model must match the one used for training and exchange with the servers to ensure - parameters go to the right places. Defaults to None. + server parameters into the right components of the provided model architecture. It specifically also + should handle any necessary unpacking of the parameters. Note that this exchanger and the model must + match the one used for training and exchange with the servers to ensure parameters go to the right + places. Defaults to None. model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this checkpointer will save much more than just the model being trained. Defaults to None. """ + if parameter_exchanger is not None: + assert isinstance( + parameter_exchanger, FullParameterExchangerWithPacking + ), "Parameter exchanger must be of based type FullParameterExchangerWithPacking" super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) def _hydrate_model_for_checkpointing(self, server_parameters: Parameters): + """ + This function is used as a means of saving the server-side model after aggregation in the FL training + trajectory. Presently, the server only holds Flower Parameters, which are essentially just ndarrays. Without + knowledge of a model architecture to which the arrays correspond. Thus, in the default implementation, we + require that a torch architecture and a parameter exchanger be provided which handles mapping these numpy + arrays into the architecture properly. + + This function overrides the base functionality of model hydration to insert an additional unpacking step + using the unpacking function of the specific type of parameter exchanger. + + NOTE: This function stores the weights directly in the self.model attribute + Args: + server_parameters (Parameters): Parameters to be injected into the torch model architecture and + checkpointed. + """ assert self.model is not None, "Hydrate model for checkpoint called but self.model is None" assert ( self.parameter_exchanger is not None ), "Hydrate model for checkpoint called but self.parameter_exchanger is None" packed_parameters = parameters_to_ndarrays(server_parameters) - # Don't need the control variates for checkpointing. assert isinstance(self.parameter_exchanger, FullParameterExchangerWithPacking) + # Use the unpacking function of the parameter exchange to handle extraction of model parameters model_ndarrays, _ = self.parameter_exchanger.unpack_parameters(packed_parameters) - self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model) + self.parameter_exchanger.pull_parameters(model_ndarrays, self.model) -class AdaptiveConstraintServerCheckpointAndStateModule(BaseServerCheckpointAndStateModule): +class ScaffoldServerCheckpointAndStateModule(PackingServerCheckpointAndAndStateModule): + def __init__( + self, + model: nn.Module | None = None, + model_checkpointers: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, + ) -> None: + """ + This module is meant to handle SCAFFOLD model and state checkpointing on the server-side of an FL process. + Unlike the module on the client side, this module has no concept of pre- or post-aggregation checkpointing. + It only considers checkpointing the global server model after aggregation, perhaps based on validation + statistics retrieved on the client side by running a federated evaluation step. Multiple model checkpointers + may be used. For state checkpointing, which saves the state of the entire server-side FL process to help with + FL restarts, we allow only a single checkpointer responsible for saving the state after each fit and eval + round of FL. + + Args: + model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture + to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. + Recall that servers only have parameters rather than torch models. So we need to know where to route + these parameters to allow for real models to be saved. Defaults to None. + model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be + used to preserve FL training state to facilitate restarting training if interrupted. Generally, this + checkpointer will save much more than just the model being trained. Defaults to None. + """ + if model is not None: + model_size = len(self.model.state_dict()) + parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)) + else: + parameter_exchanger = None + super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) + + +class AdaptiveConstraintServerCheckpointAndStateModule(PackingServerCheckpointAndAndStateModule): def __init__( self, model: nn.Module | None = None, - parameter_exchanger: FullParameterExchangerWithPacking | None = None, + model_checkpointers: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, + ) -> None: + """ + This module is meant to handle FL flows with adaptive constraints, where the server and client communicate + a loss weight parameter in addition to the model weights. Unlike the module on the client side, this module + has no concept of pre- or post-aggregation checkpointing. It only considers checkpointing the global server + model after aggregation, perhaps based on validation statistics retrieved on the client side by running a + federated evaluation step. Multiple model checkpointers may be used. For state checkpointing, which saves the + state of the entire server-side FL process to help with FL restarts, we allow only a single checkpointer + responsible for saving the state after each fit and eval round of FL. + + Args: + model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture + to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. + Recall that servers only have parameters rather than torch models. So we need to know where to route + these parameters to allow for real models to be saved. Defaults to None. + model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be + used to preserve FL training state to facilitate restarting training if interrupted. Generally, this + checkpointer will save much more than just the model being trained. Defaults to None. + """ + if model is not None: + parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint()) + else: + parameter_exchanger = None + super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) + + +class ClippingBitServerCheckpointAndStateModule(PackingServerCheckpointAndAndStateModule): + def __init__( + self, + model: nn.Module | None = None, + model_checkpointers: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, + ) -> None: + """ + This module is meant to handle FL flows with clipping bits being passed to the server along with the model + weights. This is used for DP-FL with adaptive clipping. Unlike the module on the client side, this module + has no concept of pre- or post-aggregation checkpointing. It only considers checkpointing the global server + model after aggregation, perhaps based on validation statistics retrieved on the client side by running a + federated evaluation step. Multiple model checkpointers may be used. For state checkpointing, which saves the + state of the entire server-side FL process to help with FL restarts, we allow only a single checkpointer + responsible for saving the state after each fit and eval round of FL. + + Args: + model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture + to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. + Recall that servers only have parameters rather than torch models. So we need to know where to route + these parameters to allow for real models to be saved. Defaults to None. + model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be + used to preserve FL training state to facilitate restarting training if interrupted. Generally, this + checkpointer will save much more than just the model being trained. Defaults to None. + """ + if model is not None: + parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerWithClippingBit()) + else: + parameter_exchanger = None + super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) + + +class LayerNamesServerCheckpointAndStateModule(PackingServerCheckpointAndAndStateModule): + def __init__( + self, + model: nn.Module | None = None, + model_checkpointers: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, + ) -> None: + """ + This module is meant to handle FL flows with layer names being passed to the server along with the model + weights. This is used for adaptive layer exchange FL. Unlike the module on the client side, this module + has no concept of pre- or post-aggregation checkpointing. It only considers checkpointing the global server + model after aggregation, perhaps based on validation statistics retrieved on the client side by running a + federated evaluation step. Multiple model checkpointers may be used. For state checkpointing, which saves the + state of the entire server-side FL process to help with FL restarts, we allow only a single checkpointer + responsible for saving the state after each fit and eval round of FL. + + Args: + model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture + to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. + Recall that servers only have parameters rather than torch models. So we need to know where to route + these parameters to allow for real models to be saved. Defaults to None. + model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be + used to preserve FL training state to facilitate restarting training if interrupted. Generally, this + checkpointer will save much more than just the model being trained. Defaults to None. + """ + if model is not None: + parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerWithClippingBit()) + else: + parameter_exchanger = None + super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) + + +class SparseCooServerCheckpointAndStateModule(PackingServerCheckpointAndAndStateModule): + def __init__( + self, + model: nn.Module | None = None, + model_checkpointers: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, + ) -> None: + """ + This module is meant to handle FL flows with parameters encoded in a sparse COO format being passed to the + server as the model weights. This is used for adaptive parameter-wise exchange (i.e. unstructured subsets of + parameters) . Unlike the module on the client side, this module has no concept of pre- or post-aggregation + checkpointing. It only considers checkpointing the global server model after aggregation, perhaps based on + validation statistics retrieved on the client side by running a federated evaluation step. Multiple model + checkpointers may be used. For state checkpointing, which saves the state of the entire server-side FL process + to help with FL restarts, we allow only a single checkpointer responsible for saving the state after each fit + and eval round of FL. + + Args: + model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture + to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. + Recall that servers only have parameters rather than torch models. So we need to know where to route + these parameters to allow for real models to be saved. Defaults to None. + model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be + used to preserve FL training state to facilitate restarting training if interrupted. Generally, this + checkpointer will save much more than just the model being trained. Defaults to None. + """ + if model is not None: + parameter_exchanger = FullParameterExchangerWithPacking(SparseCooParameterPacker()) + else: + parameter_exchanger = None + super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) + + +class OpacusServerCheckpointAndStateModule(BaseServerCheckpointAndStateModule): + def __init__( + self, + model: nn.Module | None = None, + parameter_exchanger: ExchangerType | None = None, model_checkpointers: CheckpointModuleInput = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: @@ -274,14 +472,67 @@ def __init__( """ super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) + self._ensure_checkpointers_are_of_opacus_type() - def _hydrate_model_for_checkpointing(self, server_parameters: Parameters): - assert self.model is not None, "Hydrate model for checkpoint called but self.model is None" - assert ( - self.parameter_exchanger is not None - ), "Hydrate model for checkpoint called but self.parameter_exchanger is None" - packed_parameters = parameters_to_ndarrays(server_parameters) - # Don't need the extra loss weight variable for checkpointing. - assert isinstance(self.parameter_exchanger, FullParameterExchangerWithPacking) - model_ndarrays, _ = self.parameter_exchanger.unpack_parameters(packed_parameters) - self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model) + def _ensure_checkpointers_are_of_opacus_type(self) -> None: + """ + Helper function to ensure that the provided checkpointers are explicitly compatible with Opacus + """ + if self.model_checkpointers is not None: + for checkpointer in self.model_checkpointers: + assert isinstance( + checkpointer, OpacusCheckpointer + ), "Provided checkpointers must have base class OpacusCheckpointer" + + +class NnUnetServerCheckpointAndStateModule(BaseServerCheckpointAndStateModule): + def __init__( + self, + model: nn.Module | None = None, + parameter_exchanger: ExchangerType | None = None, + model_checkpointers: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, + ) -> None: + """ + This module is meant to be used with the NnUnetServer class to handle model and state checkpointing on the + server-side of an FL process. Unlike the module on the client side, this module has no concept of pre- or + post-aggregation checkpointing. It only considers checkpointing the global server model after aggregation, + perhaps based on validation statistics retrieved on the client side by running a federated evaluation step. + Multiple model checkpointers may be used. For state checkpointing, which saves the state of the entire + server-side FL process to help with FL restarts, we allow only a single checkpointer responsible for saving + the state after each fit and eval round of FL. + + This implementation differs from the base class in the federated NnUnet only initializes its model after an + initial communication phase with the clients. As such, the model is not necessarily available upon + initialization, but may be set later. + + Args: + model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture + to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. + Recall that servers only have parameters rather than torch models. So we need to know where to route + these parameters to allow for real models to be saved. Defaults to None. + NOTE: For NnUnet, this need not be set upon creation, as the model architecture may only be known later + parameter_exchanger (FullParameterExchangerWithPacking | None, optional): This will facilitate routing the + server parameters into the right components of the provided model architecture. Note that this + exchanger and the model must match the one used for training and exchange with the servers to ensure + parameters go to the right places. Defaults to None. + model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be + used to preserve FL training state to facilitate restarting training if interrupted. Generally, this + checkpointer will save much more than just the model being trained. Defaults to None. + """ + self.model = model + self.parameter_exchanger = parameter_exchanger + self.model_checkpointers = ( + [model_checkpointers] if isinstance(model_checkpointers, TorchModuleCheckpointer) else model_checkpointers + ) + self.state_checkpointer = state_checkpointer + if self.model_checkpointers is not None and len(self.model_checkpointers): + # NOTE: We only check if the parameter exchanger is present. Model may be set later. + assert self.parameter_exchanger is not None, ( + "Checkpointer(s) is (are) defined but no parameter_exchanger is defined to hydrate. The functionality " + "of this class can be overridden in a child class if checkpointing without a parameter exchanger is " + "possible and desired" + ) + self._check_if_shared_checkpoint_names() diff --git a/fl4health/clients/adaptive_drift_constraint_client.py b/fl4health/clients/adaptive_drift_constraint_client.py index aa3b11e3a..2a482cbb7 100644 --- a/fl4health/clients/adaptive_drift_constraint_client.py +++ b/fl4health/clients/adaptive_drift_constraint_client.py @@ -26,9 +26,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: """ This client serves as a base for FL methods implementing an auxiliary loss penalty with a weight coefficient @@ -45,28 +46,33 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the client should send data to. - progress_bar (bool): Whether or not to display a progress bar during client training and validation. - Uses tqdm. Defaults to False + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, progress_bar=progress_bar, + client_name=client_name, ) # These are the tensors that will be used to compute the penalty loss self.drift_penalty_tensors: List[torch.Tensor] # Exchanger with packing to be able to exchange the weights and auxiliary information with the server for # adaptation - self.parameter_exchanger: FullParameterExchangerWithPacking + self.parameter_exchanger: FullParameterExchangerWithPacking[float] # Weight on the penalty loss to be used in backprop. This is what might be adapted via server calculations self.drift_penalty_weight: float # This is the loss value to be sent back to the server on which adaptation decisions will be made. diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index 9df872e67..351418ea8 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -22,10 +22,45 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: - super().__init__(data_path, metrics, device, loss_meter_type, checkpointer, reporters) + """ + Client specifically implementing the APFL Algorithm: https://arxiv.org/abs/2003.13461 + Twin models are trained. One of them is globally shared by all clients and aggregated on the server. + The other is strictly trained locally by each client. Predictions are made by a convex combination of the models. + + Args: + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. + """ + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, + ) self.model: ApflModule self.learning_rate: float diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index fc6044417..9b15c21b2 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -1,6 +1,6 @@ import datetime from collections.abc import Sequence -from logging import INFO, WARNING +from logging import INFO from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union @@ -14,7 +14,6 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader -from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer from fl4health.checkpointing.client_module import CheckpointMode, ClientCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger diff --git a/fl4health/clients/clipping_client.py b/fl4health/clients/clipping_client.py index 9c557350e..857ab370a 100644 --- a/fl4health/clients/clipping_client.py +++ b/fl4health/clients/clipping_client.py @@ -14,31 +14,56 @@ from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithClippingBit +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric class NumpyClippingClient(BasicClient): - """ - Client that clips updates being sent to the server where noise is added. - Used to obtain Client Level Differential Privacy in FL setting. - """ - def __init__( self, data_path: Path, metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: + """ + Client that clips updates being sent to the server where noise is added. Used to obtain Client Level + Differential Privacy in FL setting. + + Args: + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. + """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.parameter_exchanger: FullParameterExchangerWithPacking[float] self.clipping_bound: Optional[float] = None diff --git a/fl4health/clients/constrained_fenda_client.py b/fl4health/clients/constrained_fenda_client.py index 25e5e30a9..fe4e3f51e 100644 --- a/fl4health/clients/constrained_fenda_client.py +++ b/fl4health/clients/constrained_fenda_client.py @@ -12,6 +12,7 @@ from fl4health.losses.fenda_loss_config import ConstrainedFendaLossContainer from fl4health.model_bases.fenda_base import FendaModelWithFeatureState from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.client import clone_and_freeze_model from fl4health.utils.losses import EvaluationLosses, LossMeterType from fl4health.utils.metrics import Metric @@ -25,7 +26,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, loss_container: Optional[ConstrainedFendaLossContainer] = None, ) -> None: """ @@ -39,18 +43,30 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. - loss_configuration (Optional[ConstrainedFendaLossContainer], optional): Configuration that determines which + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. + loss_container (Optional[ConstrainedFendaLossContainer], optional): Configuration that determines which losses will be applied during FENDA training. Defaults to None. """ + super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) if loss_container: diff --git a/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py b/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py index 950916266..ed9de5c78 100644 --- a/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py +++ b/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py @@ -11,6 +11,7 @@ from fl4health.clients.ditto_client import DittoClient from fl4health.losses.deep_mmd_loss import DeepMmdLoss from fl4health.model_bases.feature_extractor_buffer import FeatureExtractorBuffer +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.client import clone_and_freeze_model from fl4health.utils.losses import EvaluationLosses, LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric @@ -25,7 +26,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, deep_mmd_loss_weight: float = 10.0, feature_extraction_layers_with_size: Optional[Dict[str, int]] = None, mmd_kernel_train_interval: int = 20, @@ -38,21 +42,29 @@ def __init__( global model. Args: - data_path (Path): path to the data to be used to load the data for client-side training. - metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model. + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or - 'cuda'. + 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. deep_mmd_loss_weight (float, optional): weight applied to the Deep MMD loss. Defaults to 10.0. feature_extraction_layers_with_size (Optional[Dict[str, int]], optional): Dictionary of layers to extract features from them and their respective feature size. Defaults to None. mmd_kernel_update_interval (int, optional): interval at which to train and update the Deep MMD kernel. If set to above 0, the kernel will be train based on whole distribution of latent features of data with - the given train interval. If set to 0, the kernal will not be trained. If set to -1, the kernel will + the given train interval. If set to 0, the kernel will not be trained. If set to -1, the kernel will be trained after each individual batch based on only that individual batch. Defaults to 20. num_accumulating_batches (int, optional): Number of batches to accumulate features to approximate the whole distribution of the latent features for updating Deep MMD kernel. This parameter is only used @@ -63,7 +75,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.deep_mmd_loss_weight = deep_mmd_loss_weight if self.deep_mmd_loss_weight == 0: diff --git a/fl4health/clients/ditto_client.py b/fl4health/clients/ditto_client.py index e46839a01..6ed5ee8ff 100644 --- a/fl4health/clients/ditto_client.py +++ b/fl4health/clients/ditto_client.py @@ -25,9 +25,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: """ This client implements the Ditto algorithm from Ditto: Fair and Robust Federated Learning Through @@ -40,30 +41,33 @@ def __init__( corresponding strategy used by the server Args: - data_path (Path): path to the data to be used to load the data for - client-side training - metrics (Sequence[Metric]): Metrics to be computed based on the labels and - predictions of the client model - device (torch.device): Device indicator for where to send the model, - batches, labels etc. Often 'cpu' or 'cuda' - loss_meter_type (LossMeterType, optional): Type of meter used to track and - compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer - module defining when and how to do checkpointing during client-side - training. No checkpointing is done if not provided. Defaults to None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the client should send data to. - progress_bar (bool): Whether or not to display a progress bar during client training and validation. - Uses tqdm. Defaults to False + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, progress_bar=progress_bar, + client_name=client_name, ) self.global_model: nn.Module diff --git a/fl4health/clients/ensemble_client.py b/fl4health/clients/ensemble_client.py index d06bec5d2..bc2e1df8a 100644 --- a/fl4health/clients/ensemble_client.py +++ b/fl4health/clients/ensemble_client.py @@ -8,6 +8,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.model_bases.ensemble_base import EnsembleModel +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import EvaluationLosses, LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType @@ -20,7 +21,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: """ This client enables the training of ensemble models in a federated manner. @@ -32,16 +36,27 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.model: EnsembleModel diff --git a/fl4health/clients/evaluate_client.py b/fl4health/clients/evaluate_client.py index 39ac6f4c7..08954fc08 100644 --- a/fl4health/clients/evaluate_client.py +++ b/fl4health/clients/evaluate_client.py @@ -21,13 +21,6 @@ class EvaluateClient(BasicClient): - """ - This client implements an evaluation only flow. That is, there is no expectation of parameter exchange with the - server past the model initialization stage. The implementing client should instantiate a global model if one is - expected from the server, which will be loaded using the passed parameters. If a model checkpoint path is provided - the client attempts to load a local model from the specified path. - """ - def __init__( self, data_path: Path, @@ -36,10 +29,30 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, model_checkpoint_path: Optional[Path] = None, reporters: Sequence[BaseReporter] | None = None, + client_name: Optional[str] = None, ) -> None: + """ + This client implements an evaluation only flow. That is, there is no expectation of parameter exchange with + the server past the model initialization stage. The implementing client should instantiate a global model if + one is expected from the server, which will be loaded using the passed parameters. If a model checkpoint path + is provided the client attempts to load a local model from the specified path. + + Args: + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + model_checkpoint_path (Optional[Path], optional): _description_. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Defaults to None. + """ # EvaluateClient does not call BasicClient constructor and sets attributes # in a custom way to account for the fact it does not involve any training - self.client_name = generate_hash() + self.client_name = generate_hash() if client_name is None else client_name self.data_path = data_path self.device = device self.model_checkpoint_path = model_checkpoint_path diff --git a/fl4health/clients/fed_pca_client.py b/fl4health/clients/fed_pca_client.py index c4ce3cede..773fb7966 100644 --- a/fl4health/clients/fed_pca_client.py +++ b/fl4health/clients/fed_pca_client.py @@ -20,6 +20,12 @@ class FedPCAClient(NumPyClient): def __init__(self, data_path: Path, device: torch.device, model_save_path: Path) -> None: """ Client that facilitates the execution of federated PCA. + + Args: + data_path (Path): path to the data to be used to load the data for client-side training + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + model_save_path (Path): Path to save the PCA components for use later, perhaps in dimensionality reduction """ self.client_name = self.generate_hash() self.model: PcaModule diff --git a/fl4health/clients/fedper_client.py b/fl4health/clients/fedper_client.py index fd56a5e29..f96ab5b1f 100644 --- a/fl4health/clients/fedper_client.py +++ b/fl4health/clients/fedper_client.py @@ -7,6 +7,14 @@ class FedPerClient(BasicClient): + """ + Client to implement the FedPer method (https://arxiv.org/abs/1912.00818). Trains a global feature extractor + shared by all clients through FedAvg and a private classifier that is unique to each client. The training is + nearly identical to the BasicClient with the exception that our parameter exchanger needs to be a fixed layer + exchanger that only exchanges the feature extraction base, which relies on the model being of + type SequentiallySplitExchangeBaseModel. + """ + def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: assert isinstance(self.model, SequentiallySplitExchangeBaseModel), ( "Models for FedPer must be of type SequentiallySplitExchangeBaseModel to facilitate partial weight " diff --git a/fl4health/clients/fedpm_client.py b/fl4health/clients/fedpm_client.py index c46220fef..b0c2deffe 100644 --- a/fl4health/clients/fedpm_client.py +++ b/fl4health/clients/fedpm_client.py @@ -22,16 +22,47 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: + """ + Client implementing the FedPM algorithm (https://arxiv.org/pdf/2209.15328). FedPM is a recent sparse, + communication efficient approach to federated learning. The method has been shown to have exceptional + information compression while maintaining good performance. Interestingly, it is also connected to the + Lottery Ticket Hypothesis. Training on the client-side is effectively the same as BasicClient. The two + components that change are ensuring that the model to be training is a Masked Model compatible with FedPM + (or to convert it to one). Second, we use the FedPM exchanger to facilitate exchange with the server. + + Args: + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. + """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) def setup_client(self, config: Config) -> None: diff --git a/fl4health/clients/fedrep_client.py b/fl4health/clients/fedrep_client.py index 338314ed7..6ae9f1681 100644 --- a/fl4health/clients/fedrep_client.py +++ b/fl4health/clients/fedrep_client.py @@ -36,10 +36,47 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: - super().__init__(data_path, metrics, device, loss_meter_type, checkpointer, reporters) + """ + Client implementing the training of FedRep (https://arxiv.org/abs/2303.05206). + Similar to FedPer, FedRep trains a global feature extractor shared by all clients through FedAvg and a + private classifier that is unique to each client. However, FedRep breaks up the client-side training of + these components into two phases. First the local classifier is trained with the feature extractor frozen. + Next, the classifier is frozen and the feature extractor is trained. + + Args: + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. + """ + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, + ) self.fedrep_train_mode = FedRepTrainMode.HEAD def _prepare_train_representations(self) -> None: diff --git a/fl4health/clients/fenda_client.py b/fl4health/clients/fenda_client.py index 644aae513..8abaeb0a9 100644 --- a/fl4health/clients/fenda_client.py +++ b/fl4health/clients/fenda_client.py @@ -9,6 +9,7 @@ from fl4health.model_bases.fenda_base import FendaModel from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric @@ -20,31 +21,47 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: - super().__init__( - data_path=data_path, - metrics=metrics, - device=device, - loss_meter_type=loss_meter_type, - checkpointer=checkpointer, - ) """ This client is used to perform client-side training associated with the FENDA method described in https://arxiv.org/pdf/2309.16825. The approach splits a model being trained into parallel feature extractors whose latent feature spaces are then further processed by a classification head. The global feature extractor is federally trained with FedAvg and the local feature extractor and classification head are exclusively - trained locally. This is closely related (and essentially an ablation of the PerFCL method). + trained locally. This is closely related (and is essentially an ablation of) the PerFCL method. + Args: - data_path (Path): Path to the data directory. - metrics (Sequence[Metric]): List of metrics to be used for evaluation. - device (torch.device): Device to be used for training. - loss_meter_type (LossMeterType, optional): Type of loss meter to be used. Defaults to - LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. """ + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, + ) def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: assert isinstance(self.model, FendaModel) diff --git a/fl4health/clients/fenda_ditto_client.py b/fl4health/clients/fenda_ditto_client.py index 8775cc2ab..5df746ca9 100644 --- a/fl4health/clients/fenda_ditto_client.py +++ b/fl4health/clients/fenda_ditto_client.py @@ -25,9 +25,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, + client_name: Optional[str] = None, freeze_global_feature_extractor: bool = False, ) -> None: """ @@ -64,14 +65,17 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the client should send data to. - progress_bar (bool): Whether or not to display a progress bar during client training and validation. - Uses tqdm. Defaults to False - + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. freeze_global_feature_extractor (bool, optional): Determines whether we freeze the FENDA global feature extractor during training. If freeze_global_feature_extractor is False, both the global and the local feature extractor in the local FENDA model will be trained. Otherwise, the global feature extractor @@ -84,9 +88,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, progress_bar=progress_bar, + client_name=client_name, ) self.global_model: SequentiallySplitModel self.model: FendaModel diff --git a/fl4health/clients/flash_client.py b/fl4health/clients/flash_client.py index a5dda30f6..ef4a496c4 100644 --- a/fl4health/clients/flash_client.py +++ b/fl4health/clients/flash_client.py @@ -8,6 +8,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.client import check_if_batch_is_empty_and_verify_input, move_data_to_device from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType @@ -21,7 +22,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: """ This client is used to perform client-side training associated with the Flash method described in @@ -30,20 +34,33 @@ def __init__( drift-aware adaptive optimization. Args: - data_path (Path): Path to the data directory. - metrics (Sequence[Metric]): List of metrics to be used for evaluation. - device (torch.device): Device to be used for training. - loss_meter_type (LossMeterType, optional): - Type of loss meter to be used. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to do - checkpointing during client-side training. No checkpointing is done if not provided. Defaults to None. + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) # gamma: Threshold for early stopping based on the change in validation loss. self.gamma: Optional[float] = None diff --git a/fl4health/clients/instance_level_dp_client.py b/fl4health/clients/instance_level_dp_client.py index 56fccd156..f903e66eb 100644 --- a/fl4health/clients/instance_level_dp_client.py +++ b/fl4health/clients/instance_level_dp_client.py @@ -15,26 +15,48 @@ class InstanceLevelDpClient(BasicClient): - """ - Client for Instance/Record level Differentially Private Federated Averaging - """ - def __init__( self, data_path: Path, metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: + """ + Client for Instance/Record level Differentially Private Federated Averaging + + Args: + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. + """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.clipping_bound: float self.noise_multiplier: float diff --git a/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py b/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py index 327e46482..9fa724c91 100644 --- a/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py +++ b/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py @@ -11,6 +11,7 @@ from fl4health.clients.ditto_client import DittoClient from fl4health.losses.mkmmd_loss import MkMmdLoss from fl4health.model_bases.feature_extractor_buffer import FeatureExtractorBuffer +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.client import clone_and_freeze_model from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric @@ -24,7 +25,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, mkmmd_loss_weight: float = 10.0, feature_extraction_layers: Optional[Sequence[str]] = None, feature_l2_norm_weight: float = 0.0, @@ -43,9 +47,17 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. mkmmd_loss_weight (float, optional): weight applied to the MK-MMD loss. Defaults to 10.0. feature_extraction_layers (Optional[Sequence[str]], optional): List of layers from which to extract and flatten features. Defaults to None. @@ -64,7 +76,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.mkmmd_loss_weight = mkmmd_loss_weight if self.mkmmd_loss_weight == 0: diff --git a/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py b/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py index ca3e691f3..642f6752c 100644 --- a/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py +++ b/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py @@ -10,6 +10,7 @@ from fl4health.clients.mr_mtl_client import MrMtlClient from fl4health.losses.mkmmd_loss import MkMmdLoss from fl4health.model_bases.feature_extractor_buffer import FeatureExtractorBuffer +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType @@ -22,7 +23,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, mkmmd_loss_weight: float = 10.0, feature_extraction_layers: Optional[Sequence[str]] = None, feature_l2_norm_weight: float = 0.0, @@ -42,9 +46,17 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. mkmmd_loss_weight (float, optional): weight applied to the MK-MMD loss. Defaults to 10.0. feature_extraction_layers (Optional[Sequence[str]], optional): List of layers from which to extract and flatten features. Defaults to None. @@ -63,7 +75,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.mkmmd_loss_weight = mkmmd_loss_weight if self.mkmmd_loss_weight == 0: diff --git a/fl4health/clients/model_merge_client.py b/fl4health/clients/model_merge_client.py index f6778ff99..6bed66108 100644 --- a/fl4health/clients/model_merge_client.py +++ b/fl4health/clients/model_merge_client.py @@ -64,7 +64,7 @@ def __init__( def setup_client(self, config: Config) -> None: """ Sets up Merge Client by initializing model, dataloader and parameter exchanger - with user defined methods. Subsquently, sets initialized attribute to True. + with user defined methods. Subsequently, sets initialized attribute to True. Args: config (Config): The configuration from the server. @@ -223,7 +223,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: Parameter exchange is assumed to always be full for model merging clients. However, this functionality may be overridden if a different exchanger is needed. - Used in non-standard way for ModelMergClient as set_parameters is only called for evaluate as + Used in non-standard way for ModelMergeClient as set_parameters is only called for evaluate as parameters should initially be set to the parameters in the nn.Module returned by get_model. Args: diff --git a/fl4health/clients/moon_client.py b/fl4health/clients/moon_client.py index e6da07ec8..7acb00b02 100644 --- a/fl4health/clients/moon_client.py +++ b/fl4health/clients/moon_client.py @@ -9,6 +9,7 @@ from fl4health.clients.basic_client import BasicClient, Config from fl4health.losses.contrastive_loss import MoonContrastiveLoss from fl4health.model_bases.sequential_split_models import SequentiallySplitModel +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.client import clone_and_freeze_model from fl4health.utils.losses import EvaluationLosses, LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric @@ -22,7 +23,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, temperature: float = 0.5, contrastive_weight: float = 1.0, len_old_models_buffer: int = 1, @@ -33,14 +37,23 @@ def __init__( loss to constrain the local training of individual parties in the non-IID setting. Args: - data_path (Path): Path to the data directory. - metrics (Sequence[Metric]): List of metrics to be used for evaluation. - device (torch.device): Device to be used for training. - loss_meter_type (LossMeterType, optional): Type of loss meter to be used. Defaults to - LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. temperature (float, optional): Temperature used in the calculation of the contrastive loss. Defaults to 0.5. contrastive_weight (float, optional): Weight placed on the contrastive loss function. Referred to as mu @@ -53,7 +66,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.temperature = temperature self.contrastive_weight = contrastive_weight diff --git a/fl4health/clients/mr_mtl_client.py b/fl4health/clients/mr_mtl_client.py index 3b2894d19..cd671771c 100644 --- a/fl4health/clients/mr_mtl_client.py +++ b/fl4health/clients/mr_mtl_client.py @@ -22,9 +22,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: """ This client implements the MR-MTL algorithm from MR-MTL: On Privacy and Personalization in Cross-Silo @@ -45,22 +46,27 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the client should send data to. - progress_bar (bool): Whether or not to display a progress bar during client training and validation. - Uses tqdm. Defaults to False + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, progress_bar=progress_bar, + client_name=client_name, ) # NOTE: The initial global model is used to house the aggregate weight updates at the beginning of a round, # because in MR-MTL, the local models are not updated with these aggregates. diff --git a/fl4health/clients/nnunet_client.py b/fl4health/clients/nnunet_client.py index 131e8af7a..755eb7b98 100644 --- a/fl4health/clients/nnunet_client.py +++ b/fl4health/clients/nnunet_client.py @@ -73,93 +73,75 @@ def __init__( verbose: bool = True, metrics: Optional[Sequence[Metric]] = None, progress_bar: bool = False, - intermediate_client_state_dir: Optional[Path] = None, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, client_name: Optional[str] = None, nnunet_trainer_class: Type[nnUNetTrainer] = nnUNetTrainer, nnunet_trainer_class_kwargs: Optional[dict[str, Any]] = {}, ) -> None: """ - A client for training nnunet models. Requires the nnunet environment variables - to be set. Also requires the following additional keys in the config sent from - the server + A client for training nnunet models. Requires the nnunet environment variables to be set. Also requires the + following additional keys in the config sent from the server: 'nnunet_plans': (serialized dict) 'nnunet_config': (str) Args: - device (torch.device): Device indicator for where to send the - model, batches, labels etc. Often 'cpu' or 'cuda' or 'mps' - dataset_id (int): The nnunet dataset id for the local client dataset - to use for training and validation. - fold (Union[int, str]): Which fold of the local client dataset to - use for validation. nnunet defaults to 5 folds (0 to 4). Can - also be set to 'all' to use all the data for both training - and validation. - data_identifier (Optional[str], optional): The nnunet data - identifier prefix to use. The final data identifier will be - {data_identifier}_config where 'config' is the nnunet config - (eg. 2d, 3d_fullres, etc.). If preprocessed data already exists - can be used to specify which preprocessed data to use. The - default data_identifier prefix is the plans name used during - training (see the plans_identifier argument). - plans_identifier (Optional[str], optional): Specify what to save - the client's local copy of the plans file as. The client - modifies the source plans json file sent from the server and - makes a local copy. If left as default None, the plans - identifier will be set as 'FL_Dataset000_plansname' where 000 - is the dataset_id and plansname is the 'plans_name' value of - the source plans file. - compile (bool, optional): If True, the client will jit compile the pytorch - model. This requires some overhead time at the beginning of training to - compile the model, but results in faster training times. Defaults to - True - always_preprocess (bool, optional): If True, will preprocess the - local client dataset even if the preprocessed data already - seems to exist. Defaults to False. The existence of the - preprocessed data is determined by matching the provided - data_identifier with that of the preprocessed data already on + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' or 'mps' + dataset_id (int): The nnunet dataset id for the local client dataset to use for training and validation. + fold (Union[int, str]): Which fold of the local client dataset to use for validation. nnunet defaults to + 5 folds (0 to 4). Can also be set to 'all' to use all the data for both training and validation. + data_identifier (Optional[str], optional): The nnunet data identifier prefix to use. The final data + identifier will be {data_identifier}_config where 'config' is the nnunet config (eg. 2d, 3d_fullres, + etc.). If preprocessed data already exists can be used to specify which preprocessed data to use. + The default data_identifier prefix is the plans name used during training (see the plans_identifier + argument). + plans_identifier (Optional[str], optional): Specify what to save the client's local copy of the plans file + as. The client modifies the source plans json file sent from the server and makes a local copy. + If left as default None, the plans identifier will be set as 'FL_Dataset000_plansname' where 000 is + the dataset_id and plansname is the 'plans_name' value of the source plans file. + compile (bool, optional): If True, the client will jit compile the pytorch model. This requires some + overhead time at the beginning of training to compile the model, but results in faster training times. + Defaults to True + always_preprocess (bool, optional): If True, will preprocess the local client dataset even if the + preprocessed data already seems to exist. Defaults to False. The existence of the preprocessed data + is determined by matching the provided data_identifier with that of the preprocessed data already on the client. - max_grad_norm (float, optional): The maximum gradient norm to use for - gradient clipping. Defaults to 12 which is the nnunetv2 2.5.1 default. - n_dataload_processes (Optional[int], optional): The number of processes to - spawn for each nnunet dataloader. If left as None we use the nnunetv2 - version 2.5.1 defaults for each config - verbose (bool, optional): If True the client will log some extra INFO logs. - Defaults to False unless the log level is DEBUG or lower. - metrics (Sequence[Metric], optional): Metrics to be computed based - on the labels and predictions of the client model. Defaults to []. - progress_bar (bool, optional): Whether or not to print a progress bar to - stdout for training. Defaults to False - intermediate_client_state_dir (Optional[Path]): An optional path to store per round - checkpoints. - loss_meter_type (LossMeterType, optional): Type of meter used to - track and compute the losses over each batch. Defaults to - LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): - Checkpointer module defining when and how to do checkpointing - during client-side training. No checkpointing is done if not - provided. Defaults to None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the client should send data to. - nnunet_trainer_class (Type[nnUNetTrainer]): A nnUNetTrainer constructor. - Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. - Must match the nnunet_trainer_class passed to the NnunetServer. - nnunet_trainer_class_kwargs (dict[str, Any]): Additonal kwargs to pass to nnunet_trainer_class. - Defaults to empty dictionary. + max_grad_norm (float, optional): The maximum gradient norm to use for gradient clipping. Defaults to 12 + which is the nnunetv2 2.5.1 default. + n_dataload_processes (Optional[int], optional): The number of processes to spawn for each nnunet + dataloader. If left as None we use the nnunetv2 version 2.5.1 defaults for each config + verbose (bool, optional): If True the client will log some extra INFO logs. Defaults to False unless + the log level is DEBUG or lower. + metrics (Sequence[Metric], optional): Metrics to be computed based on the labels and predictions of the + client model. Defaults to None. + progress_bar (bool, optional): Whether or not to print a progress bar to stdout for training. Defaults + to False + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each + batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the client should + send data to. + nnunet_trainer_class (Type[nnUNetTrainer]): A nnUNetTrainer constructor. Useful for passing custom + nnUNetTrainer. Defaults to the standard nnUNetTrainer class. Must match the nnunet_trainer_class + passed to the NnunetServer. + nnunet_trainer_class_kwargs (dict[str, Any]): Additional kwargs to pass to nnunet_trainer_class. Defaults + to empty dictionary. """ metrics = metrics if metrics else [] # Parent method sets up several class attributes super().__init__( data_path=Path("dummy/path"), # data_path not used by NnunetClient - metrics=metrics, # self.metrics - device=device, # self.device + metrics=metrics, + device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, # self.checkpointer + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, progress_bar=progress_bar, - intermediate_client_state_dir=intermediate_client_state_dir, client_name=client_name, ) diff --git a/fl4health/clients/partial_weight_exchange_client.py b/fl4health/clients/partial_weight_exchange_client.py index f52233726..1cdd45de7 100644 --- a/fl4health/clients/partial_weight_exchange_client.py +++ b/fl4health/clients/partial_weight_exchange_client.py @@ -25,8 +25,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, store_initial_model: bool = False, ) -> None: """ @@ -42,11 +44,17 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the client should send data to. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. store_initial_model (bool): Indicates whether the client should store a copy of the model weights at the beginning of each training round. The model copy might be required to select the subset of model parameters to be exchanged with the server, depending on the selection criterion used. @@ -57,8 +65,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) # Initial model parameters to be used in selecting parameters to be exchanged during training. self.initial_model: Optional[nn.Module] diff --git a/fl4health/clients/perfcl_client.py b/fl4health/clients/perfcl_client.py index e6ffdd950..695e67e38 100644 --- a/fl4health/clients/perfcl_client.py +++ b/fl4health/clients/perfcl_client.py @@ -10,6 +10,7 @@ from fl4health.model_bases.perfcl_base import PerFclModel from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.client import clone_and_freeze_model from fl4health.utils.losses import EvaluationLosses, LossMeterType from fl4health.utils.metrics import Metric @@ -23,7 +24,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, global_feature_loss_temperature: float = 0.5, local_feature_loss_temperature: float = 0.5, global_feature_contrastive_loss_weight: float = 1.0, @@ -37,14 +41,23 @@ def __init__( related to FENDA, but with additional losses on the latent spaces of the local and global feature extractors. Args: - data_path (Path): Path to the data directory. - metrics (Sequence[Metric]): List of metrics to be used for evaluation. - device (torch.device): Device to be used for training. - loss_meter_type (LossMeterType, optional): Type of loss meter to be used. Defaults to - LossMeterType.AVERAGE. - checkpointer (Optional[ClientCheckpointModule], optional): Checkpointer module defining when and how to - do checkpointing during client-side training. No checkpointing is done if not provided. Defaults to - None. + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. global_feature_loss_temperature (float, optional): Temperature to be used in the contrastive loss associated with constraining the global feature extractor in the PerFCL loss. Defaults to 0.5. local_feature_loss_temperature (float, optional): Temperature to be used in the contrastive loss @@ -59,7 +72,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.global_feature_contrastive_loss_weight = global_feature_contrastive_loss_weight self.local_feature_contrastive_loss_weight = local_feature_contrastive_loss_weight diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index b76d67a15..e159b2e44 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -23,28 +23,50 @@ class ScaffoldClient(BasicClient): - """ - Federated Learning Client for Scaffold strategy. - - Implementation based on https://arxiv.org/pdf/1910.06378.pdf. - """ - def __init__( self, data_path: Path, metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: + """ + Federated Learning Client for Scaffold strategy. + + Implementation based on https://arxiv.org/pdf/1910.06378.pdf. + + Args: + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. + """ super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.learning_rate: float # eta_l in paper self.client_control_variates: Optional[NDArrays] = None # c_i in paper diff --git a/fl4health/clients/tabular_data_client.py b/fl4health/clients/tabular_data_client.py index 59938cc87..c74617414 100644 --- a/fl4health/clients/tabular_data_client.py +++ b/fl4health/clients/tabular_data_client.py @@ -5,14 +5,17 @@ import pandas as pd import torch from flwr.common.logger import log -from flwr.common.typing import Config, NDArray, Scalar +from flwr.common.typing import Config, NDArray, Optional, Scalar from sklearn.pipeline import Pipeline +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.feature_alignment.constants import FEATURE_INFO, INPUT_DIMENSION, OUTPUT_DIMENSION, SOURCE_SPECIFIED from fl4health.feature_alignment.tab_features_info_encoder import TabularFeaturesInfoEncoder from fl4health.feature_alignment.tab_features_preprocessor import TabularFeaturesPreprocessor +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type +from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Metric @@ -24,8 +27,49 @@ def __init__( device: torch.device, id_column: str, targets: Union[str, List[str]], + loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: - super().__init__(data_path, metrics, device) + """ + Client to facilitate federated feature space alignment, specifically for tabular data, and then perform + federated training. + + Args: + data_path (Path): path to the data to be used to load the data for client-side training + metrics (Sequence[Metric]): Metrics to be computed based on the labels and predictions of the client model + device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or + 'cuda' + id_column (str): ID column. This is required for tabular encoding in cyclops, which we leverage. It should + be unique per row, but need not necessarily be a meaningful identifier (i.e. could be row number) + targets (Union[str, List[str]]): The target column or columns name. This allows for multiple targets to + be specified if desired. + loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over + each batch. Defaults to LossMeterType.AVERAGE. + checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + both checkpointing and state saving. The module, and its underlying model and state checkpointing + components will determine when and how to do checkpointing during client-side training. + No checkpointing (state or model) is done if not provided. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + progress_bar (bool, optional): Whether or not to display a progress bar during client training and + validation. Uses tqdm. Defaults to False + client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + If not passed, a hash is randomly generated. Client state will use this as part of its state file + name. Defaults to None. + """ + super().__init__( + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, + ) self.tabular_features_info_encoder: TabularFeaturesInfoEncoder self.tabular_features_preprocessor: TabularFeaturesPreprocessor self.df: pd.DataFrame diff --git a/fl4health/parameter_exchange/parameter_exchanger_base.py b/fl4health/parameter_exchange/parameter_exchanger_base.py index 738b63c12..8a16db996 100644 --- a/fl4health/parameter_exchange/parameter_exchanger_base.py +++ b/fl4health/parameter_exchange/parameter_exchanger_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, TypeVar import torch.nn as nn from flwr.common.typing import Config, NDArrays @@ -15,3 +15,6 @@ def push_parameters( @abstractmethod def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Optional[Config] = None) -> None: raise NotImplementedError + + +ExchangerType = TypeVar("ExchangerType", bound=ParameterExchanger) diff --git a/fl4health/servers/adaptive_constraint_servers/ditto_server.py b/fl4health/servers/adaptive_constraint_servers/ditto_server.py index 21b568b18..c0a860ed1 100644 --- a/fl4health/servers/adaptive_constraint_servers/ditto_server.py +++ b/fl4health/servers/adaptive_constraint_servers/ditto_server.py @@ -1,9 +1,9 @@ -from typing import Sequence +from typing import Callable, Dict, Sequence -from flwr.common.typing import Config +from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager -from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule from fl4health.reporting.base_reporter import BaseReporter from fl4health.servers.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint @@ -15,8 +15,11 @@ def __init__( client_manager: ClientManager, fl_config: Config, strategy: FedAvgWithAdaptiveConstraint, - checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, + checkpoint_and_state_module: AdaptiveConstraintServerCheckpointAndStateModule | None = None, + on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + server_name: str | None = None, + accept_failures: bool = True, ) -> None: """ This is a very basic wrapper class over the FlServer to ensure that the strategy used for Ditto is of type @@ -30,8 +33,10 @@ def __init__( example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. strategy (FedAvgWithAdaptiveConstraint): The aggregation strategy to be used by the server to handle. - client updates and other information potentially sent by the participating clients. For Ditto, the + client updates and other information potentially sent by the participating clients. For MR-MTL, the strategy must be a derivative of the FedAvgWithAdaptiveConstraint class. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should + send data to before and after each round. Defaults to None. checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model artifacts to be used or evaluated after training. The later is used to preserve training state @@ -39,8 +44,15 @@ def __init__( module is provided, no checkpointing or state preservation will happen. Defaults to None. NOTE: For Ditto, the model shared with the server is the GLOBAL MODEL, which isn't the target of FL training for this algorithm. However, one may still want to save this model for other purposes. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the server should send data to before and after each round. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. + accept_failures (bool, optional): Determines whether the server should accept failures during training or + evaluation from clients or not. If set to False, this will cause the server to shutdown all clients + and throw an exception. Defaults to True. """ assert isinstance( strategy, FedAvgWithAdaptiveConstraint @@ -49,6 +61,9 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + checkpoint_and_state_module=checkpoint_and_state_module, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, + accept_failures=accept_failures, ) diff --git a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py index 08bf62cc3..20e7bc67c 100644 --- a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py +++ b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py @@ -1,13 +1,10 @@ -from typing import Optional, Sequence, Union +from typing import Callable, Dict, Sequence -import torch.nn as nn -from flwr.common.parameter import parameters_to_ndarrays -from flwr.common.typing import Config +from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager -from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer +from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking -from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint from fl4health.reporting.base_reporter import BaseReporter from fl4health.servers.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint @@ -19,14 +16,16 @@ def __init__( client_manager: ClientManager, fl_config: Config, strategy: FedAvgWithAdaptiveConstraint, - model: Optional[nn.Module] = None, - checkpointer: Optional[Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]]] = None, reporters: Sequence[BaseReporter] | None = None, + checkpoint_and_state_module: AdaptiveConstraintServerCheckpointAndStateModule | None = None, + on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + server_name: str | None = None, + accept_failures: bool = True, ) -> None: """ This is a wrapper class around FlServer for using the FedProx method that enforces that the - parameter exchanger is a FullParameterExchangerWithPacking of the right type for model rehydration and that - the strategy is of type FedAvgWithAdaptiveConstraint. + strategy is of type FedAvgWithAdaptiveConstraint and that any checkpointing is done with the right server-side + model and state checkpointers. Args: client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if @@ -35,49 +34,40 @@ def __init__( In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. - strategy (FedAvgWithAdaptiveConstraint): The aggregation strategy to be used - by the server to handle. client updates and other information - potentially sent by the participating clients. For FedProx, the strategy - must be a derivative of the FedAvgWithAdaptiveConstraint class. - model (Optional[nn.Module], optional): This is the torch model to be - hydrated by the _hydrate_model_for_checkpointing function, Defaults to - None - checkpointer (Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]], optional): To be provided - if the server should perform server side checkpointing based on some - criteria. If none, then no server-side checkpointing is performed. - Multiple checkpointers can also be passed in a sequence to checkpoint - based on multiple criteria. Defaults to None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should - send data to before and after each round. + strategy (FedAvgWithAdaptiveConstraint): The aggregation strategy to be used by the server to handle. + client updates and other information potentially sent by the participating clients. This is required + to be of type FedAvgWithAdaptiveConstraint to use FedProx + reporters (Sequence[BaseReporter] | None, optional): sequence of FL4Health reporters which the server + should send data to before and after each round. Defaults to None. + checkpoint_and_state_module (AdaptiveConstraintServerCheckpointAndStateModule | None, optional): This + module is used to handle both model checkpointing and state checkpointing. The former is aimed at + saving model artifacts to be used or evaluated after training. The later is used to preserve training + state (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. + accept_failures (bool, optional): Determines whether the server should accept failures during training or + evaluation from clients or not. If set to False, this will cause the server to shutdown all clients + and throw an exception. Defaults to True. """ assert isinstance( strategy, FedAvgWithAdaptiveConstraint ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" - parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint()) + assert isinstance( + checkpoint_and_state_module, + AdaptiveConstraintServerCheckpointAndStateModule, + ), "checkpoint_and_state_module must have type AdaptiveConstraintServerCheckpointAndStateModule" super().__init__( client_manager=client_manager, fl_config=fl_config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, reporters=reporters, + checkpoint_and_state_module=checkpoint_and_state_module, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, + accept_failures=accept_failures, ) - - def _hydrate_model_for_checkpointing(self) -> None: - assert self.server_model is not None, ( - "Model hydration has been called but no server_model is defined to hydrate. The functionality of " - "_hydrate_model_for_checkpointing can be overridden if checkpointing without a torch architecture is " - "possible and desired" - ) - assert self.parameter_exchanger is not None, ( - "Model hydration has been called but no parameter_exchanger is defined to hydrate. The functionality of " - "_hydrate_model_for_checkpointing can be overridden if checkpointing without a parameter exchanger is " - "possible and desired" - ) - # Overriding the standard hydration method to account for the unpacking - packed_parameters = parameters_to_ndarrays(self.parameters) - # Don't need the extra loss weight variable for checkpointing. - assert isinstance(self.parameter_exchanger, FullParameterExchangerWithPacking) - model_ndarrays, _ = self.parameter_exchanger.unpack_parameters(packed_parameters) - self.parameter_exchanger.pull_parameters(model_ndarrays, self.server_model) diff --git a/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py b/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py index 0ad2fd573..ca0c23023 100644 --- a/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py +++ b/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py @@ -1,10 +1,9 @@ -from typing import Optional, Sequence, Union +from typing import Callable, Dict, Sequence -from flwr.common.typing import Config +from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager -from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer -from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule from fl4health.reporting.base_reporter import BaseReporter from fl4health.servers.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint @@ -16,8 +15,11 @@ def __init__( client_manager: ClientManager, fl_config: Config, strategy: FedAvgWithAdaptiveConstraint, - checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, + checkpoint_and_state_module: AdaptiveConstraintServerCheckpointAndStateModule | None = None, + on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + server_name: str | None = None, + accept_failures: bool = True, ) -> None: """ This is a very basic wrapper class over the FlServer to ensure that the strategy used for MR-MTL is of type @@ -33,15 +35,25 @@ def __init__( strategy (FedAvgWithAdaptiveConstraint): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. For MR-MTL, the strategy must be a derivative of the FedAvgWithAdaptiveConstraint class. - checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used - to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state - (including models) such that if FL training is interrupted, the process may be restarted. If no + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should + send data to before and after each round. Defaults to None. + checkpoint_and_state_module (AdaptiveConstraintServerCheckpointAndStateModule | None, optional): This + module is used to handle both model checkpointing and state checkpointing. The former is aimed at + saving model artifacts to be used or evaluated after training. The later is used to preserve training + state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. NOTE: For MR-MTL, the server model is an aggregation of the personal models, which isn't the target of FL training for this algorithm. However, one may still want to save this model for other purposes. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the server should send data to before and after each round. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. + accept_failures (bool, optional): Determines whether the server should accept failures during training or + evaluation from clients or not. If set to False, this will cause the server to shutdown all clients + and throw an exception. Defaults to True. + """ assert isinstance( strategy, FedAvgWithAdaptiveConstraint @@ -50,6 +62,9 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + checkpoint_and_state_module=checkpoint_and_state_module, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, + accept_failures=accept_failures, ) diff --git a/fl4health/servers/base_server.py b/fl4health/servers/base_server.py index 0e071e14a..63024baaa 100644 --- a/fl4health/servers/base_server.py +++ b/fl4health/servers/base_server.py @@ -354,7 +354,7 @@ def evaluate_round( self._terminate_after_unacceptable_failures(timeout) if loss_aggregated: - self.checkpoint_and_state_module.maybe_checkpoint(self.parameters, loss_aggregated, metrics_aggregated) + self._maybe_checkpoint(loss_aggregated, metrics_aggregated, server_round) # Report evaluation results report_data = { "val - loss - aggregated": loss_aggregated, @@ -465,28 +465,15 @@ def _maybe_checkpoint( server_round: int, ) -> None: """ - This function will run through any provided checkpointers to save the server-side model. If no checkpointers - are present, this function simply logs that no server-side checkpointing is performed. - - NOTE: The proper components for model hydration need to be in place for this implementation. If they are - not an exception will be thrown. + This function simply runs the maybe_checkpoint functionality of the checkpoint_and_state_module. If additional + functionality is desired, this function may be overridden. Args: loss_aggregated (float): aggregated loss value that can be used to determine whether to checkpoint metrics_aggregated (Dict[str, Scalar]): aggregated metrics from each of the clients for checkpointing server_round (int): What round of federated training we're on. This is just for logging purposes. """ - if self.checkpointer: - self._hydrate_model_for_checkpointing() - assert self.server_model is not None - for checkpointer in self.checkpointer: - checkpointer.maybe_checkpoint(self.server_model, loss_aggregated, metrics_aggregated) - elif server_round == 1: - # No checkpointer, just log message on the first round - log( - INFO, - "No checkpointer present. Models will not be checkpointed on server-side.", - ) + self.checkpoint_and_state_module.maybe_checkpoint(self.parameters, loss_aggregated, metrics_aggregated) def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) -> Parameters: """ diff --git a/fl4health/servers/client_level_dp_fed_avg_server.py b/fl4health/servers/client_level_dp_fed_avg_server.py index ba4122d05..e6cda5536 100644 --- a/fl4health/servers/client_level_dp_fed_avg_server.py +++ b/fl4health/servers/client_level_dp_fed_avg_server.py @@ -1,14 +1,14 @@ from collections.abc import Sequence from logging import INFO from math import ceil -from typing import List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple from flwr.common.logger import log -from flwr.common.typing import Config +from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager from flwr.server.history import History -from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.checkpointing.server_module import ClippingBitServerCheckpointAndStateModule from fl4health.client_managers.fixed_without_replacement_manager import FixedSamplingByFractionClientManager from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager from fl4health.privacy.fl_accountants import ( @@ -29,9 +29,11 @@ def __init__( strategy: ClientLevelDPFedAvgM, server_noise_multiplier: float, num_server_rounds: int, - checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, + checkpoint_and_state_module: ClippingBitServerCheckpointAndStateModule | None = None, delta: Optional[int] = None, + on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + server_name: str | None = None, accept_failures: bool = True, ) -> None: """ @@ -48,15 +50,21 @@ def __init__( client updates and other information potentially sent by the participating clients. server_noise_multiplier (float): Magnitude of noise added to the weights aggregation process by the server. num_server_rounds (int): Number of rounds of FL training carried out by the server. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should + send data to before and after each round. checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model artifacts to be used or evaluated after training. The later is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health - reporters which the server should send data to before and after each round. delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to being 1/total_samples in the FL run. Defaults to None. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. @@ -65,8 +73,10 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + checkpoint_and_state_module=checkpoint_and_state_module, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, accept_failures=accept_failures, ) self.accountant: ClientLevelAccountant diff --git a/fl4health/servers/fedpm_server.py b/fl4health/servers/fedpm_server.py index 8df75ff0e..2a35411ec 100644 --- a/fl4health/servers/fedpm_server.py +++ b/fl4health/servers/fedpm_server.py @@ -1,12 +1,12 @@ from collections.abc import Sequence -from typing import Dict, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple from flwr.common import Parameters from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager from flwr.server.server import FitResultsAndFailures -from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.checkpointing.server_module import LayerNamesServerCheckpointAndStateModule from fl4health.reporting.base_reporter import BaseReporter from fl4health.servers.base_server import FlServer from fl4health.strategies.fedpm import FedPm @@ -18,9 +18,11 @@ def __init__( client_manager: ClientManager, fl_config: Config, strategy: FedPm, - checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, - reset_frequency: int = 1, reporters: Sequence[BaseReporter] | None = None, + checkpoint_and_state_module: LayerNamesServerCheckpointAndStateModule | None = None, + reset_frequency: int = 1, + on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + server_name: str | None = None, accept_failures: bool = True, ) -> None: """ @@ -34,27 +36,35 @@ def __init__( In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. - strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and other - information potentially sent by the participating clients. This strategy must be of SCAFFOLD type. + strategy (FedPm): The aggregation strategy to be used by the server to handle client updates and other + information potentially sent by the participating clients. This strategy must be of FedPm type. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should + send data to before and after each round. checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model artifacts to be used or evaluated after training. The later is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. reset_frequency (int): Determines the frequency with which the beta priors are reset. Defaults to 1. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should - send data to before and after each round. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. """ - FlServer.__init__( + super().__init__( self, client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + checkpoint_and_state_module=checkpoint_and_state_module, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, accept_failures=accept_failures, ) self.reset_frequency = reset_frequency diff --git a/fl4health/servers/instance_level_dp_server.py b/fl4health/servers/instance_level_dp_server.py index 5cb9f3ebc..74b4a50ce 100644 --- a/fl4health/servers/instance_level_dp_server.py +++ b/fl4health/servers/instance_level_dp_server.py @@ -1,19 +1,19 @@ from collections.abc import Sequence from logging import INFO from math import ceil -from typing import List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch.nn as nn from flwr.common.logger import log -from flwr.common.typing import Config +from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager from flwr.server.history import History -from fl4health.checkpointing.opacus_checkpointer import OpacusCheckpointer +from fl4health.checkpointing.server_module import OpacusServerCheckpointAndStateModule from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager from fl4health.privacy.fl_accountants import FlInstanceLevelAccountant from fl4health.reporting.base_reporter import BaseReporter -from fl4health.servers.base_server import ExchangerType, FlServer +from fl4health.servers.base_server import FlServer from fl4health.strategies.basic_fedavg import BasicFedAvg from fl4health.strategies.strategy_with_poll import StrategyWithPolling @@ -29,11 +29,11 @@ def __init__( strategy: BasicFedAvg, local_epochs: Optional[int] = None, local_steps: Optional[int] = None, - model: nn.Module | None = None, - checkpointer: Optional[OpacusCheckpointer] = None, - parameter_exchanger: ExchangerType | None = None, + checkpoint_and_state_module: OpacusServerCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, delta: Optional[float] = None, + on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + server_name: str | None = None, accept_failures: bool = True, ) -> None: """ @@ -47,34 +47,34 @@ def __init__( In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. - noise_multiplier (int): The amount of Gaussian noise to be added to the per sample gradient during + noise_multiplier (float): The amount of Gaussian noise to be added to the per sample gradient during DP-SGD. batch_size (int): The batch size to be used in training on the client-side. Used in privacy accounting. num_server_rounds (int): The number of server rounds to be done in FL training. Used in privacy accounting + strategy (BasicFedAvg): The aggregation strategy to be used by the server to handle + client updates and other information potentially sent by the participating clients. this must be an + OpacusBasicFedAvg strategy to ensure proper treatment of the model in the Opacus framework local_epochs (Optional[int], optional): Number of local epochs to be performed on the client-side. This is used in privacy accounting. One of local_epochs or local_steps should be defined, but not both. Defaults to None. local_steps (Optional[int], optional): Number of local steps to be performed on the client-side. This is used in privacy accounting. One of local_epochs or local_steps should be defined, but not both. Defaults to None. - strategy (OpacusBasicFedAvg): The aggregation strategy to be used by the server to handle - client updates and other information potentially sent by the participating clients. this must be an - OpacusBasicFedAvg strategy to ensure proper treatment of the model in the Opacus framework - model (Optional[nn.Module]): This is the torch model to be checkpointed. It will be hydrated by the - _hydrate_model_for_checkpointing function so that it has the proper weights to be saved. If no model - is defined and checkpointing is attempted an error will throw. Defaults to None. - checkpointer (Optional[OpacusCheckpointer], optional): To be provided if the server should perform - server side checkpointing based on some criteria. If none, then no server-side checkpointing is - performed. Defaults to None. - parameter_exchanger (Optional[ExchangerType], optional): A parameter exchanger used to facilitate - server-side model checkpointing if a checkpointer has been defined. If not provided then checkpointing - will not be done unless the _hydrate_model_for_checkpointing function is overridden. Because the - server only sees numpy arrays, the parameter exchanger is used to insert the numpy arrays into a - provided model. Defaults to None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health + checkpoint_and_state_module (OpacusServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client should send data to. delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to being 1/total_samples in the FL run. Defaults to None. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. @@ -83,10 +83,10 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - model=model, - checkpointer=checkpointer, - parameter_exchanger=parameter_exchanger, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, accept_failures=accept_failures, ) diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index e1a5ea94d..0af362efd 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -14,10 +14,11 @@ from flwr.server.history import History from flwr.server.strategy import Strategy -from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer +from fl4health.checkpointing.server_module import NnUnetServerCheckpointAndStateModule +from fl4health.parameter_exchange.parameter_exchanger_base import ExchangerType from fl4health.reporting.base_reporter import BaseReporter from fl4health.reporting.reports_manager import ReportsManager -from fl4health.servers.base_server import ExchangerType, FlServer +from fl4health.servers.base_server import FlServer from fl4health.utils.config import narrow_dict_type, narrow_dict_type_and_set_attribute from fl4health.utils.nnunet_utils import NnunetConfig from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -62,12 +63,9 @@ def __init__( client_manager: ClientManager, fl_config: Config, on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]], - model: nn.Module | None = None, strategy: Strategy | None = None, - checkpointer: TorchModuleCheckpointer | Sequence[TorchModuleCheckpointer] | None = None, reporters: Sequence[BaseReporter] | None = None, - parameter_exchanger: ExchangerType | None = None, - intermediate_server_state_dir: Path | None = None, + checkpoint_and_state_module: NnUnetServerCheckpointAndStateModule | None = None, server_name: str | None = None, accept_failures: bool = True, nnunet_trainer_class: Type[nnUNetTrainer] = nnUNetTrainer, @@ -83,28 +81,24 @@ def __init__( In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. - model (nn.Module): This is the torch model to be hydrated by the _hydrate_model_for_checkpointing function - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]]]): Function used to configure how one + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]]): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. For NnunetServers this is a required function to provide the additional information necessary to a client for parameter initialization - strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle + strategy (Strategy | None, optional): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. If None the - strategy is FedAvg as set by the flwr Server. - checkpointer (TorchCheckpointer | Sequence[TorchCheckpointer], optional): To be provided if the server - should perform server side checkpointing based on some criteria. If none, then no server-side - checkpointing is performed. Multiple checkpointers can also be passed in a sequence to checkpoint - based on multiple criteria. Defaults to None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the client should - send data to. - intermediate_server_state_dir (Path): A directory to store and load checkpoints from for the server - during an FL experiment. - parameter_exchanger (Optional[ExchangerType], optional): A parameter exchanger used to facilitate - server-side model checkpointing if a checkpointer has been defined. If not provided then checkpointing - will not be done unless the _hydrate_model_for_checkpointing function is overridden. Because the - server only sees numpy arrays, the parameter exchanger is used to insert the numpy arrays into a - provided model. Defaults to None. - server_name (Optional[str]): An optional string name to uniquely identify server. + strategy is FedAvg as set by the flwr Server. Defaults to None. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client + should send data to. Defaults to None. + checkpoint_and_state_module (NnUnetServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. + NOTE: For NnUnet, this module is allowed to have all components defined other than the model, as it + may be set later when the server asks the clients to provide the architecture. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. @@ -112,16 +106,13 @@ def __init__( Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. Must match the nnunet_trainer_class passed to the NnunetClient. """ - FlServer.__init__( + super().__init__( self, client_manager=client_manager, fl_config=fl_config, strategy=strategy, reporters=reporters, - model=model, - checkpointer=checkpointer, - parameter_exchanger=parameter_exchanger, - intermediate_server_state_dir=intermediate_server_state_dir, + checkpoint_and_state_module=checkpoint_and_state_module, on_init_parameters_config_fn=on_init_parameters_config_fn, server_name=server_name, accept_failures=accept_failures, @@ -156,7 +147,7 @@ def initialize_server_model(self) -> None: self.enable_deep_supervision, ) - self.server_model = model + self.checkpoint_and_state_module.model = model def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: """ @@ -179,7 +170,10 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: # If no prior checkpoints exist, initialize server by sampling clients to get required properties to set # NOTE: Inherent assumption that if checkpoint exists for server that it also will exist for client. - if self.per_round_checkpointer is None or not self.per_round_checkpointer.checkpoint_exists(): + if ( + self.checkpoint_and_state_module.state_checkpointer is None + or self.checkpoint_and_state_module.state_checkpointer.checkpoint_exists(self.state_checkpoint_name) + ): # Sample properties from a random client to initialize plans log(INFO, "") log(INFO, "[PRE-INIT]") @@ -232,13 +226,11 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: # subclass could call parent method and not have to copy entire state. def _save_server_state(self) -> None: """ - Save server checkpoint consisting of model, history, server round, metrics reporter and server name. - This method overrides parent to also checkpoint nnunet_plans, num_input_channels, - num_segmentation_heads and enable_deep_supervision. + Save server checkpoint consisting of model, history, server round, metrics reporter and server name. This + method overrides parent to also checkpoint nnunet_plans, num_input_channels, num_segmentation_heads and + enable_deep_supervision. """ - assert self.per_round_checkpointer is not None - assert ( self.nnunet_plans_bytes is not None and self.num_input_channels is not None @@ -247,8 +239,7 @@ def _save_server_state(self) -> None: and self.nnunet_config is not None ) - ckpt = { - "model": self.server_model, + other_state_to_save = { "history": self.history, "current_round": self.current_round, "reports_manager": self.reports_manager, @@ -260,39 +251,39 @@ def _save_server_state(self) -> None: "nnunet_config": self.nnunet_config, } - self.per_round_checkpointer.save_checkpoint(ckpt) - - log( - INFO, - f"Saving server state to checkpoint at {self.per_round_checkpointer.checkpoint_path}", + self.checkpoint_and_state_module.save_state( + state_checkpoint=self.state_checkpoint_name, + server_parameters=self.parameters, + other_state=other_state_to_save, ) - def _load_server_state(self) -> None: + def _load_server_state(self) -> bool: """ Load server checkpoint consisting of model, history, server name, current round and metrics reporter. - The method overrides parent to add any necessary state when loading the checkpoint. + The method overrides parent to add any necessary state when loading the checkpoint. """ - assert self.per_round_checkpointer is not None and self.per_round_checkpointer.checkpoint_exists() - - ckpt = self.per_round_checkpointer.load_checkpoint() + # Attempt to load the server state if it exists. This variable will be None if it does not. + server_state = self.checkpoint_and_state_module.maybe_load_state(self.state_checkpoint_name) - log( - INFO, - f"Loading server state from checkpoint at {self.per_round_checkpointer.checkpoint_path}", - ) + if server_state is None: + return False # Standard attributes to load - narrow_dict_type_and_set_attribute(self, ckpt, "current_round", "current_round", int) - narrow_dict_type_and_set_attribute(self, ckpt, "server_name", "server_name", str) - narrow_dict_type_and_set_attribute(self, ckpt, "reports_manager", "reports_manager", ReportsManager) - narrow_dict_type_and_set_attribute(self, ckpt, "history", "history", History) - narrow_dict_type_and_set_attribute(self, ckpt, "model", "parameters", nn.Module, func=get_all_model_parameters) + narrow_dict_type_and_set_attribute(self, server_state, "current_round", "current_round", int) + narrow_dict_type_and_set_attribute(self, server_state, "server_name", "server_name", str) + narrow_dict_type_and_set_attribute(self, server_state, "reports_manager", "reports_manager", ReportsManager) + narrow_dict_type_and_set_attribute(self, server_state, "history", "history", History) + narrow_dict_type_and_set_attribute( + self, server_state, "model", "parameters", nn.Module, func=get_all_model_parameters + ) # Needed for when _hydrate_model_for_checkpointing is called - narrow_dict_type_and_set_attribute(self, ckpt, "model", "server_model", nn.Module) + narrow_dict_type_and_set_attribute(self, server_state, "model", "server_model", nn.Module) # NnunetServer specific attributes to load - narrow_dict_type_and_set_attribute(self, ckpt, "nnunet_plans_bytes", "nnunet_plans_bytes", bytes) - narrow_dict_type_and_set_attribute(self, ckpt, "num_segmentation_heads", "num_segmentation_heads", int) - narrow_dict_type_and_set_attribute(self, ckpt, "num_input_channels", "num_input_channels", int) - narrow_dict_type_and_set_attribute(self, ckpt, "enable_deep_supervision", "enable_deep_supervision", bool) - narrow_dict_type_and_set_attribute(self, ckpt, "nnunet_config", "nnunet_config", NnunetConfig) + narrow_dict_type_and_set_attribute(self, server_state, "nnunet_plans_bytes", "nnunet_plans_bytes", bytes) + narrow_dict_type_and_set_attribute(self, server_state, "num_segmentation_heads", "num_segmentation_heads", int) + narrow_dict_type_and_set_attribute(self, server_state, "num_input_channels", "num_input_channels", int) + narrow_dict_type_and_set_attribute( + self, server_state, "enable_deep_supervision", "enable_deep_supervision", bool + ) + narrow_dict_type_and_set_attribute(self, server_state, "nnunet_config", "nnunet_config", NnunetConfig) diff --git a/fl4health/servers/scaffold_server.py b/fl4health/servers/scaffold_server.py index e3988b7f4..2aa688235 100644 --- a/fl4health/servers/scaffold_server.py +++ b/fl4health/servers/scaffold_server.py @@ -1,10 +1,10 @@ from collections.abc import Sequence from logging import DEBUG, ERROR, INFO -from typing import Optional, Tuple +from typing import Callable, Dict, Optional, Tuple from flwr.common import Parameters, ndarrays_to_parameters, parameters_to_ndarrays from flwr.common.logger import log -from flwr.common.typing import Config +from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager from flwr.server.history import History from flwr.server.server import fit_clients @@ -22,10 +22,12 @@ def __init__( client_manager: ClientManager, fl_config: Config, strategy: Scaffold, - checkpoint_and_state_module: ScaffoldServerCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, - warm_start: bool = False, + checkpoint_and_state_module: ScaffoldServerCheckpointAndStateModule | None = None, + on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + server_name: str | None = None, accept_failures: bool = True, + warm_start: bool = False, ) -> None: """ Custom FL Server for scaffold algorithm to handle warm initialization of control variates as specified in @@ -41,28 +43,36 @@ def __init__( strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. This strategy must be of SCAFFOLD type. + reporters (Sequence[BaseReporter] | None, optional): sequence of FL4Health reporters which the server + should send data to before and after each round. Defaults to None. checkpoint_and_state_module (ScaffoldServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model artifacts to be used or evaluated after training. The later is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should - send data to before and after each round. - warm_start (bool, optional): Whether or not to initialize control variates of each client as local - gradients. The clients will perform a training pass (without updating the weights) in order to provide - a "warm" estimate of the SCAFFOLD control variates. If false, variates are initialized to 0. - Defaults to False. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. + warm_start (bool, optional): Whether or not to initialize control variates of each client as local + gradients. The clients will perform a training pass (without updating the weights) in order to provide + a "warm" estimate of the SCAFFOLD control variates. If false, variates are initialized to 0. + Defaults to False. """ - FlServer.__init__( + super().__init__( self, client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + checkpoint_and_state_module=checkpoint_and_state_module, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, accept_failures=accept_failures, ) self.warm_start = warm_start @@ -177,6 +187,9 @@ def __init__( checkpoint_and_state_module: ScaffoldServerCheckpointAndStateModule | None = None, warm_start: bool = False, reporters: Sequence[BaseReporter] | None = None, + on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + server_name: str | None = None, + accept_failures: bool = True, ) -> None: """ Custom FL Server for Instance Level Differentially Private Scaffold algorithm as specified in @@ -185,6 +198,10 @@ def __init__( Args: client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if they are to be sampled at all. + fl_config (Config): This should be the configuration that was used to setup the federated training. + In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For + example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. + NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. noise_multiplier (int): The amount of Gaussian noise to be added to the per sample gradient during DP-SGD. batch_size (int): The batch size to be used in training on the client-side. Used in privacy accounting. @@ -203,14 +220,21 @@ def __init__( artifacts to be used or evaluated after training. The later is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. - warm_start (bool, optional): Whether or not to initialize control variates of each client as - local gradients. The clients will perform a training pass (without updating the weights) in order to - provide a "warm" estimate of the SCAFFOLD control variates. If false, variates are initialized to 0. + warm_start (bool, optional): Whether or not to initialize control variates of each client as local + gradients. The clients will perform a training pass (without updating the weights) in order to provide + a "warm" estimate of the SCAFFOLD control variates. If false, variates are initialized to 0. Defaults to False. delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to being 1/total_samples in the FL run. Defaults to None. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the client should send data to. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. + accept_failures (bool, optional): Determines whether the server should accept """ # Require the strategy to be an OpacusStrategy to handle the Opacus model conversion etc. assert isinstance( @@ -224,6 +248,9 @@ def __init__( checkpoint_and_state_module=checkpoint_and_state_module, warm_start=warm_start, reporters=reporters, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, + accept_failures=accept_failures, ) InstanceLevelDpServer.__init__( self, @@ -236,6 +263,10 @@ def __init__( batch_size=batch_size, delta=delta, num_server_rounds=num_server_rounds, + checkpoint_and_state_module=checkpoint_and_state_module, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, + accept_failures=accept_failures, ) def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: diff --git a/fl4health/servers/sparse_coo_server.py b/fl4health/servers/sparse_coo_server.py new file mode 100644 index 000000000..ed8129a09 --- /dev/null +++ b/fl4health/servers/sparse_coo_server.py @@ -0,0 +1,68 @@ +from collections.abc import Sequence +from typing import Callable, Dict, Optional, Tuple + +from flwr.common import Parameters +from flwr.common.typing import Config, Scalar +from flwr.server.client_manager import ClientManager +from flwr.server.server import FitResultsAndFailures + +from fl4health.checkpointing.server_module import SparseCooServerCheckpointAndStateModule +from fl4health.reporting.base_reporter import BaseReporter +from fl4health.servers.base_server import FlServer +from fl4health.strategies.fedpm import FedPm + + +class SparseCooServer(FlServer): + def __init__( + self, + client_manager: ClientManager, + fl_config: Config, + strategy: FedPm, + reporters: Sequence[BaseReporter] | None = None, + checkpoint_and_state_module: SparseCooServerCheckpointAndStateModule | None = None, + on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + server_name: str | None = None, + accept_failures: bool = True, + ) -> None: + """ + Custom FL Server for the FedPM algorithm to allow for resetting the beta priors in Bayesian aggregation, + as specified in http://arxiv.org/pdf/2209.15328. + + Args: + client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if + they are to be sampled at all. + fl_config (Config): This should be the configuration that was used to setup the federated training. + In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For + example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. + NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. + strategy (FedPm): The aggregation strategy to be used by the server to handle client updates and other + information potentially sent by the participating clients. This strategy must be of FedPm type. + reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should + send data to before and after each round. + checkpoint_and_state_module (SparseCooServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. + reset_frequency (int): Determines the frequency with which the beta priors are reset. Defaults to 1. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. + accept_failures (bool, optional): Determines whether the server should accept failures during training or + evaluation from clients or not. If set to False, this will cause the server to shutdown all clients + and throw an exception. Defaults to True. + """ + super().__init__( + self, + client_manager=client_manager, + fl_config=fl_config, + strategy=strategy, + reporters=reporters, + checkpoint_and_state_module=checkpoint_and_state_module, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, + accept_failures=accept_failures, + ) diff --git a/fl4health/servers/tabular_feature_alignment_server.py b/fl4health/servers/tabular_feature_alignment_server.py index 6188a6032..8801bc49d 100644 --- a/fl4health/servers/tabular_feature_alignment_server.py +++ b/fl4health/servers/tabular_feature_alignment_server.py @@ -5,11 +5,11 @@ from flwr.common import Parameters from flwr.common.logger import log -from flwr.common.typing import Config +from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager from flwr.server.history import History -from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.feature_alignment.constants import ( CURRENT_SERVER_ROUND, FEATURE_INFO, @@ -31,9 +31,11 @@ def __init__( config: Config, initialize_parameters: Callable[..., Parameters], strategy: BasicFedAvg, - checkpointer: Optional[TorchModuleCheckpointer] = None, tabular_features_source_of_truth: Optional[TabularFeaturesInfoEncoder] = None, reporters: Sequence[BaseReporter] | None = None, + checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, + on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + server_name: str | None = None, accept_failures: bool = True, ) -> None: """ @@ -57,9 +59,22 @@ def __init__( tab_features_source_of_truth (Optional[TabularFeaturesInfoEncoder]): The information that is required for aligning client features. If it is not specified, then the server will randomly poll a client and gather this information from its data source. + reporters (Sequence[BaseReporter] | None, optional): sequence of FL4Health reporters which the server + should send data to before and after each round. Defaults to None. + checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used + to handle both model checkpointing and state checkpointing. The former is aimed at saving model + artifacts to be used or evaluated after training. The later is used to preserve training state + (including models) such that if FL training is interrupted, the process may be restarted. If no + module is provided, no checkpointing or state preservation will happen. Defaults to None. + on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + configure how one asks a client to provide parameters from which to initialize all other clients by + providing a Config dictionary. If this is none, then a blank config is sent with the parameter request + (which is default behavior for flower servers). Defaults to None. + server_name (str | None, optional): An optional string name to uniquely identify server. This name is also + used as part of any state checkpointing done by the server. Defaults to None. accept_failures (bool, optional): Determines whether the server should accept failures during training or - evaluation from clients or not. If set to False, this will cause the server to shutdown all clients - and throw an exception. Defaults to True. + evaluation from clients or not. If set to False, this will cause the server to shutdown all clients + and throw an exception. Defaults to True. """ if strategy.on_fit_config_fn is not None: @@ -71,8 +86,10 @@ def __init__( client_manager=client_manager, fl_config=config, strategy=strategy, - checkpointer=checkpointer, reporters=reporters, + checkpoint_and_state_module=checkpoint_and_state_module, + on_init_parameters_config_fn=on_init_parameters_config_fn, + server_name=server_name, accept_failures=accept_failures, ) # The server performs one or two rounds of polls before the normal federated training. @@ -84,7 +101,7 @@ def __init__( self.source_info_gathered = False self.dimension_info: Dict[str, int] = {} # ensure that self.strategy has type BasicFedAvg so its on_fit_config_fn can be specified. - assert isinstance(self.strategy, BasicFedAvg) + assert isinstance(self.strategy, BasicFedAvg), "This server is only compatible with BasicFedAvg at this time" self.strategy.on_fit_config_fn = partial(fit_config, self.fl_config, self.source_info_gathered) def _set_dimension_info(self, input_dimension: int, output_dimension: int) -> None: diff --git a/fl4health/utils/typing.py b/fl4health/utils/typing.py index ffb098eb0..ee0149d66 100644 --- a/fl4health/utils/typing.py +++ b/fl4health/utils/typing.py @@ -21,8 +21,6 @@ FitFailures = List[Union[Tuple[ClientProxy, FitRes], BaseException]] EvaluateFailures = List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] -ExchangerType = TypeVar("ExchangerType", bound=ParameterExchanger) - class LogLevel(Enum): NOTSET = logging.NOTSET diff --git a/research/cifar10/fedavg/server.py b/research/cifar10/fedavg/server.py index 31459f7b5..9eaee7252 100644 --- a/research/cifar10/fedavg/server.py +++ b/research/cifar10/fedavg/server.py @@ -11,6 +11,7 @@ from flwr.server.strategy import FedAvg from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer from fl4health.utils.config import load_config @@ -45,17 +46,23 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config["n_server_rounds"], config["n_clients"], ) + + # Initializing the model on the server side + model = ConvNet(in_channels=3) + parameter_exchanger = FullParameterExchanger() checkpoint_dir = os.path.join(checkpoint_stub, run_name) best_checkpoint_name = "server_best_model.pkl" last_checkpoint_name = "server_last_model.pkl" - checkpointer = [ + checkpointers = [ BestLossTorchModuleCheckpointer(checkpoint_dir, best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, last_checkpoint_name), ] + + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointers + ) + client_manager = SimpleClientManager() - # Initializing the model on the server side - model = ConvNet(in_channels=3) - parameter_exchanger = FullParameterExchanger() # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( min_fit_clients=config["n_clients"], @@ -72,10 +79,8 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ server = FlServer( client_manager=client_manager, fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( @@ -84,8 +89,8 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), ) - assert isinstance(checkpointer[0], BestLossTorchModuleCheckpointer) - log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer[0].best_score}") + assert isinstance(checkpointers[0], BestLossTorchModuleCheckpointer) + log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointers[0].best_score}") # Shutdown the server gracefully server.shutdown() diff --git a/research/cifar10/personal_server.py b/research/cifar10/personal_server.py index 6250c0841..a4d0334f7 100644 --- a/research/cifar10/personal_server.py +++ b/research/cifar10/personal_server.py @@ -28,7 +28,9 @@ def __init__( ) -> None: # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with # some globally shared weights. So we don't checkpoint a global model - super().__init__(client_manager=client_manager, fl_config=fl_config, strategy=strategy, checkpointer=None) + super().__init__( + client_manager=client_manager, fl_config=fl_config, strategy=strategy, checkpoint_and_state_module=None + ) self.best_aggregated_loss: Optional[float] = None def evaluate_round( diff --git a/research/flamby/flamby_servers/full_exchange_server.py b/research/flamby/flamby_servers/full_exchange_server.py index b204542cd..9e8dfa961 100644 --- a/research/flamby/flamby_servers/full_exchange_server.py +++ b/research/flamby/flamby_servers/full_exchange_server.py @@ -5,7 +5,7 @@ from flwr.server.client_manager import ClientManager from flwr.server.strategy import Strategy -from fl4health.checkpointing.checkpointer import TorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer @@ -15,17 +15,15 @@ def __init__( self, client_manager: ClientManager, fl_config: Config, - model: Optional[nn.Module] = None, strategy: Optional[Strategy] = None, - checkpointer: Optional[TorchModuleCheckpointer] = None, + checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, ) -> None: - # To help with model rehydration - parameter_exchanger = FullParameterExchanger() super().__init__( client_manager=client_manager, fl_config=fl_config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) + # If parameter exchanger has been defined, it needs to be of type FullParameterExchanger + if self.checkpoint_and_state_module.parameter_exchanger is not None: + assert isinstance(self.checkpoint_and_state_module.parameter_exchanger, FullParameterExchanger) diff --git a/research/flamby/flamby_servers/personal_server.py b/research/flamby/flamby_servers/personal_server.py index 48aa8d8ae..77f561e00 100644 --- a/research/flamby/flamby_servers/personal_server.py +++ b/research/flamby/flamby_servers/personal_server.py @@ -29,7 +29,9 @@ def __init__( ) -> None: # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with # some globally shared weights. So we don't checkpoint a global model - super().__init__(client_manager=client_manager, fl_config=fl_config, strategy=strategy, checkpointer=None) + super().__init__( + client_manager=client_manager, fl_config=fl_config, strategy=strategy, checkpoint_and_state_module=None + ) self.best_aggregated_loss: Optional[float] = None def evaluate_round( diff --git a/research/picai/fedavg/client.py b/research/picai/fedavg/client.py index 39758d83b..aa6c12628 100644 --- a/research/picai/fedavg/client.py +++ b/research/picai/fedavg/client.py @@ -13,6 +13,7 @@ from torch.optim import Optimizer from torchmetrics.classification import MultilabelAveragePrecision +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.reporting.base_reporter import BaseReporter @@ -38,10 +39,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - intermediate_client_state_dir: Optional[Path] = None, + client_name: Optional[str] = None, overviews_dir: Path = Path("./"), data_partition: Optional[int] = None, ) -> None: @@ -50,10 +51,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, progress_bar=progress_bar, - intermediate_client_state_dir=intermediate_client_state_dir, + client_name=client_name, ) self.data_partition = data_partition @@ -161,11 +162,18 @@ def get_optimizer(self, config: Config) -> Optimizer: ) ] + if args.artifact_dir is not None: + checkpoint_and_state_module = ClientCheckpointAndStateModule( + state_checkpointer=PerRoundStateCheckpointer(Path(args.artifact_dir)) + ) + else: + checkpoint_and_state_module = None + client = PicaiFedAvgClient( data_path=Path(args.base_dir), metrics=metrics, device=device, - intermediate_client_state_dir=args.artifact_dir, + checkpoint_and_state_module=checkpoint_and_state_module, overviews_dir=args.overviews_dir, data_partition=args.data_partition, ) diff --git a/research/picai/fedavg/server.py b/research/picai/fedavg/server.py index 7bc3a48f9..e54ceecb5 100644 --- a/research/picai/fedavg/server.py +++ b/research/picai/fedavg/server.py @@ -1,6 +1,7 @@ import argparse from functools import partial from logging import INFO +from pathlib import Path from typing import Any, Dict import flwr as fl @@ -9,6 +10,8 @@ from flwr.server.client_manager import SimpleClientManager from flwr.server.strategy import FedAvg +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer from fl4health.utils.config import load_config @@ -35,7 +38,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, n_clients: int) -> None: +def main(config: Dict[str, Any], server_address: str, n_clients: int, artifact_dir: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, @@ -47,7 +50,16 @@ def main(config: Dict[str, Any], server_address: str, n_clients: int) -> None: ) client_manager = SimpleClientManager() + model = get_model() + parameter_exchanger = FullParameterExchanger() + state_checkpointer = PerRoundStateCheckpointer(checkpoint_dir=Path(artifact_dir)) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, + parameter_exchanger=parameter_exchanger, + model_checkpointers=None, + state_checkpointer=state_checkpointer, + ) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -66,10 +78,8 @@ def main(config: Dict[str, Any], server_address: str, n_clients: int) -> None: server = FlServer( client_manager=client_manager, fl_config=config, - model=model, - parameter_exchanger=FullParameterExchanger(), strategy=strategy, - intermediate_server_state_dir=args.artifact_dir, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( @@ -112,4 +122,5 @@ def main(config: Dict[str, Any], server_address: str, n_clients: int) -> None: config = load_config(args.config_path) log(INFO, f"Server Address: {args.server_address}") - main(config, args.server_address, args.n_clients) + log(INFO, f"Artifact Directory: {args.artifact_dir}") + main(config, args.server_address, args.n_clients, args.artifact_dir) diff --git a/research/picai/fl_nnunet/start_client.py b/research/picai/fl_nnunet/start_client.py index c60887ce4..8aa656b33 100644 --- a/research/picai/fl_nnunet/start_client.py +++ b/research/picai/fl_nnunet/start_client.py @@ -6,6 +6,9 @@ from pathlib import Path from typing import Optional, Union +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule + with warnings.catch_warnings(): # Silence deprecation warnings from sentry sdk due to flwr and wandb # https://github.com/adap/flower/issues/4086 @@ -58,6 +61,13 @@ def main( metrics = [dice1, dice2] # Oddly each of these dice metrics is drastically different. + if intermediate_client_state_dir is not None: + checkpoint_and_state_module = ClientCheckpointAndStateModule( + state_checkpointer=PerRoundStateCheckpointer(Path(intermediate_client_state_dir)) + ) + else: + checkpoint_and_state_module = None + # Create and start client client = NnunetClient( # Args specific to nnUNetClient @@ -72,9 +82,7 @@ def main( device=device, metrics=metrics, progress_bar=verbose, - intermediate_client_state_dir=( - Path(intermediate_client_state_dir) if intermediate_client_state_dir is not None else None - ), + checkpoint_and_state_module=checkpoint_and_state_module, client_name=client_name, ) diff --git a/research/picai/reporting/server.py b/research/picai/reporting/server.py index 987db9a5f..9c9d698a4 100644 --- a/research/picai/reporting/server.py +++ b/research/picai/reporting/server.py @@ -10,6 +10,7 @@ from examples.models.cnn_model import Net from examples.utils.functions import make_dict_with_epochs_or_steps from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.reporting import WandBReporter from fl4health.servers.base_server import FlServer @@ -48,6 +49,9 @@ def main(config: Dict[str, Any]) -> None: BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl"), LatestTorchModuleCheckpointer(config["checkpoint_path"], "latest_model.pkl"), ] + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointers + ) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -77,11 +81,9 @@ def main(config: Dict[str, Any]) -> None: server = FlServer( client_manager=SimpleClientManager(), fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, reporters=[reporter], + checkpoint_and_state_module=checkpoint_and_state_module, strategy=strategy, - checkpointer=checkpointers, ) fl.server.start_server( diff --git a/tests/servers/test_base_server.py b/tests/servers/test_base_server.py index 4538a0bc1..287740a1c 100644 --- a/tests/servers/test_base_server.py +++ b/tests/servers/test_base_server.py @@ -12,7 +12,8 @@ from flwr.server.strategy import FedAvg from freezegun import freeze_time -from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, PerRoundStateCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.client_managers.base_sampling_manager import SimpleClientManager from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger @@ -28,21 +29,25 @@ model = LinearTransform() -class DummyFLServer(FlServer): - def _hydrate_model_for_checkpointing(self) -> None: - self.server_model = model - - def test_hydration_no_model_with_checkpointer(tmp_path: Path) -> None: # Temporary path to write pkl to, will be cleaned up at the end of the test. checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") + state_checkpointer = PerRoundStateCheckpointer() + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=None, + parameter_exchanger=None, + model_checkpointers=checkpointer, + state_checkpointer=state_checkpointer, + ) # Checkpointer is defined but there is no server-side model defined to produce a model from the server state. # An assertion error should be throw stating this fl_server_no_hydration = FlServer( - client_manager=PoissonSamplingClientManager(), fl_config={}, checkpointer=checkpointer + client_manager=PoissonSamplingClientManager(), + fl_config={}, + checkpoint_and_state_module=checkpoint_and_state_module, ) with pytest.raises(AssertionError) as assertion_error: fl_server_no_hydration._maybe_checkpoint(1.0, {}, server_round=1) @@ -54,11 +59,16 @@ def test_hydration_no_exchanger_with_checkpointer(tmp_path: Path) -> None: checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=None, model_checkpointers=checkpointer + ) # Checkpointer is defined but there is no parameter exchanger defined to produce a model from the server state. # An assertion error should be throw stating this fl_server_no_hydration = FlServer( - client_manager=PoissonSamplingClientManager(), fl_config={}, model=model, checkpointer=checkpointer + client_manager=PoissonSamplingClientManager(), + fl_config={}, + checkpoint_and_state_module=checkpoint_and_state_module, ) with pytest.raises(AssertionError) as assertion_error: fl_server_no_hydration._maybe_checkpoint(1.0, {}, server_round=1) @@ -68,7 +78,9 @@ def test_hydration_no_exchanger_with_checkpointer(tmp_path: Path) -> None: def test_no_checkpointer_maybe_checkpoint(caplog: pytest.LogCaptureFixture) -> None: - fl_server_no_checkpointer = FlServer(client_manager=PoissonSamplingClientManager(), fl_config={}) + fl_server_no_checkpointer = FlServer( + client_manager=PoissonSamplingClientManager(), fl_config={}, checkpoint_and_state_module=None + ) # Neither checkpointing nor hydration is defined, we'll have no server-side checkpointing for the FL run. fl_server_no_checkpointer._maybe_checkpoint(1.0, {}, server_round=1) @@ -80,11 +92,16 @@ def test_hydration_and_checkpointer(tmp_path: Path) -> None: checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=None, model_checkpointers=checkpointer + ) # Server-side hydration to convert server state to model and checkpointing behavior are both defined, a model # should be saved and be loaded successfully. - fl_server_both = DummyFLServer( - client_manager=PoissonSamplingClientManager(), fl_config={}, checkpointer=checkpointer + fl_server_both = FlServer( + client_manager=PoissonSamplingClientManager(), + fl_config={}, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl_server_both._maybe_checkpoint(1.0, {}, server_round=5) loaded_model = checkpointer.load_checkpoint() @@ -103,14 +120,15 @@ def test_fl_server_with_checkpointing(tmp_path: Path) -> None: # represents the model computed by the clients aggregation updated_model = LinearTransform() parameter_exchanger = FullParameterExchanger() + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=initial_model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer + ) server = FlServer( client_manager=PoissonSamplingClientManager(), fl_config={}, - parameter_exchanger=parameter_exchanger, - model=initial_model, strategy=None, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) # Parameters after aggregation (i.e. the updated server-side model) server.parameters = ndarrays_to_parameters(parameter_exchanger.push_parameters(updated_model)) diff --git a/tests/smoke_tests/load_from_checkpoint_example/client.py b/tests/smoke_tests/load_from_checkpoint_example/client.py index 662195ba5..f8209c48f 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/client.py +++ b/tests/smoke_tests/load_from_checkpoint_example/client.py @@ -11,6 +11,7 @@ from torch.utils.data import DataLoader from examples.models.cnn_model import Net +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient from fl4health.reporting import JsonReporter @@ -29,23 +30,21 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - intermediate_client_state_dir: Optional[Path] = None, client_name: Optional[str] = None, seed: int = 42, ) -> None: super().__init__( - data_path, - metrics, - device, - loss_meter_type, - checkpointer, - reporters, - progress_bar, - intermediate_client_state_dir, - client_name, + data_path=data_path, + metrics=metrics, + device=device, + loss_meter_type=loss_meter_type, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.seed = seed @@ -104,11 +103,18 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict # Set the random seed for reproducibility set_all_random_seeds(args.seed) + if args.intermediate_client_state_dir is not None: + checkpoint_and_state_module = ClientCheckpointAndStateModule( + state_checkpointer=PerRoundStateCheckpointer(Path(args.intermediate_client_state_dir)) + ) + else: + checkpoint_and_state_module = None + client = CifarClient( data_path, [Accuracy("accuracy")], device, - intermediate_client_state_dir=args.intermediate_client_state_dir, + checkpoint_and_state_module=checkpoint_and_state_module, client_name=args.client_name, seed=args.seed, reporters=[JsonReporter()], diff --git a/tests/smoke_tests/load_from_checkpoint_example/server.py b/tests/smoke_tests/load_from_checkpoint_example/server.py index c47ab58bc..0d57947b4 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/server.py +++ b/tests/smoke_tests/load_from_checkpoint_example/server.py @@ -10,7 +10,12 @@ from examples.models.cnn_model import Net from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.checkpointer import ( + BestLossTorchModuleCheckpointer, + LatestTorchModuleCheckpointer, + PerRoundStateCheckpointer, +) +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.reporting import JsonReporter from fl4health.servers.base_server import FlServer @@ -50,6 +55,13 @@ def main(config: Dict[str, Any], intermediate_server_state_dir: str, server_name BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl"), LatestTorchModuleCheckpointer(config["checkpoint_path"], "latest_model.pkl"), ] + state_checkpointer = PerRoundStateCheckpointer(Path(intermediate_server_state_dir)) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, + parameter_exchanger=parameter_exchanger, + model_checkpointers=checkpointers, + state_checkpointer=state_checkpointer, + ) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -73,7 +85,7 @@ def main(config: Dict[str, Any], intermediate_server_state_dir: str, server_name strategy=strategy, checkpointer=checkpointers, reporters=[JsonReporter()], - intermediate_server_state_dir=Path(intermediate_server_state_dir), + checkpoint_and_state_module=checkpoint_and_state_module, server_name=server_name, ) From eeaf36432a7fc98d58460755b75e622cfd871b78 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 27 Nov 2024 09:59:37 -0500 Subject: [PATCH 04/13] Fixing precommit issues/mypy stuff. --- .../ae_examples/cvae_dim_example/server.py | 2 +- .../cvae_examples/conv_cvae_example/server.py | 2 +- .../cvae_examples/mlp_cvae_example/server.py | 2 +- .../ae_examples/fedprox_vae_example/server.py | 6 +- .../instance_level_dp/client.py | 6 +- .../instance_level_dp/server.py | 22 +++--- examples/fedprox_example/server.py | 4 +- .../fedsimclr_pretraining_example/server.py | 8 ++- examples/fenda_ditto_example/client.py | 4 +- examples/nnunet_example/server.py | 17 +++-- .../server.py | 4 +- fl4health/checkpointing/checkpointer.py | 20 ++++-- fl4health/checkpointing/client_module.py | 2 +- .../checkpointing/opacus_checkpointer.py | 2 +- fl4health/checkpointing/server_module.py | 66 ++++++++++++++---- fl4health/clients/apfl_client.py | 3 +- fl4health/clients/basic_client.py | 5 +- fl4health/clients/scaffold_client.py | 15 +++- .../fedprox_server.py | 1 - fl4health/servers/base_server.py | 1 - fl4health/servers/fedpm_server.py | 21 +++--- fl4health/servers/instance_level_dp_server.py | 1 - fl4health/servers/nnunet_server.py | 6 +- fl4health/servers/scaffold_server.py | 15 ++-- fl4health/servers/sparse_coo_server.py | 68 ------------------- fl4health/utils/typing.py | 4 +- .../ag_news/dynamic_layer_exchange/client.py | 12 ++-- .../ag_news/sparse_tensor_exchange/client.py | 12 ++-- research/cifar10/adaptive_pfl/ditto/client.py | 15 ++-- .../cifar10/adaptive_pfl/fedprox/client.py | 15 ++-- .../cifar10/adaptive_pfl/fedprox/server.py | 14 ++-- .../adaptive_pfl/fenda_ditto/client.py | 15 ++-- research/cifar10/adaptive_pfl/mrmtl/client.py | 15 ++-- research/cifar10/ditto/client.py | 15 ++-- research/cifar10/ditto_deep_mmd/client.py | 15 ++-- research/cifar10/ditto_mkmmd/client.py | 15 ++-- research/cifar10/fed_dgga_pfl/ditto/client.py | 15 ++-- research/cifar10/fed_dgga_pfl/fenda/client.py | 15 ++-- .../fed_dgga_pfl/fenda_ditto/client.py | 15 ++-- research/cifar10/fedavg/client.py | 15 ++-- .../flamby/fed_heart_disease/apfl/client.py | 15 ++-- .../flamby/fed_heart_disease/ditto/client.py | 15 ++-- .../fed_heart_disease/fedadam/client.py | 15 ++-- .../fed_heart_disease/fedadam/server.py | 16 ++++- .../flamby/fed_heart_disease/fedavg/client.py | 15 ++-- .../flamby/fed_heart_disease/fedavg/server.py | 16 ++++- .../flamby/fed_heart_disease/fedper/client.py | 15 ++-- .../fed_heart_disease/fedprox/client.py | 15 ++-- .../fed_heart_disease/fedprox/server.py | 14 +++- .../flamby/fed_heart_disease/fenda/client.py | 15 ++-- .../flamby/fed_heart_disease/moon/client.py | 15 ++-- .../flamby/fed_heart_disease/moon/server.py | 16 ++++- .../flamby/fed_heart_disease/perfcl/client.py | 15 ++-- .../fed_heart_disease/scaffold/client.py | 15 ++-- .../fed_heart_disease/scaffold/server.py | 4 -- research/flamby/fed_isic2019/apfl/client.py | 15 ++-- research/flamby/fed_isic2019/ditto/client.py | 15 ++-- .../fed_isic2019/ditto_deep_mmd/client.py | 15 ++-- .../flamby/fed_isic2019/ditto_mkmmd/client.py | 15 ++-- .../flamby/fed_isic2019/fedadam/client.py | 15 ++-- .../flamby/fed_isic2019/fedadam/server.py | 14 +++- research/flamby/fed_isic2019/fedavg/client.py | 15 ++-- research/flamby/fed_isic2019/fedavg/server.py | 14 +++- research/flamby/fed_isic2019/fedper/client.py | 15 ++-- .../flamby/fed_isic2019/fedprox/client.py | 15 ++-- .../flamby/fed_isic2019/fedprox/server.py | 13 +++- research/flamby/fed_isic2019/fenda/client.py | 15 ++-- research/flamby/fed_isic2019/moon/client.py | 15 ++-- research/flamby/fed_isic2019/moon/server.py | 16 ++++- .../fed_isic2019/mr_mtl_mkmmd/client.py | 15 ++-- research/flamby/fed_isic2019/perfcl/client.py | 15 ++-- .../flamby/fed_isic2019/scaffold/client.py | 15 ++-- .../flamby/fed_isic2019/scaffold/server.py | 4 -- research/flamby/fed_ixi/apfl/client.py | 15 ++-- research/flamby/fed_ixi/ditto/client.py | 15 ++-- research/flamby/fed_ixi/fedadam/client.py | 15 ++-- research/flamby/fed_ixi/fedadam/server.py | 15 +++- research/flamby/fed_ixi/fedavg/client.py | 15 ++-- research/flamby/fed_ixi/fedavg/server.py | 20 ++++-- research/flamby/fed_ixi/fedper/client.py | 15 ++-- research/flamby/fed_ixi/fedprox/client.py | 15 ++-- research/flamby/fed_ixi/fedprox/server.py | 22 ++++-- research/flamby/fed_ixi/fenda/client.py | 15 ++-- research/flamby/fed_ixi/moon/client.py | 15 ++-- research/flamby/fed_ixi/moon/server.py | 16 ++++- research/flamby/fed_ixi/perfcl/client.py | 15 ++-- research/flamby/fed_ixi/scaffold/client.py | 15 ++-- research/flamby/fed_ixi/scaffold/server.py | 4 -- .../flamby_servers/full_exchange_server.py | 1 - research/picai/fedavg/client.py | 1 + research/picai/fl_nnunet/start_server.py | 19 ++++-- research/picai/single_node_trainer.py | 15 ++-- .../test_per_round_checkpointer.py | 16 +++-- tests/servers/test_base_server.py | 2 +- .../load_from_checkpoint_example/client.py | 1 + .../load_from_checkpoint_example/server.py | 3 - 96 files changed, 850 insertions(+), 417 deletions(-) delete mode 100644 fl4health/servers/sparse_coo_server.py diff --git a/examples/ae_examples/cvae_dim_example/server.py b/examples/ae_examples/cvae_dim_example/server.py index 5a5dd9e78..0ab919672 100644 --- a/examples/ae_examples/cvae_dim_example/server.py +++ b/examples/ae_examples/cvae_dim_example/server.py @@ -50,7 +50,7 @@ def main(config: Dict[str, Any]) -> None: parameter_exchanger = FullParameterExchanger() checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl") checkpoint_and_state_module = BaseServerCheckpointAndStateModule( - model=model, parameter_exchanger=parameter_exchanger, checkpointer=checkpointer + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer ) # Server performs simple FedAveraging as its server-side optimization strategy diff --git a/examples/ae_examples/cvae_examples/conv_cvae_example/server.py b/examples/ae_examples/cvae_examples/conv_cvae_example/server.py index f5ab44d9d..0a6ed0b3c 100644 --- a/examples/ae_examples/cvae_examples/conv_cvae_example/server.py +++ b/examples/ae_examples/cvae_examples/conv_cvae_example/server.py @@ -51,7 +51,7 @@ def main(config: Dict[str, Any]) -> None: parameter_exchanger = FullParameterExchanger() checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name) checkpoint_and_state_module = BaseServerCheckpointAndStateModule( - model=model, parameter_exchanger=parameter_exchanger, checkpointer=checkpointer + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer ) # Server performs simple FedAveraging as its server-side optimization strategy diff --git a/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py b/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py index 0f3261531..b4ab86523 100644 --- a/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py +++ b/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py @@ -51,7 +51,7 @@ def main(config: Dict[str, Any]) -> None: parameter_exchanger = FullParameterExchanger() checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name) checkpoint_and_state_module = BaseServerCheckpointAndStateModule( - model=model, parameter_exchanger=parameter_exchanger, checkpointer=checkpointer + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer ) # Server performs simple FedAveraging as its server-side optimization strategy diff --git a/examples/ae_examples/fedprox_vae_example/server.py b/examples/ae_examples/fedprox_vae_example/server.py index 055933b56..83c18afe1 100644 --- a/examples/ae_examples/fedprox_vae_example/server.py +++ b/examples/ae_examples/fedprox_vae_example/server.py @@ -10,8 +10,6 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule from fl4health.model_bases.autoencoders_base import VariationalAe -from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking -from fl4health.parameter_exchange.parameter_packer import ParameterPackerAdaptiveConstraint from fl4health.servers.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -49,10 +47,10 @@ def main(config: Dict[str, Any]) -> None: model_checkpoint_name = "best_VAE_model.pkl" # To facilitate checkpointing - parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint()) checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name) + checkpoint_and_state_module = AdaptiveConstraintServerCheckpointAndStateModule( - model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer + model=model, model_checkpointers=checkpointer ) # Server performs simple FedAveraging as its server-side optimization strategy and potentially adapts the diff --git a/examples/dp_fed_examples/instance_level_dp/client.py b/examples/dp_fed_examples/instance_level_dp/client.py index 05b754125..b1dfa2e22 100644 --- a/examples/dp_fed_examples/instance_level_dp/client.py +++ b/examples/dp_fed_examples/instance_level_dp/client.py @@ -48,12 +48,14 @@ def get_criterion(self, config: Config) -> _Loss: post_aggregation_checkpointer = BestLossOpacusCheckpointer( checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) # Load model and data data_path = Path(args.dataset_path) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - client = CifarClient(data_path, [Accuracy("accuracy")], device, checkpointer=checkpointer) + client = CifarClient( + data_path, [Accuracy("accuracy")], device, checkpoint_and_state_module=checkpoint_and_state_module + ) fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client()) client.shutdown() diff --git a/examples/dp_fed_examples/instance_level_dp/server.py b/examples/dp_fed_examples/instance_level_dp/server.py index 4e3037c6c..f13800116 100644 --- a/examples/dp_fed_examples/instance_level_dp/server.py +++ b/examples/dp_fed_examples/instance_level_dp/server.py @@ -10,6 +10,7 @@ from examples.models.cnn_model import Net from examples.utils.functions import make_dict_with_epochs_or_steps from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer +from fl4health.checkpointing.server_module import OpacusServerCheckpointAndStateModule from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.instance_level_dp_server import InstanceLevelDpServer @@ -67,7 +68,16 @@ def main(config: Dict[str, Any]) -> None: local_steps=config.get("local_steps"), ) - initial_model = map_model_to_opacus_model(Net()) + model = map_model_to_opacus_model(Net()) + + client_name = "".join(choices(string.ascii_uppercase, k=5)) + checkpoint_dir = "examples/dp_fed_examples/instance_level_dp/" + checkpoint_name = f"server_{client_name}_best_model.pkl" + checkpointer = BestLossOpacusCheckpointer(checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name) + + checkpoint_and_state_module = OpacusServerCheckpointAndStateModule( + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer + ) # ClientManager that performs Poisson type sampling client_manager = PoissonSamplingClientManager() @@ -75,7 +85,7 @@ def main(config: Dict[str, Any]) -> None: # Server performs simple FedAveraging with Instance Level Differential Privacy # Must be FedAvg sampling to ensure privacy loss is computed correctly strategy = OpacusBasicFedAvg( - model=initial_model, + model=model, fraction_fit=config["client_sampling_rate"], # Server waits for min_available_clients before starting FL rounds min_available_clients=config["n_clients"], @@ -86,22 +96,16 @@ def main(config: Dict[str, Any]) -> None: on_evaluate_config_fn=fit_config_fn, ) - client_name = "".join(choices(string.ascii_uppercase, k=5)) - checkpoint_dir = "examples/dp_fed_examples/instance_level_dp/" - checkpoint_name = f"server_{client_name}_best_model.pkl" - server = InstanceLevelDpServer( client_manager=client_manager, fl_config=config, - model=initial_model, - checkpointer=BestLossOpacusCheckpointer(checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name), - parameter_exchanger=FullParameterExchanger(), strategy=strategy, noise_multiplier=config["noise_multiplier"], local_epochs=config.get("local_epochs"), local_steps=config.get("local_steps"), batch_size=config["batch_size"], num_server_rounds=config["n_server_rounds"], + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/examples/fedprox_example/server.py b/examples/fedprox_example/server.py index 2c6590041..d5003ae65 100644 --- a/examples/fedprox_example/server.py +++ b/examples/fedprox_example/server.py @@ -81,9 +81,7 @@ def main(config: Dict[str, Any], server_address: str) -> None: reporters = [wandb_reporter, json_reporter] else: reporters = [json_reporter] - server = FedProxServer( - client_manager=client_manager, fl_config=config, strategy=strategy, model=None, reporters=reporters - ) + server = FedProxServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=reporters) fl.server.start_server( server=server, diff --git a/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py b/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py index 2f63891b6..1e45df1f3 100644 --- a/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py +++ b/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py @@ -11,6 +11,7 @@ from examples.models.ssl_models import CifarSslEncoder, CifarSslPredictionHead, CifarSslProjectionHead from examples.utils.functions import make_dict_with_epochs_or_steps from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.model_bases.fedsimclr_base import FedSimClrModel from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer @@ -51,6 +52,9 @@ def main(config: Dict[str, Any]) -> None: # To facilitate checkpointing parameter_exchanger = FullParameterExchanger() checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl") + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer + ) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -69,10 +73,8 @@ def main(config: Dict[str, Any]) -> None: server = FlServer( client_manager=SimpleClientManager(), fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/examples/fenda_ditto_example/client.py b/examples/fenda_ditto_example/client.py index 475ed256d..f90f26a81 100644 --- a/examples/fenda_ditto_example/client.py +++ b/examples/fenda_ditto_example/client.py @@ -114,7 +114,7 @@ def get_criterion(self, config: Config) -> _Loss: args.checkpoint_path, "fenda_ditto_client_post_agg.pkl" ) - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=pre_aggregation_checkpointer, post_aggregation=post_aggregation_checkpointer, ) @@ -123,7 +123,7 @@ def get_criterion(self, config: Config) -> _Loss: [Accuracy()], device, args.checkpoint_path, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=[JsonReporter()], ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/examples/nnunet_example/server.py b/examples/nnunet_example/server.py index 29d9b19de..cedae28ac 100644 --- a/examples/nnunet_example/server.py +++ b/examples/nnunet_example/server.py @@ -22,6 +22,8 @@ from flwr.server.strategy import FedAvg from examples.utils.functions import make_dict_with_epochs_or_steps +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer +from fl4health.checkpointing.server_module import NnUnetServerCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.nnunet_server import NnunetServer from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn @@ -91,17 +93,22 @@ def main( initial_parameters=params, ) + state_checkpointer = ( + PerRoundStateCheckpointer(Path(intermediate_server_state_dir)) + if intermediate_server_state_dir is not None + else None + ) + checkpoint_and_state_module = NnUnetServerCheckpointAndStateModule( + model=None, parameter_exchanger=FullParameterExchanger(), state_checkpointer=state_checkpointer + ) + server = NnunetServer( client_manager=SimpleClientManager(), fl_config=config, # The fit_config_fn contains all of the necessary information for param initialization, so we reuse it here on_init_parameters_config_fn=fit_config_fn, - model=None, - parameter_exchanger=FullParameterExchanger(), strategy=strategy, - intermediate_server_state_dir=( - Path(intermediate_server_state_dir) if intermediate_server_state_dir is not None else None - ), + checkpoint_and_state_module=checkpoint_and_state_module, server_name=server_name, ) diff --git a/examples/sparse_tensor_partial_exchange_example/server.py b/examples/sparse_tensor_partial_exchange_example/server.py index f2d18cb82..780b33a28 100644 --- a/examples/sparse_tensor_partial_exchange_example/server.py +++ b/examples/sparse_tensor_partial_exchange_example/server.py @@ -8,7 +8,7 @@ from examples.models.cnn_model import Net from examples.utils.functions import make_dict_with_epochs_or_steps -from fl4health.servers.sparse_coo_server import SparseCooServer +from fl4health.servers.base_server import FlServer from fl4health.strategies.fedavg_sparse_coo_tensor import FedAvgSparseCooTensor from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn @@ -59,7 +59,7 @@ def main(config: Dict[str, Any]) -> None: ) client_manager = SimpleClientManager() - server = SparseCooServer(client_manager=client_manager, fl_config=config, strategy=strategy) + server = FlServer(client_manager=client_manager, fl_config=config, strategy=strategy) fl.server.start_server( server=server, diff --git a/fl4health/checkpointing/checkpointer.py b/fl4health/checkpointing/checkpointer.py index 79af6d7bf..2379ce4c1 100644 --- a/fl4health/checkpointing/checkpointer.py +++ b/fl4health/checkpointing/checkpointer.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from logging import ERROR, INFO, WARNING from pathlib import Path -from typing import Any, Callable, Dict, Optional, overload +from typing import Any, Callable, Dict, Optional import torch import torch.nn as nn @@ -40,12 +40,20 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca """ raise NotImplementedError("maybe_checkpoint must be implemented by inheriting classes") - @overload - def load_checkpoint(self) -> nn.Module: - return torch.load(self.checkpoint_path) + def load_checkpoint(self, path_to_checkpoint: str | None = None) -> nn.Module: + """ + Checkpointer with the option to either specify a checkpoint path or fall back on the internal path of the + checkpointer + + Args: + path_to_checkpoint (str | None, optional): If provided, the checkpoint will be loaded from this path. + If not specified, the checkpointer will load from self.checkpoint_path. Defaults to None. - @overload - def load_checkpoint(self, path_to_checkpoint: str) -> nn.Module: + Returns: + nn.Module: _description_ + """ + if path_to_checkpoint is None: + return torch.load(self.checkpoint_path) return torch.load(path_to_checkpoint) diff --git a/fl4health/checkpointing/client_module.py b/fl4health/checkpointing/client_module.py index aa22872dd..76eda7436 100644 --- a/fl4health/checkpointing/client_module.py +++ b/fl4health/checkpointing/client_module.py @@ -136,7 +136,7 @@ def save_state(self, state_checkpoint_name: str, state: Dict[str, Any]) -> None: """ if self.state_checkpointer is not None: - self.state_checkpointer.save_checkpoint(state_checkpoint_name) + self.state_checkpointer.save_checkpoint(state_checkpoint_name, state) else: raise ValueError("Attempting to save state but no state checkpointer is specified") diff --git a/fl4health/checkpointing/opacus_checkpointer.py b/fl4health/checkpointing/opacus_checkpointer.py index 907fe9650..cb0c9a73c 100644 --- a/fl4health/checkpointing/opacus_checkpointer.py +++ b/fl4health/checkpointing/opacus_checkpointer.py @@ -74,7 +74,7 @@ def _extract_and_save_state(self, model: nn.Module) -> None: with open(self.checkpoint_path, "wb") as handle: pickle.dump(model_state_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) - def load_checkpoint(self) -> nn.Module: + def load_checkpoint(self, path_to_checkpoint: str | None = None) -> nn.Module: raise NotImplementedError( "When loading from Opacus checkpointers, you need to provide a model into which state is loaded. " "Please use load_best_checkpoint_into_model instead and provide model architecture to load state into." diff --git a/fl4health/checkpointing/server_module.py b/fl4health/checkpointing/server_module.py index 55af4484d..fa91ff7b5 100644 --- a/fl4health/checkpointing/server_module.py +++ b/fl4health/checkpointing/server_module.py @@ -5,7 +5,7 @@ from flwr.common import Parameters from flwr.common.logger import log from flwr.common.parameter import parameters_to_ndarrays -from flwr.common.typing import NDArrays, Scalar +from flwr.common.typing import Scalar from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer, TorchModuleCheckpointer from fl4health.checkpointing.opacus_checkpointer import OpacusCheckpointer @@ -170,7 +170,7 @@ def save_state( other_state["model"] = self.model else: raise ValueError("Key 'model' already exists in the other_state dictionary.") - self.state_checkpointer.save_checkpoint(state_checkpoint_name) + self.state_checkpointer.save_checkpoint(state_checkpoint_name, checkpoint_dict=other_state) else: raise ValueError("Attempting to save state but no state checkpointer is specified") @@ -239,7 +239,7 @@ def __init__( ), "Parameter exchanger must be of based type FullParameterExchangerWithPacking" super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) - def _hydrate_model_for_checkpointing(self, server_parameters: Parameters): + def _hydrate_model_for_checkpointing(self, server_parameters: Parameters) -> None: """ This function is used as a means of saving the server-side model after aggregation in the FL training trajectory. Presently, the server only holds Flower Parameters, which are essentially just ndarrays. Without @@ -294,7 +294,7 @@ def __init__( checkpointer will save much more than just the model being trained. Defaults to None. """ if model is not None: - model_size = len(self.model.state_dict()) + model_size = len(model.state_dict()) parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)) else: parameter_exchanger = None @@ -397,7 +397,7 @@ def __init__( checkpointer will save much more than just the model being trained. Defaults to None. """ if model is not None: - parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerWithClippingBit()) + parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerWithLayerNames()) else: parameter_exchanger = None super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) @@ -447,13 +447,14 @@ def __init__( state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ - This module is meant to handle FL flows with adaptive constraints, where the server and client communicate - a loss weight parameter in addition to the model weights. Unlike the module on the client side, this module - has no concept of pre- or post-aggregation checkpointing. It only considers checkpointing the global server - model after aggregation, perhaps based on validation statistics retrieved on the client side by running a - federated evaluation step. Multiple model checkpointers may be used. For state checkpointing, which saves the - state of the entire server-side FL process to help with FL restarts, we allow only a single checkpointer - responsible for saving the state after each fit and eval round of FL. + This module is meant to handle FL flows with Opacus models where special treatment by the checkpointers is + required. This module simply ensures the checkpointers are of the proper type before proceeding. + Unlike the module on the client side, this module has no concept of pre- or post-aggregation checkpointing. + It only considers checkpointing the global server model after aggregation, perhaps based on validation + statistics retrieved on the client side by running a federated evaluation step. Multiple model checkpointers + may be used. For state checkpointing, which saves the state of the entire server-side FL process to help with + FL restarts, we allow only a single checkpointer responsible for saving the state after each fit and eval + round of FL. Args: model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture @@ -536,3 +537,44 @@ def __init__( "possible and desired" ) self._check_if_shared_checkpoint_names() + + +class DpScaffoldServerCheckpointAndStateModule(ScaffoldServerCheckpointAndStateModule): + def __init__( + self, + model: nn.Module | None = None, + model_checkpointers: CheckpointModuleInput = None, + state_checkpointer: PerRoundStateCheckpointer | None = None, + ) -> None: + """ + This module is meant to handle DP SCAFFOLD model and state checkpointing on the server-side of an FL process. + Unlike the module on the client side, this module has no concept of pre- or post-aggregation checkpointing. + It only considers checkpointing the global server model after aggregation, perhaps based on validation + statistics retrieved on the client side by running a federated evaluation step. Multiple model checkpointers + may be used. For state checkpointing, which saves the state of the entire server-side FL process to help with + FL restarts, we allow only a single checkpointer responsible for saving the state after each fit and eval + round of FL. + + Args: + model (nn.Module | None, optional): Model architecture to be saved. The module will use this architecture + to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. + Recall that servers only have parameters rather than torch models. So we need to know where to route + these parameters to allow for real models to be saved. Defaults to None. + model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. + state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be + used to preserve FL training state to facilitate restarting training if interrupted. Generally, this + checkpointer will save much more than just the model being trained. Defaults to None. + """ + super().__init__(model, model_checkpointers, state_checkpointer) + self._ensure_checkpointers_are_of_opacus_type() + + def _ensure_checkpointers_are_of_opacus_type(self) -> None: + """ + Helper function to ensure that the provided checkpointers are explicitly compatible with Opacus + """ + if self.model_checkpointers is not None: + for checkpointer in self.model_checkpointers: + assert isinstance( + checkpointer, OpacusCheckpointer + ), "Provided checkpointers must have base class OpacusCheckpointer" diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index 351418ea8..32338fac9 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -30,7 +30,8 @@ def __init__( """ Client specifically implementing the APFL Algorithm: https://arxiv.org/abs/2003.13461 Twin models are trained. One of them is globally shared by all clients and aggregated on the server. - The other is strictly trained locally by each client. Predictions are made by a convex combination of the models. + The other is strictly trained locally by each client. Predictions are made by a convex combination of the + models. Args: data_path (Path): path to the data to be used to load the data for client-side training diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index 9b15c21b2..170eb2996 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -83,7 +83,7 @@ def __init__( else: # Define a default module that does nothing. self.checkpoint_and_state_module = ClientCheckpointAndStateModule( - model=None, parameter_exchanger=None, model_checkpointers=None, state_checkpointer=None + pre_aggregation=None, post_aggregation=None, state_checkpointer=None ) # Initialize reporters with client information. @@ -401,7 +401,8 @@ def _should_evaluate_after_fit(self, evaluate_after_fit: bool) -> bool: bool: Whether to perform an evaluation on the client validation set after fitting. """ pre_aggregation_checkpointing_enabled = ( - self.checkpointer is not None and self.checkpointer.pre_aggregation is not None + self.checkpoint_and_state_module is not None + and self.checkpoint_and_state_module.pre_aggregation is not None ) return evaluate_after_fit or pre_aggregation_checkpointing_enabled diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index e159b2e44..ee373563f 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -279,7 +279,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: ScaffoldClient.__init__( self, @@ -287,7 +290,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) InstanceLevelDpClient.__init__( @@ -296,5 +302,8 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) diff --git a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py index 20e7bc67c..a17e96b3a 100644 --- a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py +++ b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py @@ -4,7 +4,6 @@ from flwr.server.client_manager import ClientManager from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule -from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking from fl4health.reporting.base_reporter import BaseReporter from fl4health.servers.base_server import FlServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint diff --git a/fl4health/servers/base_server.py b/fl4health/servers/base_server.py index 63024baaa..ad6d7cd9e 100644 --- a/fl4health/servers/base_server.py +++ b/fl4health/servers/base_server.py @@ -1,6 +1,5 @@ import datetime from logging import DEBUG, ERROR, INFO, WARNING -from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import torch.nn as nn diff --git a/fl4health/servers/fedpm_server.py b/fl4health/servers/fedpm_server.py index 2a35411ec..dae383a2e 100644 --- a/fl4health/servers/fedpm_server.py +++ b/fl4health/servers/fedpm_server.py @@ -20,10 +20,10 @@ def __init__( strategy: FedPm, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: LayerNamesServerCheckpointAndStateModule | None = None, - reset_frequency: int = 1, on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, + reset_frequency: int = 1, ) -> None: """ Custom FL Server for the FedPM algorithm to allow for resetting the beta priors in Bayesian aggregation, @@ -32,20 +32,16 @@ def __init__( Args: client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if they are to be sampled at all. - fl_config (Config): This should be the configuration that was used to setup the federated training. - In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For - example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. - NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. + fl_config (Config): _description_ strategy (FedPm): The aggregation strategy to be used by the server to handle client updates and other - information potentially sent by the participating clients. This strategy must be of FedPm type. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should - send data to before and after each round. - checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used - to handle both model checkpointing and state checkpointing. The former is aimed at saving model + information potentially sent by the participating clients. This strategy must be of SCAFFOLD type. + reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the server + should send data to before and after each round. + checkpoint_and_state_module (LayerNamesServerCheckpointAndStateModule | None, optional): This module is + used to handle both model checkpointing and state checkpointing. The former is aimed at saving model artifacts to be used or evaluated after training. The later is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. - reset_frequency (int): Determines the frequency with which the beta priors are reset. Defaults to 1. on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request @@ -55,9 +51,10 @@ def __init__( accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. + reset_frequency (int, optional): Determines the frequency with which the beta priors are reset. + Defaults to 1. """ super().__init__( - self, client_manager=client_manager, fl_config=fl_config, strategy=strategy, diff --git a/fl4health/servers/instance_level_dp_server.py b/fl4health/servers/instance_level_dp_server.py index 74b4a50ce..16cca8db5 100644 --- a/fl4health/servers/instance_level_dp_server.py +++ b/fl4health/servers/instance_level_dp_server.py @@ -3,7 +3,6 @@ from math import ceil from typing import Callable, Dict, List, Optional, Tuple -import torch.nn as nn from flwr.common.logger import log from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index 0af362efd..c8f9019b6 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -2,7 +2,6 @@ import warnings from collections.abc import Callable, Sequence from logging import INFO -from pathlib import Path from typing import Any, Dict, Optional, Tuple, Type, Union import torch.nn as nn @@ -15,7 +14,6 @@ from flwr.server.strategy import Strategy from fl4health.checkpointing.server_module import NnUnetServerCheckpointAndStateModule -from fl4health.parameter_exchange.parameter_exchanger_base import ExchangerType from fl4health.reporting.base_reporter import BaseReporter from fl4health.reporting.reports_manager import ReportsManager from fl4health.servers.base_server import FlServer @@ -107,7 +105,6 @@ def __init__( Must match the nnunet_trainer_class passed to the NnunetClient. """ super().__init__( - self, client_manager=client_manager, fl_config=fl_config, strategy=strategy, @@ -252,7 +249,7 @@ def _save_server_state(self) -> None: } self.checkpoint_and_state_module.save_state( - state_checkpoint=self.state_checkpoint_name, + state_checkpoint_name=self.state_checkpoint_name, server_parameters=self.parameters, other_state=other_state_to_save, ) @@ -287,3 +284,4 @@ def _load_server_state(self) -> bool: self, server_state, "enable_deep_supervision", "enable_deep_supervision", bool ) narrow_dict_type_and_set_attribute(self, server_state, "nnunet_config", "nnunet_config", NnunetConfig) + return True diff --git a/fl4health/servers/scaffold_server.py b/fl4health/servers/scaffold_server.py index 2aa688235..2d3bf9416 100644 --- a/fl4health/servers/scaffold_server.py +++ b/fl4health/servers/scaffold_server.py @@ -9,7 +9,10 @@ from flwr.server.history import History from flwr.server.server import fit_clients -from fl4health.checkpointing.server_module import ScaffoldServerCheckpointAndStateModule +from fl4health.checkpointing.server_module import ( + DpScaffoldServerCheckpointAndStateModule, + ScaffoldServerCheckpointAndStateModule, +) from fl4health.reporting.base_reporter import BaseReporter from fl4health.servers.base_server import FlServer from fl4health.servers.instance_level_dp_server import InstanceLevelDpServer @@ -65,7 +68,6 @@ def __init__( Defaults to False. """ super().__init__( - self, client_manager=client_manager, fl_config=fl_config, strategy=strategy, @@ -184,7 +186,7 @@ def __init__( local_epochs: Optional[int] = None, local_steps: Optional[int] = None, delta: Optional[float] = None, - checkpoint_and_state_module: ScaffoldServerCheckpointAndStateModule | None = None, + checkpoint_and_state_module: DpScaffoldServerCheckpointAndStateModule | None = None, warm_start: bool = False, reporters: Sequence[BaseReporter] | None = None, on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, @@ -215,8 +217,8 @@ def __init__( strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. This strategy must be of SCAFFOLD type. - checkpoint_and_state_module (ScaffoldServerCheckpointAndStateModule | None, optional): This module is used - to handle both model checkpointing and state checkpointing. The former is aimed at saving model + checkpoint_and_state_module (DpScaffoldServerCheckpointAndStateModule | None, optional): This module is + used to handle both model checkpointing and state checkpointing. The former is aimed at saving model artifacts to be used or evaluated after training. The later is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. @@ -245,10 +247,10 @@ def __init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, - checkpoint_and_state_module=checkpoint_and_state_module, warm_start=warm_start, reporters=reporters, on_init_parameters_config_fn=on_init_parameters_config_fn, + checkpoint_and_state_module=checkpoint_and_state_module, server_name=server_name, accept_failures=accept_failures, ) @@ -263,7 +265,6 @@ def __init__( batch_size=batch_size, delta=delta, num_server_rounds=num_server_rounds, - checkpoint_and_state_module=checkpoint_and_state_module, on_init_parameters_config_fn=on_init_parameters_config_fn, server_name=server_name, accept_failures=accept_failures, diff --git a/fl4health/servers/sparse_coo_server.py b/fl4health/servers/sparse_coo_server.py deleted file mode 100644 index ed8129a09..000000000 --- a/fl4health/servers/sparse_coo_server.py +++ /dev/null @@ -1,68 +0,0 @@ -from collections.abc import Sequence -from typing import Callable, Dict, Optional, Tuple - -from flwr.common import Parameters -from flwr.common.typing import Config, Scalar -from flwr.server.client_manager import ClientManager -from flwr.server.server import FitResultsAndFailures - -from fl4health.checkpointing.server_module import SparseCooServerCheckpointAndStateModule -from fl4health.reporting.base_reporter import BaseReporter -from fl4health.servers.base_server import FlServer -from fl4health.strategies.fedpm import FedPm - - -class SparseCooServer(FlServer): - def __init__( - self, - client_manager: ClientManager, - fl_config: Config, - strategy: FedPm, - reporters: Sequence[BaseReporter] | None = None, - checkpoint_and_state_module: SparseCooServerCheckpointAndStateModule | None = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, - server_name: str | None = None, - accept_failures: bool = True, - ) -> None: - """ - Custom FL Server for the FedPM algorithm to allow for resetting the beta priors in Bayesian aggregation, - as specified in http://arxiv.org/pdf/2209.15328. - - Args: - client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if - they are to be sampled at all. - fl_config (Config): This should be the configuration that was used to setup the federated training. - In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For - example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. - NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. - strategy (FedPm): The aggregation strategy to be used by the server to handle client updates and other - information potentially sent by the participating clients. This strategy must be of FedPm type. - reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should - send data to before and after each round. - checkpoint_and_state_module (SparseCooServerCheckpointAndStateModule | None, optional): This module is used - to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state - (including models) such that if FL training is interrupted, the process may be restarted. If no - module is provided, no checkpointing or state preservation will happen. Defaults to None. - reset_frequency (int): Determines the frequency with which the beta priors are reset. Defaults to 1. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to - configure how one asks a client to provide parameters from which to initialize all other clients by - providing a Config dictionary. If this is none, then a blank config is sent with the parameter request - (which is default behavior for flower servers). Defaults to None. - server_name (str | None, optional): An optional string name to uniquely identify server. This name is also - used as part of any state checkpointing done by the server. Defaults to None. - accept_failures (bool, optional): Determines whether the server should accept failures during training or - evaluation from clients or not. If set to False, this will cause the server to shutdown all clients - and throw an exception. Defaults to True. - """ - super().__init__( - self, - client_manager=client_manager, - fl_config=fl_config, - strategy=strategy, - reporters=reporters, - checkpoint_and_state_module=checkpoint_and_state_module, - on_init_parameters_config_fn=on_init_parameters_config_fn, - server_name=server_name, - accept_failures=accept_failures, - ) diff --git a/fl4health/utils/typing.py b/fl4health/utils/typing.py index ee0149d66..a298bd86d 100644 --- a/fl4health/utils/typing.py +++ b/fl4health/utils/typing.py @@ -1,7 +1,7 @@ import logging from collections.abc import Callable from enum import Enum -from typing import List, Tuple, TypeVar, Union +from typing import List, Tuple, Union import torch import torch.nn as nn @@ -9,8 +9,6 @@ from flwr.common.typing import NDArrays from flwr.server.client_proxy import ClientProxy -from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger - TorchInputType = torch.Tensor | dict[str, torch.Tensor] TorchTargetType = torch.Tensor | dict[str, torch.Tensor] TorchPredType = dict[str, torch.Tensor] diff --git a/research/ag_news/dynamic_layer_exchange/client.py b/research/ag_news/dynamic_layer_exchange/client.py index 844b8a972..4dc920aee 100644 --- a/research/ag_news/dynamic_layer_exchange/client.py +++ b/research/ag_news/dynamic_layer_exchange/client.py @@ -38,8 +38,10 @@ def __init__( exchange_percentage: float, norm_threshold: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, store_initial_model: bool = True, ) -> None: super().__init__( @@ -47,8 +49,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, store_initial_model=store_initial_model, ) assert 0 < exchange_percentage <= 1.0 and norm_threshold > 0 @@ -180,7 +184,7 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ # Checkpointing checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -191,7 +195,7 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ learning_rate=args.learning_rate, exchange_percentage=args.exchange_percentage, norm_threshold=args.norm_threshold, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) # grpc_max_message_length is reset here so the entire model can be exchanged between the server and clients. # Note that the server must be started with the same grpc_max_message_length. Otherwise communication diff --git a/research/ag_news/sparse_tensor_exchange/client.py b/research/ag_news/sparse_tensor_exchange/client.py index 451b999a9..2b1aba8d9 100644 --- a/research/ag_news/sparse_tensor_exchange/client.py +++ b/research/ag_news/sparse_tensor_exchange/client.py @@ -37,8 +37,10 @@ def __init__( learning_rate: float, sparsity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, store_initial_model: bool = True, ) -> None: super().__init__( @@ -46,8 +48,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, store_initial_model=store_initial_model, ) self.sparsity_level = sparsity_level @@ -149,7 +153,7 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ # Checkpointing checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -159,7 +163,7 @@ def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[ device, learning_rate=args.learning_rate, sparsity_level=args.sparsity_level, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) # grpc_max_message_length is reset here so the entire model can be exchanged between the server and clients. # Note that the server must be started with the same grpc_max_message_length. Otherwise communication diff --git a/research/cifar10/adaptive_pfl/ditto/client.py b/research/cifar10/adaptive_pfl/ditto/client.py index fce7a2236..adf14587d 100644 --- a/research/cifar10/adaptive_pfl/ditto/client.py +++ b/research/cifar10/adaptive_pfl/ditto/client.py @@ -16,6 +16,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import F1, Accuracy, Metric @@ -34,14 +35,20 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.heterogeneity_level = heterogeneity_level @@ -138,7 +145,7 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -161,7 +168,7 @@ def get_model(self, config: Config) -> nn.Module: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/cifar10/adaptive_pfl/fedprox/client.py b/research/cifar10/adaptive_pfl/fedprox/client.py index 72e3736f6..3c8de5279 100644 --- a/research/cifar10/adaptive_pfl/fedprox/client.py +++ b/research/cifar10/adaptive_pfl/fedprox/client.py @@ -16,6 +16,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fed_prox_client import FedProxClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import F1, Accuracy, Metric @@ -34,14 +35,20 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.heterogeneity_level = heterogeneity_level @@ -142,7 +149,7 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -165,7 +172,7 @@ def get_model(self, config: Config) -> nn.Module: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/cifar10/adaptive_pfl/fedprox/server.py b/research/cifar10/adaptive_pfl/fedprox/server.py index 16667615f..e204aa575 100644 --- a/research/cifar10/adaptive_pfl/fedprox/server.py +++ b/research/cifar10/adaptive_pfl/fedprox/server.py @@ -10,6 +10,7 @@ from flwr.server.client_manager import SimpleClientManager from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule from fl4health.servers.adaptive_constraint_servers.fedprox_server import FedProxServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -51,6 +52,10 @@ def main( config["n_server_rounds"], config["n_clients"], ) + + # Initializing the model on the server side + model = ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) best_checkpoint_name = "server_best_model.pkl" last_checkpoint_name = "server_last_model.pkl" @@ -58,10 +63,12 @@ def main( BestLossTorchModuleCheckpointer(checkpoint_dir, best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, last_checkpoint_name), ] + checkpoint_and_state_module = AdaptiveConstraintServerCheckpointAndStateModule( + model=model, model_checkpointers=checkpointer + ) client_manager = SimpleClientManager() - # Initializing the model on the server side - model = ConvNet(in_channels=3, use_bn=False, dropout=0.1, hidden=512) + # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( min_fit_clients=config["n_clients"], @@ -81,9 +88,8 @@ def main( server = FedProxServer( client_manager=client_manager, fl_config=config, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/client.py b/research/cifar10/adaptive_pfl/fenda_ditto/client.py index 670e78eb8..447e2b9b0 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/client.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/client.py @@ -17,6 +17,7 @@ from fl4health.clients.fenda_ditto_client import FendaDittoClient from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.sequential_split_models import SequentiallySplitModel +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import F1, Accuracy, Metric @@ -35,7 +36,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, freeze_global_feature_extractor: bool = False, ) -> None: super().__init__( @@ -43,7 +47,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, freeze_global_feature_extractor=freeze_global_feature_extractor, ) self.client_number = client_number @@ -152,7 +159,7 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -175,7 +182,7 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, freeze_global_feature_extractor=args.freeze_global_extractor, ) diff --git a/research/cifar10/adaptive_pfl/mrmtl/client.py b/research/cifar10/adaptive_pfl/mrmtl/client.py index 6899072c8..2e9135f9f 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/client.py +++ b/research/cifar10/adaptive_pfl/mrmtl/client.py @@ -16,6 +16,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.mr_mtl_client import MrMtlClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import F1, Accuracy, Metric @@ -34,14 +35,20 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.heterogeneity_level = heterogeneity_level @@ -136,7 +143,7 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -159,7 +166,7 @@ def get_model(self, config: Config) -> nn.Module: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/cifar10/ditto/client.py b/research/cifar10/ditto/client.py index 2224bfe14..15328b301 100644 --- a/research/cifar10/ditto/client.py +++ b/research/cifar10/ditto/client.py @@ -16,6 +16,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data from fl4health.utils.losses import LossMeterType @@ -36,7 +37,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, use_partitioned_data: bool = False, ) -> None: super().__init__( @@ -44,7 +48,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.use_partitioned_data = use_partitioned_data self.client_number = client_number @@ -199,7 +206,7 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -218,7 +225,7 @@ def get_model(self, config: Config) -> nn.Module: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, use_partitioned_data=args.use_partitioned_data, ) diff --git a/research/cifar10/ditto_deep_mmd/client.py b/research/cifar10/ditto_deep_mmd/client.py index eac5c8f4d..c1a2d54ad 100644 --- a/research/cifar10/ditto_deep_mmd/client.py +++ b/research/cifar10/ditto_deep_mmd/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.deep_mmd_clients.ditto_deep_mmd_client import DittoDeepMmdClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data from fl4health.utils.losses import LossMeterType @@ -42,9 +43,12 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, deep_mmd_loss_weight: float = 10, deep_mmd_loss_depth: int = 1, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, use_partitioned_data: bool = False, ) -> None: feature_extraction_layers_with_size = OrderedDict(list(BASELINE_LAYERS.items())[-1 * deep_mmd_loss_depth :]) @@ -53,7 +57,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, deep_mmd_loss_weight=deep_mmd_loss_weight, feature_extraction_layers_with_size=feature_extraction_layers_with_size, ) @@ -226,7 +233,7 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -245,7 +252,7 @@ def get_model(self, config: Config) -> nn.Module: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, deep_mmd_loss_depth=args.deep_mmd_loss_depth, deep_mmd_loss_weight=args.mu, use_partitioned_data=args.use_partitioned_data, diff --git a/research/cifar10/ditto_mkmmd/client.py b/research/cifar10/ditto_mkmmd/client.py index 44f3a21bd..ed8d18821 100644 --- a/research/cifar10/ditto_mkmmd/client.py +++ b/research/cifar10/ditto_mkmmd/client.py @@ -16,6 +16,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.mkmmd_clients.ditto_mkmmd_client import DittoMkMmdClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data from fl4health.utils.losses import LossMeterType @@ -42,7 +43,10 @@ def __init__( feature_l2_norm_weight: float = 1, mkmmd_loss_depth: int = 1, beta_global_update_interval: int = 20, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, use_partitioned_data: bool = False, ) -> None: super().__init__( @@ -50,7 +54,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, mkmmd_loss_weight=mkmmd_loss_weight, feature_extraction_layers=BASELINE_LAYERS[-1 * mkmmd_loss_depth :], feature_l2_norm_weight=feature_l2_norm_weight, @@ -244,7 +251,7 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -263,7 +270,7 @@ def get_model(self, config: Config) -> nn.Module: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, feature_l2_norm_weight=args.l2, mkmmd_loss_depth=args.mkmmd_loss_depth, mkmmd_loss_weight=args.mu, diff --git a/research/cifar10/fed_dgga_pfl/ditto/client.py b/research/cifar10/fed_dgga_pfl/ditto/client.py index fce7a2236..adf14587d 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/ditto/client.py @@ -16,6 +16,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import F1, Accuracy, Metric @@ -34,14 +35,20 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.heterogeneity_level = heterogeneity_level @@ -138,7 +145,7 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -161,7 +168,7 @@ def get_model(self, config: Config) -> nn.Module: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/cifar10/fed_dgga_pfl/fenda/client.py b/research/cifar10/fed_dgga_pfl/fenda/client.py index e75493055..1913bb819 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda/client.py @@ -16,6 +16,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_client import FendaClient from fl4health.model_bases.fenda_base import FendaModel +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import F1, Accuracy, Metric @@ -34,14 +35,20 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.heterogeneity_level = heterogeneity_level @@ -136,7 +143,7 @@ def get_model(self, config: Config) -> FendaModel: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -159,7 +166,7 @@ def get_model(self, config: Config) -> FendaModel: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py index 5ac46d079..a8bad44b8 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py @@ -17,6 +17,7 @@ from fl4health.clients.fenda_ditto_client import FendaDittoClient from fl4health.model_bases.fenda_base import FendaModel from fl4health.model_bases.sequential_split_models import SequentiallySplitModel +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import F1, Accuracy, Metric @@ -35,7 +36,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, freeze_global_feature_extractor: bool = False, ) -> None: super().__init__( @@ -43,7 +47,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, freeze_global_feature_extractor=freeze_global_feature_extractor, ) self.client_number = client_number @@ -152,7 +159,7 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -175,7 +182,7 @@ def get_global_model(self, config: Config) -> SequentiallySplitModel: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, freeze_global_feature_extractor=args.freeze_global_extractor, ) diff --git a/research/cifar10/fedavg/client.py b/research/cifar10/fedavg/client.py index 168b9d1f8..3e8c18078 100644 --- a/research/cifar10/fedavg/client.py +++ b/research/cifar10/fedavg/client.py @@ -16,6 +16,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.load_data import load_cifar10_data, load_cifar10_test_data from fl4health.utils.losses import LossMeterType @@ -36,7 +37,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, use_partitioned_data: bool = False, ) -> None: super().__init__( @@ -44,7 +48,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.use_partitioned_data = use_partitioned_data self.client_number = client_number @@ -198,7 +205,7 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), @@ -217,7 +224,7 @@ def get_model(self, config: Config) -> nn.Module: client_number=args.client_number, learning_rate=args.learning_rate, heterogeneity_level=args.beta, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, use_partitioned_data=args.use_partitioned_data, ) diff --git a/research/flamby/fed_heart_disease/apfl/client.py b/research/flamby/fed_heart_disease/apfl/client.py index 50b6de939..f0a0ba002 100644 --- a/research/flamby/fed_heart_disease/apfl/client.py +++ b/research/flamby/fed_heart_disease/apfl/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import ApflModule +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric from research.flamby.flamby_data_utils import construct_fed_heard_disease_train_val_datasets @@ -32,14 +33,20 @@ def __init__( learning_rate: float, alpha_learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) assert 0 <= client_number < NUM_CLIENTS @@ -132,7 +139,7 @@ def get_criterion(self, config: Config) -> _Loss: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedHeartDiseaseApflClient( data_path=args.dataset_dir, @@ -141,7 +148,7 @@ def get_criterion(self, config: Config) -> _Loss: client_number=args.client_number, learning_rate=args.learning_rate, alpha_learning_rate=args.alpha_learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_heart_disease/ditto/client.py b/research/flamby/fed_heart_disease/ditto/client.py index fbddce0f6..0663d23ef 100644 --- a/research/flamby/fed_heart_disease/ditto/client.py +++ b/research/flamby/fed_heart_disease/ditto/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -32,14 +33,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -141,7 +148,7 @@ def get_criterion(self, config: Config) -> _Loss: else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedHeartDiseaseDittoClient( data_path=Path(args.dataset_dir), @@ -149,7 +156,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_heart_disease/fedadam/client.py b/research/flamby/fed_heart_disease/fedadam/client.py index a59222c58..ea3563c3c 100644 --- a/research/flamby/fed_heart_disease/fedadam/client.py +++ b/research/flamby/fed_heart_disease/fedadam/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric from research.flamby.flamby_data_utils import construct_fed_heard_disease_train_val_datasets @@ -31,14 +32,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -113,7 +120,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -123,7 +130,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_heart_disease/fedadam/server.py b/research/flamby/fed_heart_disease/fedadam/server.py index 5fe3e3071..017f61b58 100644 --- a/research/flamby/fed_heart_disease/fedadam/server.py +++ b/research/flamby/fed_heart_disease/fedadam/server.py @@ -11,6 +11,8 @@ from flwr.server.strategy import FedAdam from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -28,6 +30,9 @@ def main( config["n_server_rounds"], ) + model = Baseline() + summarize_model_info(model) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" federated_checkpointing: bool = config.get("federated_checkpointing", True) @@ -38,9 +43,11 @@ def main( else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer + ) + client_manager = SimpleClientManager() - model = Baseline() - summarize_model_info(model) strategy = FedAdam( min_fit_clients=config["n_clients"], @@ -57,7 +64,10 @@ def main( ) server = FullExchangeServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_heart_disease/fedavg/client.py b/research/flamby/fed_heart_disease/fedavg/client.py index 1ca53e5a2..dd4f0a261 100644 --- a/research/flamby/fed_heart_disease/fedavg/client.py +++ b/research/flamby/fed_heart_disease/fedavg/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric from research.flamby.flamby_data_utils import construct_fed_heard_disease_train_val_datasets @@ -31,14 +32,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -113,7 +120,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -123,7 +130,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_heart_disease/fedavg/server.py b/research/flamby/fed_heart_disease/fedavg/server.py index a95b205ed..1859f0711 100644 --- a/research/flamby/fed_heart_disease/fedavg/server.py +++ b/research/flamby/fed_heart_disease/fedavg/server.py @@ -11,6 +11,8 @@ from flwr.server.strategy import FedAvg from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -26,6 +28,9 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config["n_server_rounds"], ) + model = Baseline() + summarize_model_info(model) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" federated_checkpointing: bool = config.get("federated_checkpointing", True) @@ -36,9 +41,11 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer + ) + client_manager = SimpleClientManager() - model = Baseline() - summarize_model_info(model) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -55,7 +62,10 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ ) server = FullExchangeServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_heart_disease/fedper/client.py b/research/flamby/fed_heart_disease/fedper/client.py index 64e81f0ef..bbd9ea45e 100644 --- a/research/flamby/fed_heart_disease/fedper/client.py +++ b/research/flamby/fed_heart_disease/fedper/client.py @@ -19,6 +19,7 @@ from fl4health.clients.moon_client import MoonClient from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -35,14 +36,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) # MOON contrastive loss weight is set to None by default since we are not using it self.client_number = client_number @@ -144,7 +151,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedHeartDiseaseFedPerClient( data_path=Path(args.dataset_dir), @@ -152,7 +159,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_heart_disease/fedprox/client.py b/research/flamby/fed_heart_disease/fedprox/client.py index 288664161..ea8915003 100644 --- a/research/flamby/fed_heart_disease/fedprox/client.py +++ b/research/flamby/fed_heart_disease/fedprox/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fed_prox_client import FedProxClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric from research.flamby.flamby_data_utils import construct_fed_heard_disease_train_val_datasets @@ -31,14 +32,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -113,7 +120,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -123,7 +130,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_heart_disease/fedprox/server.py b/research/flamby/fed_heart_disease/fedprox/server.py index 9eff20859..e1526c21c 100644 --- a/research/flamby/fed_heart_disease/fedprox/server.py +++ b/research/flamby/fed_heart_disease/fedprox/server.py @@ -10,6 +10,7 @@ from flwr.server.client_manager import SimpleClientManager from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule from fl4health.servers.adaptive_constraint_servers.fedprox_server import FedProxServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -26,6 +27,9 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub config["n_server_rounds"], ) + model = Baseline() + summarize_model_info(model) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" federated_checkpointing: bool = config.get("federated_checkpointing", True) @@ -35,10 +39,11 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) + checkpoint_and_state_module = AdaptiveConstraintServerCheckpointAndStateModule( + model=model, model_checkpointers=checkpointer + ) client_manager = SimpleClientManager() - model = Baseline() - summarize_model_info(model) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( @@ -56,7 +61,10 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub ) server = FedProxServer( - client_manager=client_manager, fl_config=config, strategy=strategy, model=model, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_heart_disease/fenda/client.py b/research/flamby/fed_heart_disease/fenda/client.py index 45e4e433b..7834a3bf2 100644 --- a/research/flamby/fed_heart_disease/fenda/client.py +++ b/research/flamby/fed_heart_disease/fenda/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_client import FendaClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -33,14 +34,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -137,7 +144,7 @@ def get_criterion(self, config: Config) -> _Loss: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedHeartDiseaseFendaClient( data_path=Path(args.dataset_dir), @@ -145,7 +152,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client( diff --git a/research/flamby/fed_heart_disease/moon/client.py b/research/flamby/fed_heart_disease/moon/client.py index ae7e4a3d3..16143e358 100644 --- a/research/flamby/fed_heart_disease/moon/client.py +++ b/research/flamby/fed_heart_disease/moon/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.moon_client import MoonClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -33,15 +34,21 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, contrastive_weight: float = 10, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, contrastive_weight=contrastive_weight, ) self.client_number = client_number @@ -134,7 +141,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -144,7 +151,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, contrastive_weight=args.mu, ) diff --git a/research/flamby/fed_heart_disease/moon/server.py b/research/flamby/fed_heart_disease/moon/server.py index 2748458bb..85bd9d2bb 100644 --- a/research/flamby/fed_heart_disease/moon/server.py +++ b/research/flamby/fed_heart_disease/moon/server.py @@ -10,6 +10,8 @@ from flwr.server.strategy import FedAvg from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -27,6 +29,9 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config["n_server_rounds"], ) + model = FedHeartDiseaseMoonModel() + summarize_model_info(model) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" federated_checkpointing: bool = config.get("federated_checkpointing", True) @@ -37,9 +42,11 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer + ) + client_manager = SimpleClientManager() - model = FedHeartDiseaseMoonModel() - summarize_model_info(model) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -56,7 +63,10 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ ) server = FullExchangeServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_heart_disease/perfcl/client.py b/research/flamby/fed_heart_disease/perfcl/client.py index 749eb7093..fa4d6fd4c 100644 --- a/research/flamby/fed_heart_disease/perfcl/client.py +++ b/research/flamby/fed_heart_disease/perfcl/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.perfcl_client import PerFclClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -33,7 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, mu: float = 10.0, gamma: float = 10.0, ) -> None: @@ -42,7 +46,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, global_feature_contrastive_loss_weight=mu, local_feature_contrastive_loss_weight=gamma, ) @@ -155,7 +162,7 @@ def get_criterion(self, config: Config) -> _Loss: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedHeartDiseasePerFclClient( data_path=Path(args.dataset_dir), @@ -163,7 +170,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, mu=args.mu, gamma=args.gamma, ) diff --git a/research/flamby/fed_heart_disease/scaffold/client.py b/research/flamby/fed_heart_disease/scaffold/client.py index 0dc3d5f1e..42bb46ffe 100644 --- a/research/flamby/fed_heart_disease/scaffold/client.py +++ b/research/flamby/fed_heart_disease/scaffold/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.scaffold_client import ScaffoldClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric from research.flamby.flamby_data_utils import construct_fed_heard_disease_train_val_datasets @@ -31,14 +32,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -113,7 +120,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -123,7 +130,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_heart_disease/scaffold/server.py b/research/flamby/fed_heart_disease/scaffold/server.py index 35b5a2e3c..6d9a0b546 100644 --- a/research/flamby/fed_heart_disease/scaffold/server.py +++ b/research/flamby/fed_heart_disease/scaffold/server.py @@ -11,8 +11,6 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.server_module import ScaffoldServerCheckpointAndStateModule from fl4health.client_managers.fixed_without_replacement_manager import FixedSamplingByFractionClientManager -from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking -from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates from fl4health.servers.scaffold_server import ScaffoldServer from fl4health.strategies.scaffold import Scaffold from fl4health.utils.config import load_config @@ -44,10 +42,8 @@ def main( model = Baseline() summarize_model_info(model) - model_size = len(model.state_dict()) checkpoint_and_state_module = ScaffoldServerCheckpointAndStateModule( model=model, - parameter_exchanger=FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)), model_checkpointers=checkpointer, ) diff --git a/research/flamby/fed_isic2019/apfl/client.py b/research/flamby/fed_isic2019/apfl/client.py index 011115d25..224f7cb6e 100644 --- a/research/flamby/fed_isic2019/apfl/client.py +++ b/research/flamby/fed_isic2019/apfl/client.py @@ -18,6 +18,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import ApflModule +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from research.flamby.fed_isic2019.apfl.apfl_model import ApflEfficientNet @@ -34,14 +35,20 @@ def __init__( learning_rate: float, alpha_learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) assert 0 <= client_number < NUM_CLIENTS @@ -126,7 +133,7 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -137,7 +144,7 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: client_number=args.client_number, learning_rate=args.learning_rate, alpha_learning_rate=args.alpha_learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_isic2019/ditto/client.py b/research/flamby/fed_isic2019/ditto/client.py index dcefd58af..a22a57de8 100644 --- a/research/flamby/fed_isic2019/ditto/client.py +++ b/research/flamby/fed_isic2019/ditto/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -32,14 +33,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -128,7 +135,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -138,7 +145,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_isic2019/ditto_deep_mmd/client.py b/research/flamby/fed_isic2019/ditto_deep_mmd/client.py index 7d5dae85d..8ceefd44e 100644 --- a/research/flamby/fed_isic2019/ditto_deep_mmd/client.py +++ b/research/flamby/fed_isic2019/ditto_deep_mmd/client.py @@ -18,6 +18,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.deep_mmd_clients.ditto_deep_mmd_client import DittoDeepMmdClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -40,7 +41,10 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, deep_mmd_loss_weight: float = 10, deep_mmd_loss_depth: int = 1, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: feature_extraction_layers_with_size = OrderedDict( list(FED_ISIC2019_BASELINE_LAYERS.items())[-1 * deep_mmd_loss_depth :] @@ -50,7 +54,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, deep_mmd_loss_weight=deep_mmd_loss_weight, feature_extraction_layers_with_size=feature_extraction_layers_with_size, ) @@ -158,7 +165,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -168,7 +175,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, deep_mmd_loss_depth=args.deep_mmd_loss_depth, deep_mmd_loss_weight=args.mu, ) diff --git a/research/flamby/fed_isic2019/ditto_mkmmd/client.py b/research/flamby/fed_isic2019/ditto_mkmmd/client.py index 0f6c74c98..6add1e5da 100644 --- a/research/flamby/fed_isic2019/ditto_mkmmd/client.py +++ b/research/flamby/fed_isic2019/ditto_mkmmd/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.mkmmd_clients.ditto_mkmmd_client import DittoMkMmdClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -37,18 +38,24 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, mkmmd_loss_weight: float = 10, feature_l2_norm_weight: float = 1, mkmmd_loss_depth: int = 1, beta_global_update_interval: int = 20, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, mkmmd_loss_weight=mkmmd_loss_weight, feature_extraction_layers=FED_ISIC2019_BASELINE_LAYERS[-1 * mkmmd_loss_depth :], feature_l2_norm_weight=feature_l2_norm_weight, @@ -175,7 +182,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -185,7 +192,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, feature_l2_norm_weight=args.l2, mkmmd_loss_depth=args.mkmmd_loss_depth, mkmmd_loss_weight=args.mu, diff --git a/research/flamby/fed_isic2019/fedadam/client.py b/research/flamby/fed_isic2019/fedadam/client.py index 5f2880894..7ff8f181e 100644 --- a/research/flamby/fed_isic2019/fedadam/client.py +++ b/research/flamby/fed_isic2019/fedadam/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from research.flamby.fed_isic2019.fedadam.fedadam_model import FedAdamEfficientNet @@ -32,14 +33,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -114,7 +121,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -124,7 +131,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_isic2019/fedadam/server.py b/research/flamby/fed_isic2019/fedadam/server.py index d72d80411..6ca3d0e30 100644 --- a/research/flamby/fed_isic2019/fedadam/server.py +++ b/research/flamby/fed_isic2019/fedadam/server.py @@ -10,6 +10,8 @@ from flwr.server.strategy import FedAdam from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -28,12 +30,17 @@ def main( config["n_server_rounds"], ) + model = FedAdamEfficientNet() + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer + ) + client_manager = SimpleClientManager() - model = FedAdamEfficientNet() strategy = FedAdam( min_fit_clients=config["n_clients"], @@ -50,7 +57,10 @@ def main( ) server = FullExchangeServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_isic2019/fedavg/client.py b/research/flamby/fed_isic2019/fedavg/client.py index 940ee26da..514fe20fb 100644 --- a/research/flamby/fed_isic2019/fedavg/client.py +++ b/research/flamby/fed_isic2019/fedavg/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from research.flamby.flamby_data_utils import construct_fedisic_train_val_datasets @@ -31,14 +32,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -113,7 +120,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -123,7 +130,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_isic2019/fedavg/server.py b/research/flamby/fed_isic2019/fedavg/server.py index 232237fc8..67ce57138 100644 --- a/research/flamby/fed_isic2019/fedavg/server.py +++ b/research/flamby/fed_isic2019/fedavg/server.py @@ -11,6 +11,8 @@ from flwr.server.strategy import FedAvg from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -26,12 +28,17 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config["n_server_rounds"], ) + model = Baseline() + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer + ) + client_manager = SimpleClientManager() - model = Baseline() # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -48,7 +55,10 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ ) server = FullExchangeServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_isic2019/fedper/client.py b/research/flamby/fed_isic2019/fedper/client.py index 5dd1a561d..af6198451 100644 --- a/research/flamby/fed_isic2019/fedper/client.py +++ b/research/flamby/fed_isic2019/fedper/client.py @@ -19,6 +19,7 @@ from fl4health.clients.moon_client import MoonClient from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -35,14 +36,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) # MOON contrastive loss weight is set to None by default since we are not using it self.client_number = client_number @@ -132,7 +139,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -142,7 +149,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_isic2019/fedprox/client.py b/research/flamby/fed_isic2019/fedprox/client.py index 54b661bbf..4e50bea61 100644 --- a/research/flamby/fed_isic2019/fedprox/client.py +++ b/research/flamby/fed_isic2019/fedprox/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fed_prox_client import FedProxClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from research.flamby.flamby_data_utils import construct_fedisic_train_val_datasets @@ -31,14 +32,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -113,7 +120,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -123,7 +130,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_isic2019/fedprox/server.py b/research/flamby/fed_isic2019/fedprox/server.py index b3a54eaca..ebffacfa9 100644 --- a/research/flamby/fed_isic2019/fedprox/server.py +++ b/research/flamby/fed_isic2019/fedprox/server.py @@ -10,6 +10,7 @@ from flwr.server.client_manager import SimpleClientManager from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule from fl4health.servers.adaptive_constraint_servers.fedprox_server import FedProxServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -26,12 +27,17 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub config["n_server_rounds"], ) + model = Baseline() + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + checkpoint_and_state_module = AdaptiveConstraintServerCheckpointAndStateModule( + model=model, model_checkpointers=checkpointer + ) + client_manager = SimpleClientManager() - model = Baseline() # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( @@ -49,7 +55,10 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub ) server = FedProxServer( - client_manager=client_manager, fl_config=config, strategy=strategy, model=model, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_isic2019/fenda/client.py b/research/flamby/fed_isic2019/fenda/client.py index a7e8133ca..c237b06b6 100644 --- a/research/flamby/fed_isic2019/fenda/client.py +++ b/research/flamby/fed_isic2019/fenda/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_client import FendaClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -33,14 +34,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -137,7 +144,7 @@ def get_criterion(self, config: Config) -> _Loss: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIsic2019FendaClient( data_path=Path(args.dataset_dir), @@ -145,7 +152,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_isic2019/moon/client.py b/research/flamby/fed_isic2019/moon/client.py index 4c52f0229..e12a759fa 100644 --- a/research/flamby/fed_isic2019/moon/client.py +++ b/research/flamby/fed_isic2019/moon/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.moon_client import MoonClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -33,15 +34,21 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, contrastive_weight: float = 10, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, contrastive_weight=contrastive_weight, ) self.client_number = client_number @@ -134,7 +141,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -144,7 +151,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, contrastive_weight=args.mu, ) diff --git a/research/flamby/fed_isic2019/moon/server.py b/research/flamby/fed_isic2019/moon/server.py index 7e0d1bbdb..b501ef17d 100644 --- a/research/flamby/fed_isic2019/moon/server.py +++ b/research/flamby/fed_isic2019/moon/server.py @@ -10,6 +10,8 @@ from flwr.server.strategy import FedAvg from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -27,13 +29,18 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config["n_server_rounds"], ) + model = FedIsic2019MoonModel() + summarize_model_info(model) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer + ) + client_manager = SimpleClientManager() - model = FedIsic2019MoonModel() - summarize_model_info(model) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -50,7 +57,10 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ ) server = FullExchangeServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py b/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py index 3e441dd58..1163cf9bb 100644 --- a/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py +++ b/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.mkmmd_clients.mr_mtl_mkmmd_client import MrMtlMkMmdClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -41,14 +42,20 @@ def __init__( feature_l2_norm_weight: float = 1, mkmmd_loss_depth: int = 1, beta_global_update_interval: int = 20, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, mkmmd_loss_weight=mkmmd_loss_weight, feature_extraction_layers=FED_ISIC2019_BASELINE_LAYERS[-1 * mkmmd_loss_depth :], feature_l2_norm_weight=feature_l2_norm_weight, @@ -171,7 +178,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -181,7 +188,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, feature_l2_norm_weight=args.l2, mkmmd_loss_depth=args.mkmmd_loss_depth, mkmmd_loss_weight=args.mu, diff --git a/research/flamby/fed_isic2019/perfcl/client.py b/research/flamby/fed_isic2019/perfcl/client.py index 84568fd66..92fc99a82 100644 --- a/research/flamby/fed_isic2019/perfcl/client.py +++ b/research/flamby/fed_isic2019/perfcl/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.perfcl_client import PerFclClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from fl4health.utils.random import set_all_random_seeds @@ -33,7 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, mu: float = 10.0, gamma: float = 10.0, ) -> None: @@ -42,7 +46,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, global_feature_contrastive_loss_weight=mu, local_feature_contrastive_loss_weight=gamma, ) @@ -155,7 +162,7 @@ def get_criterion(self, config: Config) -> _Loss: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIsic2019PerFclClient( data_path=Path(args.dataset_dir), @@ -163,7 +170,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, mu=args.mu, gamma=args.gamma, ) diff --git a/research/flamby/fed_isic2019/scaffold/client.py b/research/flamby/fed_isic2019/scaffold/client.py index 66b6d8c79..8b0d4e51d 100644 --- a/research/flamby/fed_isic2019/scaffold/client.py +++ b/research/flamby/fed_isic2019/scaffold/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.scaffold_client import ScaffoldClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BalancedAccuracy, Metric from research.flamby.flamby_data_utils import construct_fedisic_train_val_datasets @@ -31,14 +32,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -113,7 +120,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -123,7 +130,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_isic2019/scaffold/server.py b/research/flamby/fed_isic2019/scaffold/server.py index 6fe00bba5..705ccff9b 100644 --- a/research/flamby/fed_isic2019/scaffold/server.py +++ b/research/flamby/fed_isic2019/scaffold/server.py @@ -11,8 +11,6 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.server_module import ScaffoldServerCheckpointAndStateModule from fl4health.client_managers.fixed_without_replacement_manager import FixedSamplingByFractionClientManager -from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking -from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates from fl4health.servers.scaffold_server import ScaffoldServer from fl4health.strategies.scaffold import Scaffold from fl4health.utils.config import load_config @@ -37,10 +35,8 @@ def main( client_manager = FixedSamplingByFractionClientManager() model = Baseline() - model_size = len(model.state_dict()) checkpoint_and_state_module = ScaffoldServerCheckpointAndStateModule( model=model, - parameter_exchanger=FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)), model_checkpointers=checkpointer, ) diff --git a/research/flamby/fed_ixi/apfl/client.py b/research/flamby/fed_ixi/apfl/client.py index 400b385ba..157b3714f 100644 --- a/research/flamby/fed_ixi/apfl/client.py +++ b/research/flamby/fed_ixi/apfl/client.py @@ -18,6 +18,7 @@ from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.apfl_client import ApflClient from fl4health.model_bases.apfl_base import ApflModule +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric from research.flamby.fed_ixi.apfl.apfl_model import ApflUNet @@ -34,14 +35,20 @@ def __init__( learning_rate: float, alpha_learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) assert 0 <= client_number < NUM_CLIENTS @@ -134,7 +141,7 @@ def get_criterion(self, config: Config) -> _Loss: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIxiApflClient( data_path=Path(args.dataset_dir), @@ -143,7 +150,7 @@ def get_criterion(self, config: Config) -> _Loss: client_number=args.client_number, learning_rate=args.learning_rate, alpha_learning_rate=args.alpha_learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_ixi/ditto/client.py b/research/flamby/fed_ixi/ditto/client.py index 263fb1488..1c4686367 100644 --- a/research/flamby/fed_ixi/ditto/client.py +++ b/research/flamby/fed_ixi/ditto/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric from fl4health.utils.random import set_all_random_seeds @@ -32,14 +33,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -143,7 +150,7 @@ def get_criterion(self, config: Config) -> _Loss: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIxiDittoClient( data_path=Path(args.dataset_dir), @@ -151,7 +158,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_ixi/fedadam/client.py b/research/flamby/fed_ixi/fedadam/client.py index d6ea8a8d1..130b7f269 100644 --- a/research/flamby/fed_ixi/fedadam/client.py +++ b/research/flamby/fed_ixi/fedadam/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric from research.flamby.fed_ixi.fedadam.fedadam_model import FedAdamUNet @@ -32,14 +33,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -114,7 +121,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -124,7 +131,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_ixi/fedadam/server.py b/research/flamby/fed_ixi/fedadam/server.py index ca0e98c64..8f7cf5105 100644 --- a/research/flamby/fed_ixi/fedadam/server.py +++ b/research/flamby/fed_ixi/fedadam/server.py @@ -10,6 +10,8 @@ from flwr.server.strategy import FedAdam from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -28,6 +30,9 @@ def main( config["n_server_rounds"], ) + model = FedAdamUNet() + summarize_model_info(model) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" federated_checkpointing: bool = config.get("federated_checkpointing", True) @@ -37,10 +42,11 @@ def main( if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer + ) client_manager = SimpleClientManager() - model = FedAdamUNet() - summarize_model_info(model) strategy = FedAdam( min_fit_clients=config["n_clients"], @@ -57,7 +63,10 @@ def main( ) server = FullExchangeServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_ixi/fedavg/client.py b/research/flamby/fed_ixi/fedavg/client.py index 5a56508f7..c59f3bc7b 100644 --- a/research/flamby/fed_ixi/fedavg/client.py +++ b/research/flamby/fed_ixi/fedavg/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric from research.flamby.flamby_data_utils import construct_fed_ixi_train_val_datasets @@ -31,14 +32,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -116,7 +123,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -126,7 +133,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_ixi/fedavg/server.py b/research/flamby/fed_ixi/fedavg/server.py index e9d7c0fbb..e8b6126ba 100644 --- a/research/flamby/fed_ixi/fedavg/server.py +++ b/research/flamby/fed_ixi/fedavg/server.py @@ -11,6 +11,8 @@ from flwr.server.strategy import FedAvg from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -26,6 +28,11 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config["n_server_rounds"], ) + # NOTE: We set the out_channels_first_layer to 12 rather than the default of 8. This roughly doubles the size of + # the baseline model to be used (1106520 DOF). This is to allow for a fair parameter comparison with FENDA and APFL + model = Baseline(out_channels_first_layer=12) + summarize_model_info(model) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" federated_checkpointing: bool = config.get("federated_checkpointing", True) @@ -35,14 +42,12 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer + ) client_manager = SimpleClientManager() - # NOTE: We set the out_channels_first_layer to 12 rather than the default of 8. This roughly doubles the size of - # the baseline model to be used (1106520 DOF). This is to allow for a fair parameter comparison with FENDA and APFL - model = Baseline(out_channels_first_layer=12) - summarize_model_info(model) - # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( min_fit_clients=config["n_clients"], @@ -58,7 +63,10 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ ) server = FullExchangeServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_ixi/fedper/client.py b/research/flamby/fed_ixi/fedper/client.py index 4a953ccac..b7b5fb035 100644 --- a/research/flamby/fed_ixi/fedper/client.py +++ b/research/flamby/fed_ixi/fedper/client.py @@ -19,6 +19,7 @@ from fl4health.clients.moon_client import MoonClient from fl4health.parameter_exchange.layer_exchanger import FixedLayerExchanger from fl4health.parameter_exchange.parameter_exchanger_base import ParameterExchanger +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric from fl4health.utils.random import set_all_random_seeds @@ -35,14 +36,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) # MOON contrastive loss weight is set to None by default since we are not using it self.client_number = client_number @@ -144,7 +151,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIxiFedPerClient( data_path=Path(args.dataset_dir), @@ -152,7 +159,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_ixi/fedprox/client.py b/research/flamby/fed_ixi/fedprox/client.py index 58d3e18f0..bc5fbe1d0 100644 --- a/research/flamby/fed_ixi/fedprox/client.py +++ b/research/flamby/fed_ixi/fedprox/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fed_prox_client import FedProxClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric from research.flamby.flamby_data_utils import construct_fed_ixi_train_val_datasets @@ -31,14 +32,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -116,7 +123,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -126,7 +133,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_ixi/fedprox/server.py b/research/flamby/fed_ixi/fedprox/server.py index 016a665ce..17c988d0e 100644 --- a/research/flamby/fed_ixi/fedprox/server.py +++ b/research/flamby/fed_ixi/fedprox/server.py @@ -10,6 +10,7 @@ from flwr.server.client_manager import SimpleClientManager from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule from fl4health.servers.adaptive_constraint_servers.fedprox_server import FedProxServer from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint from fl4health.utils.config import load_config @@ -26,6 +27,12 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub config["n_server_rounds"], ) + # NOTE: We set the out_channels_first_layer to 12 rather than the default of 8. This roughly doubles the size of + # the baseline model to be used (1106520 DOF). This is to allow for a fair parameter comparison with FENDA + # and APFL + model = Baseline(out_channels_first_layer=12) + summarize_model_info(model) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" federated_checkpointing: bool = config.get("federated_checkpointing", True) @@ -36,13 +43,11 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - client_manager = SimpleClientManager() + checkpoint_and_state_module = AdaptiveConstraintServerCheckpointAndStateModule( + model=model, model_checkpointers=checkpointer + ) - # NOTE: We set the out_channels_first_layer to 12 rather than the default of 8. This roughly doubles the size of - # the baseline model to be used (1106520 DOF). This is to allow for a fair parameter comparison with FENDA - # and APFL - model = Baseline(out_channels_first_layer=12) - summarize_model_info(model) + client_manager = SimpleClientManager() # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvgWithAdaptiveConstraint( @@ -60,7 +65,10 @@ def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub ) server = FedProxServer( - client_manager=client_manager, fl_config=config, strategy=strategy, model=model, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_ixi/fenda/client.py b/research/flamby/fed_ixi/fenda/client.py index 446bb5803..02897334c 100644 --- a/research/flamby/fed_ixi/fenda/client.py +++ b/research/flamby/fed_ixi/fenda/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.fenda_client import FendaClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric from fl4health.utils.random import set_all_random_seeds @@ -33,14 +34,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -137,7 +144,7 @@ def get_criterion(self, config: Config) -> _Loss: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIxiFendaClient( data_path=Path(args.dataset_dir), @@ -145,7 +152,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_ixi/moon/client.py b/research/flamby/fed_ixi/moon/client.py index d758633cd..e4bfc68df 100644 --- a/research/flamby/fed_ixi/moon/client.py +++ b/research/flamby/fed_ixi/moon/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.moon_client import MoonClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric from fl4health.utils.random import set_all_random_seeds @@ -33,15 +34,21 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, contrastive_weight: float = 10, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, contrastive_weight=contrastive_weight, ) self.client_number = client_number @@ -134,7 +141,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -144,7 +151,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, contrastive_weight=args.mu, ) diff --git a/research/flamby/fed_ixi/moon/server.py b/research/flamby/fed_ixi/moon/server.py index e96830aaa..e2062b55b 100644 --- a/research/flamby/fed_ixi/moon/server.py +++ b/research/flamby/fed_ixi/moon/server.py @@ -10,6 +10,8 @@ from flwr.server.strategy import FedAvg from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule +from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.utils.config import load_config from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn from fl4health.utils.parameter_extraction import get_all_model_parameters @@ -27,6 +29,9 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config["n_server_rounds"], ) + model = FedIxiMoonModel() + summarize_model_info(model) + checkpoint_dir = os.path.join(checkpoint_stub, run_name) checkpoint_name = "server_best_model.pkl" federated_checkpointing: bool = config.get("federated_checkpointing", True) @@ -37,9 +42,11 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer + ) + client_manager = SimpleClientManager() - model = FedIxiMoonModel() - summarize_model_info(model) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -56,7 +63,10 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ ) server = FullExchangeServer( - client_manager=client_manager, fl_config=config, model=model, strategy=strategy, checkpointer=checkpointer + client_manager=client_manager, + fl_config=config, + strategy=strategy, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( diff --git a/research/flamby/fed_ixi/perfcl/client.py b/research/flamby/fed_ixi/perfcl/client.py index 15ee18b40..7e66bf3e3 100644 --- a/research/flamby/fed_ixi/perfcl/client.py +++ b/research/flamby/fed_ixi/perfcl/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.perfcl_client import PerFclClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric from fl4health.utils.random import set_all_random_seeds @@ -33,7 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, mu: float = 10.0, gamma: float = 10.0, ) -> None: @@ -42,7 +46,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, global_feature_contrastive_loss_weight=mu, local_feature_contrastive_loss_weight=gamma, ) @@ -155,7 +162,7 @@ def get_criterion(self, config: Config) -> _Loss: if federated_checkpointing else LatestTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) - checkpointer = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) + checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer) client = FedIxiPerFclClient( data_path=Path(args.dataset_dir), @@ -163,7 +170,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, mu=args.mu, gamma=args.gamma, ) diff --git a/research/flamby/fed_ixi/scaffold/client.py b/research/flamby/fed_ixi/scaffold/client.py index d8c74ea62..f1e26d877 100644 --- a/research/flamby/fed_ixi/scaffold/client.py +++ b/research/flamby/fed_ixi/scaffold/client.py @@ -17,6 +17,7 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.scaffold_client import ScaffoldClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import BinarySoftDiceCoefficient, Metric from research.flamby.flamby_data_utils import construct_fed_ixi_train_val_datasets @@ -31,14 +32,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: Optional[str] = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -116,7 +123,7 @@ def get_criterion(self, config: Config) -> _Loss: checkpoint_dir = os.path.join(args.artifact_dir, args.run_name) checkpoint_name = f"client_{args.client_number}_best_model.pkl" - checkpointer = ClientCheckpointAndStateModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( post_aggregation=BestLossTorchModuleCheckpointer(checkpoint_dir, checkpoint_name) ) @@ -126,7 +133,7 @@ def get_criterion(self, config: Config) -> _Loss: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/flamby/fed_ixi/scaffold/server.py b/research/flamby/fed_ixi/scaffold/server.py index d5e1edf12..ea4adf6d7 100644 --- a/research/flamby/fed_ixi/scaffold/server.py +++ b/research/flamby/fed_ixi/scaffold/server.py @@ -11,8 +11,6 @@ from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer from fl4health.checkpointing.server_module import ScaffoldServerCheckpointAndStateModule from fl4health.client_managers.fixed_without_replacement_manager import FixedSamplingByFractionClientManager -from fl4health.parameter_exchange.packing_exchanger import FullParameterExchangerWithPacking -from fl4health.parameter_exchange.parameter_packer import ParameterPackerWithControlVariates from fl4health.servers.scaffold_server import ScaffoldServer from fl4health.strategies.scaffold import Scaffold from fl4health.utils.config import load_config @@ -48,10 +46,8 @@ def main( model = Baseline(out_channels_first_layer=12) summarize_model_info(model) - model_size = len(model.state_dict()) checkpoint_and_state_module = ScaffoldServerCheckpointAndStateModule( model=model, - parameter_exchanger=FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)), model_checkpointers=checkpointer, ) diff --git a/research/flamby/flamby_servers/full_exchange_server.py b/research/flamby/flamby_servers/full_exchange_server.py index 9e8dfa961..7e5097498 100644 --- a/research/flamby/flamby_servers/full_exchange_server.py +++ b/research/flamby/flamby_servers/full_exchange_server.py @@ -1,6 +1,5 @@ from typing import Optional -import torch.nn as nn from flwr.common.typing import Config from flwr.server.client_manager import ClientManager from flwr.server.strategy import Strategy diff --git a/research/picai/fedavg/client.py b/research/picai/fedavg/client.py index aa6c12628..bcd227b11 100644 --- a/research/picai/fedavg/client.py +++ b/research/picai/fedavg/client.py @@ -162,6 +162,7 @@ def get_optimizer(self, config: Config) -> Optimizer: ) ] + checkpoint_and_state_module: ClientCheckpointAndStateModule | None if args.artifact_dir is not None: checkpoint_and_state_module = ClientCheckpointAndStateModule( state_checkpointer=PerRoundStateCheckpointer(Path(args.artifact_dir)) diff --git a/research/picai/fl_nnunet/start_server.py b/research/picai/fl_nnunet/start_server.py index d7a363f6d..730d56f7c 100644 --- a/research/picai/fl_nnunet/start_server.py +++ b/research/picai/fl_nnunet/start_server.py @@ -6,6 +6,9 @@ from pathlib import Path from typing import Optional +from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer +from fl4health.checkpointing.server_module import NnUnetServerCheckpointAndStateModule + with warnings.catch_warnings(): # Silence deprecation warnings from sentry sdk due to flwr and wandb # https://github.com/adap/flower/issues/4086 @@ -92,17 +95,23 @@ def main( initial_parameters=params, ) + checkpoint_and_state_model: NnUnetServerCheckpointAndStateModule | None + if intermediate_server_state_dir is None: + checkpoint_and_state_model = None + else: + checkpoint_and_state_model = NnUnetServerCheckpointAndStateModule( + model=None, + parameter_exchanger=FullParameterExchanger(), + state_checkpointer=PerRoundStateCheckpointer(Path(intermediate_server_state_dir)), + ) + server = NnunetServer( client_manager=SimpleClientManager(), fl_config=config, # The fit_config_fn contains all of the necessary information for param initialization, so we reuse it here on_init_parameters_config_fn=fit_config_fn, - parameter_exchanger=FullParameterExchanger(), - model=None, strategy=strategy, - intermediate_server_state_dir=( - Path(intermediate_server_state_dir) if intermediate_server_state_dir is not None else None - ), + checkpoint_and_state_module=checkpoint_and_state_model, server_name=server_name, ) diff --git a/research/picai/single_node_trainer.py b/research/picai/single_node_trainer.py index a75b81a58..085d2cd5d 100644 --- a/research/picai/single_node_trainer.py +++ b/research/picai/single_node_trainer.py @@ -33,8 +33,9 @@ def __init__( if not os.path.exists(checkpoint_dir): os.mkdir(checkpoint_dir) - per_round_checkpoint_name = "ckpt.pkl" - self.per_epoch_checkpointer = PerRoundStateCheckpointer(Path(checkpoint_dir), Path(per_round_checkpoint_name)) + self.state_checkpoint_name = "ckpt.pkl" + self.state_checkpoint_path = os.path.join(checkpoint_dir, self.state_checkpoint_name) + self.per_epoch_checkpointer = PerRoundStateCheckpointer(Path(checkpoint_dir)) best_metric_checkpoint_name = "best_ckpt.pkl" self.checkpointer = BestLossTorchModuleCheckpointer(checkpoint_dir, best_metric_checkpoint_name) @@ -46,10 +47,12 @@ def __init__( self.device = device self.epoch: int - if not self.per_epoch_checkpointer.checkpoint_exists(): - self.per_epoch_checkpointer.save_checkpoint({"model": self.model, "optimizer": self.optimizer, "epoch": 0}) + if not self.per_epoch_checkpointer.checkpoint_exists(self.state_checkpoint_path): + self.per_epoch_checkpointer.save_checkpoint( + self.state_checkpoint_name, {"model": self.model, "optimizer": self.optimizer, "epoch": 0} + ) - ckpt = self.per_epoch_checkpointer.load_checkpoint() + ckpt = self.per_epoch_checkpointer.load_checkpoint(self.state_checkpoint_path) self.model, self.optimizer, self.epoch = ckpt["model"], ckpt["optimizer"], ckpt["epoch"] def _maybe_checkpoint(self, loss: float, metrics: Dict[str, Scalar]) -> None: @@ -103,7 +106,7 @@ def train_by_epochs(self, epochs: int, train_metric_mngr: MetricManager, val_met # Save checkpoint in case run gets pre-empted self.per_epoch_checkpointer.save_checkpoint( - {"model": self.model, "optimizer": self.optimizer, "epoch": epoch + 1} + self.state_checkpoint_name, {"model": self.model, "optimizer": self.optimizer, "epoch": epoch + 1} ) def validate(self, val_metric_mngr: MetricManager) -> None: diff --git a/tests/checkpointing/test_per_round_checkpointer.py b/tests/checkpointing/test_per_round_checkpointer.py index 77ab6d733..556b9be60 100644 --- a/tests/checkpointing/test_per_round_checkpointer.py +++ b/tests/checkpointing/test_per_round_checkpointer.py @@ -1,3 +1,4 @@ +import os import tempfile from pathlib import Path @@ -12,15 +13,20 @@ def test_per_round_checkpointer() -> None: model: torch.nn.Module = LinearModel() optimizer: Optimizer = torch.optim.SGD(model.parameters(), lr=0.01) with tempfile.TemporaryDirectory() as results_dir: - checkpointer = PerRoundStateCheckpointer(checkpoint_dir=Path(results_dir), checkpoint_name=Path("ckpt.pt")) + checkpoint_name = "ckpt.pt" + checkpoint_path = os.path.join(results_dir, checkpoint_name) + checkpointer = PerRoundStateCheckpointer(checkpoint_dir=Path(results_dir)) - assert not checkpointer.checkpoint_exists() + assert not checkpointer.checkpoint_exists(checkpoint_path) - checkpointer.save_checkpoint({"model": model, "optimizer": optimizer, "current_round": 0}) + checkpointer.save_checkpoint( + checkpoint_name=checkpoint_name, + checkpoint_dict={"model": model, "optimizer": optimizer, "current_round": 0}, + ) - assert checkpointer.checkpoint_exists() + assert checkpointer.checkpoint_exists(checkpoint_path) - ckpt = checkpointer.load_checkpoint() + ckpt = checkpointer.load_checkpoint(checkpoint_path) assert "model" in ckpt and isinstance(ckpt["model"], torch.nn.Module) assert "optimizer" in ckpt and isinstance(ckpt["optimizer"], torch.optim.Optimizer) diff --git a/tests/servers/test_base_server.py b/tests/servers/test_base_server.py index 287740a1c..6d5905df6 100644 --- a/tests/servers/test_base_server.py +++ b/tests/servers/test_base_server.py @@ -34,7 +34,7 @@ def test_hydration_no_model_with_checkpointer(tmp_path: Path) -> None: checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") - state_checkpointer = PerRoundStateCheckpointer() + state_checkpointer = PerRoundStateCheckpointer(checkpoint_dir=checkpoint_dir) checkpoint_and_state_module = BaseServerCheckpointAndStateModule( model=None, parameter_exchanger=None, diff --git a/tests/smoke_tests/load_from_checkpoint_example/client.py b/tests/smoke_tests/load_from_checkpoint_example/client.py index f8209c48f..a21a48de0 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/client.py +++ b/tests/smoke_tests/load_from_checkpoint_example/client.py @@ -103,6 +103,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict # Set the random seed for reproducibility set_all_random_seeds(args.seed) + checkpoint_and_state_module: ClientCheckpointAndStateModule | None if args.intermediate_client_state_dir is not None: checkpoint_and_state_module = ClientCheckpointAndStateModule( state_checkpointer=PerRoundStateCheckpointer(Path(args.intermediate_client_state_dir)) diff --git a/tests/smoke_tests/load_from_checkpoint_example/server.py b/tests/smoke_tests/load_from_checkpoint_example/server.py index 0d57947b4..7b8787751 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/server.py +++ b/tests/smoke_tests/load_from_checkpoint_example/server.py @@ -80,10 +80,7 @@ def main(config: Dict[str, Any], intermediate_server_state_dir: str, server_name server = FlServer( client_manager=SimpleClientManager(), fl_config=config, - model=model, - parameter_exchanger=parameter_exchanger, strategy=strategy, - checkpointer=checkpointers, reporters=[JsonReporter()], checkpoint_and_state_module=checkpoint_and_state_module, server_name=server_name, From f93c03b1dad6c0eae730cd4dac3aa9e7c5ebed1c Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 27 Nov 2024 10:24:35 -0500 Subject: [PATCH 05/13] Fixing issues with the unit tests --- tests/servers/test_base_server.py | 42 +++++++++++-------------------- 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/tests/servers/test_base_server.py b/tests/servers/test_base_server.py index 6d5905df6..8264c55bb 100644 --- a/tests/servers/test_base_server.py +++ b/tests/servers/test_base_server.py @@ -22,6 +22,7 @@ from fl4health.strategies.basic_fedavg import BasicFedAvg from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, MetricPrefix +from fl4health.utils.parameter_extraction import get_all_model_parameters from tests.test_utils.assert_metrics_dict import assert_metrics_dict from tests.test_utils.custom_client_proxy import CustomClientProxy from tests.test_utils.models_for_test import LinearTransform @@ -35,23 +36,16 @@ def test_hydration_no_model_with_checkpointer(tmp_path: Path) -> None: checkpoint_dir.mkdir() checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") state_checkpointer = PerRoundStateCheckpointer(checkpoint_dir=checkpoint_dir) - checkpoint_and_state_module = BaseServerCheckpointAndStateModule( - model=None, - parameter_exchanger=None, - model_checkpointers=checkpointer, - state_checkpointer=state_checkpointer, - ) - # Checkpointer is defined but there is no server-side model defined to produce a model from the server state. # An assertion error should be throw stating this - fl_server_no_hydration = FlServer( - client_manager=PoissonSamplingClientManager(), - fl_config={}, - checkpoint_and_state_module=checkpoint_and_state_module, - ) with pytest.raises(AssertionError) as assertion_error: - fl_server_no_hydration._maybe_checkpoint(1.0, {}, server_round=1) - assert "Model hydration has been called but no server_model is defined to hydrate" in str(assertion_error.value) + BaseServerCheckpointAndStateModule( + model=None, + parameter_exchanger=None, + model_checkpointers=checkpointer, + state_checkpointer=state_checkpointer, + ) + assert "Checkpointer(s) is (are) defined but no model is defined to hydrate" in str(assertion_error.value) def test_hydration_no_exchanger_with_checkpointer(tmp_path: Path) -> None: @@ -59,20 +53,11 @@ def test_hydration_no_exchanger_with_checkpointer(tmp_path: Path) -> None: checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") - checkpoint_and_state_module = BaseServerCheckpointAndStateModule( - model=model, parameter_exchanger=None, model_checkpointers=checkpointer - ) - # Checkpointer is defined but there is no parameter exchanger defined to produce a model from the server state. # An assertion error should be throw stating this - fl_server_no_hydration = FlServer( - client_manager=PoissonSamplingClientManager(), - fl_config={}, - checkpoint_and_state_module=checkpoint_and_state_module, - ) with pytest.raises(AssertionError) as assertion_error: - fl_server_no_hydration._maybe_checkpoint(1.0, {}, server_round=1) - assert "Model hydration has been called but no parameter_exchanger is defined to hydrate." in str( + BaseServerCheckpointAndStateModule(model=model, parameter_exchanger=None, model_checkpointers=checkpointer) + assert "Checkpointer(s) is (are) defined but no parameter_exchanger is defined to hydrate." in str( assertion_error.value ) @@ -84,7 +69,7 @@ def test_no_checkpointer_maybe_checkpoint(caplog: pytest.LogCaptureFixture) -> N # Neither checkpointing nor hydration is defined, we'll have no server-side checkpointing for the FL run. fl_server_no_checkpointer._maybe_checkpoint(1.0, {}, server_round=1) - assert "No checkpointer present. Models will not be checkpointed on server-side." in caplog.text + assert "No model checkpointers specified. Skipping any checkpointing." in caplog.text def test_hydration_and_checkpointer(tmp_path: Path) -> None: @@ -93,7 +78,7 @@ def test_hydration_and_checkpointer(tmp_path: Path) -> None: checkpoint_dir.mkdir() checkpointer = BestLossTorchModuleCheckpointer(str(checkpoint_dir), "best_model.pkl") checkpoint_and_state_module = BaseServerCheckpointAndStateModule( - model=model, parameter_exchanger=None, model_checkpointers=checkpointer + model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer ) # Server-side hydration to convert server state to model and checkpointing behavior are both defined, a model @@ -103,6 +88,9 @@ def test_hydration_and_checkpointer(tmp_path: Path) -> None: fl_config={}, checkpoint_and_state_module=checkpoint_and_state_module, ) + # Need to mock set the parameters as no FL or exchange is happening. + fl_server_both.parameters = get_all_model_parameters(model) + fl_server_both._maybe_checkpoint(1.0, {}, server_round=5) loaded_model = checkpointer.load_checkpoint() assert isinstance(loaded_model, LinearTransform) From fd6b34ed51c774026008e142d04d656a4a017167 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:14:47 -0500 Subject: [PATCH 06/13] Fixing some small smoke test issues. --- examples/apfl_example/server.py | 4 ++-- .../adaptive_constraint_servers/ditto_server.py | 13 +++++++++---- .../adaptive_constraint_servers/fedprox_server.py | 9 +++++---- .../adaptive_constraint_servers/mrmtl_server.py | 5 +++++ .../servers/client_level_dp_fed_avg_server.py | 5 +++++ fl4health/servers/fedpm_server.py | 5 +++++ fl4health/servers/instance_level_dp_server.py | 5 +++++ fl4health/servers/nnunet_server.py | 5 +++++ fl4health/servers/scaffold_server.py | 14 ++++++++++---- 9 files changed, 51 insertions(+), 14 deletions(-) diff --git a/examples/apfl_example/server.py b/examples/apfl_example/server.py index 8a58b9545..573d9bc8a 100644 --- a/examples/apfl_example/server.py +++ b/examples/apfl_example/server.py @@ -43,7 +43,7 @@ def main(config: Dict[str, Any]) -> None: local_steps=config.get("local_steps"), ) - initial_model = ApflModule(MnistNetWithBnAndFrozen()) + model = ApflModule(MnistNetWithBnAndFrozen()) # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( @@ -56,7 +56,7 @@ def main(config: Dict[str, Any]) -> None: on_evaluate_config_fn=fit_config_fn, fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn, - initial_parameters=get_all_model_parameters(initial_model), + initial_parameters=get_all_model_parameters(model), ) client_manager = SimpleClientManager() diff --git a/fl4health/servers/adaptive_constraint_servers/ditto_server.py b/fl4health/servers/adaptive_constraint_servers/ditto_server.py index c0a860ed1..31d61f23b 100644 --- a/fl4health/servers/adaptive_constraint_servers/ditto_server.py +++ b/fl4health/servers/adaptive_constraint_servers/ditto_server.py @@ -37,10 +37,10 @@ def __init__( strategy must be a derivative of the FedAvgWithAdaptiveConstraint class. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should send data to before and after each round. Defaults to None. - checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used - to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state - (including models) such that if FL training is interrupted, the process may be restarted. If no + checkpoint_and_state_module (AdaptiveConstraintServerCheckpointAndStateModule | None, optional): This + module is used to handle both model checkpointing and state checkpointing. The former is aimed at + saving model artifacts to be used or evaluated after training. The later is used to preserve training + state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. NOTE: For Ditto, the model shared with the server is the GLOBAL MODEL, which isn't the target of FL training for this algorithm. However, one may still want to save this model for other purposes. @@ -57,6 +57,11 @@ def __init__( assert isinstance( strategy, FedAvgWithAdaptiveConstraint ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" + if checkpoint_and_state_module is not None: + assert isinstance( + checkpoint_and_state_module, + AdaptiveConstraintServerCheckpointAndStateModule, + ), "checkpoint_and_state_module must have type AdaptiveConstraintServerCheckpointAndStateModule" super().__init__( client_manager=client_manager, fl_config=fl_config, diff --git a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py index a17e96b3a..8b6a5a88c 100644 --- a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py +++ b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py @@ -56,10 +56,11 @@ def __init__( assert isinstance( strategy, FedAvgWithAdaptiveConstraint ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" - assert isinstance( - checkpoint_and_state_module, - AdaptiveConstraintServerCheckpointAndStateModule, - ), "checkpoint_and_state_module must have type AdaptiveConstraintServerCheckpointAndStateModule" + if checkpoint_and_state_module is not None: + assert isinstance( + checkpoint_and_state_module, + AdaptiveConstraintServerCheckpointAndStateModule, + ), "checkpoint_and_state_module must have type AdaptiveConstraintServerCheckpointAndStateModule" super().__init__( client_manager=client_manager, fl_config=fl_config, diff --git a/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py b/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py index ca0c23023..3316df94d 100644 --- a/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py +++ b/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py @@ -58,6 +58,11 @@ def __init__( assert isinstance( strategy, FedAvgWithAdaptiveConstraint ), "Strategy must be of base type FedAvgWithAdaptiveConstraint" + if checkpoint_and_state_module is not None: + assert isinstance( + checkpoint_and_state_module, + AdaptiveConstraintServerCheckpointAndStateModule, + ), "checkpoint_and_state_module must have type AdaptiveConstraintServerCheckpointAndStateModule" super().__init__( client_manager=client_manager, fl_config=fl_config, diff --git a/fl4health/servers/client_level_dp_fed_avg_server.py b/fl4health/servers/client_level_dp_fed_avg_server.py index e6cda5536..635ce23f5 100644 --- a/fl4health/servers/client_level_dp_fed_avg_server.py +++ b/fl4health/servers/client_level_dp_fed_avg_server.py @@ -69,6 +69,11 @@ def __init__( evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. """ + if checkpoint_and_state_module is not None: + assert isinstance( + checkpoint_and_state_module, + ClippingBitServerCheckpointAndStateModule, + ), "checkpoint_and_state_module must have type ClippingBitServerCheckpointAndStateModule" super().__init__( client_manager=client_manager, fl_config=fl_config, diff --git a/fl4health/servers/fedpm_server.py b/fl4health/servers/fedpm_server.py index dae383a2e..2abaa6ad3 100644 --- a/fl4health/servers/fedpm_server.py +++ b/fl4health/servers/fedpm_server.py @@ -54,6 +54,11 @@ def __init__( reset_frequency (int, optional): Determines the frequency with which the beta priors are reset. Defaults to 1. """ + if checkpoint_and_state_module is not None: + assert isinstance( + checkpoint_and_state_module, + LayerNamesServerCheckpointAndStateModule, + ), "checkpoint_and_state_module must have type LayerNamesServerCheckpointAndStateModule" super().__init__( client_manager=client_manager, fl_config=fl_config, diff --git a/fl4health/servers/instance_level_dp_server.py b/fl4health/servers/instance_level_dp_server.py index 16cca8db5..d75755e13 100644 --- a/fl4health/servers/instance_level_dp_server.py +++ b/fl4health/servers/instance_level_dp_server.py @@ -78,6 +78,11 @@ def __init__( evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. """ + if checkpoint_and_state_module is not None: + assert isinstance( + checkpoint_and_state_module, + OpacusServerCheckpointAndStateModule, + ), "checkpoint_and_state_module must have type OpacusServerCheckpointAndStateModule" super().__init__( client_manager=client_manager, fl_config=fl_config, diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index c8f9019b6..e4d5a045a 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -104,6 +104,11 @@ def __init__( Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. Must match the nnunet_trainer_class passed to the NnunetClient. """ + if checkpoint_and_state_module is not None: + assert isinstance( + checkpoint_and_state_module, + NnUnetServerCheckpointAndStateModule, + ), "checkpoint_and_state_module must have type NnUnetServerCheckpointAndStateModule" super().__init__( client_manager=client_manager, fl_config=fl_config, diff --git a/fl4health/servers/scaffold_server.py b/fl4health/servers/scaffold_server.py index 2d3bf9416..5ff50c7b6 100644 --- a/fl4health/servers/scaffold_server.py +++ b/fl4health/servers/scaffold_server.py @@ -67,7 +67,13 @@ def __init__( a "warm" estimate of the SCAFFOLD control variates. If false, variates are initialized to 0. Defaults to False. """ - super().__init__( + if checkpoint_and_state_module is not None: + assert isinstance( + checkpoint_and_state_module, + ScaffoldServerCheckpointAndStateModule, + ), "checkpoint_and_state_module must have type ScaffoldServerCheckpointAndStateModule" + FlServer.__init__( + self, client_manager=client_manager, fl_config=fl_config, strategy=strategy, @@ -258,13 +264,13 @@ def __init__( self, client_manager=client_manager, fl_config=fl_config, - strategy=strategy, noise_multiplier=noise_multiplier, + num_server_rounds=num_server_rounds, + batch_size=batch_size, + strategy=strategy, local_epochs=local_epochs, local_steps=local_steps, - batch_size=batch_size, delta=delta, - num_server_rounds=num_server_rounds, on_init_parameters_config_fn=on_init_parameters_config_fn, server_name=server_name, accept_failures=accept_failures, From 2a4c2a21a489c9cd39e632776e9b245df43a6d48 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:44:44 -0500 Subject: [PATCH 07/13] Fixing small issue introduced in merge --- fl4health/servers/nnunet_server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index 3e0c5ae5d..299e395db 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -221,7 +221,10 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: self.num_input_channels = narrow_dict_type(properties, "num_input_channels", int) self.enable_deep_supervision = narrow_dict_type(properties, "enable_deep_supervision", bool) - if self.per_round_checkpointer is None or not self.per_round_checkpointer.checkpoint_exists(): + if ( + self.checkpoint_and_state_module.state_checkpointer is None + or not self.checkpoint_and_state_module.state_checkpointer.checkpoint_exists(self.state_checkpoint_name) + ): # If we're starting training from scratch, set the nnunet_config property and initialize the server model self.nnunet_config = NnunetConfig(self.fl_config["nnunet_config"]) self.initialize_server_model() From 9732abd3defd23bc6e09be0621b94c2c4910ab9d Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:26:31 -0500 Subject: [PATCH 08/13] Renaming the type aliases to be better type representations --- fl4health/checkpointing/client_module.py | 10 +++--- fl4health/checkpointing/server_module.py | 42 ++++++++++++------------ 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/fl4health/checkpointing/client_module.py b/fl4health/checkpointing/client_module.py index 76eda7436..0387240e5 100644 --- a/fl4health/checkpointing/client_module.py +++ b/fl4health/checkpointing/client_module.py @@ -8,7 +8,7 @@ from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer, TorchModuleCheckpointer -CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None +ModelCheckpointers = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None class CheckpointMode(Enum): @@ -19,8 +19,8 @@ class CheckpointMode(Enum): class ClientCheckpointAndStateModule: def __init__( self, - pre_aggregation: CheckpointModuleInput = None, - post_aggregation: CheckpointModuleInput = None, + pre_aggregation: ModelCheckpointers = None, + post_aggregation: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -37,10 +37,10 @@ def __init__( That's because the target model for these methods is never globally aggregated. That is, they remain local Args: - pre_aggregation (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + pre_aggregation (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their validation metrics/losses **BEFORE** server-side aggregation. Defaults to None. - post_aggregation (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence + post_aggregation (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their validation metrics/losses **AFTER** server-side aggregation. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer is used to diff --git a/fl4health/checkpointing/server_module.py b/fl4health/checkpointing/server_module.py index fa91ff7b5..44aee0b96 100644 --- a/fl4health/checkpointing/server_module.py +++ b/fl4health/checkpointing/server_module.py @@ -19,7 +19,7 @@ SparseCooParameterPacker, ) -CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None +ModelCheckpointers = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None class BaseServerCheckpointAndStateModule: @@ -27,7 +27,7 @@ def __init__( self, model: nn.Module | None = None, parameter_exchanger: ExchangerType | None = None, - model_checkpointers: CheckpointModuleInput = None, + model_checkpointers: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -48,7 +48,7 @@ def __init__( server parameters into the right components of the provided model architecture. Note that this exchanger and the model must match the one used for training and exchange with the servers to ensure parameters go to the right places. Defaults to None. - model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this @@ -206,7 +206,7 @@ def __init__( self, model: nn.Module | None = None, parameter_exchanger: FullParameterExchangerWithPacking | None = None, - model_checkpointers: CheckpointModuleInput = None, + model_checkpointers: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -227,7 +227,7 @@ def __init__( should handle any necessary unpacking of the parameters. Note that this exchanger and the model must match the one used for training and exchange with the servers to ensure parameters go to the right places. Defaults to None. - model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this @@ -270,7 +270,7 @@ class ScaffoldServerCheckpointAndStateModule(PackingServerCheckpointAndAndStateM def __init__( self, model: nn.Module | None = None, - model_checkpointers: CheckpointModuleInput = None, + model_checkpointers: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -287,7 +287,7 @@ def __init__( to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. Recall that servers only have parameters rather than torch models. So we need to know where to route these parameters to allow for real models to be saved. Defaults to None. - model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this @@ -305,7 +305,7 @@ class AdaptiveConstraintServerCheckpointAndStateModule(PackingServerCheckpointAn def __init__( self, model: nn.Module | None = None, - model_checkpointers: CheckpointModuleInput = None, + model_checkpointers: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -322,7 +322,7 @@ def __init__( to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. Recall that servers only have parameters rather than torch models. So we need to know where to route these parameters to allow for real models to be saved. Defaults to None. - model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this @@ -339,7 +339,7 @@ class ClippingBitServerCheckpointAndStateModule(PackingServerCheckpointAndAndSta def __init__( self, model: nn.Module | None = None, - model_checkpointers: CheckpointModuleInput = None, + model_checkpointers: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -356,7 +356,7 @@ def __init__( to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. Recall that servers only have parameters rather than torch models. So we need to know where to route these parameters to allow for real models to be saved. Defaults to None. - model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this @@ -373,7 +373,7 @@ class LayerNamesServerCheckpointAndStateModule(PackingServerCheckpointAndAndStat def __init__( self, model: nn.Module | None = None, - model_checkpointers: CheckpointModuleInput = None, + model_checkpointers: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -390,7 +390,7 @@ def __init__( to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. Recall that servers only have parameters rather than torch models. So we need to know where to route these parameters to allow for real models to be saved. Defaults to None. - model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this @@ -407,7 +407,7 @@ class SparseCooServerCheckpointAndStateModule(PackingServerCheckpointAndAndState def __init__( self, model: nn.Module | None = None, - model_checkpointers: CheckpointModuleInput = None, + model_checkpointers: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -425,7 +425,7 @@ def __init__( to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. Recall that servers only have parameters rather than torch models. So we need to know where to route these parameters to allow for real models to be saved. Defaults to None. - model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this @@ -443,7 +443,7 @@ def __init__( self, model: nn.Module | None = None, parameter_exchanger: ExchangerType | None = None, - model_checkpointers: CheckpointModuleInput = None, + model_checkpointers: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -465,7 +465,7 @@ def __init__( server parameters into the right components of the provided model architecture. Note that this exchanger and the model must match the one used for training and exchange with the servers to ensure parameters go to the right places. Defaults to None. - model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this @@ -491,7 +491,7 @@ def __init__( self, model: nn.Module | None = None, parameter_exchanger: ExchangerType | None = None, - model_checkpointers: CheckpointModuleInput = None, + model_checkpointers: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -517,7 +517,7 @@ def __init__( server parameters into the right components of the provided model architecture. Note that this exchanger and the model must match the one used for training and exchange with the servers to ensure parameters go to the right places. Defaults to None. - model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this @@ -543,7 +543,7 @@ class DpScaffoldServerCheckpointAndStateModule(ScaffoldServerCheckpointAndStateM def __init__( self, model: nn.Module | None = None, - model_checkpointers: CheckpointModuleInput = None, + model_checkpointers: ModelCheckpointers = None, state_checkpointer: PerRoundStateCheckpointer | None = None, ) -> None: """ @@ -560,7 +560,7 @@ def __init__( to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger. Recall that servers only have parameters rather than torch models. So we need to know where to route these parameters to allow for real models to be saved. Defaults to None. - model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of + model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None. state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be used to preserve FL training state to facilitate restarting training if interrupted. Generally, this From 5cf2df686fb49a6172ca3d8e1a769f0fc3d2b160 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:14:06 -0500 Subject: [PATCH 09/13] Avoiding some code duplication --- fl4health/checkpointing/server_module.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/fl4health/checkpointing/server_module.py b/fl4health/checkpointing/server_module.py index 44aee0b96..7c2447b08 100644 --- a/fl4health/checkpointing/server_module.py +++ b/fl4health/checkpointing/server_module.py @@ -523,20 +523,15 @@ def __init__( used to preserve FL training state to facilitate restarting training if interrupted. Generally, this checkpointer will save much more than just the model being trained. Defaults to None. """ - self.model = model - self.parameter_exchanger = parameter_exchanger - self.model_checkpointers = ( - [model_checkpointers] if isinstance(model_checkpointers, TorchModuleCheckpointer) else model_checkpointers + super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer) + + def _validate_model_checkpointer_components(self) -> None: + # NOTE: We only check if the parameter exchanger is present. Model may be set later. + assert self.parameter_exchanger is not None, ( + "Checkpointer(s) is (are) defined but no parameter_exchanger is defined to hydrate. The functionality of " + "this class can be overridden in a child class if checkpointing without a parameter exchanger is " + "possible and desired" ) - self.state_checkpointer = state_checkpointer - if self.model_checkpointers is not None and len(self.model_checkpointers): - # NOTE: We only check if the parameter exchanger is present. Model may be set later. - assert self.parameter_exchanger is not None, ( - "Checkpointer(s) is (are) defined but no parameter_exchanger is defined to hydrate. The functionality " - "of this class can be overridden in a child class if checkpointing without a parameter exchanger is " - "possible and desired" - ) - self._check_if_shared_checkpoint_names() class DpScaffoldServerCheckpointAndStateModule(ScaffoldServerCheckpointAndStateModule): From a25aae4cc242161411e9e0a051b42a0f6a931222 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Tue, 7 Jan 2025 17:12:37 -0500 Subject: [PATCH 10/13] A few small updates from John J.'s PR suggestions. --- fl4health/checkpointing/checkpointer.py | 5 +++-- fl4health/checkpointing/server_module.py | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/fl4health/checkpointing/checkpointer.py b/fl4health/checkpointing/checkpointer.py index 2379ce4c1..1d29ad274 100644 --- a/fl4health/checkpointing/checkpointer.py +++ b/fl4health/checkpointing/checkpointer.py @@ -43,14 +43,15 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca def load_checkpoint(self, path_to_checkpoint: str | None = None) -> nn.Module: """ Checkpointer with the option to either specify a checkpoint path or fall back on the internal path of the - checkpointer + checkpointer. The flexibility to specify a load path is useful, for example, if you are not overwriting + checkpoints when saving and need to load a specific past checkpoint for whatever reason. Args: path_to_checkpoint (str | None, optional): If provided, the checkpoint will be loaded from this path. If not specified, the checkpointer will load from self.checkpoint_path. Defaults to None. Returns: - nn.Module: _description_ + nn.Module: Returns a torch module loaded from the proper checkpoint path. """ if path_to_checkpoint is None: return torch.load(self.checkpoint_path) diff --git a/fl4health/checkpointing/server_module.py b/fl4health/checkpointing/server_module.py index 7c2447b08..f4a94f5c6 100644 --- a/fl4health/checkpointing/server_module.py +++ b/fl4health/checkpointing/server_module.py @@ -166,11 +166,11 @@ def save_state( """ if self.state_checkpointer is not None: self._hydrate_model_for_checkpointing(server_parameters) - if "model" not in other_state: - other_state["model"] = self.model - else: + if "model" in other_state: raise ValueError("Key 'model' already exists in the other_state dictionary.") - self.state_checkpointer.save_checkpoint(state_checkpoint_name, checkpoint_dict=other_state) + + checkpoint_dict = other_state | {"model": self.model} + self.state_checkpointer.save_checkpoint(state_checkpoint_name, checkpoint_dict=checkpoint_dict) else: raise ValueError("Attempting to save state but no state checkpointer is specified") From e6ab253fae573129117024e1c57fcd4708ced5b4 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:31:13 -0500 Subject: [PATCH 11/13] Grammatical correction to documentation --- fl4health/servers/adaptive_constraint_servers/ditto_server.py | 2 +- .../servers/adaptive_constraint_servers/fedprox_server.py | 2 +- fl4health/servers/adaptive_constraint_servers/mrmtl_server.py | 2 +- fl4health/servers/base_server.py | 2 +- fl4health/servers/client_level_dp_fed_avg_server.py | 2 +- fl4health/servers/fedpm_server.py | 2 +- fl4health/servers/instance_level_dp_server.py | 2 +- fl4health/servers/nnunet_server.py | 2 +- fl4health/servers/scaffold_server.py | 4 ++-- fl4health/servers/tabular_feature_alignment_server.py | 2 +- 10 files changed, 11 insertions(+), 11 deletions(-) diff --git a/fl4health/servers/adaptive_constraint_servers/ditto_server.py b/fl4health/servers/adaptive_constraint_servers/ditto_server.py index 31d61f23b..27a03bf20 100644 --- a/fl4health/servers/adaptive_constraint_servers/ditto_server.py +++ b/fl4health/servers/adaptive_constraint_servers/ditto_server.py @@ -39,7 +39,7 @@ def __init__( send data to before and after each round. Defaults to None. checkpoint_and_state_module (AdaptiveConstraintServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at - saving model artifacts to be used or evaluated after training. The later is used to preserve training + saving model artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. NOTE: For Ditto, the model shared with the server is the GLOBAL MODEL, which isn't the target of FL diff --git a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py index 8b6a5a88c..9acfc4b55 100644 --- a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py +++ b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py @@ -40,7 +40,7 @@ def __init__( should send data to before and after each round. Defaults to None. checkpoint_and_state_module (AdaptiveConstraintServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at - saving model artifacts to be used or evaluated after training. The later is used to preserve training + saving model artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to diff --git a/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py b/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py index 3316df94d..2f4aba07b 100644 --- a/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py +++ b/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py @@ -39,7 +39,7 @@ def __init__( send data to before and after each round. Defaults to None. checkpoint_and_state_module (AdaptiveConstraintServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at - saving model artifacts to be used or evaluated after training. The later is used to preserve training + saving model artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. NOTE: For MR-MTL, the server model is an aggregation of the personal models, which isn't the target of diff --git a/fl4health/servers/base_server.py b/fl4health/servers/base_server.py index ad6d7cd9e..53b929648 100644 --- a/fl4health/servers/base_server.py +++ b/fl4health/servers/base_server.py @@ -53,7 +53,7 @@ def __init__( should send data to before and after each round. Defaults to None. checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state + artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to diff --git a/fl4health/servers/client_level_dp_fed_avg_server.py b/fl4health/servers/client_level_dp_fed_avg_server.py index 635ce23f5..c080fbbd7 100644 --- a/fl4health/servers/client_level_dp_fed_avg_server.py +++ b/fl4health/servers/client_level_dp_fed_avg_server.py @@ -54,7 +54,7 @@ def __init__( send data to before and after each round. checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state + artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to diff --git a/fl4health/servers/fedpm_server.py b/fl4health/servers/fedpm_server.py index 2abaa6ad3..23c5a5d4a 100644 --- a/fl4health/servers/fedpm_server.py +++ b/fl4health/servers/fedpm_server.py @@ -39,7 +39,7 @@ def __init__( should send data to before and after each round. checkpoint_and_state_module (LayerNamesServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state + artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to diff --git a/fl4health/servers/instance_level_dp_server.py b/fl4health/servers/instance_level_dp_server.py index d75755e13..f01c446ba 100644 --- a/fl4health/servers/instance_level_dp_server.py +++ b/fl4health/servers/instance_level_dp_server.py @@ -61,7 +61,7 @@ def __init__( Defaults to None. checkpoint_and_state_module (OpacusServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state + artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index 299e395db..018a2d691 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -90,7 +90,7 @@ def __init__( should send data to. Defaults to None. checkpoint_and_state_module (NnUnetServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state + artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. NOTE: For NnUnet, this module is allowed to have all components defined other than the model, as it diff --git a/fl4health/servers/scaffold_server.py b/fl4health/servers/scaffold_server.py index 5ff50c7b6..111924fd3 100644 --- a/fl4health/servers/scaffold_server.py +++ b/fl4health/servers/scaffold_server.py @@ -50,7 +50,7 @@ def __init__( should send data to before and after each round. Defaults to None. checkpoint_and_state_module (ScaffoldServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state + artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to @@ -225,7 +225,7 @@ def __init__( type. checkpoint_and_state_module (DpScaffoldServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state + artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. warm_start (bool, optional): Whether or not to initialize control variates of each client as local diff --git a/fl4health/servers/tabular_feature_alignment_server.py b/fl4health/servers/tabular_feature_alignment_server.py index 8801bc49d..3a689d84e 100644 --- a/fl4health/servers/tabular_feature_alignment_server.py +++ b/fl4health/servers/tabular_feature_alignment_server.py @@ -63,7 +63,7 @@ def __init__( should send data to before and after each round. Defaults to None. checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used to handle both model checkpointing and state checkpointing. The former is aimed at saving model - artifacts to be used or evaluated after training. The later is used to preserve training state + artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to From 7066e70fe23a2bcfb3ab7a6a86eb4c2146be58e6 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:39:35 -0500 Subject: [PATCH 12/13] Changing ubuntu to latest to see if its fixed --- .github/workflows/smoke_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/smoke_tests.yaml b/.github/workflows/smoke_tests.yaml index 1da21d605..2d0c48828 100644 --- a/.github/workflows/smoke_tests.yaml +++ b/.github/workflows/smoke_tests.yaml @@ -10,7 +10,7 @@ on: jobs: test: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 From 17ef40295c6add5c90ec5eb489500cc93b905756 Mon Sep 17 00:00:00 2001 From: David Emerson <43939939+emersodb@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:02:03 -0500 Subject: [PATCH 13/13] Still a problem --- .github/workflows/smoke_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/smoke_tests.yaml b/.github/workflows/smoke_tests.yaml index 2d0c48828..1da21d605 100644 --- a/.github/workflows/smoke_tests.yaml +++ b/.github/workflows/smoke_tests.yaml @@ -10,7 +10,7 @@ on: jobs: test: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 steps: - name: Checkout code uses: actions/checkout@v4