-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes to Support Expanded Experimentation with FedDG-GA #252
Conversation
from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType, TrainingLosses | ||
from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, Metric, MetricManager | ||
from fl4health.utils.random import generate_hash | ||
from fl4health.utils.typing import LogLevel, TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType | ||
|
||
|
||
class LoggingMode(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this into its own file, utils.logging.py. In future work, we probably want to create a logging module to abstract some of the logging components that currently reside in the BasicClient anyway. So this is a very small step in that direction.
fl4health/clients/basic_client.py
Outdated
@@ -247,8 +240,13 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N | |||
except ValueError: | |||
evaluate_after_fit = False | |||
|
|||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will control whether the contents of the loss dictionary should be packed into the metrics for communication with the server or not. Packing the losses makes the full set of additional losses available to the server, which is essentially for FedDG-GA to work for an expanded set of FL techniques. However, it also means that the server would be aware of these which is, perhaps, advantageous in other settings as well.
fl4health/clients/basic_client.py
Outdated
@@ -326,21 +326,6 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict | |||
metrics, | |||
) | |||
|
|||
def evaluate_after_fit(self) -> Tuple[float, Dict[str, Scalar]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is being torched in favor of just using validate with pack_losses_with_val_metrics = True
. We were packing the loss here specifically to facilitate FedDG-GA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible in the future that we'll want to reinstate this function to allow it to be overridden in upper classes, but for now I think it can be dropped
@@ -536,69 +527,6 @@ def get_client_specific_reports(self) -> Dict[str, Any]: | |||
""" | |||
return {} | |||
|
|||
def _move_data_to_device( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty generic function. Moved it to utils/client.py
and made a slight typing improvement. Let me know if you disagree with the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
fl4health/clients/basic_client.py
Outdated
two" | ||
) | ||
|
||
def is_empty_batch(self, input: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty generic function. Moved it to utils/client.py
. Let me know if you disagree with the change.
@@ -818,9 +749,12 @@ def _validate_or_test( | |||
metrics = metric_manager.compute() | |||
self._log_results(loss_dict, metrics, logging_mode=logging_mode) | |||
|
|||
if include_losses_in_metrics: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where we inject the loss_dict
results into the metrics during validation.
@@ -1034,24 +977,6 @@ def set_optimizer(self, config: Config) -> None: | |||
assert not isinstance(optimizer, dict) | |||
self.optimizers = {"global": optimizer} | |||
|
|||
def clone_and_freeze_model(self, model: nn.Module) -> nn.Module: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty generic function. Moved it to utils/client.py. Let me know if you disagree with the change.
@@ -1241,35 +1166,6 @@ def update_before_epoch(self, epoch: int) -> None: | |||
""" | |||
pass | |||
|
|||
def maybe_progress_bar(self, iterable: Iterable) -> Iterable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty generic function. Moved it to utils/client.py. Let me know if you disagree with the change.
@@ -9,7 +9,7 @@ | |||
from fl4health.checkpointing.client_module import ClientCheckpointModule | |||
from fl4health.clients.ditto_client import DittoClient | |||
from fl4health.model_bases.fenda_base import FendaModel | |||
from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel | |||
from fl4health.model_bases.sequential_split_models import SequentiallySplitModel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While it didn't actually change the flow of the example, since it uses a FullParameterExchanger
this type is misleading, given that it's mean to facilitate only exchanging the feature extraction module.
@@ -23,7 +30,7 @@ class FairnessMetricType(Enum): | |||
"""Defines the basic types for fairness metrics, their default names and their default signals""" | |||
|
|||
ACCURACY = "val - prediction - accuracy" | |||
LOSS = "val - loss" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing to checkpoint, as this is the name of the vanilla loss for basic clients during validation.
# Setting self.num_rounds once and doing some sanity checks | ||
assert self.on_fit_config_fn is not None, "on_fit_config_fn must be specified" | ||
config = self.on_fit_config_fn(server_round) | ||
assert "evaluate_after_fit" in config, "evaluate_after_fit must be present in config" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FedDG-GA requires evaluate_after_fit
to be present and true. Similarly, it requires pack_losses_with_val_metrics
to be present and true to allow the server to see the measured validation loss. So when using the strategy we check for them and throw if not present.
|
||
assert self.on_evaluate_config_fn is not None, "on_fit_config_fn must be specified" | ||
config = self.on_evaluate_config_fn(server_round) | ||
assert "pack_losses_with_val_metrics" in config, "pack_losses_with_val_metrics must be present in config" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During validation, we need the server to have access to the losses. So pack_losses_with_val_metrics
must be present and true for FedDG-GA to work properly.
fit_metrics = [(res.num_examples, res.metrics) for _, res in results] | ||
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics) | ||
elif server_round == 1: # Only log this warning once | ||
log(WARNING, "No fit_metrics_aggregation_fn provided") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than relying on the original aggregation approach and discarding the weights aggregation, which is a bit of a wasted set of calculations, just do the metrics aggregation here, since that is all we want anyway.
self.evaluation_metrics[cid] = eval_res.metrics | ||
# adding the loss to the metrics | ||
val_loss_key = FairnessMetricType.LOSS.value | ||
self.evaluation_metrics[cid][val_loss_key] = eval_res.loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to force the loss into the dictionary here. It should be packed with the metrics. If it's not, something else has gone wrong.
@@ -0,0 +1,268 @@ | |||
from logging import INFO, WARNING |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class is needed in order to facilitate FedDG-GA for clients that also require adaptive constraint considerations. These are clients such as Ditto, MR-MTL, and FedProx. This strategy needs to coordinate the exchange of the loss weights in addition to the Generalization Adjustments for aggregation
updated_weights, train_loss = self.parameter_packer.unpack_parameters( | ||
parameters_to_ndarrays(fit_res.parameters) | ||
) | ||
# Modify the parameters in-place to just be the model weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the training losses are packed with the weights, we need to extract the losses and jam the weights back into the FitRes parameters object so that FedDG-GA can happen unabated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deleted this as it seemed to be a carbon copy of the BasicExample readme? Perhaps that's not true. Just let me know
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything looks pretty much good to go for me! I will wait till you have a chance to take a look at the minor comments I made and take one last pass tomorrow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look good to me!
PR Type
Feature/Experimentation.
Short Description
The changes in this PR are targeted at enabling expanded experimentation with the FedDG-GA strategy for a wider set of FL approaches. Also included is a name change,
FedDgGaStrategy
->FedDgGa
to fit other strategy formats a bit better and renaming of the associated file fromfeddg_ga_strategy.py
tofeddg_ga.py
as it is already under the strategies folder.For FENDA+Ditto, there is also a bug fix moving from
SequentiallySplitExchangeBaseModel
toSequentiallySplitModel
. In this setting we want to exchange the whole Ditto model, not just the feature extractor component. It wasn't causing a real bug, as a FullParameterExchanger was being used anyway, but the typing was dissonant.Finally, I moved some "client agnostic" functionality out of basic client and into a utils file to help trim a few functions from the
BasicClient
class.Tests Added
Added a test for the new Fed DG-GA strategy that is compatible with adaptive constraint server-client pairs