diff --git a/fl4health/server/base_server.py b/fl4health/server/base_server.py index 2773e55eb..56f9c7472 100644 --- a/fl4health/server/base_server.py +++ b/fl4health/server/base_server.py @@ -349,7 +349,9 @@ def evaluate_round( "eval_round_end": str(end_time), } dummy_params = Parameters([], "None") - config = self.strategy.configure_fit(server_round, dummy_params, self._client_manager)[0][1].config + config = self.strategy.configure_evaluate(server_round, dummy_params, self._client_manager)[0][ + 1 + ].config if config.get("local_epochs", None) is not None: report_data["fit_epoch"] = server_round * config["local_epochs"] elif config.get("local_steps", None) is not None: diff --git a/tests/server/test_base_server.py b/tests/server/test_base_server.py index 4ca55722b..c2c8c5894 100644 --- a/tests/server/test_base_server.py +++ b/tests/server/test_base_server.py @@ -10,6 +10,7 @@ from flwr.common.parameter import ndarrays_to_parameters from flwr.server.client_proxy import ClientProxy from flwr.server.history import History +from flwr.server.strategy import FedAvg from freezegun import freeze_time from fl4health.checkpointing.checkpointer import BestLossTorchCheckpointer @@ -247,9 +248,12 @@ def test_metrics_reporter_evaluate_round(mock_evaluate_round: Mock) -> None: test_metrics_aggregated, (None, None), ) - + client_manager = SimpleClientManager() + client_manager.register(CustomClientProxy("test_id", 1)) reporter = JsonReporter() - fl_server = FlServer(SimpleClientManager(), reporters=[reporter]) + fl_server = FlServer( + client_manager, reporters=[reporter], strategy=FedAvg(min_evaluate_clients=1, min_available_clients=1) + ) fl_server.evaluate_round(test_round, None) metrics_to_assert = {