diff --git a/research/rxrx1/data/data_utils.py b/research/rxrx1/data/data_utils.py index d5f7a7328..73f995052 100644 --- a/research/rxrx1/data/data_utils.py +++ b/research/rxrx1/data/data_utils.py @@ -1,6 +1,5 @@ from collections import defaultdict from pathlib import Path -from typing import Optional import numpy as np import pandas as pd @@ -32,7 +31,7 @@ def label_frequency(dataset: Rxrx1Dataset | Subset) -> None: def create_splits( - dataset: Rxrx1Dataset, seed: Optional[int] = None, train_fraction: float = 0.8 + dataset: Rxrx1Dataset, seed: int | None = None, train_fraction: float = 0.8 ) -> tuple[Subset, Subset]: """ Splits the dataset into training and validation sets. @@ -70,7 +69,7 @@ def create_splits( def load_rxrx1_data( - data_path: Path, client_num: int, batch_size: int, seed: Optional[int] = None, train_val_split: float = 0.8 + data_path: Path, client_num: int, batch_size: int, seed: int | None = None, train_val_split: float = 0.8 ) -> tuple[DataLoader, DataLoader, dict[str, int]]: # Read the CSV file diff --git a/research/rxrx1/data/dataset.py b/research/rxrx1/data/dataset.py index 8373b9314..34cf48d11 100644 --- a/research/rxrx1/data/dataset.py +++ b/research/rxrx1/data/dataset.py @@ -1,6 +1,6 @@ import os +from collections.abc import Callable from pathlib import Path -from typing import Callable, Optional import pandas as pd import torch @@ -10,7 +10,7 @@ class Rxrx1Dataset(Dataset): - def __init__(self, metadata: pd.DataFrame, root: Path, dataset_type: str, transform: Optional[Callable] = None): + def __init__(self, metadata: pd.DataFrame, root: Path, dataset_type: str, transform: Callable | None = None): """ Args: metadata (DataFrame): A DataFrame containing image metadata. diff --git a/research/rxrx1/ditto/client.py b/research/rxrx1/ditto/client.py index b3cc4e623..d027cba56 100644 --- a/research/rxrx1/ditto/client.py +++ b/research/rxrx1/ditto/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -14,9 +14,10 @@ from torch.utils.data import DataLoader from torchvision import models -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.ditto_client import DittoClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -33,14 +34,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -53,7 +60,7 @@ def setup_client(self, config: Config) -> None: assert 0 <= self.client_number < num_clients super().setup_client(config) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_rxrx1_data( data_path=self.data_path, client_num=self.client_number, batch_size=batch_size, seed=self.client_number @@ -61,7 +68,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) test_loader, _ = load_rxrx1_test_data( data_path=self.data_path, client_num=self.client_number, batch_size=batch_size @@ -72,7 +79,7 @@ def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized # Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9) @@ -145,14 +152,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) @@ -163,7 +170,7 @@ def get_model(self, config: Config) -> nn.Module: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/rxrx1/ditto/server.py b/research/rxrx1/ditto/server.py index ed6cfb37c..96c3b044e 100644 --- a/research/rxrx1/ditto/server.py +++ b/research/rxrx1/ditto/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/rxrx1/ditto_deep_mmd/client.py b/research/rxrx1/ditto_deep_mmd/client.py index ed05ca623..fc6f0ee41 100644 --- a/research/rxrx1/ditto_deep_mmd/client.py +++ b/research/rxrx1/ditto_deep_mmd/client.py @@ -1,9 +1,9 @@ import argparse import os from collections import OrderedDict +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -15,9 +15,10 @@ from torch.utils.data import DataLoader from torchvision import models -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.deep_mmd_clients.ditto_deep_mmd_client import DittoDeepMmdClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -41,9 +42,12 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: str | None = None, deep_mmd_loss_weight: float = 10, deep_mmd_loss_depth: int = 1, - checkpointer: Optional[ClientCheckpointModule] = None, ) -> None: feature_extraction_layers_with_size = OrderedDict(list(BASELINE_LAYERS.items())[-1 * deep_mmd_loss_depth :]) super().__init__( @@ -51,7 +55,10 @@ def __init__( metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, deep_mmd_loss_weight=deep_mmd_loss_weight, feature_extraction_layers_with_size=feature_extraction_layers_with_size, ) @@ -66,7 +73,7 @@ def setup_client(self, config: Config) -> None: assert 0 <= self.client_number < num_clients super().setup_client(config) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_rxrx1_data( data_path=self.data_path, client_num=self.client_number, batch_size=batch_size, seed=self.client_number @@ -74,7 +81,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) test_loader, _ = load_rxrx1_test_data( data_path=self.data_path, client_num=self.client_number, batch_size=batch_size @@ -85,7 +92,7 @@ def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized # Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9) @@ -175,14 +182,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) @@ -193,7 +200,7 @@ def get_model(self, config: Config) -> nn.Module: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, deep_mmd_loss_depth=args.deep_mmd_loss_depth, deep_mmd_loss_weight=args.mu, ) diff --git a/research/rxrx1/ditto_deep_mmd/server.py b/research/rxrx1/ditto_deep_mmd/server.py index ed6cfb37c..96c3b044e 100644 --- a/research/rxrx1/ditto_deep_mmd/server.py +++ b/research/rxrx1/ditto_deep_mmd/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/rxrx1/ditto_mkmmd/client.py b/research/rxrx1/ditto_mkmmd/client.py index 2cb78924e..dd67f63af 100644 --- a/research/rxrx1/ditto_mkmmd/client.py +++ b/research/rxrx1/ditto_mkmmd/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -14,9 +14,10 @@ from torch.utils.data import DataLoader from torchvision import models -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.mkmmd_clients.ditto_mkmmd_client import DittoMkMmdClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -39,14 +40,20 @@ def __init__( feature_l2_norm_weight: float = 1, mkmmd_loss_depth: int = 1, beta_global_update_interval: int = 20, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, mkmmd_loss_weight=mkmmd_loss_weight, feature_extraction_layers=BASELINE_LAYERS[-1 * mkmmd_loss_depth :], feature_l2_norm_weight=feature_l2_norm_weight, @@ -66,7 +73,7 @@ def setup_client(self, config: Config) -> None: assert 0 <= self.client_number < num_clients super().setup_client(config) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_rxrx1_data( data_path=self.data_path, client_num=self.client_number, batch_size=batch_size, seed=self.client_number @@ -74,7 +81,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) test_loader, _ = load_rxrx1_test_data( data_path=self.data_path, client_num=self.client_number, batch_size=batch_size @@ -85,7 +92,7 @@ def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized # Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9) @@ -192,14 +199,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) @@ -210,7 +217,7 @@ def get_model(self, config: Config) -> nn.Module: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, feature_l2_norm_weight=args.l2, mkmmd_loss_depth=args.mkmmd_loss_depth, mkmmd_loss_weight=args.mu, diff --git a/research/rxrx1/ditto_mkmmd/server.py b/research/rxrx1/ditto_mkmmd/server.py index ed6cfb37c..96c3b044e 100644 --- a/research/rxrx1/ditto_mkmmd/server.py +++ b/research/rxrx1/ditto_mkmmd/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/rxrx1/evaluate_on_test.py b/research/rxrx1/evaluate_on_test.py index ba54f4bef..396c3acb6 100644 --- a/research/rxrx1/evaluate_on_test.py +++ b/research/rxrx1/evaluate_on_test.py @@ -1,7 +1,6 @@ import argparse from logging import INFO from pathlib import Path -from typing import Dict import pandas as pd import torch @@ -42,7 +41,7 @@ def main( ) -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") all_run_folder_dir = get_all_run_folders(artifact_dir) - test_results: Dict[str, float] = {} + test_results: dict[str, float] = {} metrics = [Accuracy("rxrx1_accuracy")] all_pre_best_local_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} diff --git a/research/rxrx1/fedavg/client.py b/research/rxrx1/fedavg/client.py index 793abfe55..6bee554e4 100644 --- a/research/rxrx1/fedavg/client.py +++ b/research/rxrx1/fedavg/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -14,9 +14,10 @@ from torch.utils.data import DataLoader from torchvision import models -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer -from fl4health.checkpointing.client_module import ClientCheckpointModule +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule from fl4health.clients.basic_client import BasicClient +from fl4health.reporting.base_reporter import BaseReporter from fl4health.utils.config import narrow_dict_type from fl4health.utils.losses import LossMeterType from fl4health.utils.metrics import Accuracy, Metric @@ -33,14 +34,20 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[ClientCheckpointModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, + reporters: Sequence[BaseReporter] | None = None, + progress_bar: bool = False, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, metrics=metrics, device=device, loss_meter_type=loss_meter_type, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, + reporters=reporters, + progress_bar=progress_bar, + client_name=client_name, ) self.client_number = client_number self.learning_rate: float = learning_rate @@ -53,7 +60,7 @@ def setup_client(self, config: Config) -> None: assert 0 <= self.client_number < num_clients super().setup_client(config) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_rxrx1_data( data_path=self.data_path, client_num=self.client_number, batch_size=batch_size, seed=self.client_number @@ -61,7 +68,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) test_loader, _ = load_rxrx1_test_data( data_path=self.data_path, client_num=self.client_number, batch_size=batch_size @@ -141,14 +148,14 @@ def get_model(self, config: Config) -> nn.Module: pre_aggregation_last_checkpoint_name = f"pre_aggregation_client_{args.client_number}_last_model.pkl" post_aggregation_best_checkpoint_name = f"post_aggregation_client_{args.client_number}_best_model.pkl" post_aggregation_last_checkpoint_name = f"post_aggregation_client_{args.client_number}_last_model.pkl" - checkpointer = ClientCheckpointModule( + checkpoint_and_state_module = ClientCheckpointAndStateModule( pre_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, pre_aggregation_last_checkpoint_name), ], post_aggregation=[ - BestLossTorchCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), + BestLossTorchModuleCheckpointer(checkpoint_dir, post_aggregation_best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, post_aggregation_last_checkpoint_name), ], ) @@ -159,7 +166,7 @@ def get_model(self, config: Config) -> nn.Module: device=device, client_number=args.client_number, learning_rate=args.learning_rate, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.client.start_client(server_address=args.server_address, client=client.to_client()) diff --git a/research/rxrx1/fedavg/server.py b/research/rxrx1/fedavg/server.py index aefd56b46..659782718 100644 --- a/research/rxrx1/fedavg/server.py +++ b/research/rxrx1/fedavg/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -11,7 +11,8 @@ from flwr.server.strategy import FedAvg from torchvision import models -from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer +from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer +from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger from fl4health.servers.base_server import FlServer from fl4health.utils.config import load_config @@ -36,7 +37,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, @@ -45,17 +46,22 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config["n_server_rounds"], config["n_clients"], ) + # Initializing the model on the server side + model = models.resnet18(pretrained=True) + parameter_exchanger = FullParameterExchanger() checkpoint_dir = os.path.join(checkpoint_stub, run_name) best_checkpoint_name = "server_best_model.pkl" last_checkpoint_name = "server_last_model.pkl" - checkpointer = [ - BestLossTorchCheckpointer(checkpoint_dir, best_checkpoint_name), - LatestTorchCheckpointer(checkpoint_dir, last_checkpoint_name), + checkpointers = [ + BestLossTorchModuleCheckpointer(checkpoint_dir, best_checkpoint_name), + LatestTorchModuleCheckpointer(checkpoint_dir, last_checkpoint_name), ] + + checkpoint_and_state_module = BaseServerCheckpointAndStateModule( + model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointers + ) + client_manager = SimpleClientManager() - # Initializing the model on the server side - model = models.resnet18(pretrained=True) - parameter_exchanger = FullParameterExchanger() # Server performs simple FedAveraging as its server-side optimization strategy strategy = FedAvg( min_fit_clients=config["n_clients"], @@ -72,10 +78,8 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ server = FlServer( client_manager=client_manager, fl_config=config, - parameter_exchanger=parameter_exchanger, - model=model, strategy=strategy, - checkpointer=checkpointer, + checkpoint_and_state_module=checkpoint_and_state_module, ) fl.server.start_server( @@ -84,8 +88,8 @@ def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_ config=fl.server.ServerConfig(num_rounds=config["n_server_rounds"]), ) - assert isinstance(checkpointer[0], BestLossTorchCheckpointer) - log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointer[0].best_score}") + assert isinstance(checkpointers[0], BestLossTorchModuleCheckpointer) + log(INFO, f"Best Aggregated (Weighted) Loss seen by the Server: \n{checkpointers[0].best_score}") # Shutdown the server gracefully server.shutdown() diff --git a/research/rxrx1/find_best_hp.py b/research/rxrx1/find_best_hp.py index b44a63ad1..2737bb4fc 100644 --- a/research/rxrx1/find_best_hp.py +++ b/research/rxrx1/find_best_hp.py @@ -1,18 +1,17 @@ import argparse import os from logging import INFO -from typing import List, Optional import numpy as np from flwr.common.logger import log -def get_hp_folders(hp_sweep_dir: str) -> List[str]: +def get_hp_folders(hp_sweep_dir: str) -> list[str]: paths_in_hp_sweep_dir = [os.path.join(hp_sweep_dir, contents) for contents in os.listdir(hp_sweep_dir)] return [hp_folder for hp_folder in paths_in_hp_sweep_dir if os.path.isdir(hp_folder)] -def get_run_folders(hp_dir: str) -> List[str]: +def get_run_folders(hp_dir: str) -> list[str]: run_folder_names = [folder_name for folder_name in os.listdir(hp_dir) if "Run" in folder_name] return [os.path.join(hp_dir, run_folder_name) for run_folder_name in run_folder_names] @@ -27,7 +26,7 @@ def get_weighted_loss_from_server_log(run_folder_path: str) -> float: def main(hp_sweep_dir: str) -> None: hp_folders = get_hp_folders(hp_sweep_dir) - best_avg_loss: Optional[float] = None + best_avg_loss: float | None = None best_folder = "" for hp_folder in hp_folders: run_folders = get_run_folders(hp_folder) diff --git a/research/rxrx1/personal_server.py b/research/rxrx1/personal_server.py index 6250c0841..ae3469497 100644 --- a/research/rxrx1/personal_server.py +++ b/research/rxrx1/personal_server.py @@ -1,5 +1,4 @@ from logging import INFO -from typing import Dict, Optional, Tuple from flwr.common.logger import log from flwr.common.typing import Config, Scalar @@ -24,18 +23,20 @@ def __init__( self, client_manager: ClientManager, fl_config: Config, - strategy: Optional[Strategy] = None, + strategy: Strategy | None = None, ) -> None: # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with # some globally shared weights. So we don't checkpoint a global model - super().__init__(client_manager=client_manager, fl_config=fl_config, strategy=strategy, checkpointer=None) - self.best_aggregated_loss: Optional[float] = None + super().__init__( + client_manager=client_manager, fl_config=fl_config, strategy=strategy, checkpoint_and_state_module=None + ) + self.best_aggregated_loss: float | None = None def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # loss_aggregated is the aggregated validation per step loss # aggregated over each client (weighted by num examples) eval_round_results = super().evaluate_round(server_round, timeout) diff --git a/research/rxrx1/utils.py b/research/rxrx1/utils.py index eb4844f9c..3255793cd 100644 --- a/research/rxrx1/utils.py +++ b/research/rxrx1/utils.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Sequence, Tuple +from collections.abc import Sequence import numpy as np import torch @@ -9,7 +9,7 @@ from fl4health.utils.metrics import Metric, MetricManager -def get_all_run_folders(artifact_dir: str) -> List[str]: +def get_all_run_folders(artifact_dir: str) -> list[str]: run_folder_names = [folder_name for folder_name in os.listdir(artifact_dir) if "Run" in folder_name] return [os.path.join(artifact_dir, run_folder_name) for run_folder_name in run_folder_names] @@ -26,13 +26,13 @@ def load_last_global_model(run_folder_dir: str) -> nn.Module: return model -def get_metric_avg_std(metrics: List[float]) -> Tuple[float, float]: +def get_metric_avg_std(metrics: list[float]) -> tuple[float, float]: mean = float(np.mean(metrics)) std = float(np.std(metrics, ddof=1)) return mean, std -def write_measurement_results(eval_write_path: str, results: Dict[str, float]) -> None: +def write_measurement_results(eval_write_path: str, results: dict[str, float]) -> None: with open(eval_write_path, "w") as f: for key, metric_value in results.items(): f.write(f"{key}: {metric_value}\n")