Skip to content

Commit

Permalink
Merge pull request #298 from VectorInstitute/dbe/some_server_side_che…
Browse files Browse the repository at this point in the history
…ckpointer_consolidation

Consolidating model and state checkpointing on the client and server sides.
  • Loading branch information
emersodb authored Jan 8, 2025
2 parents dce71b2 + 17ef402 commit 44ba556
Show file tree
Hide file tree
Showing 152 changed files with 3,297 additions and 1,532 deletions.
12 changes: 7 additions & 5 deletions examples/ae_examples/cvae_dim_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from flwr.server.strategy import FedAvg

from examples.models.mnist_model import MnistNet
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -47,7 +48,10 @@ def main(config: Dict[str, Any]) -> None:
model = MnistNet(int(config["latent_dim"]) * 2)
# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl")
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl")
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -66,10 +70,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
12 changes: 7 additions & 5 deletions examples/ae_examples/cvae_examples/conv_cvae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from flwr.server.strategy import FedAvg

from examples.ae_examples.cvae_examples.conv_cvae_example.models import ConvConditionalDecoder, ConvConditionalEncoder
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.model_bases.autoencoders_base import ConditionalVae
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
Expand Down Expand Up @@ -48,7 +49,10 @@ def main(config: Dict[str, Any]) -> None:

# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -67,10 +71,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
12 changes: 7 additions & 5 deletions examples/ae_examples/cvae_examples/mlp_cvae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from flwr.server.strategy import FedAvg

from examples.ae_examples.cvae_examples.mlp_cvae_example.models import MnistConditionalDecoder, MnistConditionalEncoder
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.model_bases.autoencoders_base import ConditionalVae
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
Expand Down Expand Up @@ -48,7 +49,10 @@ def main(config: Dict[str, Any]) -> None:

# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -67,10 +71,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
15 changes: 8 additions & 7 deletions examples/ae_examples/fedprox_vae_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from flwr.server.client_manager import SimpleClientManager

from examples.ae_examples.fedprox_vae_example.models import MnistVariationalDecoder, MnistVariationalEncoder
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import AdaptiveConstraintServerCheckpointAndStateModule
from fl4health.model_bases.autoencoders_base import VariationalAe
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
from fl4health.strategies.fedavg_with_adaptive_constraint import FedAvgWithAdaptiveConstraint
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -47,8 +47,11 @@ def main(config: Dict[str, Any]) -> None:
model_checkpoint_name = "best_VAE_model.pkl"

# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], model_checkpoint_name)
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], model_checkpoint_name)

checkpoint_and_state_module = AdaptiveConstraintServerCheckpointAndStateModule(
model=model, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy and potentially adapts the
# FedProx proximal weight mu
Expand All @@ -70,10 +73,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
4 changes: 2 additions & 2 deletions examples/apfl_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main(config: Dict[str, Any]) -> None:
local_steps=config.get("local_steps"),
)

initial_model = ApflModule(MnistNetWithBnAndFrozen())
model = ApflModule(MnistNetWithBnAndFrozen())

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -56,7 +56,7 @@ def main(config: Dict[str, Any]) -> None:
on_evaluate_config_fn=fit_config_fn,
fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
initial_parameters=get_all_model_parameters(initial_model),
initial_parameters=get_all_model_parameters(model),
)

client_manager = SimpleClientManager()
Expand Down
14 changes: 8 additions & 6 deletions examples/basic_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

from examples.models.cnn_model import Net
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer, LatestTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer, LatestTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -44,9 +45,12 @@ def main(config: Dict[str, Any]) -> None:
# To facilitate checkpointing
parameter_exchanger = FullParameterExchanger()
checkpointers = [
BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl"),
LatestTorchCheckpointer(config["checkpoint_path"], "latest_model.pkl"),
BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl"),
LatestTorchModuleCheckpointer(config["checkpoint_path"], "latest_model.pkl"),
]
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointers
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -65,10 +69,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointers,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
24 changes: 15 additions & 9 deletions examples/docker_basic_example/fl_client/client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import argparse
from pathlib import Path
from typing import Sequence
from typing import Sequence, Tuple

