Skip to content

Commit

Permalink
Merge branch 'main' into dbe/migrate_some_mypy_typing
Browse files Browse the repository at this point in the history
  • Loading branch information
emersodb committed Jan 8, 2025
2 parents 50c6530 + 44ba556 commit 5514652
Show file tree
Hide file tree
Showing 13 changed files with 27 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
- id: isort

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.14.0
rev: v1.14.1
hooks:
- id: mypy
name: mypy
Expand Down
5 changes: 3 additions & 2 deletions fl4health/checkpointing/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: dict[str, Sca
def load_checkpoint(self, path_to_checkpoint: str | None = None) -> nn.Module:
"""
Checkpointer with the option to either specify a checkpoint path or fall back on the internal path of the
checkpointer
checkpointer. The flexibility to specify a load path is useful, for example, if you are not overwriting
checkpoints when saving and need to load a specific past checkpoint for whatever reason.
Args:
path_to_checkpoint (str | None, optional): If provided, the checkpoint will be loaded from this path.
If not specified, the checkpointer will load from self.checkpoint_path. Defaults to None.
Returns:
nn.Module: _description_
nn.Module: Returns a torch module loaded from the proper checkpoint path.
"""
if path_to_checkpoint is None:
return torch.load(self.checkpoint_path)
Expand Down
29 changes: 12 additions & 17 deletions fl4health/checkpointing/server_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def save_state(
"""
if self.state_checkpointer is not None:
self._hydrate_model_for_checkpointing(server_parameters)
if "model" not in other_state:
other_state["model"] = self.model
else:
if "model" in other_state:
raise ValueError("Key 'model' already exists in the other_state dictionary.")
self.state_checkpointer.save_checkpoint(state_checkpoint_name, checkpoint_dict=other_state)

checkpoint_dict = other_state | {"model": self.model}
self.state_checkpointer.save_checkpoint(state_checkpoint_name, checkpoint_dict=checkpoint_dict)
else:
raise ValueError("Attempting to save state but no state checkpointer is specified")

Expand Down Expand Up @@ -524,20 +524,15 @@ def __init__(
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
checkpointer will save much more than just the model being trained. Defaults to None.
"""
self.model = model
self.parameter_exchanger = parameter_exchanger
self.model_checkpointers = (
[model_checkpointers] if isinstance(model_checkpointers, TorchModuleCheckpointer) else model_checkpointers
super().__init__(model, parameter_exchanger, model_checkpointers, state_checkpointer)

def _validate_model_checkpointer_components(self) -> None:
# NOTE: We only check if the parameter exchanger is present. Model may be set later.
assert self.parameter_exchanger is not None, (
"Checkpointer(s) is (are) defined but no parameter_exchanger is defined to hydrate. The functionality of "
"this class can be overridden in a child class if checkpointing without a parameter exchanger is "
"possible and desired"
)
self.state_checkpointer = state_checkpointer
if self.model_checkpointers is not None and len(self.model_checkpointers):
# NOTE: We only check if the parameter exchanger is present. Model may be set later.
assert self.parameter_exchanger is not None, (
"Checkpointer(s) is (are) defined but no parameter_exchanger is defined to hydrate. The functionality "
"of this class can be overridden in a child class if checkpointing without a parameter exchanger is "
"possible and desired"
)
self._check_if_shared_checkpoint_names()


class DpScaffoldServerCheckpointAndStateModule(ScaffoldServerCheckpointAndStateModule):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
send data to before and after each round. Defaults to None.
checkpoint_and_state_module (AdaptiveConstraintServerCheckpointAndStateModule | None, optional): This
module is used to handle both model checkpointing and state checkpointing. The former is aimed at
saving model artifacts to be used or evaluated after training. The later is used to preserve training
saving model artifacts to be used or evaluated after training. The latter is used to preserve training
state (including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
NOTE: For Ditto, the model shared with the server is the GLOBAL MODEL, which isn't the target of FL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
should send data to before and after each round. Defaults to None.
checkpoint_and_state_module (AdaptiveConstraintServerCheckpointAndStateModule | None, optional): This
module is used to handle both model checkpointing and state checkpointing. The former is aimed at
saving model artifacts to be used or evaluated after training. The later is used to preserve training
saving model artifacts to be used or evaluated after training. The latter is used to preserve training
state (including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
send data to before and after each round. Defaults to None.
checkpoint_and_state_module (AdaptiveConstraintServerCheckpointAndStateModule | None, optional): This
module is used to handle both model checkpointing and state checkpointing. The former is aimed at
saving model artifacts to be used or evaluated after training. The later is used to preserve training
saving model artifacts to be used or evaluated after training. The latter is used to preserve training
state (including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
NOTE: For MR-MTL, the server model is an aggregation of the personal models, which isn't the target of
Expand Down
2 changes: 1 addition & 1 deletion fl4health/servers/base_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
should send data to before and after each round. Defaults to None.
checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used
to handle both model checkpointing and state checkpointing. The former is aimed at saving model
artifacts to be used or evaluated after training. The later is used to preserve training state
artifacts to be used or evaluated after training. The latter is used to preserve training state
(including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to
Expand Down
2 changes: 1 addition & 1 deletion fl4health/servers/client_level_dp_fed_avg_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
send data to before and after each round.
checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used
to handle both model checkpointing and state checkpointing. The former is aimed at saving model
artifacts to be used or evaluated after training. The later is used to preserve training state
artifacts to be used or evaluated after training. The latter is used to preserve training state
(including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
delta (float | None, optional): The delta value for epsilon-delta DP accounting. If None it defaults to
Expand Down
2 changes: 1 addition & 1 deletion fl4health/servers/fedpm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
should send data to before and after each round.
checkpoint_and_state_module (LayerNamesServerCheckpointAndStateModule | None, optional): This module is
used to handle both model checkpointing and state checkpointing. The former is aimed at saving model
artifacts to be used or evaluated after training. The later is used to preserve training state
artifacts to be used or evaluated after training. The latter is used to preserve training state
(including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to
Expand Down
2 changes: 1 addition & 1 deletion fl4health/servers/instance_level_dp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
Defaults to None.
checkpoint_and_state_module (OpacusServerCheckpointAndStateModule | None, optional): This module is used
to handle both model checkpointing and state checkpointing. The former is aimed at saving model
artifacts to be used or evaluated after training. The later is used to preserve training state
artifacts to be used or evaluated after training. The latter is used to preserve training state
(including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health
Expand Down
2 changes: 1 addition & 1 deletion fl4health/servers/nnunet_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
should send data to. Defaults to None.
checkpoint_and_state_module (NnUnetServerCheckpointAndStateModule | None, optional): This module is used
to handle both model checkpointing and state checkpointing. The former is aimed at saving model
artifacts to be used or evaluated after training. The later is used to preserve training state
artifacts to be used or evaluated after training. The latter is used to preserve training state
(including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
NOTE: For NnUnet, this module is allowed to have all components defined other than the model, as it
Expand Down
4 changes: 2 additions & 2 deletions fl4health/servers/scaffold_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
should send data to before and after each round. Defaults to None.
checkpoint_and_state_module (ScaffoldServerCheckpointAndStateModule | None, optional): This module is used
to handle both model checkpointing and state checkpointing. The former is aimed at saving model
artifacts to be used or evaluated after training. The later is used to preserve training state
artifacts to be used or evaluated after training. The latter is used to preserve training state
(including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to
Expand Down Expand Up @@ -224,7 +224,7 @@ def __init__(
type.
checkpoint_and_state_module (DpScaffoldServerCheckpointAndStateModule | None, optional): This module is
used to handle both model checkpointing and state checkpointing. The former is aimed at saving model
artifacts to be used or evaluated after training. The later is used to preserve training state
artifacts to be used or evaluated after training. The latter is used to preserve training state
(including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
warm_start (bool, optional): Whether or not to initialize control variates of each client as local
Expand Down
2 changes: 1 addition & 1 deletion fl4health/servers/tabular_feature_alignment_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
should send data to before and after each round. Defaults to None.
checkpoint_and_state_module (BaseServerCheckpointAndStateModule | None, optional): This module is used
to handle both model checkpointing and state checkpointing. The former is aimed at saving model
artifacts to be used or evaluated after training. The later is used to preserve training state
artifacts to be used or evaluated after training. The latter is used to preserve training state
(including models) such that if FL training is interrupted, the process may be restarted. If no
module is provided, no checkpointing or state preservation will happen. Defaults to None.
on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to
Expand Down

0 comments on commit 5514652

Please sign in to comment.