-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding the ability to specify a checkpointing for models **before** aggregation. #128
Conversation
…g functionality to allow for pre- and post-aggregation (server-side) checkpointing. Also allowing for more generic checkpointing functionality. That is, given a loss value and a dictionary of metrics, users can define an arbitrary scoring function on those objects to produce a checkpoint. The best loss checkpointer is a specific instantiation of this type of checkpointer.
checkpoint_dir, checkpoint_name, checkpoint_score_function=loss_score_function, maximize=False | ||
) | ||
|
||
def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Scalar]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, I'm overriding this method to replace the logging with something more specific. If anyone has any better ideas on how to do this, let me know.
@@ -631,8 +657,6 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]: | |||
metrics = self.val_metric_manager.compute() | |||
self._handle_logging(loss_dict, metrics, is_validation=True) | |||
|
|||
# Checkpoint based on loss which is output of user defined compute_loss method | |||
self._maybe_checkpoint(loss_dict["checkpoint"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that I moved this into the evaluate function rather than the validate function, as we may not want to checkpoint every time we validate etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Checkpoint score functions, and pre-aggregation vs post-aggregation checkpointing are nice additions towards flexibility.
Note: There are a scary number of files changed because this affects the functionality of BasicClient in a way that requires slight modifications of many of the inheriting classes. Most of the changes should be quite minor in those downstream classes.
PR Type
Feature
Clickup Ticket(s):
Refactoring the client-side checkpointing functionality to allow for pre- and post-aggregation (server-side) checkpointing. That is, the user can specify a checkpointer for models after the weights have been aggregated by the server (still supported, but was the only default option), or prior to this happening (i.e. right after training). If either of the checkpointers are not specified, then checkpointing at that time is skipped. Note, this allows us to mimic "fine-tuning" of models. That is, the models final training is exclusively local.
For example, with FedProx, the model checkpointed would actually be a "personal" model for each client (potentially after several preceding rounds of aggregation).
Also allowing for more generic checkpointing functionality. That is, given a loss value and a dictionary of metrics, users can define an arbitrary scoring function on those objects to produce a checkpoint. The best loss checkpointer is a specific instantiation of this type of checkpointer.
Tests Added
Added some new tests to cover the new ClientSideCheckpointModule functionality and also the new functionality associated with the TorchCheckpointer child classes.