import flwr as fl
import torch
import torch.nn as nn
from flwr.common.typing import Config
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from torch.utils.data import DataLoader

from examples.models.cnn_model import Net
from fl4health.clients.basic_client import BasicClient
Expand All @@ -20,17 +24,19 @@ def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.dev
self.model = Net()
self.parameter_exchanger = FullParameterExchanger()

def setup_client(self, config: Config) -> None:
super().setup_client(config)
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, validation_loader, num_examples = load_cifar10_data(self.data_path, batch_size)
train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size)
return train_loader, val_loader

self.train_loader = train_loader
self.val_loader = validation_loader
self.num_examples = num_examples
def get_criterion(self, config: Config) -> _Loss:
return torch.nn.CrossEntropyLoss()

self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
def get_optimizer(self, config: Config) -> Optimizer:
return torch.optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)

def get_model(self, config: Config) -> nn.Module:
return Net().to(self.device)


if __name__ == "__main__":
Expand Down
8 changes: 5 additions & 3 deletions examples/dp_fed_examples/instance_level_dp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data import DataLoader

from examples.models.cnn_model import Net
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer
from fl4health.clients.instance_level_dp_client import InstanceLevelDpClient
from fl4health.utils.config import narrow_dict_type
Expand Down Expand Up @@ -48,12 +48,14 @@ def get_criterion(self, config: Config) -> _Loss:
post_aggregation_checkpointer = BestLossOpacusCheckpointer(
checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name
)
checkpointer = ClientCheckpointModule(post_aggregation=post_aggregation_checkpointer)
checkpoint_and_state_module = ClientCheckpointAndStateModule(post_aggregation=post_aggregation_checkpointer)

# Load model and data
data_path = Path(args.dataset_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
client = CifarClient(data_path, [Accuracy("accuracy")], device, checkpointer=checkpointer)
client = CifarClient(
data_path, [Accuracy("accuracy")], device, checkpoint_and_state_module=checkpoint_and_state_module
)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())

client.shutdown()
22 changes: 13 additions & 9 deletions examples/dp_fed_examples/instance_level_dp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from examples.models.cnn_model import Net
from examples.utils.functions import make_dict_with_epochs_or_steps
from fl4health.checkpointing.opacus_checkpointer import BestLossOpacusCheckpointer
from fl4health.checkpointing.server_module import OpacusServerCheckpointAndStateModule
from fl4health.client_managers.poisson_sampling_manager import PoissonSamplingClientManager
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.instance_level_dp_server import InstanceLevelDpServer
Expand Down Expand Up @@ -67,15 +68,24 @@ def main(config: Dict[str, Any]) -> None:
local_steps=config.get("local_steps"),
)

initial_model = map_model_to_opacus_model(Net())
model = map_model_to_opacus_model(Net())

client_name = "".join(choices(string.ascii_uppercase, k=5))
checkpoint_dir = "examples/dp_fed_examples/instance_level_dp/"
checkpoint_name = f"server_{client_name}_best_model.pkl"
checkpointer = BestLossOpacusCheckpointer(checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name)

checkpoint_and_state_module = OpacusServerCheckpointAndStateModule(
model=model, parameter_exchanger=FullParameterExchanger(), model_checkpointers=checkpointer
)

# ClientManager that performs Poisson type sampling
client_manager = PoissonSamplingClientManager()

# Server performs simple FedAveraging with Instance Level Differential Privacy
# Must be FedAvg sampling to ensure privacy loss is computed correctly
strategy = OpacusBasicFedAvg(
model=initial_model,
model=model,
fraction_fit=config["client_sampling_rate"],
# Server waits for min_available_clients before starting FL rounds
min_available_clients=config["n_clients"],
Expand All @@ -86,22 +96,16 @@ def main(config: Dict[str, Any]) -> None:
on_evaluate_config_fn=fit_config_fn,
)

