Skip to content

Commit

Permalink
Renaming the type aliases to be better type representations
Browse files Browse the repository at this point in the history
  • Loading branch information
emersodb committed Jan 6, 2025
1 parent b2227e4 commit 9732abd
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 26 deletions.
10 changes: 5 additions & 5 deletions fl4health/checkpointing/client_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer, TorchModuleCheckpointer

CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None
ModelCheckpointers = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None


class CheckpointMode(Enum):
Expand All @@ -19,8 +19,8 @@ class CheckpointMode(Enum):
class ClientCheckpointAndStateModule:
def __init__(
self,
pre_aggregation: CheckpointModuleInput = None,
post_aggregation: CheckpointModuleInput = None,
pre_aggregation: ModelCheckpointers = None,
post_aggregation: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -37,10 +37,10 @@ def __init__(
That's because the target model for these methods is never globally aggregated. That is, they remain local
Args:
pre_aggregation (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
pre_aggregation (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their validation metrics/losses **BEFORE**
server-side aggregation. Defaults to None.
post_aggregation (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence
post_aggregation (ModelCheckpointers, optional): If defined, this checkpointer (or sequence
of checkpointers) is used to checkpoint models based on their validation metrics/losses **AFTER**
server-side aggregation. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer is used to
Expand Down
42 changes: 21 additions & 21 deletions fl4health/checkpointing/server_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
SparseCooParameterPacker,
)

CheckpointModuleInput = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None
ModelCheckpointers = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None


class BaseServerCheckpointAndStateModule:
def __init__(
self,
model: nn.Module | None = None,
parameter_exchanger: ExchangerType | None = None,
model_checkpointers: CheckpointModuleInput = None,
model_checkpointers: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -48,7 +48,7 @@ def __init__(
server parameters into the right components of the provided model architecture. Note that this
exchanger and the model must match the one used for training and exchange with the servers to ensure
parameters go to the right places. Defaults to None.
model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
Expand Down Expand Up @@ -206,7 +206,7 @@ def __init__(
self,
model: nn.Module | None = None,
parameter_exchanger: FullParameterExchangerWithPacking | None = None,
model_checkpointers: CheckpointModuleInput = None,
model_checkpointers: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -227,7 +227,7 @@ def __init__(
should handle any necessary unpacking of the parameters. Note that this exchanger and the model must
match the one used for training and exchange with the servers to ensure parameters go to the right
places. Defaults to None.
model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
Expand Down Expand Up @@ -270,7 +270,7 @@ class ScaffoldServerCheckpointAndStateModule(PackingServerCheckpointAndAndStateM
def __init__(
self,
model: nn.Module | None = None,
model_checkpointers: CheckpointModuleInput = None,
model_checkpointers: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -287,7 +287,7 @@ def __init__(
to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger.
Recall that servers only have parameters rather than torch models. So we need to know where to route
these parameters to allow for real models to be saved. Defaults to None.
model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
Expand All @@ -305,7 +305,7 @@ class AdaptiveConstraintServerCheckpointAndStateModule(PackingServerCheckpointAn
def __init__(
self,
model: nn.Module | None = None,
model_checkpointers: CheckpointModuleInput = None,
model_checkpointers: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -322,7 +322,7 @@ def __init__(
to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger.
Recall that servers only have parameters rather than torch models. So we need to know where to route
these parameters to allow for real models to be saved. Defaults to None.
model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
Expand All @@ -339,7 +339,7 @@ class ClippingBitServerCheckpointAndStateModule(PackingServerCheckpointAndAndSta
def __init__(
self,
model: nn.Module | None = None,
model_checkpointers: CheckpointModuleInput = None,
model_checkpointers: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -356,7 +356,7 @@ def __init__(
to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger.
Recall that servers only have parameters rather than torch models. So we need to know where to route
these parameters to allow for real models to be saved. Defaults to None.
model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
Expand All @@ -373,7 +373,7 @@ class LayerNamesServerCheckpointAndStateModule(PackingServerCheckpointAndAndStat
def __init__(
self,
model: nn.Module | None = None,
model_checkpointers: CheckpointModuleInput = None,
model_checkpointers: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -390,7 +390,7 @@ def __init__(
to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger.
Recall that servers only have parameters rather than torch models. So we need to know where to route
these parameters to allow for real models to be saved. Defaults to None.
model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
Expand All @@ -407,7 +407,7 @@ class SparseCooServerCheckpointAndStateModule(PackingServerCheckpointAndAndState
def __init__(
self,
model: nn.Module | None = None,
model_checkpointers: CheckpointModuleInput = None,
model_checkpointers: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -425,7 +425,7 @@ def __init__(
to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger.
Recall that servers only have parameters rather than torch models. So we need to know where to route
these parameters to allow for real models to be saved. Defaults to None.
model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
Expand All @@ -443,7 +443,7 @@ def __init__(
self,
model: nn.Module | None = None,
parameter_exchanger: ExchangerType | None = None,
model_checkpointers: CheckpointModuleInput = None,
model_checkpointers: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -465,7 +465,7 @@ def __init__(
server parameters into the right components of the provided model architecture. Note that this
exchanger and the model must match the one used for training and exchange with the servers to ensure
parameters go to the right places. Defaults to None.
model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
Expand All @@ -491,7 +491,7 @@ def __init__(
self,
model: nn.Module | None = None,
parameter_exchanger: ExchangerType | None = None,
model_checkpointers: CheckpointModuleInput = None,
model_checkpointers: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -517,7 +517,7 @@ def __init__(
server parameters into the right components of the provided model architecture. Note that this
exchanger and the model must match the one used for training and exchange with the servers to ensure
parameters go to the right places. Defaults to None.
model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
Expand All @@ -543,7 +543,7 @@ class DpScaffoldServerCheckpointAndStateModule(ScaffoldServerCheckpointAndStateM
def __init__(
self,
model: nn.Module | None = None,
model_checkpointers: CheckpointModuleInput = None,
model_checkpointers: ModelCheckpointers = None,
state_checkpointer: PerRoundStateCheckpointer | None = None,
) -> None:
"""
Expand All @@ -560,7 +560,7 @@ def __init__(
to hold the server parameters and facilitate checkpointing with the help of the parameter exchanger.
Recall that servers only have parameters rather than torch models. So we need to know where to route
these parameters to allow for real models to be saved. Defaults to None.
model_checkpointers (CheckpointModuleInput, optional): If defined, this checkpointer (or sequence of
model_checkpointers (ModelCheckpointers, optional): If defined, this checkpointer (or sequence of
checkpointers) is used to checkpoint models based on their defined scoring function. Defaults to None.
state_checkpointer (PerRoundStateCheckpointer | None, optional): If defined, this checkpointer will be
used to preserve FL training state to facilitate restarting training if interrupted. Generally, this
Expand Down

0 comments on commit 9732abd

Please sign in to comment.