Skip to content

Commit

Permalink
Merge pull request #280 from VectorInstitute/dbe/server_shutdowns_on_…
Browse files Browse the repository at this point in the history
…client_failures

First pass at optionally shutting down server (and remaining clients) on failures
  • Loading branch information
emersodb authored Nov 11, 2024
2 parents 910c36e + 04ae53f commit 986350c
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 108 deletions.
109 changes: 77 additions & 32 deletions fl4health/server/base_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from logging import DEBUG, INFO, WARN, WARNING
from logging import DEBUG, ERROR, INFO, WARNING
from pathlib import Path
from typing import Dict, Generic, List, Optional, Sequence, Tuple, TypeVar, Union

Expand All @@ -24,6 +24,7 @@
from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, MetricPrefix
from fl4health.utils.parameter_extraction import get_all_model_parameters
from fl4health.utils.random import generate_hash
from fl4health.utils.typing import EvaluateFailures, FitFailures


class FlServer(Server):
Expand All @@ -34,33 +35,34 @@ def __init__(
reporters: Sequence[BaseReporter] | None = None,
checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None,
server_name: Optional[str] = None,
accept_failures: bool = True,
) -> None:
"""
Base Server for the library to facilitate strapping additional/useful machinery to the base flwr server.
Args:
client_manager (ClientManager): Determines the mechanism by which clients
are sampled by the server, if they are to be sampled at all.
strategy (Optional[Strategy], optional): The aggregation strategy to be
used by the server to handle. client updates and other information
potentially sent by the participating clients. If None the strategy is
FedAvg as set by the flwr Server.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health
reporters which the server should send data to before and after each round.
checkpointer (TorchCheckpointer | Sequence [TorchCheckpointer], optional):
To be provided if the server should perform server side checkpointing
based on some criteria. If none, then no server-side checkpointing is
performed. Multiple checkpointers can also be passed in a sequence to
checkpointer based on multiple criteria. Ensure checkpoint names are
different for each checkpoint or they will overwrite on another.
Defaults to None.
server_name (Optional[str]): An optional string name to uniquely identify
server.
client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if
they are to be sampled at all.
strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle.
client updates and other information potentially sent by the participating clients. If None the
strategy is FedAvg as set by the flwr Server.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server
should send data to before and after each round.
checkpointer (TorchCheckpointer | Sequence [TorchCheckpointer], optional): To be provided if the server
should perform server side checkpointing based on some criteria. If none, then no server-side
checkpointing is performed. Multiple checkpointers can also be passed in a sequence to checkpointer
based on multiple criteria. Ensure checkpoint names are different for each checkpoint or they will
overwrite on another. Defaults to None.
server_name (Optional[str]): An optional string name to uniquely identify server.
accept_failures (bool, optional): Determines whether the server should accept failures during training or
evaluation from clients or not. If set to False, this will cause the server to shutdown all clients
and throw an exception. Defaults to True.
"""

super().__init__(client_manager=client_manager, strategy=strategy)
self.checkpointer = [checkpointer] if isinstance(checkpointer, TorchCheckpointer) else checkpointer
self.server_name = server_name if server_name is not None else generate_hash()
self.accept_failures = accept_failures

# Initialize reporters with server name information.
self.reports_manager = ReportsManager(reporters)
Expand Down Expand Up @@ -131,8 +133,13 @@ def fit_round(
server_round,
)
if fit_round_results is not None:
_, metrics, _ = fit_round_results
_, metrics, fit_results_and_failures = fit_round_results
self.reports_manager.report({"fit_metrics": metrics}, server_round)
failures = fit_results_and_failures[1] if fit_results_and_failures else None

if failures and not self.accept_failures:
self._log_client_failures(failures)
self._terminate_after_unacceptable_failures(timeout)

return fit_round_results

