Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to Support Expanded Experimentation with FedDG-GA #252

Merged
merged 9 commits into from
Nov 4, 2024

Conversation

emersodb
Copy link
Collaborator

PR Type

Feature/Experimentation.

Short Description

The changes in this PR are targeted at enabling expanded experimentation with the FedDG-GA strategy for a wider set of FL approaches. Also included is a name change, FedDgGaStrategy -> FedDgGa to fit other strategy formats a bit better and renaming of the associated file from feddg_ga_strategy.py to feddg_ga.py as it is already under the strategies folder.

For FENDA+Ditto, there is also a bug fix moving from SequentiallySplitExchangeBaseModel to SequentiallySplitModel. In this setting we want to exchange the whole Ditto model, not just the feature extractor component. It wasn't causing a real bug, as a FullParameterExchanger was being used anyway, but the typing was dissonant.

Finally, I moved some "client agnostic" functionality out of basic client and into a utils file to help trim a few functions from the BasicClient class.

Tests Added

Added a test for the new Fed DG-GA strategy that is compatible with adaptive constraint server-client pairs

from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType, TrainingLosses
from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, Metric, MetricManager
from fl4health.utils.random import generate_hash
from fl4health.utils.typing import LogLevel, TorchFeatureType, TorchInputType, TorchPredType, TorchTargetType


class LoggingMode(Enum):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this into its own file, utils.logging.py. In future work, we probably want to create a logging module to abstract some of the logging components that currently reside in the BasicClient anyway. So this is a very small step in that direction.

@@ -247,8 +240,13 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N
except ValueError:
evaluate_after_fit = False

try:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will control whether the contents of the loss dictionary should be packed into the metrics for communication with the server or not. Packing the losses makes the full set of additional losses available to the server, which is essentially for FedDG-GA to work for an expanded set of FL techniques. However, it also means that the server would be aware of these which is, perhaps, advantageous in other settings as well.

@@ -326,21 +326,6 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict
metrics,
)

def evaluate_after_fit(self) -> Tuple[float, Dict[str, Scalar]]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is being torched in favor of just using validate with pack_losses_with_val_metrics = True. We were packing the loss here specifically to facilitate FedDG-GA.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible in the future that we'll want to reinstate this function to allow it to be overridden in upper classes, but for now I think it can be dropped

@@ -536,69 +527,6 @@ def get_client_specific_reports(self) -> Dict[str, Any]:
"""
return {}

def _move_data_to_device(
Copy link
Collaborator Author

@emersodb emersodb Oct 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty generic function. Moved it to utils/client.py and made a slight typing improvement. Let me know if you disagree with the change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

two"
)

def is_empty_batch(self, input: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> bool:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty generic function. Moved it to utils/client.py. Let me know if you disagree with the change.

@@ -818,9 +749,12 @@ def _validate_or_test(
metrics = metric_manager.compute()
self._log_results(loss_dict, metrics, logging_mode=logging_mode)

if include_losses_in_metrics:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where we inject the loss_dict results into the metrics during validation.

@@ -1034,24 +977,6 @@ def set_optimizer(self, config: Config) -> None:
assert not isinstance(optimizer, dict)
self.optimizers = {"global": optimizer}

def clone_and_freeze_model(self, model: nn.Module) -> nn.Module:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty generic function. Moved it to utils/client.py. Let me know if you disagree with the change.

@@ -1241,35 +1166,6 @@ def update_before_epoch(self, epoch: int) -> None:
"""
pass

def maybe_progress_bar(self, iterable: Iterable) -> Iterable:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty generic function. Moved it to utils/client.py. Let me know if you disagree with the change.

@@ -9,7 +9,7 @@
from fl4health.checkpointing.client_module import ClientCheckpointModule
from fl4health.clients.ditto_client import DittoClient
from fl4health.model_bases.fenda_base import FendaModel
from fl4health.model_bases.sequential_split_models import SequentiallySplitExchangeBaseModel
from fl4health.model_bases.sequential_split_models import SequentiallySplitModel
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While it didn't actually change the flow of the example, since it uses a FullParameterExchanger this type is misleading, given that it's mean to facilitate only exchanging the feature extraction module.

@@ -23,7 +30,7 @@ class FairnessMetricType(Enum):
"""Defines the basic types for fairness metrics, their default names and their default signals"""

ACCURACY = "val - prediction - accuracy"
LOSS = "val - loss"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing to checkpoint, as this is the name of the vanilla loss for basic clients during validation.

# Setting self.num_rounds once and doing some sanity checks
assert self.on_fit_config_fn is not None, "on_fit_config_fn must be specified"
config = self.on_fit_config_fn(server_round)
assert "evaluate_after_fit" in config, "evaluate_after_fit must be present in config"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FedDG-GA requires evaluate_after_fit to be present and true. Similarly, it requires pack_losses_with_val_metrics to be present and true to allow the server to see the measured validation loss. So when using the strategy we check for them and throw if not present.


assert self.on_evaluate_config_fn is not None, "on_fit_config_fn must be specified"
config = self.on_evaluate_config_fn(server_round)
assert "pack_losses_with_val_metrics" in config, "pack_losses_with_val_metrics must be present in config"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During validation, we need the server to have access to the losses. So pack_losses_with_val_metrics must be present and true for FedDG-GA to work properly.

fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
elif server_round == 1: # Only log this warning once
log(WARNING, "No fit_metrics_aggregation_fn provided")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than relying on the original aggregation approach and discarding the weights aggregation, which is a bit of a wasted set of calculations, just do the metrics aggregation here, since that is all we want anyway.

self.evaluation_metrics[cid] = eval_res.metrics
# adding the loss to the metrics
val_loss_key = FairnessMetricType.LOSS.value
self.evaluation_metrics[cid][val_loss_key] = eval_res.loss
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to force the loss into the dictionary here. It should be packed with the metrics. If it's not, something else has gone wrong.

@@ -0,0 +1,268 @@
from logging import INFO, WARNING
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class is needed in order to facilitate FedDG-GA for clients that also require adaptive constraint considerations. These are clients such as Ditto, MR-MTL, and FedProx. This strategy needs to coordinate the exchange of the loss weights in addition to the Generalization Adjustments for aggregation

updated_weights, train_loss = self.parameter_packer.unpack_parameters(
parameters_to_ndarrays(fit_res.parameters)
)
# Modify the parameters in-place to just be the model weights.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the training losses are packed with the weights, we need to extract the losses and jam the weights back into the FitRes parameters object so that FedDG-GA can happen unabated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted this as it seemed to be a carbon copy of the BasicExample readme? Perhaps that's not true. Just let me know

@emersodb emersodb marked this pull request as ready for review October 10, 2024 15:35
@emersodb emersodb requested a review from lotif October 10, 2024 15:52
Copy link
Collaborator

@jewelltaylor jewelltaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks pretty much good to go for me! I will wait till you have a chance to take a look at the minor comments I made and take one last pass tomorrow

@emersodb emersodb requested a review from jewelltaylor November 1, 2024 16:00
Copy link
Collaborator

@jewelltaylor jewelltaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look good to me!

@emersodb emersodb merged commit 3e237cf into main Nov 4, 2024
6 checks passed
@emersodb emersodb deleted the dbe/support_changes_for_feddgga branch November 4, 2024 13:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants