Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the ability to specify a checkpointing for models **before** aggregation. #128

Merged
merged 5 commits into from
May 8, 2024

Conversation

emersodb
Copy link
Collaborator

Note: There are a scary number of files changed because this affects the functionality of BasicClient in a way that requires slight modifications of many of the inheriting classes. Most of the changes should be quite minor in those downstream classes.

PR Type

Feature

Clickup Ticket(s):

  1. https://app.clickup.com/t/86880z10x
  2. https://app.clickup.com/t/860r0e3a0

Refactoring the client-side checkpointing functionality to allow for pre- and post-aggregation (server-side) checkpointing. That is, the user can specify a checkpointer for models after the weights have been aggregated by the server (still supported, but was the only default option), or prior to this happening (i.e. right after training). If either of the checkpointers are not specified, then checkpointing at that time is skipped. Note, this allows us to mimic "fine-tuning" of models. That is, the models final training is exclusively local.

For example, with FedProx, the model checkpointed would actually be a "personal" model for each client (potentially after several preceding rounds of aggregation).

Also allowing for more generic checkpointing functionality. That is, given a loss value and a dictionary of metrics, users can define an arbitrary scoring function on those objects to produce a checkpoint. The best loss checkpointer is a specific instantiation of this type of checkpointer.

Tests Added

Added some new tests to cover the new ClientSideCheckpointModule functionality and also the new functionality associated with the TorchCheckpointer child classes.

…g functionality to allow for pre- and post-aggregation (server-side) checkpointing. Also allowing for more generic checkpointing functionality. That is, given a loss value and a dictionary of metrics, users can define an arbitrary scoring function on those objects to produce a checkpoint. The best loss checkpointer is a specific instantiation of this type of checkpointer.
checkpoint_dir, checkpoint_name, checkpoint_score_function=loss_score_function, maximize=False
)

def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Scalar]) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, I'm overriding this method to replace the logging with something more specific. If anyone has any better ideas on how to do this, let me know.

@@ -631,8 +657,6 @@ def validate(self) -> Tuple[float, Dict[str, Scalar]]:
metrics = self.val_metric_manager.compute()
self._handle_logging(loss_dict, metrics, is_validation=True)

# Checkpoint based on loss which is output of user defined compute_loss method
self._maybe_checkpoint(loss_dict["checkpoint"])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that I moved this into the evaluate function rather than the validate function, as we may not want to checkpoint every time we validate etc.

Copy link
Collaborator

@fatemetkl fatemetkl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Checkpoint score functions, and pre-aggregation vs post-aggregation checkpointing are nice additions towards flexibility.

Base automatically changed from dbe/implement_fed_rep to main May 8, 2024 17:54
@emersodb emersodb merged commit 95a60c8 into main May 8, 2024
6 checks passed
@emersodb emersodb deleted the dbe/adding_checkpointer_post_train branch May 8, 2024 19:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants