Skip to content

Commit

Permalink
Merge pull request #288 from VectorInstitute/add_more_reports
Browse files Browse the repository at this point in the history
Modified some of the reporting keys and overhauled the wandb reporter
  • Loading branch information
scarere authored Nov 18, 2024
2 parents 416e1c6 + f95cf20 commit 3747589
Show file tree
Hide file tree
Showing 18 changed files with 286 additions and 191 deletions.
57 changes: 36 additions & 21 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
self.initial_weights: Optional[NDArrays] = None

self.total_steps: int = 0 # Need to track total_steps across rounds for WANDB reporting
self.total_epochs: int = 0 # Will remain as 0 if training by steps

# Attributes to be initialized in setup_client
self.parameter_exchanger: ParameterExchanger
Expand Down Expand Up @@ -221,9 +222,9 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N
config (Config): The config from the server.
Returns:
Tuple[Union[int, None], Union[int, None], int, bool]: Returns the local_epochs, local_steps,
current_server_round and evaluate_after_fit. Ensures only one of local_epochs and local_steps
is defined in the config and sets the one that is not to None.
Tuple[Union[int, None], Union[int, None], int, bool, bool]: Returns the local_epochs, local_steps,
current_server_round, evaluate_after_fit and pack_losses_with_val_metrics. Ensures only one of
local_epochs and local_steps is defined in the config and sets the one that is not to None.
Raises:
ValueError: If the config contains both local_steps and local epochs or if local_steps, local_epochs or
Expand Down Expand Up @@ -307,15 +308,24 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict
# We perform a pre-aggregation checkpoint if applicable
self._maybe_checkpoint(validation_loss, validation_metrics, CheckpointMode.PRE_AGGREGATION)

# Notes on report values:
# - Train by steps: round metrics/losses are computed using all samples from the round
# - Train by epochs: round metrics/losses computed using only the samples from the final epoch of the round
# - fit_round_metrics: Computed at the end of the round on the samples directly
# - fit_round_losses: The average of the losses computed for each step.
# * (Hence likely higher than the final loss of the round.)
self.reports_manager.report(
{
"fit_metrics": metrics,
"fit_losses": loss_dict,
"fit_round_metrics": metrics,
"fit_round_losses": loss_dict,
"round": current_server_round,
"round_start": str(round_start_time),
"round_end": str(datetime.datetime.now()),
"fit_start": str(fit_start_time),
"fit_end": str(fit_end_time),
"fit_round_start": str(fit_start_time),
"fit_round_time_elapsed": str(fit_end_time - fit_start_time),
"fit_round_end": str(fit_end_time),
"fit_step": self.total_steps,
"fit_epoch": self.total_epochs,
},
current_server_round,
)
Expand Down Expand Up @@ -364,11 +374,14 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di

