diff --git a/fl4health/servers/__init__.py b/fl4health/servers/__init__.py index 4fc21eb06..6e565b8f9 100644 --- a/fl4health/servers/__init__.py +++ b/fl4health/servers/__init__.py @@ -1,11 +1,11 @@ # Note: This is commented out until logging issues resolved # Import server classes so that they can be directly imported from fl4health.server -# from fl4health.server.base_server import FlServer, FlServerWithCheckpointing, FlServerWithInitializer -# from fl4health.server.client_level_dp_fed_avg_server import ClientLevelDPFedAvgServer -# from fl4health.server.evaluate_server import EvaluateServer -# from fl4health.server.fedpm_server import FedPmServer -# from fl4health.server.instance_level_dp_server import InstanceLevelDpServer -# from fl4health.server.model_merge_server import ModelMergeServer -# from fl4health.server.nnunet_server import NnunetServer -# from fl4health.server.scaffold_server import ScaffoldServer, DPScaffoldServer -# from fl4health.server.tabular_feature_alignment_server import TabularFeatureAlignmentServer +# from fl4health.servers.base_server import FlServer, FlServerWithCheckpointing, FlServerWithInitializer +# from fl4health.servers.client_level_dp_fed_avg_server import ClientLevelDPFedAvgServer +# from fl4health.servers.evaluate_server import EvaluateServer +# from fl4health.servers.fedpm_server import FedPmServer +# from fl4health.servers.instance_level_dp_server import InstanceLevelDpServer +# from fl4health.servers.model_merge_server import ModelMergeServer +# from fl4health.servers.nnunet_server import NnunetServer +# from fl4health.servers.scaffold_server import ScaffoldServer, DPScaffoldServer +# from fl4health.servers.tabular_feature_alignment_server import TabularFeatureAlignmentServer diff --git a/tests/servers/test_base_server.py b/tests/servers/test_base_server.py index 4becc015f..cbe940696 100644 --- a/tests/servers/test_base_server.py +++ b/tests/servers/test_base_server.py @@ -99,7 +99,7 @@ def test_fl_server_with_checkpointing(tmp_path: Path) -> None: assert torch.equal(updated_model.linear.weight, loaded_model.linear.weight) -@patch("fl4health.server.base_server.Server.fit") +@patch("fl4health.servers.base_server.Server.fit") @freeze_time("2012-12-12 12:12:12") def test_metrics_reporter_fit(mock_fit: Mock) -> None: test_history = History() @@ -128,7 +128,7 @@ def test_metrics_reporter_fit(mock_fit: Mock) -> None: assert len(errors) == 0, f"Metrics check failed. Errors: {errors}, {reporter.metrics}" -@patch("fl4health.server.base_server.Server.fit_round") +@patch("fl4health.servers.base_server.Server.fit_round") @freeze_time("2012-12-12 12:12:12") def test_metrics_reporter_fit_round(mock_fit_round: Mock) -> None: test_round = 2 @@ -237,7 +237,7 @@ def test_handle_result_aggregation() -> None: ) -@patch("fl4health.server.base_server.FlServer._evaluate_round") +@patch("fl4health.servers.base_server.FlServer._evaluate_round") @freeze_time("2012-12-12 12:12:12") def test_metrics_reporter_evaluate_round(mock_evaluate_round: Mock) -> None: test_round = 2 diff --git a/tests/servers/test_evaluate_server.py b/tests/servers/test_evaluate_server.py index 3f987ee95..4280c898c 100644 --- a/tests/servers/test_evaluate_server.py +++ b/tests/servers/test_evaluate_server.py @@ -9,7 +9,7 @@ from tests.test_utils.assert_metrics_dict import assert_metrics_dict -@patch("fl4health.server.evaluate_server.EvaluateServer.federated_evaluate") +@patch("fl4health.servers.evaluate_server.EvaluateServer.federated_evaluate") @freeze_time("2012-12-12 12:12:12") def test_metrics_reporter_fit(mock_federated_evaluate: Mock) -> None: pass