Expand Down Expand Up @@ -324,7 +331,12 @@ def evaluate_round(
eval_round_results = self._evaluate_round(server_round, timeout)
end_time = datetime.datetime.now()
if eval_round_results:
loss_aggregated, metrics_aggregated, _ = eval_round_results
loss_aggregated, metrics_aggregated, (_, failures) = eval_round_results

if failures and not self.accept_failures:
self._log_client_failures(failures)
self._terminate_after_unacceptable_failures(timeout)

if loss_aggregated:
self._maybe_checkpoint(loss_aggregated, metrics_aggregated, server_round)
# Report evaluation results
Expand All @@ -345,6 +357,37 @@ def evaluate_round(

return eval_round_results

def _terminate_after_unacceptable_failures(self, timeout: Optional[float]) -> None:
assert not self.accept_failures
# First we shutdown all clients involved in the FL training/evaluation if they can be.
self.disconnect_all_clients(timeout=timeout)
# Throw an exception alerting the user to failures on the client-side causing termination
self.shutdown()
raise ValueError(
f"The server encountered failures from the clients and accept_failures is set to {self.accept_failures}"
)

def _log_client_failures(self, failures: FitFailures | EvaluateFailures) -> None:
log(
ERROR,
f"There were {len(failures)} failures in the fitting process. This will result in termination of "
"the FL process",
)
for failure in failures:
if isinstance(failure, BaseException):
log(
ERROR,
"An exception was returned instead of any failed results. As such the client ID is unknown. "
"Please check the client logs to determine which failed.\n"
f"The exception thrown was {repr(failure)}",
)
else:
client_proxy, _ = failure
log(
ERROR,
f"Client {client_proxy.cid} failed but did not return an exception. Partial results were received",
)


ExchangerType = TypeVar("ExchangerType", bound=ParameterExchanger)

Expand All @@ -360,6 +403,7 @@ def __init__(
checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None,
intermediate_server_state_dir: Optional[Path] = None,
server_name: Optional[str] = None,
accept_failures: bool = True,
) -> None:
"""
This is a standard FL server but equipped with the assumption that the parameter exchanger is capable of
Expand All @@ -386,13 +430,12 @@ def __init__(
intermediate_server_state_dir (Path): A directory to store and load checkpoints from for the server
during an FL experiment.
server_name (Optional[str]): An optional string name to uniquely identify server.
accept_failures (bool, optional): Determines whether the server should accept failures during training or
evaluation from clients or not. If set to False, this will cause the server to shutdown all clients
and throw an exception. Defaults to True.
"""
super().__init__(
client_manager,
strategy,
reporters,
checkpointer,
server_name=server_name,
client_manager, strategy, reporters, checkpointer, server_name=server_name, accept_failures=accept_failures
)
self.server_model = model
# To facilitate model rehydration from server-side state for checkpointing
Expand Down Expand Up @@ -608,6 +651,7 @@ def __init__(
strategy: Optional[Strategy] = None,
reporters: Sequence[BaseReporter] | None = None,
checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None,
accept_failures: bool = True,
) -> None:
"""
Server with an initialize hook method that is called prior to fit. Override the self.initialize method to do
Expand All @@ -623,13 +667,14 @@ def __init__(
strategy is FedAvg as set by the flwr Server.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health
reporters which the server should send data to before and after each round.
checkpointer (Optional[Union[TorchCheckpointer, Sequence
[TorchCheckpointer]]], optional): To be provided if the server
should perform server side checkpointing based on some
criteria. If none, then no server-side checkpointing is
performed. Defaults to None.
checkpointer (Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]], optional): To be provided
if the server should perform server side checkpointing based on some criteria. If none, then no
server-side checkpointing is performed. Defaults to None.
accept_failures (bool, optional): Determines whether the server should accept failures during training or
evaluation from clients or not. If set to False, this will cause the server to shutdown all clients
and throw an exception. Defaults to True.
"""
super().__init__(client_manager, strategy, reporters, checkpointer)
super().__init__(client_manager, strategy, reporters, checkpointer, accept_failures=accept_failures)
self.initialized = False

def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) -> Parameters:
Expand Down Expand Up @@ -663,7 +708,7 @@ def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) -
log(INFO, "Received initial parameters from one random client")
else:
log(
WARN,
WARNING,
"Failed to receive initial parameters from the client." " Empty initial parameters will be used.",
)
return get_parameters_res.parameters
Expand Down
5 changes: 5 additions & 0 deletions fl4health/server/client_level_dp_fed_avg_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
checkpointer: Optional[TorchCheckpointer] = None,
reporters: Sequence[BaseReporter] | None = None,
delta: Optional[int] = None,
accept_failures: bool = True,
) -> None:
"""
Server to be used in case of Client Level Differential Privacy with Federated Averaging.
Expand All @@ -48,12 +49,16 @@ def __init__(
reporters which the server should send data to before and after each round.
delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to
being 1/total_samples in the FL run. Defaults to None.
accept_failures (bool, optional): Determines whether the server should accept failures during training or
evaluation from clients or not. If set to False, this will cause the server to shutdown all clients
and throw an exception. Defaults to True.
"""
super().__init__(
client_manager=client_manager,
strategy=strategy,
checkpointer=checkpointer,
reporters=reporters,
accept_failures=accept_failures,
)
self.accountant: ClientLevelAccountant
self.server_noise_multiplier = server_noise_multiplier
Expand Down
4 changes: 2 additions & 2 deletions fl4health/server/evaluate_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def __init__(
accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True.
min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 1.
Defaults to 1.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health
reporters which the client should send data to.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the client should
send data to.
"""
# We aren't aggregating model weights, so setting the strategy to be none.
super().__init__(client_manager=client_manager, strategy=None)
Expand Down
29 changes: 15 additions & 14 deletions fl4health/server/fedpm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,34 @@ def __init__(
checkpointer: Optional[TorchCheckpointer] = None,
reset_frequency: int = 1,
reporters: Sequence[BaseReporter] | None = None,
accept_failures: bool = True,
) -> None:
"""
Custom FL Server for the FedPM algorithm to allow for resetting the beta priors in Bayesian aggregation,
as specified in http://arxiv.org/pdf/2209.15328.
Args:
client_manager (ClientManager): Determines the mechanism by which clients
are sampled by the server, if they are to be sampled at all.
strategy (Scaffold): The aggregation strategy to be used by the server to
handle client updates and other information potentially sent by the
participating clients. This strategy must be of SCAFFOLD type.
checkpointer (Optional[TorchCheckpointer], optional): To be provided if the
server should perform server side checkpointing based on some criteria.
If none, then no server-side checkpointing is performed. Defaults to
None.
reset_frequency (int): Determines the frequency with which the beta priors
are reset. Defaults to 1.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health
reporters which the server should send data to before and after each
round.
client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if
they are to be sampled at all.
strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and other
information potentially sent by the participating clients. This strategy must be of SCAFFOLD type.
checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform
server side checkpointing based on some criteria. If none, then no server-side checkpointing is
performed. Defaults to None.
reset_frequency (int): Determines the frequency with which the beta priors are reset. Defaults to 1.
reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should
send data to before and after each round.
accept_failures (bool, optional): Determines whether the server should accept failures during training or
evaluation from clients or not. If set to False, this will cause the server to shutdown all clients
and throw an exception. Defaults to True.
"""
FlServer.__init__(
self,
client_manager=client_manager,
strategy=strategy,
checkpointer=checkpointer,
reporters=reporters,
accept_failures=accept_failures,
)
self.reset_frequency = reset_frequency

Expand Down
5 changes: 5 additions & 0 deletions fl4health/server/instance_level_dp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
checkpointer: Optional[OpacusCheckpointer] = None,
reporters: Sequence[BaseReporter] | None = None,
delta: Optional[float] = None,
accept_failures: bool = True,
) -> None:
"""
Server to be used in case of Instance Level Differential Privacy with Federated Averaging.
Expand Down Expand Up @@ -57,12 +58,16 @@ def __init__(
reporters which the client should send data to.
delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to
being 1/total_samples in the FL run. Defaults to None.
accept_failures (bool, optional): Determines whether the server should accept failures during training or
evaluation from clients or not. If set to False, this will cause the server to shutdown all clients
and throw an exception. Defaults to True.
"""
super().__init__(
client_manager=client_manager,
strategy=strategy,
checkpointer=checkpointer,
reporters=reporters,
accept_failures=accept_failures,
)

# Ensure that one of local_epochs and local_steps is passed (and not both)
Expand Down
Loading

0 comments on commit 986350c

Please sign in to comment.