self.reports_manager.report(
{
"eval_metrics": metrics,
"eval_loss": loss,
"eval_start": str(start_time),
"eval_time_elapsed": str(elapsed),
"eval_end": str(end_time),
"eval_round_metrics": metrics,
"eval_round_loss": loss,
"eval_round_start": str(start_time),
"eval_round_time_elapsed": str(elapsed),
"eval_round_end": str(end_time),
"fit_step": self.total_steps,
"fit_epoch": self.total_epochs,
"round": current_server_round,
},
current_server_round,
)
Expand Down Expand Up @@ -607,7 +620,7 @@ def train_by_epochs(
# update before epoch hook
self.update_before_epoch(epoch=local_epoch)
# Update report data dict
report_data.update({"fit_epoch": local_epoch})
report_data.update({"fit_epoch": self.total_epochs})
for input, target in maybe_progress_bar(self.train_loader, self.progress_bar):
self.update_before_step(steps_this_round, current_round)
# Assume first dimension is batch size. Sampling iterators (such as Poisson batch sampling), can
Expand All @@ -623,20 +636,22 @@ def train_by_epochs(
self.update_metric_manager(preds, target, self.train_metric_manager)
self.update_after_step(steps_this_round, current_round)
self.update_lr_schedulers(epoch=local_epoch)
report_data.update({"fit_losses": losses.as_dict(), "fit_step": self.total_steps})
report_data.update({"fit_step_losses": losses.as_dict(), "fit_step": self.total_steps})
report_data.update(self.get_client_specific_reports())
self.reports_manager.report(report_data, current_round, local_epoch, self.total_steps)
self.reports_manager.report(report_data, current_round, self.total_epochs, self.total_steps)
self.total_steps += 1
steps_this_round += 1

# Log and report results
metrics = self.train_metric_manager.compute()
loss_dict = self.train_loss_meter.compute().as_dict()

# Log and report results
self._log_results(loss_dict, metrics, current_round, local_epoch)
report_data.update({"fit_metrics": metrics})
report_data.update({"fit_epoch_metrics": metrics, "fit_epoch_losses": loss_dict})
report_data.update(self.get_client_specific_reports())
self.reports_manager.report(report_data, current_round, local_epoch)
self.reports_manager.report(report_data, current_round, self.total_epochs)
self._log_results(loss_dict, metrics, current_round, local_epoch)

# Update internal epoch counter
self.total_epochs += 1

# Return final training metrics
return loss_dict, metrics
Expand Down Expand Up @@ -690,7 +705,7 @@ def train_by_steps(
self.update_metric_manager(preds, target, self.train_metric_manager)
self.update_after_step(step, current_round)
self.update_lr_schedulers(step=step)
report_data.update({"fit_losses": losses.as_dict(), "fit_step": self.total_steps})
report_data.update({"fit_step_losses": losses.as_dict(), "fit_step": self.total_steps})
report_data.update(self.get_client_specific_reports())
self.reports_manager.report(report_data, current_round, None, self.total_steps)
self.total_steps += 1
Expand Down
11 changes: 11 additions & 0 deletions fl4health/clients/nnunet_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from nnunetv2.experiment_planning.plan_and_preprocess_api import extract_fingerprints, preprocess_dataset
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
from nnunetv2.training.dataloading.utils import unpack_dataset
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.dataset_name_id_conversion import convert_id_to_dataset_name
Expand Down Expand Up @@ -246,9 +247,15 @@ def get_model(self, config: Config) -> nn.Module:
return self.nnunet_trainer.network

def get_criterion(self, config: Config) -> _Loss:
if isinstance(self.nnunet_trainer.loss, DeepSupervisionWrapper):
self.reports_manager.report({"Criterion": self.nnunet_trainer.loss.loss.__class__.__name__})
else:
self.reports_manager.report({"Criterion": self.nnunet_trainer.loss.__class__.__name__})

return Module2LossWrapper(self.nnunet_trainer.loss)

def get_optimizer(self, config: Config) -> Optimizer:
self.reports_manager.report({"Optimizer": self.nnunet_trainer.optimizer.__class__.__name__})
return self.nnunet_trainer.optimizer

def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler:
Expand Down Expand Up @@ -289,6 +296,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler:
# Create and return LR Scheduler Wrapper for the PolyLRScheduler so that it is
# compatible with Torch LRScheduler
# Create and return LR Scheduler. This is nnunet default for version 2.5.1
self.reports_manager.report({"LR Scheduler": "PolyLRScheduler"})
return PolyLRSchedulerWrapper(
self.optimizers[optimizer_key],
initial_lr=self.nnunet_trainer.initial_lr,
Expand Down Expand Up @@ -685,6 +693,9 @@ def get_client_specific_logs(
else:
return "", []

def get_client_specific_reports(self) -> Dict[str, Any]:
return {"learning_rate": float(self.optimizers["global"].param_groups[0]["lr"])}

@use_default_signal_handlers # Experiment planner spawns a process I think
def get_properties(self, config: Config) -> Dict[str, Scalar]:
"""
Expand Down
10 changes: 6 additions & 4 deletions fl4health/reporting/base_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ def report(
data (dict): The data to maybe report from the server or client.
round (int | None, optional): The current FL round. If None, this indicates that the method was called
outside of a round (e.g. for summary information). Defaults to None.
epoch (int | None, optional): The current epoch. If None then this method was not called at or within the
scope of an epoch. Defaults to None.
step (int | None, optional): The current step (total). If None then this method was called outside the
scope of a training or evaluation step (eg. at the end of an epoch or round) Defaults to None.
epoch (int | None, optional): The current epoch (In total across all rounds). If None then this method was
not called at or within the scope of an epoch. Should always be None if training by steps. Defaults to
None.
step (int | None, optional): The current step (In total across all rounds and epochs). If None then this
method was called outside the scope of a training or evaluation step (eg. at the end of an epoch or
round) Defaults to None.
"""
raise NotImplementedError

Expand Down
Loading

0 comments on commit 3747589

Please sign in to comment.