Skip to content

Commit

Permalink
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 2, 2025
1 parent fe57eea commit 2d939b3
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions tests/utils/snapshotter_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from collections.abc import Sequence
from pathlib import Path
from typing import Dict, Optional
Expand All @@ -7,27 +8,26 @@
from flwr.common import Scalar

from fl4health.clients.basic_client import BasicClient
from fl4health.reporting import JsonReporter
from fl4health.reporting.base_reporter import BaseReporter
from fl4health.reporting.reports_manager import ReportsManager
from fl4health.utils.client import fold_loss_dict_into_metrics
from fl4health.utils.logging import LoggingMode
from tests.test_utils.assert_metrics_dict import assert_metrics_dict
from fl4health.reporting.reports_manager import ReportsManager
from fl4health.utils.losses import TrainingLosses, LossMeter
from fl4health.utils.losses import LossMeter, TrainingLosses
from fl4health.utils.metrics import MetricManager
from fl4health.reporting import JsonReporter
from fl4health.utils.snapshotter import SerizableObjectSnapshotter
import copy
from tests.test_utils.assert_metrics_dict import assert_metrics_dict


def test_loss_meter_snapshotter() -> None:
metrics: Dict[str, Scalar] = {"test_metric": 1234}
reporter = JsonReporter()
fl_client = MockBasicClient(metrics=metrics,reporters=[reporter])
fl_client = MockBasicClient(metrics=metrics, reporters=[reporter])
ckpt = {}

fl_client.train_loss_meter.update( TrainingLosses(backward=torch.Tensor([35]), additional_losses=None))
fl_client.train_loss_meter.update(TrainingLosses(backward=torch.Tensor([35]), additional_losses=None))
snapshotter = SerizableObjectSnapshotter(fl_client)
ckpt['train_loss_meter'] = snapshotter.save("train_loss_meter", LossMeter)
ckpt["train_loss_meter"] = snapshotter.save("train_loss_meter", LossMeter)
old_loss_meter = copy.deepcopy(fl_client.train_loss_meter)
fl_client.train_loss_meter.update(TrainingLosses(backward=torch.Tensor([10]), additional_losses=None))
assert len(old_loss_meter.losses_list) != len(fl_client.train_loss_meter.losses_list)
Expand All @@ -37,24 +37,29 @@ def test_loss_meter_snapshotter() -> None:
assert len(old_loss_meter.losses_list) == len(fl_client.train_loss_meter.losses_list)
for i in range(len(fl_client.train_loss_meter.losses_list)):
assert old_loss_meter.losses_list[i].backward == fl_client.train_loss_meter.losses_list[i].backward
assert old_loss_meter.losses_list[i].additional_losses == fl_client.train_loss_meter.losses_list[i].additional_losses
assert (
old_loss_meter.losses_list[i].additional_losses
== fl_client.train_loss_meter.losses_list[i].additional_losses
)


def test_reports_manager_snapshotter() -> None:
metrics: Dict[str, Scalar] = {"test_metric": 1234}
reporter = JsonReporter()
fl_client = MockBasicClient(metrics=metrics,reporters=[reporter])
fl_client = MockBasicClient(metrics=metrics, reporters=[reporter])
ckpt = {}

fl_client.reports_manager.report({"start": "2012-12-12 12:12:10"})
snapshotter = SerizableObjectSnapshotter(fl_client)
ckpt['reports_manager'] = snapshotter.save("reports_manager", ReportsManager)
ckpt["reports_manager"] = snapshotter.save("reports_manager", ReportsManager)
old_reports_manager = copy.deepcopy(fl_client.reports_manager)
fl_client.reports_manager.report({"shutdown": "2012-12-12 12:12:12"})
assert old_reports_manager.reporters[0].metrics != fl_client.reports_manager.reporters[0].metrics

snapshotter.load(ckpt, "reports_manager", ReportsManager)
assert old_reports_manager.reporters[0].metrics == fl_client.reports_manager.reporters[0].metrics


## LEFT OFF HERE
# def test_metric_manager_snapshotter() -> None:
# metrics: Dict[str, Scalar] = {"test_metric": 1234}
Expand Down Expand Up @@ -133,4 +138,4 @@ def mock_validate_or_test( # type: ignore
if logging_mode == LoggingMode.VALIDATION:
return self.mock_loss, self.mock_metrics
else:
return self.mock_loss, self.mock_metrics_test
return self.mock_loss, self.mock_metrics_test

0 comments on commit 2d939b3

Please sign in to comment.