Skip to content

Commit

Permalink
test smoke tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sanaAyrml committed Dec 5, 2024
1 parent 8938bf5 commit 6136a90
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
set_pack_losses_with_val_metrics,
)
from fl4health.utils.config import narrow_dict_type, narrow_dict_type_and_set_attribute
from fl4health.utils.early_stopper import EarlyStopper
from fl4health.utils.logging import LoggingMode
from fl4health.utils.losses import EvaluationLosses, LossMeter, LossMeterType, TrainingLosses
from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, Metric, MetricManager
Expand All @@ -48,7 +47,6 @@ def __init__(
progress_bar: bool = False,
intermediate_client_state_dir: Optional[Path] = None,
client_name: Optional[str] = None,
early_stopper: Optional[EarlyStopper] = None,
) -> None:
"""
Base FL Client with functionality to train, evaluate, log, report and checkpoint.
Expand Down Expand Up @@ -98,8 +96,6 @@ def __init__(
else:
self.per_round_checkpointer = None

self.early_stopper = early_stopper

# Initialize reporters with client information.
self.reports_manager = ReportsManager(reporters)
self.reports_manager.initialize(id=self.client_name)
Expand Down Expand Up @@ -645,7 +641,7 @@ def train_by_epochs(
self.reports_manager.report(report_data, current_round, self.total_epochs, self.total_steps)
self.total_steps += 1
steps_this_round += 1
if self.early_stopper is not None:
if hasattr(self, "early_stopper"):
if self.total_steps % self.early_stopper.interval_steps == 0 and self.early_stopper.should_stop():
log(INFO, "Early stopping criterion met. Stopping training.")
break
Expand Down Expand Up @@ -717,7 +713,7 @@ def train_by_steps(
report_data.update(self.get_client_specific_reports())
self.reports_manager.report(report_data, current_round, None, self.total_steps)
self.total_steps += 1
if self.early_stopper is not None:
if hasattr(self, "early_stopper"):
if self.total_steps % self.early_stopper.interval_steps == 0 and self.early_stopper.should_stop():
log(INFO, "Early stopping criterion met. Stopping training.")
break
Expand Down Expand Up @@ -866,9 +862,18 @@ def setup_client(self, config: Config) -> None:
self.parameter_exchanger = self.get_parameter_exchanger(config)

self.reports_manager.report({"host_type": "client", "initialized": str(datetime.datetime.now())})

self.initialized = True

def setup_early_stopper(
self,
patience: int = -1,
interval_steps: int = 5,
snapshot_dir: Optional[Path] = None,
) -> None:
from fl4health.utils.early_stopper import EarlyStopper

self.early_stopper = EarlyStopper(self, patience, interval_steps, snapshot_dir)

def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
"""
Returns Full Parameter Exchangers. Subclasses that require custom Parameter Exchangers can override this.
Expand Down

0 comments on commit 6136a90

Please sign in to comment.