Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to Support Expanded Experimentation with FedDG-GA #252

Merged
merged 9 commits into from
Nov 4, 2024
1 change: 1 addition & 0 deletions examples/feddg_ga_example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ from the FL4Health directory. The following arguments must be present in the spe
* `batch_size`: size of the batches each client will train on
* `n_server_rounds`: The number of rounds to run FL
* `evaluate_after_fit`: Should be set to `True`. Performs an evaluation at the end of each client's fit round.
* `pack_losses_with_val_metrics`: Should be set to `True`. Includes validation losses with metrics calculations

## Starting Clients

Expand Down
5 changes: 4 additions & 1 deletion examples/feddg_ga_example/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ n_server_rounds: 3 # The number of rounds to run FL
n_clients: 2 # The number of clients in the FL experiment
local_steps: 5 # The number of local steps (one per batch) to complete for client
batch_size: 128 # The batch size for client training
evaluate_after_fit: True # Evaluates model immediately after local training on the validation set (in addition to the training set)
# Evaluates model immediately after local training on the validation set (in addition to the training set)
evaluate_after_fit: True
# Packs the measured validation losses with the metrics (required for fed-dgga)
pack_losses_with_val_metrics: True
7 changes: 5 additions & 2 deletions examples/feddg_ga_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fl4health.model_bases.apfl_base import ApflModule
from fl4health.reporting import JsonReporter
from fl4health.server.base_server import FlServer
from fl4health.strategies.feddg_ga_strategy import FedDgGaStrategy
from fl4health.strategies.feddg_ga import FedDgGa
from fl4health.utils.config import load_config
from fl4health.utils.metric_aggregation import evaluate_metrics_aggregation_fn, fit_metrics_aggregation_fn
from fl4health.utils.parameter_extraction import get_all_model_parameters
Expand All @@ -25,13 +25,15 @@ def fit_config(
local_epochs: Optional[int] = None,
local_steps: Optional[int] = None,
evaluate_after_fit: bool = False,
pack_losses_with_val_metrics: bool = False,
) -> Config:
return {
**make_dict_with_epochs_or_steps(local_epochs, local_steps),
"current_server_round": current_round,
"batch_size": batch_size,
"n_server_rounds": n_server_rounds,
"evaluate_after_fit": evaluate_after_fit,
"pack_losses_with_val_metrics": pack_losses_with_val_metrics,
}


Expand All @@ -44,12 +46,13 @@ def main(config: Dict[str, Any]) -> None:
local_epochs=config.get("local_epochs"),
local_steps=config.get("local_steps"),
evaluate_after_fit=config.get("evaluate_after_fit", False),
pack_losses_with_val_metrics=config.get("pack_losses_with_val_metrics", False),
)

initial_model = ApflModule(MnistNetWithBnAndFrozen())

# Implementation of FedDG-GA as a server side strategy
strategy = FedDgGaStrategy(
strategy = FedDgGa(
min_fit_clients=config["n_clients"],
min_evaluate_clients=config["n_clients"],
# Server waits for min_available_clients before starting FL rounds
Expand Down
6 changes: 3 additions & 3 deletions examples/fenda_ditto_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from fl4health.clients.fenda_ditto_client import FendaDittoClient
from fl4health.model_bases.fenda_base import FendaModel
from fl4health.model_bases.parallel_split_models import ParallelFeatureJoinMode
from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel
from fl4health.model_bases.sequential_split_models import SequentiallySplitModel
from fl4health.reporting import JsonReporter
from fl4health.utils.config import narrow_dict_type
from fl4health.utils.load_data import load_mnist_data
Expand All @@ -37,8 +37,8 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
return train_loader, val_loader

def get_global_model(self, config: Config) -> SequentiallySplitExchangeBaseModel:
return SequentiallySplitExchangeBaseModel(
def get_global_model(self, config: Config) -> SequentiallySplitModel:
return SequentiallySplitModel(
base_module=SequentialGlobalFeatureExtractorMnist(),
head_module=SequentialLocalPredictionHeadMnist(),
).to(self.device)
Expand Down
Loading