client_name = "".join(choices(string.ascii_uppercase, k=5))
checkpoint_dir = "examples/dp_fed_examples/instance_level_dp/"
checkpoint_name = f"server_{client_name}_best_model.pkl"

server = InstanceLevelDpServer(
client_manager=client_manager,
fl_config=config,
model=initial_model,
checkpointer=BestLossOpacusCheckpointer(checkpoint_dir=checkpoint_dir, checkpoint_name=checkpoint_name),
parameter_exchanger=FullParameterExchanger(),
strategy=strategy,
noise_multiplier=config["noise_multiplier"],
local_epochs=config.get("local_epochs"),
local_steps=config.get("local_steps"),
batch_size=config["batch_size"],
num_server_rounds=config["n_server_rounds"],
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
6 changes: 3 additions & 3 deletions examples/fedopt_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from examples.fedopt_example.client_data import LabelEncoder, Vocabulary, construct_dataloaders
from examples.fedopt_example.metrics import CompoundMetric
from examples.models.lstm_model import LSTM
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule
from fl4health.clients.basic_client import BasicClient, TorchInputType
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.losses import LossMeterType
Expand All @@ -27,9 +27,9 @@ def __init__(
metrics: Sequence[Metric],
device: torch.device,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[ClientCheckpointModule] = None,
checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None,
) -> None:
super().__init__(data_path, metrics, device, loss_meter_type, checkpointer)
super().__init__(data_path, metrics, device, loss_meter_type, checkpoint_and_state_module)
self.weight_matrix: torch.Tensor
self.vocabulary: Vocabulary
self.label_encoder: LabelEncoder
Expand Down
12 changes: 7 additions & 5 deletions examples/fedpca_examples/dim_reduction/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from flwr.server.strategy import FedAvg

from examples.models.mnist_model import MnistNet
from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer
from fl4health.checkpointing.checkpointer import BestLossTorchModuleCheckpointer
from fl4health.checkpointing.server_module import BaseServerCheckpointAndStateModule
from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger
from fl4health.servers.base_server import FlServer
from fl4health.utils.config import load_config
Expand Down Expand Up @@ -47,7 +48,10 @@ def main(config: Dict[str, Any]) -> None:
parameter_exchanger = FullParameterExchanger()

# To facilitate checkpointing
checkpointer = BestLossTorchCheckpointer(config["checkpoint_path"], "best_model.pkl")
checkpointer = BestLossTorchModuleCheckpointer(config["checkpoint_path"], "best_model.pkl")
checkpoint_and_state_module = BaseServerCheckpointAndStateModule(
model=model, parameter_exchanger=parameter_exchanger, model_checkpointers=checkpointer
)

# Server performs simple FedAveraging as its server-side optimization strategy
strategy = FedAvg(
Expand All @@ -66,10 +70,8 @@ def main(config: Dict[str, Any]) -> None:
server = FlServer(
client_manager=SimpleClientManager(),
fl_config=config,
parameter_exchanger=parameter_exchanger,
model=model,
strategy=strategy,
checkpointer=checkpointer,
checkpoint_and_state_module=checkpoint_and_state_module,
)

fl.server.start_server(
Expand Down
4 changes: 1 addition & 3 deletions examples/fedprox_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def main(config: Dict[str, Any], server_address: str) -> None:
reporters = [wandb_reporter, json_reporter]
else:
reporters = [json_reporter]
server = FedProxServer(
client_manager=client_manager, fl_config=config, strategy=strategy, model=None, reporters=reporters
)
server = FedProxServer(client_manager=client_manager, fl_config=config, strategy=strategy, reporters=reporters)

fl.server.start_server(
server=server,
Expand Down
Loading

0 comments on commit 44ba556

Please sign in to comment.