diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index f66f4ec97..9908b4597 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -114,6 +114,7 @@ def __init__( self.initial_weights: Optional[NDArrays] = None self.total_steps: int = 0 # Need to track total_steps across rounds for WANDB reporting + self.total_epochs: int = 0 # Will remain as 0 if training by steps # Attributes to be initialized in setup_client self.parameter_exchanger: ParameterExchanger @@ -221,9 +222,9 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N config (Config): The config from the server. Returns: - Tuple[Union[int, None], Union[int, None], int, bool]: Returns the local_epochs, local_steps, - current_server_round and evaluate_after_fit. Ensures only one of local_epochs and local_steps - is defined in the config and sets the one that is not to None. + Tuple[Union[int, None], Union[int, None], int, bool, bool]: Returns the local_epochs, local_steps, + current_server_round, evaluate_after_fit and pack_losses_with_val_metrics. Ensures only one of + local_epochs and local_steps is defined in the config and sets the one that is not to None. Raises: ValueError: If the config contains both local_steps and local epochs or if local_steps, local_epochs or @@ -307,15 +308,24 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict # We perform a pre-aggregation checkpoint if applicable self._maybe_checkpoint(validation_loss, validation_metrics, CheckpointMode.PRE_AGGREGATION) + # Notes on report values: + # - Train by steps: round metrics/losses are computed using all samples from the round + # - Train by epochs: round metrics/losses computed using only the samples from the final epoch of the round + # - fit_round_metrics: Computed at the end of the round on the samples directly + # - fit_round_losses: The average of the losses computed for each step. + # * (Hence likely higher than the final loss of the round.) self.reports_manager.report( { - "fit_metrics": metrics, - "fit_losses": loss_dict, + "fit_round_metrics": metrics, + "fit_round_losses": loss_dict, "round": current_server_round, "round_start": str(round_start_time), "round_end": str(datetime.datetime.now()), - "fit_start": str(fit_start_time), - "fit_end": str(fit_end_time), + "fit_round_start": str(fit_start_time), + "fit_round_time_elapsed": str(fit_end_time - fit_start_time), + "fit_round_end": str(fit_end_time), + "fit_step": self.total_steps, + "fit_epoch": self.total_epochs, }, current_server_round, ) @@ -364,11 +374,14 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di self.reports_manager.report( { - "eval_metrics": metrics, - "eval_loss": loss, - "eval_start": str(start_time), - "eval_time_elapsed": str(elapsed), - "eval_end": str(end_time), + "eval_round_metrics": metrics, + "eval_round_loss": loss, + "eval_round_start": str(start_time), + "eval_round_time_elapsed": str(elapsed), + "eval_round_end": str(end_time), + "fit_step": self.total_steps, + "fit_epoch": self.total_epochs, + "round": current_server_round, }, current_server_round, ) @@ -607,7 +620,7 @@ def train_by_epochs( # update before epoch hook self.update_before_epoch(epoch=local_epoch) # Update report data dict - report_data.update({"fit_epoch": local_epoch}) + report_data.update({"fit_epoch": self.total_epochs}) for input, target in maybe_progress_bar(self.train_loader, self.progress_bar): self.update_before_step(steps_this_round, current_round) # Assume first dimension is batch size. Sampling iterators (such as Poisson batch sampling), can @@ -623,20 +636,22 @@ def train_by_epochs( self.update_metric_manager(preds, target, self.train_metric_manager) self.update_after_step(steps_this_round, current_round) self.update_lr_schedulers(epoch=local_epoch) - report_data.update({"fit_losses": losses.as_dict(), "fit_step": self.total_steps}) + report_data.update({"fit_step_losses": losses.as_dict(), "fit_step": self.total_steps}) report_data.update(self.get_client_specific_reports()) - self.reports_manager.report(report_data, current_round, local_epoch, self.total_steps) + self.reports_manager.report(report_data, current_round, self.total_epochs, self.total_steps) self.total_steps += 1 steps_this_round += 1 + # Log and report results metrics = self.train_metric_manager.compute() loss_dict = self.train_loss_meter.compute().as_dict() - - # Log and report results - self._log_results(loss_dict, metrics, current_round, local_epoch) - report_data.update({"fit_metrics": metrics}) + report_data.update({"fit_epoch_metrics": metrics, "fit_epoch_losses": loss_dict}) report_data.update(self.get_client_specific_reports()) - self.reports_manager.report(report_data, current_round, local_epoch) + self.reports_manager.report(report_data, current_round, self.total_epochs) + self._log_results(loss_dict, metrics, current_round, local_epoch) + + # Update internal epoch counter + self.total_epochs += 1 # Return final training metrics return loss_dict, metrics @@ -690,7 +705,7 @@ def train_by_steps( self.update_metric_manager(preds, target, self.train_metric_manager) self.update_after_step(step, current_round) self.update_lr_schedulers(step=step) - report_data.update({"fit_losses": losses.as_dict(), "fit_step": self.total_steps}) + report_data.update({"fit_step_losses": losses.as_dict(), "fit_step": self.total_steps}) report_data.update(self.get_client_specific_reports()) self.reports_manager.report(report_data, current_round, None, self.total_steps) self.total_steps += 1 diff --git a/fl4health/clients/nnunet_client.py b/fl4health/clients/nnunet_client.py index f393e8174..734f6c913 100644 --- a/fl4health/clients/nnunet_client.py +++ b/fl4health/clients/nnunet_client.py @@ -52,6 +52,7 @@ from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, preprocess_dataset from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw from nnunetv2.training.dataloading.utils import unpack_dataset + from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name @@ -246,9 +247,15 @@ def get_model(self, config: Config) -> nn.Module: return self.nnunet_trainer.network def get_criterion(self, config: Config) -> _Loss: + if isinstance(self.nnunet_trainer.loss, DeepSupervisionWrapper): + self.reports_manager.report({"Criterion": self.nnunet_trainer.loss.loss.__class__.__name__}) + else: + self.reports_manager.report({"Criterion": self.nnunet_trainer.loss.__class__.__name__}) + return Module2LossWrapper(self.nnunet_trainer.loss) def get_optimizer(self, config: Config) -> Optimizer: + self.reports_manager.report({"Optimizer": self.nnunet_trainer.optimizer.__class__.__name__}) return self.nnunet_trainer.optimizer def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler: @@ -289,6 +296,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler: # Create and return LR Scheduler Wrapper for the PolyLRScheduler so that it is # compatible with Torch LRScheduler # Create and return LR Scheduler. This is nnunet default for version 2.5.1 + self.reports_manager.report({"LR Scheduler": "PolyLRScheduler"}) return PolyLRSchedulerWrapper( self.optimizers[optimizer_key], initial_lr=self.nnunet_trainer.initial_lr, @@ -685,6 +693,9 @@ def get_client_specific_logs( else: return "", [] + def get_client_specific_reports(self) -> Dict[str, Any]: + return {"learning_rate": float(self.optimizers["global"].param_groups[0]["lr"])} + @use_default_signal_handlers # Experiment planner spawns a process I think def get_properties(self, config: Config) -> Dict[str, Scalar]: """ diff --git a/fl4health/reporting/base_reporter.py b/fl4health/reporting/base_reporter.py index 74438973b..040021130 100644 --- a/fl4health/reporting/base_reporter.py +++ b/fl4health/reporting/base_reporter.py @@ -24,10 +24,12 @@ def report( data (dict): The data to maybe report from the server or client. round (int | None, optional): The current FL round. If None, this indicates that the method was called outside of a round (e.g. for summary information). Defaults to None. - epoch (int | None, optional): The current epoch. If None then this method was not called at or within the - scope of an epoch. Defaults to None. - step (int | None, optional): The current step (total). If None then this method was called outside the - scope of a training or evaluation step (eg. at the end of an epoch or round) Defaults to None. + epoch (int | None, optional): The current epoch (In total across all rounds). If None then this method was + not called at or within the scope of an epoch. Should always be None if training by steps. Defaults to + None. + step (int | None, optional): The current step (In total across all rounds and epochs). If None then this + method was called outside the scope of a training or evaluation step (eg. at the end of an epoch or + round) Defaults to None. """ raise NotImplementedError diff --git a/fl4health/reporting/wandb_reporter.py b/fl4health/reporting/wandb_reporter.py index 7a67170d2..edc6f231c 100644 --- a/fl4health/reporting/wandb_reporter.py +++ b/fl4health/reporting/wandb_reporter.py @@ -1,9 +1,11 @@ from enum import Enum +from logging import WARNING from pathlib import Path from typing import Any import wandb import wandb.wandb_run +from flwr.common.logger import log from fl4health.reporting.base_reporter import BaseReporter @@ -28,14 +30,14 @@ def __init__( tags: list[str] | None = None, name: str | None = None, id: str | None = None, - **kwargs: Any + **kwargs: Any, ) -> None: """ _summary_ Args: - wandb_step_type (StepType | str, optional): How frequently to log data. Either every 'round', 'epoch' or - 'step'. Defaults to StepType.ROUND. + wandb_step_type (StepType | str, optional): Whether to use the 'round', 'epoch' or 'step' as the + wandb_step value when logging information to the wandb server. project (str | None, optional): The name of the project where you're sending the new run. If unspecified, wandb will try to infer or set to "uncategorized" entity (str | None, optional): An entity is a username or team name where you're sending runs. This entity @@ -76,22 +78,69 @@ def __init__( self.name = name self.id = id + # Keep track of epoch and step. Initialize as 0. + self.current_epoch = 0 + self.current_step = 0 + # Initialize run later to avoid creating runs while debugging self.run: wandb.wandb_run.Run def initialize(self, **kwargs: Any) -> None: """Checks if an id was provided by the client or server. - If an id was passed to the WandBReporter on init then it takes priority over the - one passed by the client/server. + If an id was passed to the WandBReporter on init then it takes priority over the one passed by the + client/server. """ if self.id is None: self.id = kwargs.get("id") self.initialized = True + def define_metrics(self) -> None: + """This method defines some of the metrics we expect to see from Basic Client and server. + + Note that you do not have to define metrics, but it can be useful for determining what should and shouldn't go + into the run summary. + """ + # Note that the hidden argument is not working. Raised issue here: https://github.com/wandb/wandb/issues/8890 + # Round, epoch and step + self.run.define_metric("fit_step", summary="none", hidden=True) # Current fit step + self.run.define_metric("fit_epoch", summary="none", hidden=True) # Current fit epoch + self.run.define_metric("round", summary="none", hidden=True) # Current server round + self.run.define_metric("round_start", summary="none", hidden=True) + self.run.define_metric("round_end", summary="none", hidden=True) + # A server round contains a fit_round and maybe also an evaluate round + self.run.define_metric("fit_round_start", summary="none", hidden=True) + self.run.define_metric("fit_round_time_elapsed", summary="none", hidden=True) + self.run.define_metric("fit_round_end", summary="none", hidden=True) + self.run.define_metric("eval_round_start", summary="none", hidden=True) + self.run.define_metric("eval_round_time_elapsed", summary="none", hidden=True) + self.run.define_metric("eval_round_end", summary="none", hidden=True) + # The metrics computed on all the samples from the final epoch, or the entire round if training by steps + self.run.define_metric("fit_round_metrics", step_metric="round", summary="best") + self.run.define_metric("eval_round_metrics", step_metric="round", summary="best") + # Average of the losses for each step in the final epoch, or the entire round if training by steps. + self.run.define_metric("fit_round_losses", step_metric="round", summary="best", goal="minimize") + self.run.define_metric("eval_round_loss", step_metric="round", summary="best", goal="minimize") + # The metrics computed at the end of the epoch on all the samples from the epoch + self.run.define_metric("fit_round_metrics", step_metric="fit_epoch", summary="best") + # Average of the losses for each step in the epoch + self.run.define_metric("fit_epoch_losses", step_metric="fit_epoch", summary="best", goal="minimize") + # The loss and metrics for each individual step + self.run.define_metric("fit_step_metrics", step_metric="fit_step", summary="best") + self.run.define_metric("fit_step_losses", step_metric="fit_step", summary="best", goal="minimize") + # FlServer (Base Server) specific metrics + self.run.define_metric("val - loss - aggregated", step_metric="round", summary="best", goal="minimize") + self.run.define_metric("eval_round_metrics_aggregated", step_metric="round", summary="best") + # The following metrics don't work with wandb since they are currently obtained after training instead of live + self.run.define_metric("val - loss - centralized", step_metric="round", summary="best", goal="minimize") + self.run.define_metric("eval_round_metrics_centralized", step_metric="round", summary="best") + def start_run(self, wandb_init_kwargs: dict[str, Any]) -> None: """Initializes the wandb run. + We avoid doing this in the self.init function so that when debugging, jobs that fail before training starts do + not get uploaded to wandb. + Args: wandb_init_kwargs (dict[str, Any]): Keyword arguments for wandb.init() excluding the ones explicitly accessible through WandBReporter.init(). @@ -108,71 +157,76 @@ def start_run(self, wandb_init_kwargs: dict[str, Any]) -> None: tags=self.tags, name=self.name, id=self.id, - **wandb_init_kwargs # Other less commonly used kwargs + **wandb_init_kwargs, # Other less commonly used kwargs ) self.run_id = self.run._run_id # If run_id was None, we need to reset run id self.run_started = True - def get_wandb_timestep( - self, - round: int | None, - epoch: int | None, - step: int | None, - ) -> int | None: - """Determines the current step based on the timestep type. - - The report method is called every round epoch and step by default. Depending on the wandb_step_type we need to - determine whether or not to ignore the call to avoid reporting to frequently. E.g. if wandb_step_type is EPOCH - then we should not report data that is sent every step, but we should report data that is sent once an epoch or - once a round. We can do this by ignoring calls to report where step is not None. - - Args: - round (int | None): The current round or None if called outside of a round. - epoch (int | None): The current epoch or None if called outside of a epoch. - step (int | None): The current step (total) or None if called outside of - step. - - Returns: - int | None: Returns None if the reporter should not report metrics on this - call. If an integer is returned then it is what the reporter should use - as the current wandb step based on its wandb_step_type. - """ - if self.wandb_step_type == StepType.ROUND and epoch is None and step is None: - return round # If epoch or step are integers, we should ignore report whend wandb_step_type is ROUND - elif self.wandb_step_type == StepType.EPOCH and step is None: - return epoch # If step is an integer, we should ignore report when wandb_step_type is EPOCH or ROUND - elif self.wandb_step_type == StepType.STEP: - return step # Since step is the finest granularity step type, we always report for wandb_step_type STEP - - # Return None otherwise - return None + # Wandb metric definitions + self.define_metrics() def report( self, - data: dict, + data: dict[str, Any], round: int | None = None, epoch: int | None = None, - batch: int | None = None, + step: int | None = None, ) -> None: - # If round is None, assume data is summary information. Always report this. - if round is None: - if not self.run_started: - self.start_run(self.wandb_init_kwargs) - self.run.summary.update(data) + """Reports wandb compatible data to the wandb server. - # Get wandb step based on timestep_type - wandb_step = self.get_wandb_timestep(round, epoch, batch) + Data passed to self.report is always reported. If round is None, the data is reported as config information. + If round is specified, the data is logged to the wandb run at the current wandb step which is either the + current round, epoch or step depending on the wandb_step_type passed on initialization. The current epoch and + step are initialized at 0 and updated internally when specified as arguments to report. Therefore leaving epoch + or step as None will overwrite the data for the previous epoch/step if the key is the same, otherwise the new + key-value pairs are added. For example, if {"loss": value} is logged every epoch but wandb_step_type is + 'round', then the value for "loss" at round 1 will be it's value at the last epoch of that round. You can only + update or overwrite the current wandb step, previous steps can not be modified. - # If wandb_step is None, then we should not report on this call - if wandb_step is None: - return - - # Check if wandb run has been initialized + Args: + data (dict[str, Any]): Dictionary of wandb compatible data to log + round (int | None, optional): The current FL round. If None, this indicates that the method was called + outside of a round (e.g. for summary information). Defaults to None. + epoch (int | None, optional): The current epoch (In total across all rounds). If None then this method was + not called at or within the scope of an epoch. Defaults to None. + step (int | None, optional): The current step (In total across all rounds and epochs). If None then this + method was called outside the scope of a training or evaluation step (eg. at the end of an epoch or + round) Defaults to None. + """ + # Now that report has been called we are finally forced to start the run. if not self.run_started: self.start_run(self.wandb_init_kwargs) - # Log data - self.run.log(data, step=wandb_step) + # If round is None, assume data is summary information. + if round is None: + wandb.config.update(data) + return + + # Update current epoch and step if they were specified + if epoch is not None: + if epoch < self.current_epoch: + log( + WARNING, + f"The specified current epoch ({epoch}) is less than a previous \ + current epoch ({self.current_epoch})", + ) + self.current_epoch = epoch + + if step is not None: + if step < self.current_step: + log( + WARNING, + f"The specified current step ({step}) is less than a previous current step ({self.current_step})", + ) + self.current_step = step + + # Log based on step type + if self.wandb_step_type == StepType.ROUND: + self.run.log(data, step=round) + elif self.wandb_step_type == StepType.EPOCH: + self.run.log(data, step=self.current_epoch) + elif self.wandb_step_type == StepType.STEP: + self.run.log(data, step=self.current_step) def shutdown(self) -> None: self.run.finish() diff --git a/fl4health/server/base_server.py b/fl4health/server/base_server.py index ca15dbb85..56f9c7472 100644 --- a/fl4health/server/base_server.py +++ b/fl4health/server/base_server.py @@ -27,6 +27,7 @@ from fl4health.utils.typing import EvaluateFailures, FitFailures +# TODO: Have the server save the config as an attribute on init so that it has access to training hyperparams. class FlServer(Server): def __init__( self, @@ -81,7 +82,7 @@ def report_centralized_eval(self, history: History, num_rounds: int) -> None: round_metrics = {} for metric, vals in history.metrics_centralized.items(): round_metrics.update({metric: vals[round][1]}) - self.reports_manager.report({"eval_metrics_centralized": round_metrics}, round + 1) + self.reports_manager.report({"eval_round_metrics_centralized": round_metrics}, round + 1) def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: """ @@ -115,6 +116,7 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float } ) + # WARNING: This will not work with wandb. Wandb reporting must be done live. self.report_centralized_eval(history, num_rounds) return history, elapsed_time @@ -134,7 +136,7 @@ def fit_round( ) if fit_round_results is not None: _, metrics, fit_results_and_failures = fit_round_results - self.reports_manager.report({"fit_metrics": metrics}, server_round) + self.reports_manager.report({"fit_round_metrics": metrics}, server_round) failures = fit_results_and_failures[1] if fit_results_and_failures else None if failures and not self.accept_failures: @@ -340,18 +342,24 @@ def evaluate_round( if loss_aggregated: self._maybe_checkpoint(loss_aggregated, metrics_aggregated, server_round) # Report evaluation results - self.reports_manager.report( - { - "val - loss - aggregated": loss_aggregated, - "round": server_round, - "eval_round_start": str(start_time), - "eval_round_end": str(end_time), - }, - server_round, - ) + report_data = { + "val - loss - aggregated": loss_aggregated, + "round": server_round, + "eval_round_start": str(start_time), + "eval_round_end": str(end_time), + } + dummy_params = Parameters([], "None") + config = self.strategy.configure_evaluate(server_round, dummy_params, self._client_manager)[0][ + 1 + ].config + if config.get("local_epochs", None) is not None: + report_data["fit_epoch"] = server_round * config["local_epochs"] + elif config.get("local_steps", None) is not None: + report_data["fit_step"] = server_round * config["local_steps"] + self.reports_manager.report(report_data, server_round) if len(metrics_aggregated) > 0: self.reports_manager.report( - {"eval_metrics_aggregated": metrics_aggregated}, + {"eval_round_metrics_aggregated": metrics_aggregated}, server_round, ) diff --git a/poetry.lock b/poetry.lock index 6f6e872a2..bbc6d6e83 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7858,21 +7858,21 @@ testing = ["coverage (>=5.0)", "pytest", "pytest-cov"] [[package]] name = "wandb" -version = "0.18.3" +version = "0.18.7" description = "A CLI and library for interacting with the Weights & Biases API." optional = false python-versions = ">=3.7" files = [ - {file = "wandb-0.18.3-py3-none-any.whl", hash = "sha256:7da64f7da0ff7572439de10bfd45534e8811e71e78ac2ccc3b818f1c0f3a9aef"}, - {file = "wandb-0.18.3-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:6674d8a5c40c79065b9c7eb765136756d5ebc9457a5f9abc820a660fb23f8b67"}, - {file = "wandb-0.18.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:741f566e409a2684d3047e4cc25e8e914d78196b901190937b24b6abb8b052e5"}, - {file = "wandb-0.18.3-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:8be5e877570b693001c52dcc2089e48e6a4dcbf15f3adf5c9349f95148b59d58"}, - {file = "wandb-0.18.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d788852bd4739fa18de3918f309c3a955b5cef3247fae1c40df3a63af637e1a0"}, - {file = "wandb-0.18.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ab81424eb207d78239a8d69c90521a70074fb81e3709055484e43c76fe44dc08"}, - {file = "wandb-0.18.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2c91315b8b62423eae18577d66a4b4bb8e4341a7d5c849cb2963e3b3dff0bf6d"}, - {file = "wandb-0.18.3-py3-none-win32.whl", hash = "sha256:92a647dab783938ec87776a9fae8a13e72e6dad939c53e357cdea9d2570f0ad8"}, - {file = "wandb-0.18.3-py3-none-win_amd64.whl", hash = "sha256:29cac2cfa3124241fed22cfedc9a52e1500275ee9bbb0b428ce4bf63c4723bf0"}, - {file = "wandb-0.18.3.tar.gz", hash = "sha256:eb2574cea72bc908c6ce1b37edf7a889619e6e06e1b4714eecfe0662ded43c06"}, + {file = "wandb-0.18.7-py3-none-any.whl", hash = "sha256:c2b9f9fea6daf8b62a505ea5d77d7e5e375c6014947a8882c0497399a9a1e4af"}, + {file = "wandb-0.18.7-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:9fb2d381b20a079d7bb519b1b5cbbd94a10e941a2a0c5ccc044748b00344a294"}, + {file = "wandb-0.18.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:87209f5aed8dbcf4b699ce745d096bc13b3cb66217efa5c44dd772d4f7fe7836"}, + {file = "wandb-0.18.7-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:e31d2115c558257406bf9beffe13d42313d958f2809cb15123a8e6a6d18d66c6"}, + {file = "wandb-0.18.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e261e9f87005a4487548137d04bfa10fa14e3306b9901bc6ac2f3335c73df7c6"}, + {file = "wandb-0.18.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3133683a5b3bd3a50cf498e6b5ecc7406738619ae9f245326a9fa2e80ad313f"}, + {file = "wandb-0.18.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:7ca272660d880ba007aa7b4be2f88160692b2f12dccd431bd2f6471c85e68986"}, + {file = "wandb-0.18.7-py3-none-win32.whl", hash = "sha256:a42b63c9b9e552b51e51b35caf26d81675dbc012317bc2701e39b3d84d479354"}, + {file = "wandb-0.18.7-py3-none-win_amd64.whl", hash = "sha256:4ba9fda6dd7db02a23c6b302411fe26c3fcfea4947cc130a65e1de19812d324e"}, + {file = "wandb-0.18.7.tar.gz", hash = "sha256:00f9891558d4833ee47f21ce6c603499f0bd1a7ce117ff55ee1a041e9094f9a2"}, ] [package.dependencies] @@ -7884,9 +7884,10 @@ protobuf = {version = ">=3.19.0,<4.21.0 || >4.21.0,<5.28.0 || >5.28.0,<6", marke psutil = ">=5.0.0" pyyaml = "*" requests = ">=2.0.0,<3" -sentry-sdk = ">=1.0.0" +sentry-sdk = ">=2.0.0" setproctitle = "*" setuptools = "*" +typing-extensions = {version = ">=4.4,<5", markers = "python_version < \"3.12\""} [package.extras] aws = ["boto3"] diff --git a/tests/clients/test_basic_client.py b/tests/clients/test_basic_client.py index a19253885..bcd7a7dae 100644 --- a/tests/clients/test_basic_client.py +++ b/tests/clients/test_basic_client.py @@ -61,8 +61,8 @@ def test_metrics_reporter_fit() -> None: "rounds": { test_current_server_round: { "round_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), - "fit_losses": test_loss_dict, - "fit_metrics": test_metrics, + "fit_round_losses": test_loss_dict, + "fit_round_metrics": test_metrics, "round": test_current_server_round, }, }, @@ -103,10 +103,10 @@ def test_metrics_reporter_evaluate() -> None: "initialized": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), "rounds": { test_current_server_round: { - "eval_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), - "eval_loss": test_loss, - "eval_metrics": test_metrics_final, - "eval_end": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "eval_round_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), + "eval_round_loss": test_loss, + "eval_round_metrics": test_metrics_final, + "eval_round_end": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), }, }, } diff --git a/tests/server/test_base_server.py b/tests/server/test_base_server.py index 0a06d0919..c2c8c5894 100644 --- a/tests/server/test_base_server.py +++ b/tests/server/test_base_server.py @@ -10,6 +10,7 @@ from flwr.common.parameter import ndarrays_to_parameters from flwr.server.client_proxy import ClientProxy from flwr.server.history import History +from flwr.server.strategy import FedAvg from freezegun import freeze_time from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer @@ -114,11 +115,11 @@ def test_metrics_reporter_fit(mock_fit: Mock) -> None: "fit_end": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), "rounds": { 1: { - "eval_metrics_centralized": {"test_metric1": 123.123}, + "eval_round_metrics_centralized": {"test_metric1": 123.123}, "val - loss - centralized": 123.123, }, 2: { - "eval_metrics_centralized": {"test_metric1": 123}, + "eval_round_metrics_centralized": {"test_metric1": 123}, "val - loss - centralized": 123, }, }, @@ -142,7 +143,7 @@ def test_metrics_reporter_fit_round(mock_fit_round: Mock) -> None: "rounds": { test_round: { "fit_round_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), - "fit_metrics": test_metrics_aggregated, + "fit_round_metrics": test_metrics_aggregated, "fit_round_end": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), }, }, @@ -247,9 +248,12 @@ def test_metrics_reporter_evaluate_round(mock_evaluate_round: Mock) -> None: test_metrics_aggregated, (None, None), ) - + client_manager = SimpleClientManager() + client_manager.register(CustomClientProxy("test_id", 1)) reporter = JsonReporter() - fl_server = FlServer(SimpleClientManager(), reporters=[reporter]) + fl_server = FlServer( + client_manager, reporters=[reporter], strategy=FedAvg(min_evaluate_clients=1, min_available_clients=1) + ) fl_server.evaluate_round(test_round, None) metrics_to_assert = { @@ -257,7 +261,7 @@ def test_metrics_reporter_evaluate_round(mock_evaluate_round: Mock) -> None: test_round: { "eval_round_start": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), "val - loss - aggregated": test_loss_aggregated, - "eval_metrics_aggregated": test_metrics_aggregated, + "eval_round_metrics_aggregated": test_metrics_aggregated, "eval_round_end": str(datetime.datetime(2012, 12, 12, 12, 12, 12)), }, }, diff --git a/tests/smoke_tests/apfl_client_metrics.json b/tests/smoke_tests/apfl_client_metrics.json index ede630e6d..1b90e0e75 100644 --- a/tests/smoke_tests/apfl_client_metrics.json +++ b/tests/smoke_tests/apfl_client_metrics.json @@ -1,58 +1,58 @@ { "rounds": { "1": { - "fit_metrics": { + "fit_round_metrics": { "train - personal - accuracy": 0.6063, "train - global - accuracy": 0.6078, "train - local - accuracy": 0.4797 }, - "fit_losses": { + "fit_round_losses": { "backward": 1.4428, "global": 1.3272, "local": 1.6362 }, - "eval_metrics": { + "eval_round_metrics": { "val - personal - accuracy": 0.5269, "val - global - accuracy": 0.5331, "val - local - accuracy": 0.4984 }, - "eval_loss": 1.5862 + "eval_round_loss": 1.5862 }, "2": { - "fit_metrics": { + "fit_round_metrics": { "train - personal - accuracy": 0.8063, "train - global - accuracy": 0.8266, "train - local - accuracy": 0.7688 }, - "fit_losses": { + "fit_round_losses": { "backward": 0.6089, "global": 0.6306, "local": 0.7349 }, - "eval_metrics": { + "eval_round_metrics": { "val - personal - accuracy": 0.7117, "val - global - accuracy": 0.7162, "val - local - accuracy": 0.6274 }, - "eval_loss": 0.9102 + "eval_round_loss": 0.9102 }, "3": { - "fit_metrics": { + "fit_round_metrics": { "train - personal - accuracy": 0.85, "train - global - accuracy": 0.8703, "train - local - accuracy": 0.8422 }, - "fit_losses": { + "fit_round_losses": { "backward": 0.4666, "global": 0.4760, "local": 0.6050 }, - "eval_metrics": { + "eval_round_metrics": { "val - personal - accuracy": 0.7681, "val - global - accuracy": 0.7943, "val - local - accuracy": 0.7094 }, - "eval_loss": 0.7815 + "eval_round_loss": 0.7815 } } } diff --git a/tests/smoke_tests/apfl_server_metrics.json b/tests/smoke_tests/apfl_server_metrics.json index ee68c9c64..eac29e9a0 100644 --- a/tests/smoke_tests/apfl_server_metrics.json +++ b/tests/smoke_tests/apfl_server_metrics.json @@ -1,7 +1,7 @@ { "rounds": { "1": { - "eval_metrics_aggregated": { + "eval_round_metrics_aggregated": { "val - personal - accuracy": 0.5269, "val - global - accuracy": 0.5331, "val - local - accuracy": 0.4984 @@ -9,7 +9,7 @@ "val - loss - aggregated": 1.5862 }, "2": { - "eval_metrics_aggregated": { + "eval_round_metrics_aggregated": { "val - personal - accuracy": 0.7117, "val - global - accuracy": 0.7162, "val - local - accuracy": 0.6274 @@ -17,7 +17,7 @@ "val - loss - aggregated": 0.9102 }, "3": { - "eval_metrics_aggregated": { + "eval_round_metrics_aggregated": { "val - personal - accuracy": 0.7681, "val - global - accuracy": 0.7943, "val - local - accuracy": 0.7094 diff --git a/tests/smoke_tests/basic_client_metrics.json b/tests/smoke_tests/basic_client_metrics.json index acd823c42..556e4b7e6 100644 --- a/tests/smoke_tests/basic_client_metrics.json +++ b/tests/smoke_tests/basic_client_metrics.json @@ -1,42 +1,42 @@ { "rounds": { "1": { - "fit_metrics": { + "fit_round_metrics": { "train - prediction - accuracy": 0.084375 }, - "fit_losses": { + "fit_round_losses": { "backward": 3.4583 } }, "2": { - "fit_metrics": { + "fit_round_metrics": { "train - prediction - accuracy": 0.1 }, - "fit_losses": { + "fit_round_losses": { "backward": 2.32976 }, - "eval_metrics": { + "eval_round_metrics": { "val - prediction - accuracy": 0.0942, "test - num_examples": 10000, "test - checkpoint": 2.30616, "test - prediction - accuracy": 0.0966 }, - "eval_loss": 2.3042 + "eval_round_loss": 2.3042 }, "3": { - "fit_metrics": { + "fit_round_metrics": { "train - prediction - accuracy": 0.096875 }, - "fit_losses": { + "fit_round_losses": { "backward": 2.31093 }, - "eval_metrics": { + "eval_round_metrics": { "val - prediction - accuracy": 0.0936, "test - num_examples": 10000, "test - checkpoint": 2.30109, "test - prediction - accuracy": 0.0972 }, - "eval_loss": 2.2999 + "eval_round_loss": 2.2999 } } } diff --git a/tests/smoke_tests/basic_server_metrics.json b/tests/smoke_tests/basic_server_metrics.json index b8ad7f447..e71636e2f 100644 --- a/tests/smoke_tests/basic_server_metrics.json +++ b/tests/smoke_tests/basic_server_metrics.json @@ -1,7 +1,7 @@ { "rounds": { "1": { - "eval_metrics_aggregated": { + "eval_round_metrics_aggregated": { "val - prediction - accuracy": 0.1031, "test - prediction - accuracy": 0.1039, "test - loss - aggregated": 2.3613 @@ -9,7 +9,7 @@ "val - loss - aggregated": 2.3567 }, "2": { - "eval_metrics_aggregated": { + "eval_round_metrics_aggregated": { "val - prediction - accuracy": 0.0942, "test - prediction - accuracy": 0.0966, "test - loss - aggregated": 2.3061 @@ -17,7 +17,7 @@ "val - loss - aggregated": 2.3042 }, "3": { - "eval_metrics_aggregated": { + "eval_round_metrics_aggregated": { "val - prediction - accuracy": 0.0936, "test - prediction - accuracy": 0.0972, "test - loss - aggregated": 2.3010 diff --git a/tests/smoke_tests/feddg_ga_client_metrics.json b/tests/smoke_tests/feddg_ga_client_metrics.json index fee31d578..7dedc5eef 100644 --- a/tests/smoke_tests/feddg_ga_client_metrics.json +++ b/tests/smoke_tests/feddg_ga_client_metrics.json @@ -1,7 +1,7 @@ { "rounds": { "1": { - "fit_metrics": { + "fit_round_metrics": { "train - personal - accuracy": 0.6063, "train - global - accuracy": 0.6078, "train - local - accuracy": 0.4797, @@ -10,20 +10,20 @@ "val - local - accuracy": 0.4984, "val - checkpoint": 1.5862 }, - "fit_losses": { + "fit_round_losses": { "backward": 1.4428, "global": 1.3272, "local": 1.6362 }, - "eval_metrics": { + "eval_round_metrics": { "val - personal - accuracy": 0.5269, "val - global - accuracy": 0.5331, "val - local - accuracy": 0.4984 }, - "eval_loss": 1.5862 + "eval_round_loss": 1.5862 }, "2": { - "fit_metrics": { + "fit_round_metrics": { "train - personal - accuracy": 0.7906, "train - global - accuracy": {"target_value": 0.8109, "custom_tolerance": 0.005}, "train - local - accuracy": {"target_value": 0.7453, "custom_tolerance": 0.005}, @@ -32,19 +32,19 @@ "val - local - accuracy": {"target_value": 0.614, "custom_tolerance": 0.005}, "val - checkpoint": {"target_value": 1.0008, "custom_tolerance": 0.005} }, - "fit_losses": { + "fit_round_losses": { "global": 0.6785, "local": 0.7770, "backward": 0.6460 }, - "eval_metrics": { + "eval_round_metrics": { "val - personal - accuracy": {"target_value": 0.6584, "custom_tolerance": 0.005}, "val - local - accuracy": {"target_value": 0.614, "custom_tolerance": 0.005} }, - "eval_loss": {"target_value": 1.0008, "custom_tolerance": 0.005} + "eval_round_loss": {"target_value": 1.0008, "custom_tolerance": 0.005} }, "3": { - "fit_metrics": { + "fit_round_metrics": { "train - personal - accuracy": 0.8359, "train - global - accuracy": 0.8656, "train - local - accuracy": {"target_value": 0.8078, "custom_tolerance": 0.005}, @@ -53,17 +53,17 @@ "val - local - accuracy": {"target_value": 0.7508, "custom_tolerance": 0.005}, "val - checkpoint": {"target_value": 0.6042, "custom_tolerance": 0.005} }, - "fit_losses": { + "fit_round_losses": { "global": 0.5084, "local": {"target_value": 0.6561, "custom_tolerance": 0.005}, "backward": {"target_value": 0.5327, "custom_tolerance": 0.005} }, - "eval_metrics": { + "eval_round_metrics": { "val - personal - accuracy": {"target_value": 0.8218, "custom_tolerance": 0.005}, "val - global - accuracy": 0.8497, "val - local - accuracy": {"target_value": 0.7508, "custom_tolerance": 0.005} }, - "eval_loss": {"target_value": 0.6042, "custom_tolerance": 0.005} + "eval_round_loss": {"target_value": 0.6042, "custom_tolerance": 0.005} } } } diff --git a/tests/smoke_tests/feddg_ga_server_metrics.json b/tests/smoke_tests/feddg_ga_server_metrics.json index 648943759..f1bcfd938 100644 --- a/tests/smoke_tests/feddg_ga_server_metrics.json +++ b/tests/smoke_tests/feddg_ga_server_metrics.json @@ -1,7 +1,7 @@ { "rounds": { "1": { - "eval_metrics_aggregated": { + "eval_round_metrics_aggregated": { "val - personal - accuracy": 0.5269, "val - global - accuracy": 0.5331, "val - local - accuracy": 0.4984 @@ -9,7 +9,7 @@ "val - loss - aggregated": 1.5862 }, "2": { - "eval_metrics_aggregated": { + "eval_round_metrics_aggregated": { "val - personal - accuracy": {"target_value": 0.6584, "custom_tolerance": 0.005}, "val - global - accuracy": 0.6031, "val - local - accuracy": {"target_value": 0.614, "custom_tolerance": 0.005} @@ -17,7 +17,7 @@ "val - loss - aggregated": {"target_value": 1.0008, "custom_tolerance": 0.005} }, "3": { - "eval_metrics_aggregated": { + "eval_round_metrics_aggregated": { "val - personal - accuracy": {"target_value": 0.8218, "custom_tolerance": 0.005}, "val - global - accuracy": 0.8497, "val - local - accuracy": {"target_value": 0.7508, "custom_tolerance": 0.005} diff --git a/tests/smoke_tests/fedprox_client_metrics.json b/tests/smoke_tests/fedprox_client_metrics.json index 119b9e3c2..41af2ea7f 100644 --- a/tests/smoke_tests/fedprox_client_metrics.json +++ b/tests/smoke_tests/fedprox_client_metrics.json @@ -1,34 +1,34 @@ { "rounds": { "1": { - "fit_metrics": {"train - prediction - accuracy": 0.2484}, - "fit_losses": { + "fit_round_metrics": {"train - prediction - accuracy": 0.2484}, + "fit_round_losses": { "loss": 2.1330, "backward": 2.1598, "penalty_loss": 0.0268 }, - "eval_metrics": {"val - prediction - accuracy": 0.3633}, - "eval_loss": 1.9861 + "eval_round_metrics": {"val - prediction - accuracy": 0.3633}, + "eval_round_loss": 1.9861 }, "2": { - "fit_metrics": {"train - prediction - accuracy": 0.4531}, - "fit_losses": { + "fit_round_metrics": {"train - prediction - accuracy": 0.4531}, + "fit_round_losses": { "penalty_loss": 0.0, "loss": 1.7784, "backward": 1.7784 }, - "eval_metrics": {"val - prediction - accuracy": 0.5016}, - "eval_loss": 1.4836 + "eval_round_metrics": {"val - prediction - accuracy": 0.5016}, + "eval_round_loss": 1.4836 }, "3": { - "fit_metrics": {"train - prediction - accuracy": 0.6016}, - "fit_losses": { + "fit_round_metrics": {"train - prediction - accuracy": 0.6016}, + "fit_round_losses": { "penalty_loss": 0.0, "loss": 1.3226, "backward": 1.3226 }, - "eval_metrics": {"val - prediction - accuracy": 0.6901}, - "eval_loss": 1.1124 + "eval_round_metrics": {"val - prediction - accuracy": 0.6901}, + "eval_round_loss": 1.1124 } } } diff --git a/tests/smoke_tests/fedprox_server_metrics.json b/tests/smoke_tests/fedprox_server_metrics.json index 531e53ef5..c671332ba 100644 --- a/tests/smoke_tests/fedprox_server_metrics.json +++ b/tests/smoke_tests/fedprox_server_metrics.json @@ -1,15 +1,15 @@ { "rounds": { "1": { - "eval_metrics_aggregated": {"val - prediction - accuracy": 0.3633}, + "eval_round_metrics_aggregated": {"val - prediction - accuracy": 0.3633}, "val - loss - aggregated": 1.9861 }, "2": { - "eval_metrics_aggregated": {"val - prediction - accuracy": 0.5016}, + "eval_round_metrics_aggregated": {"val - prediction - accuracy": 0.5016}, "val - loss - aggregated": 1.4836 }, "3": { - "eval_metrics_aggregated": {"val - prediction - accuracy": 0.6901}, + "eval_round_metrics_aggregated": {"val - prediction - accuracy": 0.6901}, "val - loss - aggregated": 1.1124 } } diff --git a/tests/smoke_tests/scaffold_client_metrics.json b/tests/smoke_tests/scaffold_client_metrics.json index 32a202f59..cb5670133 100644 --- a/tests/smoke_tests/scaffold_client_metrics.json +++ b/tests/smoke_tests/scaffold_client_metrics.json @@ -1,30 +1,30 @@ { "rounds": { "0": { - "fit_metrics": {"train - prediction - accuracy": 0.2031}, - "fit_losses": {"backward": 2.2655} + "fit_round_metrics": {"train - prediction - accuracy": 0.2031}, + "fit_round_losses": {"backward": 2.2655} }, "1": { - "fit_metrics": {"train - prediction - accuracy": 0.18125}, - "fit_losses": {"backward": 2.2684}, - "eval_metrics": {"val - prediction - accuracy": {"target_value": 0.1824, "custom_tolerance": 0.005}}, - "eval_loss": 2.2785 + "fit_round_metrics": {"train - prediction - accuracy": 0.18125}, + "fit_round_losses": {"backward": 2.2684}, + "eval_round_metrics": {"val - prediction - accuracy": {"target_value": 0.1824, "custom_tolerance": 0.005}}, + "eval_round_loss": 2.2785 }, "2": { - "fit_metrics": {"train - prediction - accuracy": {"target_value": 0.3906, "custom_tolerance": 0.05}}, - "fit_losses": { + "fit_round_metrics": {"train - prediction - accuracy": {"target_value": 0.3906, "custom_tolerance": 0.05}}, + "fit_round_losses": { "backward": {"target_value": 2.1567, "custom_tolerance": 0.005} }, - "eval_metrics": {"val - prediction - accuracy": {"target_value": 0.3332, "custom_tolerance": 0.05}}, - "eval_loss": {"target_value": 2.2509, "custom_tolerance": 0.005} + "eval_round_metrics": {"val - prediction - accuracy": {"target_value": 0.3332, "custom_tolerance": 0.05}}, + "eval_round_loss": {"target_value": 2.2509, "custom_tolerance": 0.005} }, "3": { - "fit_metrics": {"train - prediction - accuracy": {"target_value": 0.4078, "custom_tolerance": 0.05}}, - "fit_losses": { + "fit_round_metrics": {"train - prediction - accuracy": {"target_value": 0.4078, "custom_tolerance": 0.05}}, + "fit_round_losses": { "backward": {"target_value": 2.0964, "custom_tolerance": 0.005} }, - "eval_metrics": {"val - prediction - accuracy": {"target_value": 0.4062, "custom_tolerance": 0.005}}, - "eval_loss": {"target_value": 2.2070, "custom_tolerance": 0.05} + "eval_round_metrics": {"val - prediction - accuracy": {"target_value": 0.4062, "custom_tolerance": 0.005}}, + "eval_round_loss": {"target_value": 2.2070, "custom_tolerance": 0.05} } } } diff --git a/tests/smoke_tests/scaffold_server_metrics.json b/tests/smoke_tests/scaffold_server_metrics.json index cd9d60875..2fab497ff 100644 --- a/tests/smoke_tests/scaffold_server_metrics.json +++ b/tests/smoke_tests/scaffold_server_metrics.json @@ -1,15 +1,15 @@ { "rounds": { "1": { - "eval_metrics_aggregated": {"val - prediction - accuracy": {"target_value": 0.1824, "custom_tolerance": 0.005}}, + "eval_round_metrics_aggregated": {"val - prediction - accuracy": {"target_value": 0.1824, "custom_tolerance": 0.005}}, "val - loss - aggregated": 2.2785 }, "2": { - "eval_metrics_aggregated": {"val - prediction - accuracy": {"target_value": 0.3332, "custom_tolerance": 0.05}}, + "eval_round_metrics_aggregated": {"val - prediction - accuracy": {"target_value": 0.3332, "custom_tolerance": 0.05}}, "val - loss - aggregated": {"target_value": 2.2509, "custom_tolerance": 0.005} }, "3": { - "eval_metrics_aggregated": {"val - prediction - accuracy": 0.4062}, + "eval_round_metrics_aggregated": {"val - prediction - accuracy": 0.4062}, "val - loss - aggregated": {"target_value": 2.2070, "custom_tolerance": 0.05} } }