diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6579e5597..5e60d1351 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,6 +46,14 @@ repos: - id: nbqa-flake8 - id: nbqa-mypy + - repo: local + hooks: + - id: mypy legacy type check + name: mypy legacy type check + entry: python mypy_disallow_legacy_types.py + language: python + pass_filenames: true + ci: autofix_commit_msg: | [pre-commit.ci] Add auto fixes from pre-commit.com hooks diff --git a/CONTRIBUTING.MD b/CONTRIBUTING.MD index ffc5aef63..493f6099f 100644 --- a/CONTRIBUTING.MD +++ b/CONTRIBUTING.MD @@ -44,6 +44,17 @@ The settings for `mypy` are in the `mypy.ini`, settings for `flake8` are contain All of these checks and formatters are invoked by pre-commit hooks. These hooks are run remotely on GitHub. In order to ensure that your code conforms to these standards, and, therefore, passes the remote checks, you can install the pre-commit hooks to be run locally. This is done by running (with your environment active) +**Note**: We use the modern mypy types introduced in Python 3.10 and above. See some of the [documentation here](https://mypy.readthedocs.io/en/stable/builtin_types.html) + +For example, this means that we're using `list[str], tuple[int, int], tuple[int, ...], dict[str, int], type[C]` as built-in types and `Iterable[int], Sequence[bool], Mapping[str, int], Callable[[...], ...]` from collections.abc (as now recommended by mypy). + +We are also moving to the new Optional and Union specification style: +```python +Optional[typing_stuff] -> typing_stuff | None +Union[typing1, typing2] -> typing1 | typing2 +Optional[Union[typing1, typing2]] -> typing1 | typing2 | None +``` + ```bash pre-commit install ``` diff --git a/examples/ae_examples/cvae_dim_example/README.md b/examples/ae_examples/cvae_dim_example/README.md index c9a0ca227..2db4161e3 100644 --- a/examples/ae_examples/cvae_dim_example/README.md +++ b/examples/ae_examples/cvae_dim_example/README.md @@ -17,7 +17,7 @@ from the FL4Health directory. The following arguments must be present in the spe * `n_server_rounds`: The number of rounds to run FL * `checkpoint_path`: path to save the best server model * `latent_dim`: size of the latent vector in the CVAE or VAE model -* `cvae_model_path`: path to the saved CVAE model for dimesionality reduction +* `cvae_model_path`: path to the saved CVAE model for dimensionality reduction **NOTE**: Instead of using a global CVAE for all the clients, you can pass personalized CVAE models to each client, but make sure that these models are previously trained in an FL setting, and are not very different, otherwise, that can lead the dimensionality reduction to map the data samples into different latent spaces which might increase the heterogeneity. diff --git a/examples/ae_examples/cvae_dim_example/client.py b/examples/ae_examples/cvae_dim_example/client.py index c7d21fff5..a6d88071e 100644 --- a/examples/ae_examples/cvae_dim_example/client.py +++ b/examples/ae_examples/cvae_dim_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence from pathlib import Path -from typing import Sequence, Tuple import flwr as fl import torch @@ -26,7 +26,7 @@ def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.dev super().__init__(data_path, metrics, device) self.condition = condition - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) cvae_model_path = Path(narrow_dict_type(config, "cvae_model_path", str)) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100) diff --git a/examples/ae_examples/cvae_dim_example/server.py b/examples/ae_examples/cvae_dim_example/server.py index 0ab919672..8e41136cf 100644 --- a/examples/ae_examples/cvae_dim_example/server.py +++ b/examples/ae_examples/cvae_dim_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/ae_examples/cvae_examples/conv_cvae_example/client.py b/examples/ae_examples/cvae_examples/conv_cvae_example/client.py index aa2aed303..9969f88d2 100644 --- a/examples/ae_examples/cvae_examples/conv_cvae_example/client.py +++ b/examples/ae_examples/cvae_examples/conv_cvae_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence from pathlib import Path -from typing import Sequence, Tuple import flwr as fl import torch @@ -25,7 +25,7 @@ def binary_class_condition_data_converter( data: torch.Tensor, target: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor]: # Create a condition for each data sample. # Condition is the binary representation of the target. binary_representation = bin(int(target))[2:] # Convert to binary and remove the '0b' prefix @@ -56,7 +56,7 @@ def setup_client(self, config: Config) -> None: assert isinstance(self.model, ConditionalVae) self.model.unpack_input_condition = self.autoencoder_converter.get_unpacking_function() - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100) # To make sure pixels stay in the range [0.0, 1.0]. diff --git a/examples/ae_examples/cvae_examples/conv_cvae_example/models.py b/examples/ae_examples/cvae_examples/conv_cvae_example/models.py index c1678ccd3..5f6643165 100644 --- a/examples/ae_examples/cvae_examples/conv_cvae_example/models.py +++ b/examples/ae_examples/cvae_examples/conv_cvae_example/models.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn import torch.nn.functional as F @@ -23,7 +21,7 @@ def __init__( self.fc_mu = nn.Linear(64, latent_dim) self.fc_logvar = nn.Linear(64, latent_dim) - def forward(self, input: torch.Tensor, condition: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, input: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x = self.conv(input) # Flatten the tensor x = x.view(x.size(0), -1) diff --git a/examples/ae_examples/cvae_examples/conv_cvae_example/server.py b/examples/ae_examples/cvae_examples/conv_cvae_example/server.py index 0a6ed0b3c..0e6c5c16e 100644 --- a/examples/ae_examples/cvae_examples/conv_cvae_example/server.py +++ b/examples/ae_examples/cvae_examples/conv_cvae_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -32,7 +32,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py b/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py index 866197d54..23cf2ee40 100644 --- a/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py +++ b/examples/ae_examples/cvae_examples/mlp_cvae_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence from pathlib import Path -from typing import Sequence, Tuple import flwr as fl import torch @@ -44,7 +44,7 @@ def setup_client(self, config: Config) -> None: assert isinstance(self.model, ConditionalVae) self.model.unpack_input_condition = self.autoencoder_converter.get_unpacking_function() - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100) # ToTensor transform is used to make sure pixels stay in the range [0.0, 1.0]. diff --git a/examples/ae_examples/cvae_examples/mlp_cvae_example/models.py b/examples/ae_examples/cvae_examples/mlp_cvae_example/models.py index f705881aa..50a8dbe8f 100644 --- a/examples/ae_examples/cvae_examples/mlp_cvae_example/models.py +++ b/examples/ae_examples/cvae_examples/mlp_cvae_example/models.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn import torch.nn.functional as F @@ -19,7 +17,7 @@ def __init__( self.fc_mu = nn.Linear(256, latent_dim) self.fc_logvar = nn.Linear(256, latent_dim) - def forward(self, input: torch.Tensor, condition: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, input: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: input = torch.cat((input, condition), dim=-1) x = F.relu(self.fc1(input)) x = F.relu(self.fc2(x)) diff --git a/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py b/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py index b4ab86523..4389c73d1 100644 --- a/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py +++ b/examples/ae_examples/cvae_examples/mlp_cvae_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -32,7 +32,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/ae_examples/fedprox_vae_example/client.py b/examples/ae_examples/fedprox_vae_example/client.py index 69a5e3bc3..066031621 100644 --- a/examples/ae_examples/fedprox_vae_example/client.py +++ b/examples/ae_examples/fedprox_vae_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -22,7 +21,7 @@ class VaeFedProxClient(FedProxClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100) # Flattening the input images to use an MLP-based variational autoencoder. diff --git a/examples/ae_examples/fedprox_vae_example/models.py b/examples/ae_examples/fedprox_vae_example/models.py index 1bae83e25..df50a52ec 100644 --- a/examples/ae_examples/fedprox_vae_example/models.py +++ b/examples/ae_examples/fedprox_vae_example/models.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn import torch.nn.functional as F @@ -18,7 +16,7 @@ def __init__( self.fc_mu = nn.Linear(256, latent_dim) self.fc_logvar = nn.Linear(256, latent_dim) - def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: x = F.relu(self.fc1(input)) x = F.relu(self.fc2(x)) return self.fc_mu(x), self.fc_logvar(x) diff --git a/examples/ae_examples/fedprox_vae_example/server.py b/examples/ae_examples/fedprox_vae_example/server.py index 83c18afe1..ee5ca375b 100644 --- a/examples/ae_examples/fedprox_vae_example/server.py +++ b/examples/ae_examples/fedprox_vae_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -31,7 +31,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/apfl_example/client.py b/examples/apfl_example/client.py index 75818a3d9..893adad54 100644 --- a/examples/apfl_example/client.py +++ b/examples/apfl_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Dict, Tuple import flwr as fl import torch @@ -22,7 +21,7 @@ class MnistApflClient(ApflClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) @@ -31,7 +30,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: def get_model(self, config: Config) -> nn.Module: return ApflModule(MnistNetWithBnAndFrozen()).to(self.device) - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01) global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01) return {"local": local_optimizer, "global": global_optimizer} diff --git a/examples/apfl_example/server.py b/examples/apfl_example/server.py index 573d9bc8a..2aa950a71 100644 --- a/examples/apfl_example/server.py +++ b/examples/apfl_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -22,8 +22,8 @@ def fit_config( batch_size: int, n_server_rounds: int, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/basic_example/client.py b/examples/basic_example/client.py index 5d44e5db8..50ea582d0 100644 --- a/examples/basic_example/client.py +++ b/examples/basic_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Optional, Tuple import flwr as fl import torch @@ -18,12 +17,12 @@ class CifarClient(BasicClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) test_loader, _ = load_cifar10_test_data(self.data_path, batch_size) return test_loader diff --git a/examples/basic_example/server.py b/examples/basic_example/server.py index 040c35655..2472daed6 100644 --- a/examples/basic_example/server.py +++ b/examples/basic_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -21,8 +21,8 @@ def fit_config( batch_size: int, current_server_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -31,7 +31,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/datasets/dataset_partitioners.py b/examples/datasets/dataset_partitioners.py index 7724cb324..5d9fc6369 100644 --- a/examples/datasets/dataset_partitioners.py +++ b/examples/datasets/dataset_partitioners.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, cast +from typing import cast import numpy as np import pandas as pd @@ -21,27 +21,27 @@ def __init__(self, dataset_path: Path, partition_dir: Path) -> None: @abstractmethod def partition_dataset( - self, n_partiions: int, label_column_name: Optional[str] = None, label_map: Optional[Dict[int, str]] = None + self, n_partitions: int, label_column_name: str | None = None, label_map: dict[int, str] | None = None ) -> None: pass class JsonToPandasDatasetPartitioner(DatasetPartitioner): - def __init__(self, dataset_path: Path, partition_dir: Path, config: Dict[str, str]) -> None: + def __init__(self, dataset_path: Path, partition_dir: Path, config: dict[str, str]) -> None: self._parse_config(config) super().__init__(dataset_path, partition_dir) - def _parse_config(self, config: Dict[str, str]) -> None: + def _parse_config(self, config: dict[str, str]) -> None: if "json_lines" in config: self.json_lines = config["json_lines"] == "True" def partition_dataset( - self, n_partitions: int, label_column_name: Optional[str] = None, label_map: Optional[Dict[int, str]] = None + self, n_partitions: int, label_column_name: str | None = None, label_map: dict[int, str] | None = None ) -> None: df = pd.read_json(self.dataset_path, lines=self.json_lines) # Shuffle the dataframe rows df = df.sample(frac=1).reset_index(drop=True) - paritioned_dfs = cast(List[pd.DataFrame], np.array_split(df, n_partitions)) + paritioned_dfs = cast(list[pd.DataFrame], np.array_split(df, n_partitions)) for chunk, df in enumerate(paritioned_dfs): df.to_json( @@ -53,14 +53,14 @@ def partition_dataset( class CsvToPandasDatasetPartitioner(DatasetPartitioner): def partition_dataset( - self, n_partitions: int, label_column_name: Optional[str] = None, label_map: Optional[Dict[int, str]] = None + self, n_partitions: int, label_column_name: str | None = None, label_map: dict[int, str] | None = None ) -> None: df = pd.read_csv(self.dataset_path, names=["label", "title", "body"]) # Shuffle the dataframe rows df = df.sample(frac=1).reset_index(drop=True) if label_column_name and label_map: df["category"] = df[label_column_name].map(label_map) - paritioned_dfs = cast(List[pd.DataFrame], np.array_split(df, n_partitions)) + paritioned_dfs = cast(list[pd.DataFrame], np.array_split(df, n_partitions)) for chunk, df in enumerate(paritioned_dfs): df.to_json( @@ -71,7 +71,7 @@ def partition_dataset( def construct_dataset_partitioner( - dataset_path: Path, partition_dir: Path, config: Dict[str, str] + dataset_path: Path, partition_dir: Path, config: dict[str, str] ) -> DatasetPartitioner: data_loader_enum = DatasetPartitionerEnum(config["dataset_partitioner_type"]) if data_loader_enum == DatasetPartitionerEnum.JSON_TO_PANDAS: diff --git a/examples/ditto_example/client.py b/examples/ditto_example/client.py index b91c572d3..0a385755a 100644 --- a/examples/ditto_example/client.py +++ b/examples/ditto_example/client.py @@ -1,7 +1,6 @@ import argparse from logging import INFO from pathlib import Path -from typing import Dict, Tuple import flwr as fl import torch @@ -23,7 +22,7 @@ class MnistDittoClient(DittoClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) @@ -32,7 +31,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: def get_model(self, config: Config) -> nn.Module: return MnistNet().to(self.device) - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Note that the global optimizer operates on self.global_model.parameters() global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=0.01) local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.01) diff --git a/examples/ditto_example/server.py b/examples/ditto_example/server.py index 4137c68a1..0bf485ff3 100644 --- a/examples/ditto_example/server.py +++ b/examples/ditto_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -20,8 +20,8 @@ def fit_config( n_server_rounds: int, downsampling_ratio: float, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -32,7 +32,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/docker_basic_example/README.md b/examples/docker_basic_example/README.md index 450d34734..ea53a2548 100644 --- a/examples/docker_basic_example/README.md +++ b/examples/docker_basic_example/README.md @@ -5,7 +5,7 @@ In order to run the demo, first ensure that Docker Desktop is running. Instructi ``` docker compose up ``` -This will initiate the services specified in the file `docker-compose.yml`. Namely, the fl_server and fl_client services are built and run according to the Dockerfiles in the `fl_server` and `fl_client` directories, respectively. Each of these directories also include a `requirement.txt` file seperate from the `requirement.txt` in the root of the repository. These files include the python packages required to run the respective containers. +This will initiate the services specified in the file `docker-compose.yml`. Namely, the fl_server and fl_client services are built and run according to the Dockerfiles in the `fl_server` and `fl_client` directories, respectively. Each of these directories also include a `requirement.txt` file separate from the `requirement.txt` in the root of the repository. These files include the python packages required to run the respective containers. A config.yaml must be present in the root of this directory with the following arguments: * `n_clients`: number of clients the server waits for in order to run the FL training diff --git a/examples/docker_basic_example/fl_client/client.py b/examples/docker_basic_example/fl_client/client.py index 678e138f6..918e1b89b 100644 --- a/examples/docker_basic_example/fl_client/client.py +++ b/examples/docker_basic_example/fl_client/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence from pathlib import Path -from typing import Sequence, Tuple import flwr as fl import torch @@ -24,7 +24,7 @@ def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.dev self.model = Net() self.parameter_exchanger = FullParameterExchanger() - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) return train_loader, val_loader diff --git a/examples/docker_basic_example/fl_server/server.py b/examples/docker_basic_example/fl_server/server.py index eed311a3f..34bc59ed3 100644 --- a/examples/docker_basic_example/fl_server/server.py +++ b/examples/docker_basic_example/fl_server/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -24,7 +24,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/dp_fed_examples/client_level_dp/README.md b/examples/dp_fed_examples/client_level_dp/README.md index 57f544eac..e89080659 100644 --- a/examples/dp_fed_examples/client_level_dp/README.md +++ b/examples/dp_fed_examples/client_level_dp/README.md @@ -1,6 +1,6 @@ # Client Level Differential Privacy Federated Learning Example -This example shows how to implement Differential Privacy into the Federated Learning framework. In this case we focus on *client level* privacy which is a more substantial version of instance level DP, where the participation of an entire client's set of data is protected from training dataset membership inference. This example uses the FedAvgM implementation with unweighted averaging (To be implemented) suggested in Differentially Private Learning with Adaptive Clipping. The example uses an accountant specifically tailered to this approach. The clients are Poisson sampled by default. +This example shows how to implement Differential Privacy into the Federated Learning framework. In this case we focus on *client level* privacy which is a more substantial version of instance level DP, where the participation of an entire client's set of data is protected from training dataset membership inference. This example uses the FedAvgM implementation with unweighted averaging (To be implemented) suggested in Differentially Private Learning with Adaptive Clipping. The example uses an accountant specifically tailored to this approach. The clients are Poisson sampled by default. ## Running the Example In order to run the example, first ensure you have [installed the dependencies in your virtual environment according to the main README](/README.md#development-requirements) and it has been activated. diff --git a/examples/dp_fed_examples/client_level_dp/client.py b/examples/dp_fed_examples/client_level_dp/client.py index 090ad6013..948058285 100644 --- a/examples/dp_fed_examples/client_level_dp/client.py +++ b/examples/dp_fed_examples/client_level_dp/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -18,7 +17,7 @@ class CifarClient(NumpyClippingClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) return train_loader, val_loader diff --git a/examples/dp_fed_examples/client_level_dp/server.py b/examples/dp_fed_examples/client_level_dp/server.py index 01e32b92f..3fb2ca448 100644 --- a/examples/dp_fed_examples/client_level_dp/server.py +++ b/examples/dp_fed_examples/client_level_dp/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -19,8 +19,8 @@ def construct_config( current_round: int, batch_size: int, adaptive_clipping: bool, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: # NOTE: The omitted variable is server_round which allows for dynamically changing the config each round return { @@ -35,13 +35,13 @@ def fit_config( batch_size: int, adaptive_clipping: bool, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return construct_config(current_round, batch_size, adaptive_clipping, local_epochs, local_steps) -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/dp_fed_examples/client_level_dp_weighted/README.md b/examples/dp_fed_examples/client_level_dp_weighted/README.md index cb2d229bd..39ff65600 100644 --- a/examples/dp_fed_examples/client_level_dp_weighted/README.md +++ b/examples/dp_fed_examples/client_level_dp_weighted/README.md @@ -1,6 +1,6 @@ # Client Level Differential Privacy Federated Learning Example -This example shows how to implement Differential Privacy into the Federated Learning framework. In this case we focus on *client level* privacy which is a more substantial version of instance level DP, where the participation of an entire client's set of data is protected from training dataset membership inference. This example uses the FedAvgM implementation with weighted averaging suggested in Learning Differentially Private Recurrent Language Models along with the adaptive clipping scheme proposed in Differentially Private Learning with Adaptive Clipping. The example uses an accountant specifically tailered to this approach. The clients are Poisson sampled by default. +This example shows how to implement Differential Privacy into the Federated Learning framework. In this case we focus on *client level* privacy which is a more substantial version of instance level DP, where the participation of an entire client's set of data is protected from training dataset membership inference. This example uses the FedAvgM implementation with weighted averaging suggested in Learning Differentially Private Recurrent Language Models along with the adaptive clipping scheme proposed in Differentially Private Learning with Adaptive Clipping. The example uses an accountant specifically tailored to this approach. The clients are Poisson sampled by default. The example involves collaboratively learning a logistic regression model across multiple hospitals to classify breast cancer given 31 features. The dataset is sourced from [kaggle](https://www.kaggle.com/competitions/breast-cancer-classification/overview). A processed federated version of the dataset is available in the repository. diff --git a/examples/dp_fed_examples/client_level_dp_weighted/client.py b/examples/dp_fed_examples/client_level_dp_weighted/client.py index 8d2130894..ff3d4f2ae 100644 --- a/examples/dp_fed_examples/client_level_dp_weighted/client.py +++ b/examples/dp_fed_examples/client_level_dp_weighted/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -21,7 +20,7 @@ class HospitalClient(NumpyClippingClient): def get_model(self, config: Config) -> nn.Module: return LogisticRegression(input_dim=31, output_dim=1).to(self.device) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) scaler_bytes = narrow_dict_type(config, "scaler", bytes) train_loader, val_loader, _ = load_data(self.data_path, batch_size, scaler_bytes) diff --git a/examples/dp_fed_examples/client_level_dp_weighted/data.py b/examples/dp_fed_examples/client_level_dp_weighted/data.py index 827d74535..a72f8c904 100644 --- a/examples/dp_fed_examples/client_level_dp_weighted/data.py +++ b/examples/dp_fed_examples/client_level_dp_weighted/data.py @@ -1,6 +1,5 @@ import pickle from pathlib import Path -from typing import Dict, Tuple import numpy as np import pandas as pd @@ -13,13 +12,13 @@ class Scaler: def __init__(self) -> None: self.scaler = MinMaxScaler() - def __call__(self, train_x: np.ndarray, val_x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + def __call__(self, train_x: np.ndarray, val_x: np.ndarray) -> tuple[np.ndarray, np.ndarray]: scaled_train_x = self.scaler.fit_transform(train_x) scaled_val_x = self.scaler.transform(val_x) return scaled_train_x, scaled_val_x -def load_data(data_dir: Path, batch_size: int, scaler_bytes: bytes) -> Tuple[DataLoader, DataLoader, Dict[str, int]]: +def load_data(data_dir: Path, batch_size: int, scaler_bytes: bytes) -> tuple[DataLoader, DataLoader, dict[str, int]]: data = pd.read_csv(data_dir, index_col=False) features = data.loc[:, data.columns != "label"].values labels = data["label"].values diff --git a/examples/dp_fed_examples/client_level_dp_weighted/server.py b/examples/dp_fed_examples/client_level_dp_weighted/server.py index 033a85c4b..32236a635 100644 --- a/examples/dp_fed_examples/client_level_dp_weighted/server.py +++ b/examples/dp_fed_examples/client_level_dp_weighted/server.py @@ -1,7 +1,7 @@ import argparse import pickle from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -21,8 +21,8 @@ def construct_config( current_round: int, batch_size: int, adaptive_clipping: bool, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: # NOTE: The omitted variable is server_round which allows for dynamically changing the config each round return { @@ -38,13 +38,13 @@ def fit_config( batch_size: int, adaptive_clipping: bool, server_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return construct_config(server_round, batch_size, adaptive_clipping, local_epochs, local_steps) -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/dp_fed_examples/instance_level_dp/client.py b/examples/dp_fed_examples/instance_level_dp/client.py index b1dfa2e22..54f057037 100644 --- a/examples/dp_fed_examples/instance_level_dp/client.py +++ b/examples/dp_fed_examples/instance_level_dp/client.py @@ -2,7 +2,6 @@ import string from pathlib import Path from random import choices -from typing import Tuple import flwr as fl import torch @@ -22,7 +21,7 @@ class CifarClient(InstanceLevelDpClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) return train_loader, val_loader diff --git a/examples/dp_fed_examples/instance_level_dp/server.py b/examples/dp_fed_examples/instance_level_dp/server.py index f13800116..a7068ffdb 100644 --- a/examples/dp_fed_examples/instance_level_dp/server.py +++ b/examples/dp_fed_examples/instance_level_dp/server.py @@ -2,7 +2,7 @@ import string from functools import partial from random import choices -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -25,8 +25,8 @@ def construct_config( batch_size: int, noise_multiplier: float, clipping_bound: float, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: # NOTE: a new client is created in each round # NOTE: The omitted variable is server_round which allows for dynamically changing the config each round @@ -44,8 +44,8 @@ def fit_config( noise_multiplier: float, clipping_bound: float, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return construct_config( current_round, @@ -57,7 +57,7 @@ def fit_config( ) -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/dp_scaffold_example/README.md b/examples/dp_scaffold_example/README.md index 57d208897..d965a1429 100644 --- a/examples/dp_scaffold_example/README.md +++ b/examples/dp_scaffold_example/README.md @@ -1,6 +1,6 @@ # DP-SCAFFOLD Federated Learning Example This is an example of [Differentially Private Federated Learning on Heterogeneous Data -](https://arxiv.org/abs/2111.09278)(DP-SCAFFOLD). DP-SCAFFOLD is a differentially private adaption of SCAFFOLD - a method that corrects for client drift during the optimization procedure. In particular, DP-SCAFFOLD offers instance level privacy towards the server or a third party with access to the final model. At a given level of noise, DP-SCAFFOLD offers the same privacy guarentees as DP-FedAvg while offering better convergence. We leverage Opacus DP-SGD algorithm to impose DP guarantees and accounting is done using an instance-level privacy accountants. +](https://arxiv.org/abs/2111.09278)(DP-SCAFFOLD). DP-SCAFFOLD is a differentially private adaption of SCAFFOLD - a method that corrects for client drift during the optimization procedure. In particular, DP-SCAFFOLD offers instance level privacy towards the server or a third party with access to the final model. At a given level of noise, DP-SCAFFOLD offers the same privacy guarantees as DP-FedAvg while offering better convergence. We leverage Opacus DP-SGD algorithm to impose DP guarantees and accounting is done using an instance-level privacy accountants. In this demo, DP-SCAFFOLD is applied to an augmented version of the MNIST dataset that is non--IID. The FL server expects three clients to be spun up (i.e. it will wait until three clients report in before starting training). Each client has a modified version of the MNIST dataset. This modification essentially subsamples a certain number from the original training and validation sets of MNIST in order to synthetically induce local variations in the statistical properties of the clients training/validation data. In theory, the models should be able to perform well on their local data while learning from other clients data that has different statistical properties. The proportion of labels at each client is determined by Dirichlet distribution across the classes. The lower the beta parameter is for each class, the higher the degree of the label heterogeneity. diff --git a/examples/dp_scaffold_example/client.py b/examples/dp_scaffold_example/client.py index f9e6f072d..1d8626b64 100644 --- a/examples/dp_scaffold_example/client.py +++ b/examples/dp_scaffold_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -19,7 +18,7 @@ class MnistDPScaffoldClient(DPScaffoldClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=1.0) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) diff --git a/examples/dp_scaffold_example/server.py b/examples/dp_scaffold_example/server.py index 238bc3bb7..95565832c 100644 --- a/examples/dp_scaffold_example/server.py +++ b/examples/dp_scaffold_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -32,7 +32,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/dynamic_layer_exchange_example/client.py b/examples/dynamic_layer_exchange_example/client.py index 9f2ff697d..da727b5b7 100644 --- a/examples/dynamic_layer_exchange_example/client.py +++ b/examples/dynamic_layer_exchange_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -22,7 +21,7 @@ class CifarDynamicLayerClient(PartialWeightExchangeClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sample_percentage = narrow_dict_type(config, "sample_percentage", float) beta = narrow_dict_type(config, "beta", float) diff --git a/examples/dynamic_layer_exchange_example/server.py b/examples/dynamic_layer_exchange_example/server.py index 8324ca313..eff4b2c1b 100644 --- a/examples/dynamic_layer_exchange_example/server.py +++ b/examples/dynamic_layer_exchange_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -25,8 +25,8 @@ def fit_config( sample_percentage: float, beta: float, current_server_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: config: Config = { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -43,7 +43,7 @@ def fit_config( return config -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/ensemble_example/client.py b/examples/ensemble_example/client.py index def6af758..c5348258b 100644 --- a/examples/ensemble_example/client.py +++ b/examples/ensemble_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Dict, Tuple import flwr as fl import torch @@ -20,22 +19,22 @@ class MnistEnsembleClient(EnsembleClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=float(config["sample_percentage"])) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler=sampler) return train_loader, val_loader def get_model(self, config: Config) -> EnsembleModel: - ensemble_models: Dict[str, nn.Module] = { + ensemble_models: dict[str, nn.Module] = { "model_0": ConfigurableMnistNet(out_channel_mult=1).to(self.device), "model_1": ConfigurableMnistNet(out_channel_mult=2).to(self.device), "model_2": ConfigurableMnistNet(out_channel_mult=3).to(self.device), } return EnsembleModel(ensemble_models) - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: - ensemble_optimizers: Dict[str, torch.optim.Optimizer] = { + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: + ensemble_optimizers: dict[str, torch.optim.Optimizer] = { "model_0": torch.optim.AdamW(self.model.ensemble_models["model_0"].parameters(), lr=0.01), "model_1": torch.optim.AdamW(self.model.ensemble_models["model_1"].parameters(), lr=0.01), "model_2": torch.optim.AdamW(self.model.ensemble_models["model_2"].parameters(), lr=0.01), diff --git a/examples/ensemble_example/server.py b/examples/ensemble_example/server.py index 539aabc33..c3a8c623a 100644 --- a/examples/ensemble_example/server.py +++ b/examples/ensemble_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl import torch.nn as nn @@ -30,7 +30,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, @@ -40,7 +40,7 @@ def main(config: Dict[str, Any]) -> None: config["n_server_rounds"], ) - ensemble_models: Dict[str, nn.Module] = { + ensemble_models: dict[str, nn.Module] = { "model_0": ConfigurableMnistNet(out_channel_mult=1), "model_1": ConfigurableMnistNet(out_channel_mult=2), "model_2": ConfigurableMnistNet(out_channel_mult=3), diff --git a/examples/feature_alignment_example/client.py b/examples/feature_alignment_example/client.py index cd2e88b85..ba2f42954 100644 --- a/examples/feature_alignment_example/client.py +++ b/examples/feature_alignment_example/client.py @@ -1,7 +1,7 @@ import argparse +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import List, Sequence, Tuple, Union import flwr as fl import numpy as np @@ -28,11 +28,11 @@ def __init__( metrics: Sequence[Metric], device: torch.device, id_column: str, - targets: Union[str, List[str]], + targets: str | list[str], ) -> None: super().__init__(data_path, metrics, device, id_column, targets) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) # random train-valid split. indices = np.random.permutation(self.aligned_features.shape[0]) diff --git a/examples/feature_alignment_example/misalign_data.py b/examples/feature_alignment_example/misalign_data.py index ecf614a72..6a327fd4c 100644 --- a/examples/feature_alignment_example/misalign_data.py +++ b/examples/feature_alignment_example/misalign_data.py @@ -1,13 +1,12 @@ import random from logging import INFO from pathlib import Path -from typing import List import pandas as pd from flwr.common.logger import log -def random_split_data(df: pd.DataFrame, n: int) -> List[pd.DataFrame]: +def random_split_data(df: pd.DataFrame, n: int) -> list[pd.DataFrame]: df_rand = df.sample(frac=1.0, random_state=42) num_rows_per_df = len(df_rand) // n smaller_dfs = [df_rand.iloc[i * num_rows_per_df : (i + 1) * num_rows_per_df] for i in range(n - 1)] diff --git a/examples/feature_alignment_example/server.py b/examples/feature_alignment_example/server.py index 662507001..6756bcbe0 100644 --- a/examples/feature_alignment_example/server.py +++ b/examples/feature_alignment_example/server.py @@ -1,6 +1,6 @@ import argparse from pathlib import Path -from typing import Any, Dict +from typing import Any import flwr as fl import pandas as pd @@ -34,7 +34,7 @@ def construct_tab_feature_info_encoder( return TabularFeaturesInfoEncoder.encoder_from_dataframe(df, id_column, target_column) -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: client_manager = PoissonSamplingClientManager() strategy = BasicFedAvg( min_fit_clients=config["n_clients"], diff --git a/examples/fedbn_example/client.py b/examples/fedbn_example/client.py index 30658a1a8..93b1c365e 100644 --- a/examples/fedbn_example/client.py +++ b/examples/fedbn_example/client.py @@ -1,7 +1,7 @@ import argparse +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Sequence, Tuple import flwr as fl import torch @@ -24,7 +24,7 @@ class MnistFedBNClient(BasicClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) @@ -49,7 +49,7 @@ def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.dev super().__init__(data_path, metrics, device) self.dataset_name = dataset_name - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _, _ = load_skin_cancer_data(self.data_path, self.dataset_name, batch_size) return train_loader, val_loader diff --git a/examples/fedbn_example/server.py b/examples/fedbn_example/server.py index 4214ad28a..46a9aad01 100644 --- a/examples/fedbn_example/server.py +++ b/examples/fedbn_example/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict, Optional +from typing import Any import flwr as fl import torch.nn as nn @@ -22,8 +22,8 @@ def fit_config( batch_size: int, n_server_rounds: int, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, dataset_name: str) -> None: +def main(config: dict[str, Any], server_address: str, dataset_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/feddg_ga_example/client.py b/examples/feddg_ga_example/client.py index 1a84fbee1..9908ef80c 100644 --- a/examples/feddg_ga_example/client.py +++ b/examples/feddg_ga_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Dict, Tuple import flwr as fl import torch @@ -22,7 +21,7 @@ class MnistApflClient(ApflClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) @@ -31,7 +30,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: def get_model(self, config: Config) -> nn.Module: return ApflModule(MnistNetWithBnAndFrozen()).to(self.device) - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01) global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01) return {"local": local_optimizer, "global": global_optimizer} diff --git a/examples/feddg_ga_example/server.py b/examples/feddg_ga_example/server.py index a235069b8..7d7c2775b 100644 --- a/examples/feddg_ga_example/server.py +++ b/examples/feddg_ga_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -22,8 +22,8 @@ def fit_config( batch_size: int, n_server_rounds: int, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, evaluate_after_fit: bool = False, pack_losses_with_val_metrics: bool = False, ) -> Config: @@ -37,7 +37,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/federated_eval_example/client.py b/examples/federated_eval_example/client.py index a3f1190c3..63e28ba59 100644 --- a/examples/federated_eval_example/client.py +++ b/examples/federated_eval_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -24,7 +24,7 @@ def __init__( data_path: Path, metrics: Sequence[Metric], device: torch.device, - model_checkpoint_path: Optional[Path], + model_checkpoint_path: Path | None, reporters: Sequence[BaseReporter] | None = None, ) -> None: super().__init__( @@ -36,11 +36,11 @@ def __init__( reporters=reporters, ) - def initialize_global_model(self, config: Config) -> Optional[nn.Module]: + def initialize_global_model(self, config: Config) -> nn.Module | None: # Initialized a global model to be hydrated with a server-side model if the parameters are passed return Net().to(self.device) - def get_data_loader(self, config: Config) -> Tuple[DataLoader]: + def get_data_loader(self, config: Config) -> tuple[DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) evaluation_loader, _ = load_cifar10_test_data(self.data_path, batch_size) return (evaluation_loader,) diff --git a/examples/federated_eval_example/server.py b/examples/federated_eval_example/server.py index 6d2511ace..9423513dc 100644 --- a/examples/federated_eval_example/server.py +++ b/examples/federated_eval_example/server.py @@ -1,6 +1,6 @@ import argparse from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any import flwr as fl @@ -10,7 +10,7 @@ from fl4health.utils.metric_aggregation import uniform_evaluate_metrics_aggregation_fn -def main(config: Dict[str, Any], server_checkpoint_path: Optional[Path]) -> None: +def main(config: dict[str, Any], server_checkpoint_path: Path | None) -> None: evaluate_config = {"batch_size": config["batch_size"]} # ClientManager that performs Poisson type sampling diff --git a/examples/fedopt_example/client.py b/examples/fedopt_example/client.py index 81cd90886..9e10c00c9 100644 --- a/examples/fedopt_example/client.py +++ b/examples/fedopt_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -27,7 +27,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, ) -> None: super().__init__(data_path, metrics, device, loss_meter_type, checkpoint_and_state_module) self.weight_matrix: torch.Tensor @@ -35,7 +35,7 @@ def __init__( self.label_encoder: LabelEncoder self.batch_size: int - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: sequence_length = narrow_dict_type(config, "sequence_length", int) self.batch_size = narrow_dict_type(config, "batch_size", int) # NOTE: self.vocabulary and self.label_encoder are initialized in setup_client before the call to @@ -61,7 +61,7 @@ def get_model(self, config: Config) -> nn.Module: def setup_client(self, config: Config) -> None: self.vocabulary = Vocabulary.from_json(narrow_dict_type(config, "vocabulary", str)) self.label_encoder = LabelEncoder.from_json(narrow_dict_type(config, "label_encoder", str)) - # Since the label_encoder is required for CompundMetric but it is not available until after we receive + # Since the label_encoder is required for CompoundMetric but it is not available until after we receive # it from the Server, we pass it to the CompoundMetric through the CompoundMetric._setup method once its # available for metric in self.metrics: @@ -72,13 +72,13 @@ def setup_client(self, config: Config) -> None: def predict( self, input: TorchInputType, - ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: """ Computes the prediction(s), and potentially features, of the model(s) given the input. Args: input (TorchInputType): the input to self.model's forward pass. TorchInputType is simply an alias - for the union of torch.Tensor and Dict[str, torch.Tensor]. + for the union of torch.Tensor and dict[str, torch.Tensor]. """ # While this isn't optimal, this is a good example of a custom predict function to manipulate the predictions assert isinstance(self.model, LSTM) and isinstance(input, torch.Tensor) diff --git a/examples/fedopt_example/client_data.py b/examples/fedopt_example/client_data.py index e0117dba7..43cfa9b15 100644 --- a/examples/fedopt_example/client_data.py +++ b/examples/fedopt_example/client_data.py @@ -2,7 +2,6 @@ import json from pathlib import Path -from typing import Dict, List, Optional, Tuple import nltk import numpy as np @@ -16,7 +15,7 @@ class LabelEncoder: - def __init__(self, classes: List[str], label_to_class: Dict[int, str], class_to_label: Dict[str, int]) -> None: + def __init__(self, classes: list[str], label_to_class: dict[int, str], class_to_label: dict[str, int]) -> None: self.classes = classes self.label_to_class = label_to_class self.class_to_label = class_to_label @@ -55,7 +54,7 @@ def label_dataframe(self, df: pd.DataFrame, class_column: str) -> pd.DataFrame: class Vocabulary: - def __init__(self, vocabulary_dict: Optional[Dict[str, int]], train_set: Optional[List[List[str]]]) -> None: + def __init__(self, vocabulary_dict: dict[str, int] | None, train_set: list[list[str]] | None) -> None: if vocabulary_dict is not None: self.word2index = vocabulary_dict elif train_set is not None: @@ -64,7 +63,7 @@ def __init__(self, vocabulary_dict: Optional[Dict[str, int]], train_set: Optiona raise ValueError("Must provide either precomputed dictionary or training set to create vocabulary") self.vocabulary_size = len(self.word2index) - def _create_vocabulary(self, train_set: List[List[str]]) -> None: + def _create_vocabulary(self, train_set: list[list[str]]) -> None: word2index = {"": 0, "": 1, "": 2, "": 3} current_index = 4 for tokenized_text in train_set: @@ -74,7 +73,7 @@ def _create_vocabulary(self, train_set: List[List[str]]) -> None: current_index += 1 self.word2index = word2index - def encode_and_pad(self, tokenized_text: List[str], seq_length: int) -> List[int]: + def encode_and_pad(self, tokenized_text: list[str], seq_length: int) -> list[int]: sos = [self.word2index[""]] eos = [self.word2index[""]] pad = [self.word2index[""]] @@ -101,7 +100,7 @@ def from_json(json_str: str) -> Vocabulary: return Vocabulary(json.loads(json_str), None) -def tokenize_labeled_text(df: pd.DataFrame) -> List[Tuple[int, List[str]]]: +def tokenize_labeled_text(df: pd.DataFrame) -> list[tuple[int, list[str]]]: # Assumes the dataframe has two columns (label and text to be tokenized) return [(label, word_tokenize(text)) for label, text in list(df.to_records(index=False))] @@ -118,7 +117,7 @@ def create_weight_matrix(train_df: pd.DataFrame) -> torch.Tensor: def construct_dataloaders( path: Path, vocabulary: Vocabulary, label_encoder: LabelEncoder, sequence_length: int, batch_size: int -) -> Tuple[DataLoader, DataLoader, Dict[str, int], torch.Tensor]: +) -> tuple[DataLoader, DataLoader, dict[str, int], torch.Tensor]: df = get_local_data(path) # lower case the headlines and description and concatenate df["article_text"] = df["title"].str.lower() + " " + df["body"].str.lower() diff --git a/examples/fedopt_example/metrics.py b/examples/fedopt_example/metrics.py index b0a37905c..dc99fcd05 100644 --- a/examples/fedopt_example/metrics.py +++ b/examples/fedopt_example/metrics.py @@ -2,7 +2,6 @@ import json from logging import INFO -from typing import Dict, List, Optional import torch from flwr.common.logger import log @@ -37,7 +36,7 @@ def get_f1(self) -> float: return 0.0 return (2 * precision * recall) / (precision + recall) - def summarize(self) -> Dict[str, float]: + def summarize(self) -> dict[str, float]: return { f"{self.class_name}_precision": self.get_precision(), f"{self.class_name}_recall": self.get_recall(), @@ -63,7 +62,7 @@ def merge_outcomes(outcome_1: "Outcome", outcome_2: "Outcome") -> Outcome: class ServerMetrics: - def __init__(self, true_preds: int, total_preds: int, outcomes: List[Outcome]) -> None: + def __init__(self, true_preds: int, total_preds: int, outcomes: list[Outcome]) -> None: self.true_preds = true_preds self.total_preds = total_preds self.outcomes = outcomes @@ -91,10 +90,10 @@ def __init__(self, name: str) -> None: super().__init__(name) self.true_preds = 0 self.total_preds = 0 - self.classes: List[str] - self.label_to_class: Dict[int, str] + self.classes: list[str] + self.label_to_class: dict[int, str] self.n_classes: int - self.outcome_dict: Dict[str, Outcome] + self.outcome_dict: dict[str, Outcome] def setup(self, label_encoder: LabelEncoder) -> None: """ @@ -109,7 +108,7 @@ def setup(self, label_encoder: LabelEncoder) -> None: self.label_to_class = label_encoder.label_to_class self.n_classes = len(self.classes) - def _initialize_outcomes(self, classes: List[str]) -> Dict[str, Outcome]: + def _initialize_outcomes(self, classes: list[str]) -> dict[str, Outcome]: return {topic: Outcome(topic) for topic in classes} def update(self, input: torch.Tensor, target: torch.Tensor) -> None: @@ -132,7 +131,7 @@ def update(self, input: torch.Tensor, target: torch.Tensor) -> None: self.outcome_dict[true_class].false_negative += count self.outcome_dict[pred_class].false_positive += count - def compute(self, name: Optional[str]) -> Metrics: + def compute(self, name: str | None) -> Metrics: sum_f1 = 0.0 results: Metrics = {"total_preds": self.total_preds, "true_preds": self.true_preds} log_string = "" diff --git a/examples/fedopt_example/server.py b/examples/fedopt_example/server.py index cffafcda6..ba91b197f 100644 --- a/examples/fedopt_example/server.py +++ b/examples/fedopt_example/server.py @@ -3,7 +3,7 @@ from functools import partial from logging import INFO from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any import flwr as fl from flwr.common.logger import log @@ -20,10 +20,10 @@ from fl4health.utils.parameter_extraction import get_all_model_parameters -def metric_aggregation(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def metric_aggregation(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: total_preds = 0 true_preds = 0 - outcome_dict: Dict[str, Outcome] = {} + outcome_dict: dict[str, Outcome] = {} # Run through all of the metrics for _, client_metrics in all_client_metrics: for metric_name, metric_value in client_metrics.items(): @@ -93,7 +93,7 @@ def fit_config( ) -def pretrain_vocabulary(path: Path) -> Tuple[Vocabulary, LabelEncoder]: +def pretrain_vocabulary(path: Path) -> tuple[Vocabulary, LabelEncoder]: df = get_local_data(path) # Drop 20% of the texts to artificially create some UNK tokens processed_df, _ = train_test_split(df, test_size=0.8) @@ -103,7 +103,7 @@ def pretrain_vocabulary(path: Path) -> Tuple[Vocabulary, LabelEncoder]: return Vocabulary(None, headline_text + body_text), label_encoder -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: log(INFO, "Fitting vocabulary to a centralized text sample") data_path = Path( os.path.join( diff --git a/examples/fedpca_examples/dim_reduction/README.md b/examples/fedpca_examples/dim_reduction/README.md index 4c1613594..c3dc13f26 100644 --- a/examples/fedpca_examples/dim_reduction/README.md +++ b/examples/fedpca_examples/dim_reduction/README.md @@ -1,7 +1,7 @@ # PCA Dimensionality Reduction Example This example leverages federally computed principal components of the MNIST dataset to perform dimensionality reduction on the images, before proceeding with normal training. -This example assumes that the principal components of MNIST have already been computed and saved (run the example in `exampes/fedpca_examples/perform_pca` to do this), and the user supplies a path to the saved principal components to perform dimensionality reduction. +This example assumes that the principal components of MNIST have already been computed and saved (run the example in `examples/fedpca_examples/perform_pca` to do this), and the user supplies a path to the saved principal components to perform dimensionality reduction. Each client performs Dirichlet subsampling on the whole dataset to produce heterogeneous local datasets. diff --git a/examples/fedpca_examples/dim_reduction/client.py b/examples/fedpca_examples/dim_reduction/client.py index f049e1777..09c2ce00d 100644 --- a/examples/fedpca_examples/dim_reduction/client.py +++ b/examples/fedpca_examples/dim_reduction/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -22,7 +21,7 @@ class MnistFedPcaClient(BasicClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) pca_path = Path(narrow_dict_type(config, "pca_path", str)) new_dimension = narrow_dict_type(config, "new_dimension", int) diff --git a/examples/fedpca_examples/dim_reduction/server.py b/examples/fedpca_examples/dim_reduction/server.py index 57ef8f0bf..10bfc5e4b 100644 --- a/examples/fedpca_examples/dim_reduction/server.py +++ b/examples/fedpca_examples/dim_reduction/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/fedpca_examples/perform_pca/README.md b/examples/fedpca_examples/perform_pca/README.md index 1c4db6d43..bff30fc7b 100644 --- a/examples/fedpca_examples/perform_pca/README.md +++ b/examples/fedpca_examples/perform_pca/README.md @@ -3,7 +3,7 @@ This example performs federated principal component analysis. The goal is to com This is achieved by each client performing PCA locally at first, then the principal components are sent to a central server to be merged. -Each client performs Dirichlet subsampling on the whole dataset to produce heterogeneous local datasets. This is done to ensure that local principal compoents are distinct across different clients. +Each client performs Dirichlet subsampling on the whole dataset to produce heterogeneous local datasets. This is done to ensure that local principal components are distinct across different clients. ## Running the Example In order to run the example, first ensure you have [installed the dependencies in your virtual environment according to the main README](/README.md#development-requirements) and it has been activated. @@ -36,6 +36,6 @@ python -m examples.fedpca_examples.perform_pca.client --dataset_path /path/to/da the dataset already exists at the path specified, it will be loaded from there. Otherwise, the dataset will be automatically downloaded to the path specified and used in the run. -* The argument `components_save_path` specifies the directory in which the merged principal components will be saved, so they can be leveraged for other downstream tasks. An example of dimensionality reduction can be found at `exampes/fedpca_examples/dim_reduction`. +* The argument `components_save_path` specifies the directory in which the merged principal components will be saved, so they can be leveraged for other downstream tasks. An example of dimensionality reduction can be found at `examples/fedpca_examples/dim_reduction`. After the clients have been started federated pca should commence. diff --git a/examples/fedpca_examples/perform_pca/client.py b/examples/fedpca_examples/perform_pca/client.py index fa08fd882..32ac867f0 100644 --- a/examples/fedpca_examples/perform_pca/client.py +++ b/examples/fedpca_examples/perform_pca/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -16,7 +15,7 @@ class MnistFedPCAClient(FedPCAClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.5, beta=0.5) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) diff --git a/examples/fedpca_examples/perform_pca/server.py b/examples/fedpca_examples/perform_pca/server.py index 7fdbc1e88..9f99ed8cd 100644 --- a/examples/fedpca_examples/perform_pca/server.py +++ b/examples/fedpca_examples/perform_pca/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -32,7 +32,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/fedper_example/client.py b/examples/fedper_example/client.py index 990e4ea14..9d2de9fc7 100644 --- a/examples/fedper_example/client.py +++ b/examples/fedper_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence, Set from pathlib import Path -from typing import Sequence, Set, Tuple import flwr as fl import torch @@ -33,7 +33,7 @@ def __init__( super().__init__(data_path=data_path, metrics=metrics, device=device) self.minority_numbers = minority_numbers - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) downsample_percentage = narrow_dict_type(config, "downsampling_ratio", float) sampler = MinorityLabelBasedSampler(list(range(10)), downsample_percentage, self.minority_numbers) diff --git a/examples/fedper_example/server.py b/examples/fedper_example/server.py index ba670c738..8a952055a 100644 --- a/examples/fedper_example/server.py +++ b/examples/fedper_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -24,8 +24,8 @@ def fit_config( n_server_rounds: int, downsampling_ratio: float, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -36,7 +36,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/fedpm_example/client.py b/examples/fedpm_example/client.py index 3e59bd56c..bc793c120 100644 --- a/examples/fedpm_example/client.py +++ b/examples/fedpm_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence, Set from pathlib import Path -from typing import Sequence, Set, Tuple import flwr as fl import torch @@ -29,7 +29,7 @@ def __init__( super().__init__(data_path=data_path, metrics=metrics, device=device) self.minority_numbers = minority_numbers - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) downsample_percentage = narrow_dict_type(config, "downsampling_ratio", float) sampler = MinorityLabelBasedSampler(list(range(10)), downsample_percentage, self.minority_numbers) diff --git a/examples/fedpm_example/server.py b/examples/fedpm_example/server.py index ec568c23d..dd3629519 100644 --- a/examples/fedpm_example/server.py +++ b/examples/fedpm_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -22,8 +22,8 @@ def fit_config( is_masked_model: bool, priors_reset_frequency: int, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -36,7 +36,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/fedprox_example/client.py b/examples/fedprox_example/client.py index fe67cc53f..95628b005 100644 --- a/examples/fedprox_example/client.py +++ b/examples/fedprox_example/client.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn from flwr.common.logger import log -from flwr.common.typing import Config, Tuple +from flwr.common.typing import Config from torch.nn.modules.loss import _Loss from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -22,7 +22,7 @@ class MnistFedProxClient(FedProxClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) diff --git a/examples/fedprox_example/server.py b/examples/fedprox_example/server.py index d5003ae65..274a6ffe5 100644 --- a/examples/fedprox_example/server.py +++ b/examples/fedprox_example/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.logger import log @@ -23,9 +23,9 @@ def fit_config( batch_size: int, n_server_rounds: int, current_round: int, - reporting_config: Optional[Dict[str, str]] = None, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + reporting_config: dict[str, str] | None = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: base_config: Config = { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -42,7 +42,7 @@ def fit_config( return base_config -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/fedrep_example/client.py b/examples/fedrep_example/client.py index bb5d4479b..2147e9ada 100644 --- a/examples/fedrep_example/client.py +++ b/examples/fedrep_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Dict, Tuple import flwr as fl import torch @@ -23,7 +22,7 @@ class CifarFedRepClient(FedRepClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sample_percentage = narrow_dict_type(config, "sample_percentage", float) beta = narrow_dict_type(config, "beta", float) @@ -39,7 +38,7 @@ def get_model(self, config: Config) -> nn.Module: ).to(self.device) return model - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # We have two optimizers that are used for the head and representation optimization stages respectively assert isinstance(self.model, FedRepModel) representation_optimizer = torch.optim.AdamW(self.model.base_module.parameters(), lr=0.001) diff --git a/examples/fedrep_example/server.py b/examples/fedrep_example/server.py index 06a2eb146..be2754782 100644 --- a/examples/fedrep_example/server.py +++ b/examples/fedrep_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -38,7 +38,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py b/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py index 8a777b39e..fb61a40ed 100644 --- a/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py +++ b/examples/fedsimclr_example/fedsimclr_finetuning_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -20,7 +19,7 @@ from fl4health.utils.metrics import Accuracy -def get_finetune_dataset(data_dir: Path, batch_size: int) -> Tuple[DataLoader, DataLoader]: +def get_finetune_dataset(data_dir: Path, batch_size: int) -> tuple[DataLoader, DataLoader]: # Select test data (ie train=False) because train data was used in the pretraining stage data, targets = get_cifar10_data_and_target_tensors(data_dir, train=False) train_data, train_targets, val_data, val_targets = split_data_and_targets(data, targets) @@ -42,7 +41,7 @@ def get_finetune_dataset(data_dir: Path, batch_size: int) -> Tuple[DataLoader, D class CifarClient(BasicClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader = get_finetune_dataset(self.data_path, batch_size) return train_loader, val_loader diff --git a/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py b/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py index 1a263997b..f87060b61 100644 --- a/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py +++ b/examples/fedsimclr_example/fedsimclr_finetuning_example/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any import flwr as fl import torch.nn as nn @@ -23,8 +23,8 @@ def fit_config( batch_size: int, current_server_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -39,7 +39,7 @@ def load_model( return FedSimClrModel.load_pretrained_model(model_path) -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py b/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py index 07ab98aeb..efa0bf5fa 100644 --- a/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py +++ b/examples/fedsimclr_example/fedsimclr_pretraining_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Callable from pathlib import Path -from typing import Callable, Tuple import flwr as fl import torch @@ -20,7 +20,7 @@ from fl4health.utils.typing import TorchTargetType -def get_transforms() -> Tuple[Callable, Callable]: +def get_transforms() -> tuple[Callable, Callable]: input_transform = transforms.Compose( [ ToNumpy(), @@ -47,7 +47,7 @@ def get_transforms() -> Tuple[Callable, Callable]: return input_transform, target_transform -def get_pretrain_dataset(data_dir: Path, batch_size: int) -> Tuple[DataLoader, DataLoader]: +def get_pretrain_dataset(data_dir: Path, batch_size: int) -> tuple[DataLoader, DataLoader]: data, targets = get_cifar10_data_and_target_tensors(data_dir, True) train_data, _, val_data, _ = split_data_and_targets(data, targets) @@ -65,7 +65,7 @@ def get_pretrain_dataset(data_dir: Path, batch_size: int) -> Tuple[DataLoader, D class SslCifarClient(BasicClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader = get_pretrain_dataset(self.data_path, batch_size) return train_loader, val_loader diff --git a/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py b/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py index 1e45df1f3..0821936fa 100644 --- a/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py +++ b/examples/fedsimclr_example/fedsimclr_pretraining_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl import torch.nn as nn @@ -23,8 +23,8 @@ def fit_config( batch_size: int, current_server_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/fenda_ditto_example/client.py b/examples/fenda_ditto_example/client.py index f90f26a81..acf8cabb2 100644 --- a/examples/fenda_ditto_example/client.py +++ b/examples/fenda_ditto_example/client.py @@ -1,12 +1,11 @@ import argparse from logging import INFO from pathlib import Path -from typing import Dict import flwr as fl import torch from flwr.common.logger import log -from flwr.common.typing import Config, Tuple +from flwr.common.typing import Config from torch.nn.modules.loss import _Loss from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -31,7 +30,7 @@ class MnistFendaDittoClient(FendaDittoClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) @@ -50,7 +49,7 @@ def get_model(self, config: Config) -> FendaModel: ParallelSplitHeadClassifier(ParallelFeatureJoinMode.CONCATENATE), ).to(self.device) - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=0.01) local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.01) return {"global": global_optimizer, "local": local_optimizer} diff --git a/examples/fenda_ditto_example/server.py b/examples/fenda_ditto_example/server.py index 1daa29dfe..f68871f24 100644 --- a/examples/fenda_ditto_example/server.py +++ b/examples/fenda_ditto_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -24,8 +24,8 @@ def fit_config( n_server_rounds: int, downsampling_ratio: float, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -36,7 +36,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/fenda_example/client.py b/examples/fenda_example/client.py index fbcd84f87..a23b27663 100644 --- a/examples/fenda_example/client.py +++ b/examples/fenda_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence, Set from pathlib import Path -from typing import Sequence, Set, Tuple import flwr as fl import torch @@ -31,7 +31,7 @@ def __init__( super().__init__(data_path=data_path, metrics=metrics, device=device) self.minority_numbers = minority_numbers - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) downsample_percentage = narrow_dict_type(config, "downsampling_ratio", float) sampler = MinorityLabelBasedSampler(list(range(10)), downsample_percentage, self.minority_numbers) diff --git a/examples/fenda_example/server.py b/examples/fenda_example/server.py index 9baa6fa7f..64c837bd3 100644 --- a/examples/fenda_example/server.py +++ b/examples/fenda_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -22,8 +22,8 @@ def fit_config( n_server_rounds: int, downsampling_ratio: float, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -34,7 +34,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/fl_plus_local_ft_example/client.py b/examples/fl_plus_local_ft_example/client.py index 1e74ef5c4..7d2e07649 100644 --- a/examples/fl_plus_local_ft_example/client.py +++ b/examples/fl_plus_local_ft_example/client.py @@ -1,7 +1,6 @@ import argparse from logging import INFO from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -20,7 +19,7 @@ class CifarClient(BasicClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) return train_loader, val_loader diff --git a/examples/fl_plus_local_ft_example/server.py b/examples/fl_plus_local_ft_example/server.py index b4a2d8ca8..dd13a0d64 100644 --- a/examples/fl_plus_local_ft_example/server.py +++ b/examples/fl_plus_local_ft_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -18,8 +18,8 @@ def fit_config( batch_size: int, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -28,7 +28,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/flash_example/client.py b/examples/flash_example/client.py index ea47fbe3f..702a38094 100644 --- a/examples/flash_example/client.py +++ b/examples/flash_example/client.py @@ -4,7 +4,7 @@ import flwr as fl import torch import torch.nn as nn -from flwr.common.typing import Config, Tuple +from flwr.common.typing import Config from torch.nn.modules.loss import _Loss from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -17,7 +17,7 @@ class CifarFlashClient(FlashClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_cifar10_data( self.data_path, diff --git a/examples/flash_example/server.py b/examples/flash_example/server.py index 97da1ac4e..535c73bb1 100644 --- a/examples/flash_example/server.py +++ b/examples/flash_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -18,8 +18,8 @@ def fit_config( batch_size: int, current_server_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -28,7 +28,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: fit_config_fn = partial(fit_config, config["batch_size"], local_epochs=config.get("local_epochs")) model = Net() diff --git a/examples/model_merge_example/README.md b/examples/model_merge_example/README.md index e884de761..0814c5746 100644 --- a/examples/model_merge_example/README.md +++ b/examples/model_merge_example/README.md @@ -5,7 +5,7 @@ average these weights and perform evaluation on the client side and the server s evaluation function. The server expects two clients to be spun up (i.e. it will wait until two clients report in before starting model merging and evaluation). For convenience, pre-trained models on the MNIST train set have been provided for each of the clients in `assets/checkpoints_for_examples/model_merge_example` -under `0.pt` and `1.pt`. The model merging and subsequent evaluation can be perfomed with these weights +under `0.pt` and `1.pt`. The model merging and subsequent evaluation can be performed with these weights out-of-the-box. ## Running the Example diff --git a/examples/model_merge_example/server.py b/examples/model_merge_example/server.py index 6bf00cc70..813828a1e 100644 --- a/examples/model_merge_example/server.py +++ b/examples/model_merge_example/server.py @@ -1,8 +1,9 @@ import argparse from collections import OrderedDict +from collections.abc import Sequence from functools import partial from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any import flwr as fl import torch @@ -36,7 +37,7 @@ def server_side_evaluate_fn( _: int, parameters: NDArrays, config: Config, -) -> Optional[Tuple[float, Dict[str, Scalar]]]: +) -> tuple[float, dict[str, Scalar]] | None: model.to(device) model.eval() evaluate_metric_manager = MetricManager(metrics, "evaluate") @@ -54,7 +55,7 @@ def server_side_evaluate_fn( return 0.0, evaluate_metric_manager.compute() -def main(config: Dict[str, Any], data_path: Path) -> None: +def main(config: dict[str, Any], data_path: Path) -> None: _, val_loader, _ = load_mnist_data(data_path, config["batch_size"]) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") server_side_evaluate_fn_partial = partial(server_side_evaluate_fn, MnistNet(), val_loader, [Accuracy("")], device) diff --git a/examples/models/lstm_model.py b/examples/models/lstm_model.py index 5ef61ea4f..e37584c5e 100644 --- a/examples/models/lstm_model.py +++ b/examples/models/lstm_model.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn @@ -20,7 +18,7 @@ def __init__(self, vocab_size: int, vocab_dimension: int = 128, lstm_dimension: self.drop = nn.Dropout(p=0.3) self.fc = nn.Linear(2 * lstm_dimension, 4) - def forward(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + def forward(self, x: torch.Tensor, hidden: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: text_emb = self.embedding(x) out, _ = self.lstm(text_emb, hidden) @@ -34,6 +32,6 @@ def forward(self, x: torch.Tensor, hidden: Tuple[torch.Tensor, torch.Tensor]) -> return text_out - def init_hidden(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + def init_hidden(self, batch_size: int) -> tuple[torch.Tensor, torch.Tensor]: # 4 since the number of layers is 2 and it is bidirectional (so 1 per layer per direction) return (torch.zeros(4, batch_size, self.lstm_dimension), torch.zeros(4, batch_size, self.lstm_dimension)) diff --git a/examples/models/masked_model.py b/examples/models/masked_model.py index 6427bb23e..5acc3c7c8 100644 --- a/examples/models/masked_model.py +++ b/examples/models/masked_model.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn import torch.nn.functional as F @@ -9,7 +7,7 @@ class Masked4Cnn(nn.Module): - def __init__(self, device: Optional[torch.device] = None) -> None: + def __init__(self, device: torch.device | None = None) -> None: super().__init__() self.conv1 = MaskedConv2d( in_channels=1, out_channels=64, kernel_size=3, stride=1, padding="same", device=device diff --git a/examples/moon_example/client.py b/examples/moon_example/client.py index 4813ce48a..66aff4831 100644 --- a/examples/moon_example/client.py +++ b/examples/moon_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence, Set from pathlib import Path -from typing import Sequence, Set, Tuple import flwr as fl import torch @@ -31,7 +31,7 @@ def __init__( super().__init__(data_path=data_path, metrics=metrics, device=device, contrastive_weight=contrastive_weight) self.minority_numbers = minority_numbers - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) downsample_percentage = narrow_dict_type(config, "downsampling_ratio", float) sampler = MinorityLabelBasedSampler(list(range(10)), downsample_percentage, self.minority_numbers) diff --git a/examples/moon_example/server.py b/examples/moon_example/server.py index 353b7e253..21437650f 100644 --- a/examples/moon_example/server.py +++ b/examples/moon_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -21,8 +21,8 @@ def fit_config( n_server_rounds: int, downsampling_ratio: float, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/mr_mtl_example/client.py b/examples/mr_mtl_example/client.py index 833b9ee97..e38734f8e 100644 --- a/examples/mr_mtl_example/client.py +++ b/examples/mr_mtl_example/client.py @@ -6,7 +6,7 @@ import torch import torch.nn as nn from flwr.common.logger import log -from flwr.common.typing import Config, Tuple +from flwr.common.typing import Config from torch.nn.modules.loss import _Loss from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -22,7 +22,7 @@ class MnistMrMtlClient(MrMtlClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) diff --git a/examples/mr_mtl_example/server.py b/examples/mr_mtl_example/server.py index 958d117cd..c6933bc66 100644 --- a/examples/mr_mtl_example/server.py +++ b/examples/mr_mtl_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -20,8 +20,8 @@ def fit_config( n_server_rounds: int, downsampling_ratio: float, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -32,7 +32,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/nnunet_example/README.md b/examples/nnunet_example/README.md index 26cf81702..e5f890955 100644 --- a/examples/nnunet_example/README.md +++ b/examples/nnunet_example/README.md @@ -31,7 +31,7 @@ To run a federated learning experiment with nnunet models, first ensure you are python -m examples.nnunet_example.server --config_path examples/nnunet_example/config.yaml ``` -Once the server has started, start the necessary number of clients specified by the n_clients key in the config file. Each client can be started by running the following command in a seperate session. To view a list of optional flags use the --help flag. +Once the server has started, start the necessary number of clients specified by the n_clients key in the config file. Each client can be started by running the following command in a separate session. To view a list of optional flags use the --help flag. ```bash python -m examples.nnunet_example.client --dataset_path examples/datasets/nnunet diff --git a/examples/nnunet_example/client.py b/examples/nnunet_example/client.py index cd7974a2e..12acdfdba 100644 --- a/examples/nnunet_example/client.py +++ b/examples/nnunet_example/client.py @@ -4,7 +4,6 @@ from logging import DEBUG, INFO from os.path import exists, join from pathlib import Path -from typing import Optional, Union from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule @@ -32,12 +31,12 @@ def main( dataset_path: Path, msd_dataset_name: str, server_address: str, - fold: Union[int, str], + fold: int | str, always_preprocess: bool = False, verbose: bool = True, compile: bool = True, - intermediate_client_state_dir: Optional[str] = None, - client_name: Optional[str] = None, + intermediate_client_state_dir: str | None = None, + client_name: str | None = None, ) -> None: # Log device and server address device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -208,7 +207,7 @@ def main( ) # Check fold argument and start main method - fold: Union[int, str] = "all" if args.fold == "all" else int(args.fold) + fold: int | str = "all" if args.fold == "all" else int(args.fold) main( dataset_path=Path(args.dataset_path), msd_dataset_name=args.msd_dataset_name, diff --git a/examples/nnunet_example/server.py b/examples/nnunet_example/server.py index cedae28ac..0b38f7fa2 100644 --- a/examples/nnunet_example/server.py +++ b/examples/nnunet_example/server.py @@ -4,7 +4,6 @@ import warnings from functools import partial from pathlib import Path -from typing import Optional import yaml @@ -35,9 +34,9 @@ def get_config( n_server_rounds: int, batch_size: int, n_clients: int, - nnunet_plans: Optional[str] = None, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + nnunet_plans: str | None = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: # Create config config: Config = { @@ -60,8 +59,8 @@ def get_config( def main( config: dict, server_address: str, - intermediate_server_state_dir: Optional[str] = None, - server_name: Optional[str] = None, + intermediate_server_state_dir: str | None = None, + server_name: str | None = None, ) -> None: # Partial function with everything set except current server round fit_config_fn = partial( diff --git a/examples/perfcl_example/client.py b/examples/perfcl_example/client.py index c659f453a..588e208a7 100644 --- a/examples/perfcl_example/client.py +++ b/examples/perfcl_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence, Set from pathlib import Path -from typing import Sequence, Set, Tuple import flwr as fl import torch @@ -37,7 +37,7 @@ def __init__( ) self.minority_numbers = minority_numbers - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) downsample_percentage = narrow_dict_type(config, "downsampling_ratio", float) sampler = MinorityLabelBasedSampler(list(range(10)), downsample_percentage, self.minority_numbers) diff --git a/examples/perfcl_example/server.py b/examples/perfcl_example/server.py index 27f75b3ee..563c11332 100644 --- a/examples/perfcl_example/server.py +++ b/examples/perfcl_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -22,8 +22,8 @@ def fit_config( n_server_rounds: int, downsampling_ratio: float, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -34,7 +34,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/scaffold_example/client.py b/examples/scaffold_example/client.py index 2974b1961..35a123f1e 100644 --- a/examples/scaffold_example/client.py +++ b/examples/scaffold_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -21,7 +20,7 @@ class MnistScaffoldClient(ScaffoldClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) diff --git a/examples/scaffold_example/server.py b/examples/scaffold_example/server.py index 2ebdfecf6..3512df2df 100644 --- a/examples/scaffold_example/server.py +++ b/examples/scaffold_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -25,7 +25,7 @@ def fit_config(local_steps: int, batch_size: int, n_server_rounds: int, current_ } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/sparse_tensor_partial_exchange_example/client.py b/examples/sparse_tensor_partial_exchange_example/client.py index cff31dd85..080713436 100644 --- a/examples/sparse_tensor_partial_exchange_example/client.py +++ b/examples/sparse_tensor_partial_exchange_example/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Tuple import flwr as fl import torch @@ -21,7 +20,7 @@ class CifarSparseCooTensorClient(PartialWeightExchangeClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) return train_loader, val_loader diff --git a/examples/sparse_tensor_partial_exchange_example/server.py b/examples/sparse_tensor_partial_exchange_example/server.py index 780b33a28..e41c6948b 100644 --- a/examples/sparse_tensor_partial_exchange_example/server.py +++ b/examples/sparse_tensor_partial_exchange_example/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -19,8 +19,8 @@ def fit_config( batch_size: int, sparsity_level: float, current_server_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: config: Config = { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -31,7 +31,7 @@ def fit_config( return config -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/utils/functions.py b/examples/utils/functions.py index df0332145..46cfec234 100644 --- a/examples/utils/functions.py +++ b/examples/utils/functions.py @@ -1,9 +1,6 @@ -from typing import Optional - - def make_dict_with_epochs_or_steps( - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> dict[str, int]: if local_epochs is not None: return {"local_epochs": local_epochs} diff --git a/examples/warm_up_example/fedavg_warm_up/client.py b/examples/warm_up_example/fedavg_warm_up/client.py index 5b9aa463f..3948a5398 100644 --- a/examples/warm_up_example/fedavg_warm_up/client.py +++ b/examples/warm_up_example/fedavg_warm_up/client.py @@ -1,13 +1,13 @@ import argparse +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Sequence import flwr as fl import torch import torch.nn as nn from flwr.common.logger import log -from flwr.common.typing import Config, Tuple +from flwr.common.typing import Config from torch.nn.modules.loss import _Loss from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -43,7 +43,7 @@ def __init__( checkpoint_and_state_module=checkpoint_and_state_module, ) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) diff --git a/examples/warm_up_example/fedavg_warm_up/server.py b/examples/warm_up_example/fedavg_warm_up/server.py index a13c8b529..38574b617 100644 --- a/examples/warm_up_example/fedavg_warm_up/server.py +++ b/examples/warm_up_example/fedavg_warm_up/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.logger import log @@ -26,8 +26,8 @@ def fit_config( group: str, entity: str, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -40,7 +40,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/warm_up_example/warmed_up_fedprox/client.py b/examples/warm_up_example/warmed_up_fedprox/client.py index 8ecd22345..1dddf1e00 100644 --- a/examples/warm_up_example/warmed_up_fedprox/client.py +++ b/examples/warm_up_example/warmed_up_fedprox/client.py @@ -1,14 +1,14 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence import flwr as fl import torch import torch.nn as nn from flwr.common.logger import log -from flwr.common.typing import Config, NDArrays, Tuple +from flwr.common.typing import Config, NDArrays from torch.nn.modules.loss import _Loss from torch.optim import Optimizer from torch.utils.data import DataLoader @@ -30,7 +30,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, pretrained_model_dir: Path, - weights_mapping_path: Optional[Path], + weights_mapping_path: Path | None, ) -> None: super().__init__( data_path=data_path, @@ -45,7 +45,7 @@ def __init__( weights_mapping_path=weights_mapping_path, ) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) diff --git a/examples/warm_up_example/warmed_up_fedprox/server.py b/examples/warm_up_example/warmed_up_fedprox/server.py index 9bb6c48ce..98e3c56f1 100644 --- a/examples/warm_up_example/warmed_up_fedprox/server.py +++ b/examples/warm_up_example/warmed_up_fedprox/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.logger import log @@ -26,8 +26,8 @@ def fit_config( group: str, entity: str, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -41,7 +41,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/examples/warm_up_example/warmed_up_fenda/client.py b/examples/warm_up_example/warmed_up_fenda/client.py index 3d65a641d..b57974e9b 100644 --- a/examples/warm_up_example/warmed_up_fenda/client.py +++ b/examples/warm_up_example/warmed_up_fenda/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,7 +32,7 @@ def __init__( metrics: Sequence[Metric], device: torch.device, pretrained_model_dir: Path, - weights_mapping_path: Optional[Path], + weights_mapping_path: Path | None, ) -> None: super().__init__( data_path=data_path, @@ -47,7 +47,7 @@ def __init__( weights_mapping_path=weights_mapping_path, ) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=1) batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler) diff --git a/examples/warm_up_example/warmed_up_fenda/server.py b/examples/warm_up_example/warmed_up_fenda/server.py index e8801df3d..a622175ce 100644 --- a/examples/warm_up_example/warmed_up_fenda/server.py +++ b/examples/warm_up_example/warmed_up_fenda/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.logger import log @@ -28,8 +28,8 @@ def fit_config( group: str, entity: str, current_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -42,7 +42,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/fl4health/checkpointing/checkpointer.py b/fl4health/checkpointing/checkpointer.py index 1d29ad274..0494de59c 100644 --- a/fl4health/checkpointing/checkpointer.py +++ b/fl4health/checkpointing/checkpointer.py @@ -1,15 +1,16 @@ import os from abc import ABC, abstractmethod +from collections.abc import Callable from logging import ERROR, INFO, WARNING from pathlib import Path -from typing import Any, Callable, Dict, Optional +from typing import Any import torch import torch.nn as nn from flwr.common.logger import log from flwr.common.typing import Scalar -CheckpointScoreFunctionType = Callable[[float, Dict[str, Scalar]], float] +CheckpointScoreFunctionType = Callable[[float, dict[str, Scalar]], float] class TorchModuleCheckpointer(ABC): @@ -25,7 +26,7 @@ def __init__(self, checkpoint_dir: str, checkpoint_name: str) -> None: self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name) @abstractmethod - def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Scalar]) -> None: + def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: dict[str, Scalar]) -> None: """ Abstract method to be implemented by every TorchCheckpointer. Based on the loss and metrics provided it should determine whether to produce a checkpoint AND save it if applicable. @@ -33,7 +34,7 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca Args: model (nn.Module): Model to potentially save via the checkpointer loss (float): Computed loss associated with the model. - metrics (Dict[str, float]): Computed metrics associated with the model. + metrics (dict[str, float]): Computed metrics associated with the model. Raises: NotImplementedError: Must be implemented by the checkpointer @@ -81,7 +82,7 @@ def __init__( by the scoring function. Defaults to False. """ super().__init__(checkpoint_dir, checkpoint_name) - self.best_score: Optional[float] = None + self.best_score: float | None = None self.checkpoint_score_function = checkpoint_score_function # Whether we're looking to maximize (or minimize) the score produced by the checkpoint score function self.maximize = maximize @@ -108,7 +109,7 @@ def _should_checkpoint(self, comparison_score: float) -> bool: # If best score is none, then this is the first checkpoint return True - def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Scalar]) -> None: + def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: dict[str, Scalar]) -> None: """ Given the loss/metrics associated with the provided model, the checkpointer uses the scoring function to produce a score. This score will then be used to determine whether the model should be checkpointed or not. @@ -117,7 +118,7 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca model (nn.Module): Model that might be persisted if the scoring function determines it should be loss (float): Loss associated with the provided model. Will potentially contribute to checkpointing decision, based on the score function. - metrics (Dict[str, Scalar]): Metrics associated with the provided model. Will potentially contribute to + metrics (dict[str, Scalar]): Metrics associated with the provided model. Will potentially contribute to the checkpointing decision, based on the score function. Raises: @@ -160,12 +161,12 @@ def __init__(self, checkpoint_dir: str, checkpoint_name: str) -> None: """ # This function is required by the parent class, but not used in the LatestTorchCheckpointer - def null_score_function(loss: float, _: Dict[str, Scalar]) -> float: + def null_score_function(loss: float, _: dict[str, Scalar]) -> float: return 0.0 super().__init__(checkpoint_dir, checkpoint_name, null_score_function, False) - def maybe_checkpoint(self, model: nn.Module, loss: float, _: Dict[str, Scalar]) -> None: + def maybe_checkpoint(self, model: nn.Module, loss: float, _: dict[str, Scalar]) -> None: """ This function is essentially a pass through, as this class always checkpoints the provided model @@ -173,7 +174,7 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, _: Dict[str, Scalar]) model (nn.Module): Model to be checkpointed whenever this function is called loss (float): Loss associated with the provided model. Will potentially contribute to checkpointing decision, based on the score function. NOT USED. - metrics (Dict[str, Scalar]): Metrics associated with the provided model. Will potentially contribute to + metrics (dict[str, Scalar]): Metrics associated with the provided model. Will potentially contribute to the checkpointing decision, based on the score function. NOT USED. Raises: @@ -204,14 +205,14 @@ def __init__(self, checkpoint_dir: str, checkpoint_name: str) -> None: # The BestLossTorchCheckpointer just uses the provided loss to scoring checkpoints. More complicated # approaches may be used by other classes. - def loss_score_function(loss: float, _: Dict[str, Scalar]) -> float: + def loss_score_function(loss: float, _: dict[str, Scalar]) -> float: return loss super().__init__( checkpoint_dir, checkpoint_name, checkpoint_score_function=loss_score_function, maximize=False ) - def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Scalar]) -> None: + def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: dict[str, Scalar]) -> None: """ This function will decide whether to checkpoint the provided model based on the loss argument. If the provided loss is better than any previous losses seen by this checkpointer, the model will be saved. @@ -220,7 +221,7 @@ def maybe_checkpoint(self, model: nn.Module, loss: float, metrics: Dict[str, Sca model (nn.Module): Model that might be persisted if the scoring function determines it should be loss (float): Loss associated with the provided model. This value is used to determine whether to save the model or not. - metrics (Dict[str, Scalar]): Metrics associated with the provided model. Will not be used by this + metrics (dict[str, Scalar]): Metrics associated with the provided model. Will not be used by this checkpointer. Raises: @@ -266,7 +267,7 @@ def __init__(self, checkpoint_dir: Path) -> None: ) self.checkpoint_dir = checkpoint_dir - def save_checkpoint(self, checkpoint_name: str, checkpoint_dict: Dict[str, Any]) -> None: + def save_checkpoint(self, checkpoint_name: str, checkpoint_dict: dict[str, Any]) -> None: """ Saves checkpoint_dict to checkpoint path form from this classes checkpointer dir and the provided checkpoint name. @@ -274,7 +275,7 @@ def save_checkpoint(self, checkpoint_name: str, checkpoint_dict: Dict[str, Any]) Args: checkpoint_name (str): Name of the state checkpoint file. - checkpoint_dict (Dict[str, Any]): A dictionary with string keys and values of type Any representing the + checkpoint_dict (dict[str, Any]): A dictionary with string keys and values of type Any representing the state to checkpoint. Raises: @@ -290,7 +291,7 @@ def save_checkpoint(self, checkpoint_name: str, checkpoint_dict: Dict[str, Any]) log(ERROR, f"Encountered the following error while saving the checkpoint: {e}") raise e - def load_checkpoint(self, checkpoint_name: str) -> Dict[str, Any]: + def load_checkpoint(self, checkpoint_name: str) -> dict[str, Any]: """ Loads and returns the checkpoint stored in checkpoint_dir under the provided name if it exists. If it doesn't exist, an assertion error will be thrown. @@ -299,7 +300,7 @@ def load_checkpoint(self, checkpoint_name: str) -> Dict[str, Any]: checkpoint_name (str): Name of the state checkpoint to be loaded. Returns: - Dict[str, Any]: A dictionary representing the checkpointed state, as loaded by torch.load. + dict[str, Any]: A dictionary representing the checkpointed state, as loaded by torch.load. """ checkpoint_path = os.path.join(self.checkpoint_dir, checkpoint_name) diff --git a/fl4health/checkpointing/client_module.py b/fl4health/checkpointing/client_module.py index 0387240e5..f33591309 100644 --- a/fl4health/checkpointing/client_module.py +++ b/fl4health/checkpointing/client_module.py @@ -1,6 +1,7 @@ +from collections.abc import Sequence from enum import Enum from logging import INFO -from typing import Any, Dict, Sequence, Union +from typing import Any import torch.nn as nn from flwr.common.logger import log @@ -8,7 +9,7 @@ from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer, TorchModuleCheckpointer -ModelCheckpointers = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None +ModelCheckpointers = TorchModuleCheckpointer | Sequence[TorchModuleCheckpointer] | None class CheckpointMode(Enum): @@ -84,7 +85,7 @@ def _check_if_shared_checkpoint_names(self) -> None: ) def maybe_checkpoint( - self, model: nn.Module, loss: float, metrics: Dict[str, Scalar], mode: CheckpointMode + self, model: nn.Module, loss: float, metrics: dict[str, Scalar], mode: CheckpointMode ) -> None: """ Performs model checkpointing for a particular mode (either pre- or post-aggregation) if any checkpointers are @@ -95,7 +96,7 @@ def maybe_checkpoint( model (nn.Module): The model that might be checkpointed by the checkpointers. loss (float): The metric value obtained by the provided model. Used by the checkpointer(s) to decide whether to checkpoint the model. - metrics (Dict[str, Scalar]): The metrics obtained by the provided model. Potentially used by checkpointer + metrics (dict[str, Scalar]): The metrics obtained by the provided model. Potentially used by checkpointer to decide whether to checkpoint the model. mode (CheckpointMode): Determines which of the types of checkpointers to use. Currently, the only modes available are pre- and post-aggregation. @@ -118,7 +119,7 @@ def maybe_checkpoint( else: raise ValueError(f"Unrecognized mode for checkpointing: {str(mode)}") - def save_state(self, state_checkpoint_name: str, state: Dict[str, Any]) -> None: + def save_state(self, state_checkpoint_name: str, state: dict[str, Any]) -> None: """ This function is meant to facilitate saving state required to restart an FL process on the client side. This function will simply save whatever information is passed in the state variable using the file name in @@ -127,7 +128,7 @@ def save_state(self, state_checkpoint_name: str, state: Dict[str, Any]) -> None: Args: state_checkpoint_name (str): Name of the state checkpoint file. The checkpointer itself will have a directory to which state will be saved. - state (Dict[str, Any]): State to be saved so that training might be resumed on the client if federated + state (dict[str, Any]): State to be saved so that training might be resumed on the client if federated training is interrupted. For example, this might contain things like optimizer states, learning rate scheduler states, etc. @@ -140,7 +141,7 @@ def save_state(self, state_checkpoint_name: str, state: Dict[str, Any]) -> None: else: raise ValueError("Attempting to save state but no state checkpointer is specified") - def maybe_load_state(self, state_checkpoint_name: str) -> Dict[str, Any] | None: + def maybe_load_state(self, state_checkpoint_name: str) -> dict[str, Any] | None: """ This function facilitates loading of any pre-existing state (with the name state_checkpoint_name) in the directory of the state_checkpointer. If the state already exists at the proper path, the state is loaded @@ -154,7 +155,7 @@ def maybe_load_state(self, state_checkpoint_name: str) -> Dict[str, Any] | None: ValueError: Throws an error if this function is called, but no state checkpointer has been provided Returns: - Dict[str, Any] | None: If the state checkpoint properly exists and is loaded correctly, this dictionary + dict[str, Any] | None: If the state checkpoint properly exists and is loaded correctly, this dictionary carries that state. Otherwise, we return a None (or throw an exception). """ diff --git a/fl4health/checkpointing/opacus_checkpointer.py b/fl4health/checkpointing/opacus_checkpointer.py index cb0c9a73c..ff8442e93 100644 --- a/fl4health/checkpointing/opacus_checkpointer.py +++ b/fl4health/checkpointing/opacus_checkpointer.py @@ -1,6 +1,6 @@ import pickle from logging import INFO -from typing import Any, Dict +from typing import Any import torch.nn as nn from flwr.common.logger import log @@ -17,7 +17,7 @@ class OpacusCheckpointer(FunctionTorchModuleCheckpointer): fixes this issue. """ - def maybe_checkpoint(self, model: GradSampleModule, loss: float, metrics: Dict[str, Scalar]) -> None: + def maybe_checkpoint(self, model: GradSampleModule, loss: float, metrics: dict[str, Scalar]) -> None: """ Overriding the checkpointing strategy of the FunctionTorchCheckpointer to save model state dictionaries instead of using the torch.save workflow. @@ -25,7 +25,7 @@ def maybe_checkpoint(self, model: GradSampleModule, loss: float, metrics: Dict[s Args: model (nn.Module): Model to be potentially saved (should be an Opacus wrapped model) loss (float): Loss value associated with the model to be used in checkpointing decisions. - metrics (Dict[str, Scalar]): Metrics associated with the model to be used in checkpointing decisions. + metrics (dict[str, Scalar]): Metrics associated with the model to be used in checkpointing decisions. """ assert isinstance( model, GradSampleModule @@ -47,16 +47,16 @@ def maybe_checkpoint(self, model: GradSampleModule, loss: float, metrics: Dict[s f"{self.comparison_str} Best score ({self.best_score})", ) - def _process_state_dict_keys(self, opacus_state_dict: Dict[str, Any]) -> Dict[str, Any]: + def _process_state_dict_keys(self, opacus_state_dict: dict[str, Any]) -> dict[str, Any]: """ State dictionary keys for Opacus modules will be prefixed with an _module. So we remove these when loading the state information into a standard torch model. Args: - opacus_state_dict (Dict[str, Any]): A state dictionary produced by an Opacus GradSamplingModule + opacus_state_dict (dict[str, Any]): A state dictionary produced by an Opacus GradSamplingModule Returns: - Dict[str, Any]: A state dictionary with the _module. removed from the key prefixes to facilitate loading + dict[str, Any]: A state dictionary with the _module. removed from the key prefixes to facilitate loading the state dictionary into a non-Opacus model. """ @@ -113,12 +113,12 @@ def __init__(self, checkpoint_dir: str, checkpoint_name: str) -> None: """ # This function is required by the parent class, but not used in the LatestOpacusCheckpointer - def latest_score_function(loss: float, _: Dict[str, Scalar]) -> float: + def latest_score_function(loss: float, _: dict[str, Scalar]) -> float: return 0.0 super().__init__(checkpoint_dir, checkpoint_name, latest_score_function, False) - def maybe_checkpoint(self, model: GradSampleModule, loss: float, _: Dict[str, Scalar]) -> None: + def maybe_checkpoint(self, model: GradSampleModule, loss: float, _: dict[str, Scalar]) -> None: assert isinstance( model, GradSampleModule ), f"Model is of type: {type(model)}. This checkpointer need only be used to checkpoint Opacus modules" @@ -141,14 +141,14 @@ def __init__(self, checkpoint_dir: str, checkpoint_name: str) -> None: # The BestLossOpacusCheckpointer just uses the provided loss to scoring checkpoints. More complicated # approaches may be used by other classes. - def loss_score_function(loss: float, _: Dict[str, Scalar]) -> float: + def loss_score_function(loss: float, _: dict[str, Scalar]) -> float: return loss super().__init__( checkpoint_dir, checkpoint_name, checkpoint_score_function=loss_score_function, maximize=False ) - def maybe_checkpoint(self, model: GradSampleModule, loss: float, metrics: Dict[str, Scalar]) -> None: + def maybe_checkpoint(self, model: GradSampleModule, loss: float, metrics: dict[str, Scalar]) -> None: assert isinstance( model, GradSampleModule ), f"Model is of type: {type(model)}. This checkpointer need only be used to checkpoint Opacus modules" diff --git a/fl4health/checkpointing/server_module.py b/fl4health/checkpointing/server_module.py index f4a94f5c6..cc7ad3171 100644 --- a/fl4health/checkpointing/server_module.py +++ b/fl4health/checkpointing/server_module.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from logging import INFO -from typing import Any, Dict, Sequence, Union +from typing import Any import torch.nn as nn from flwr.common import Parameters @@ -19,7 +20,7 @@ SparseCooParameterPacker, ) -ModelCheckpointers = Union[TorchModuleCheckpointer, Sequence[TorchModuleCheckpointer]] | None +ModelCheckpointers = TorchModuleCheckpointer | Sequence[TorchModuleCheckpointer] | None class BaseServerCheckpointAndStateModule: @@ -97,7 +98,7 @@ def _check_if_shared_checkpoint_names(self) -> None: f"will be lost. The current paths are:\n{formatted_all_paths}" ) - def maybe_checkpoint(self, server_parameters: Parameters, loss: float, metrics: Dict[str, Scalar]) -> None: + def maybe_checkpoint(self, server_parameters: Parameters, loss: float, metrics: dict[str, Scalar]) -> None: """ If there are model checkpointers defined in this class, we hydrate the model for checkpointing with the server parameters and call maybe checkpoint model on each of the checkpointers to decide whether to checkpoint based @@ -107,7 +108,7 @@ def maybe_checkpoint(self, server_parameters: Parameters, loss: float, metrics: server_parameters (Parameters): Parameters held by the server that should be injected into the model loss (float): The aggregated loss value obtained by the current aggregated server model. Potentially used by checkpointer to decide whether to checkpoint the model. - metrics (Dict[str, Scalar]): The aggregated metrics obtained by the aggregated server model. Potentially + metrics (dict[str, Scalar]): The aggregated metrics obtained by the aggregated server model. Potentially used by checkpointer to decide whether to checkpoint the model. """ if self.model_checkpointers is not None and len(self.model_checkpointers) > 0: @@ -141,7 +142,7 @@ def _hydrate_model_for_checkpointing(self, server_parameters: Parameters) -> Non self.parameter_exchanger.pull_parameters(model_ndarrays, self.model) def save_state( - self, state_checkpoint_name: str, server_parameters: Parameters, other_state: Dict[str, Any] + self, state_checkpoint_name: str, server_parameters: Parameters, other_state: dict[str, Any] ) -> None: """ This function is meant to facilitate saving state required to restart on FL process on the server side. By @@ -157,7 +158,7 @@ def save_state( server_parameters (Parameters): Like model checkpointers, these are the aggregated Parameters stored by the server representing model state. They are mapped to a torch model architecture via the _hydrate_model_for_checkpointing function. - other_state (Dict[str, Any]): Any additional state (such as current server round) to be checkpointed in + other_state (dict[str, Any]): Any additional state (such as current server round) to be checkpointed in order to allow FL to restart from where it left off. Raises: @@ -174,7 +175,7 @@ def save_state( else: raise ValueError("Attempting to save state but no state checkpointer is specified") - def maybe_load_state(self, state_checkpoint_name: str) -> Dict[str, Any] | None: + def maybe_load_state(self, state_checkpoint_name: str) -> dict[str, Any] | None: """ This function facilitates loading of any pre-existing state (with the name state_checkpoint_name) in the directory of the state_checkpointer. If the state already exists at the proper path, the state is loaded @@ -188,7 +189,7 @@ def maybe_load_state(self, state_checkpoint_name: str) -> Dict[str, Any] | None: ValueError: Throws an error if this function is called, but no state checkpointer has been provided Returns: - Dict[str, Any] | None: If the state checkpoint properly exists and is loaded correctly, this dictionary + dict[str, Any] | None: If the state checkpoint properly exists and is loaded correctly, this dictionary carries that state. Otherwise, we return a None (or throw an exception). """ if self.state_checkpointer is not None: diff --git a/fl4health/client_managers/base_sampling_manager.py b/fl4health/client_managers/base_sampling_manager.py index 0d8051132..2f9962e55 100644 --- a/fl4health/client_managers/base_sampling_manager.py +++ b/fl4health/client_managers/base_sampling_manager.py @@ -1,5 +1,3 @@ -from typing import List, Optional, Union - from flwr.server.client_manager import SimpleClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.criterion import Criterion @@ -9,8 +7,8 @@ class BaseFractionSamplingManager(SimpleClientManager): """Overrides the Simple Client Manager to Provide Fixed Sampling without replacement for Clients""" def sample( - self, num_clients: int, min_num_clients: Optional[int] = None, criterion: Optional[Criterion] = None - ) -> List[ClientProxy]: + self, num_clients: int, min_num_clients: int | None = None, criterion: Criterion | None = None + ) -> list[ClientProxy]: raise NotImplementedError( "The basic sampling function is not implemented for these managers. " "Please use the fraction sample function instead" @@ -19,12 +17,12 @@ def sample( def sample_fraction( self, sample_fraction: float, - min_num_clients: Optional[int] = None, - criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: + min_num_clients: int | None = None, + criterion: Criterion | None = None, + ) -> list[ClientProxy]: raise NotImplementedError - def wait_and_filter(self, min_num_clients: Union[int, None], criterion: Optional[Criterion] = None) -> List: + def wait_and_filter(self, min_num_clients: int | None, criterion: Criterion | None = None) -> list[str]: if min_num_clients is not None: self.wait_for(min_num_clients) else: @@ -36,9 +34,7 @@ def wait_and_filter(self, min_num_clients: Union[int, None], criterion: Optional return available_cids - def sample_all( - self, min_num_clients: Optional[int] = None, criterion: Optional[Criterion] = None - ) -> List[ClientProxy]: + def sample_all(self, min_num_clients: int | None = None, criterion: Criterion | None = None) -> list[ClientProxy]: available_cids = self.wait_and_filter(min_num_clients, criterion) return [self.clients[cid] for cid in available_cids] diff --git a/fl4health/client_managers/fixed_sampling_client_manager.py b/fl4health/client_managers/fixed_sampling_client_manager.py index 4aec65a8b..ad9d60af8 100644 --- a/fl4health/client_managers/fixed_sampling_client_manager.py +++ b/fl4health/client_managers/fixed_sampling_client_manager.py @@ -1,5 +1,3 @@ -from typing import List, Optional - from flwr.server.client_manager import SimpleClientManager from flwr.server.client_proxy import ClientProxy from flwr.server.criterion import Criterion @@ -10,7 +8,7 @@ class FixedSamplingClientManager(SimpleClientManager): def __init__(self) -> None: super().__init__() - self.current_sample: Optional[List[ClientProxy]] = None + self.current_sample: list[ClientProxy] | None = None def reset_sample(self) -> None: """Resets the saved sample so self.sample produces a new sample again.""" @@ -19,22 +17,22 @@ def reset_sample(self) -> None: def sample( self, num_clients: int, - min_num_clients: Optional[int] = None, - criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: + min_num_clients: int | None = None, + criterion: Criterion | None = None, + ) -> list[ClientProxy]: """ Return a new client sample for the first time it runs. For subsequent runs, it will return the same sampling until self.reset_sampling() is called. Args: num_clients: (int) The number of clients to sample. - min_num_clients: (Optional[int]) The minimum number of clients to return in the sample. + min_num_clients: (int | None) The minimum number of clients to return in the sample. Optional, default is num_clients. - criterion: (Optional[Criterion]) A criterion to filter clients to sample. + criterion: (Criterion | None) A criterion to filter clients to sample. Optional, default is no criterion (no filter). Returns: - List[ClientProxy]: A list of sampled clients as ClientProxy instances. + list[ClientProxy]: A list of sampled clients as ClientProxy instances. """ if self.current_sample is None: self.current_sample = super().sample(num_clients, min_num_clients, criterion) diff --git a/fl4health/client_managers/fixed_without_replacement_manager.py b/fl4health/client_managers/fixed_without_replacement_manager.py index 29c40c60a..fd9e1a62d 100644 --- a/fl4health/client_managers/fixed_without_replacement_manager.py +++ b/fl4health/client_managers/fixed_without_replacement_manager.py @@ -1,6 +1,5 @@ import random from logging import WARNING -from typing import List, Optional from flwr.common.logger import log from flwr.server.client_proxy import ClientProxy @@ -16,9 +15,9 @@ def sample_fraction( self, sample_fraction: float, # minimum number of clients required to be available - min_num_clients: Optional[int] = None, - criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: + min_num_clients: int | None = None, + criterion: Criterion | None = None, + ) -> list[ClientProxy]: """Sample a number of Flower ClientProxy instances.""" available_cids = self.wait_and_filter(min_num_clients, criterion) diff --git a/fl4health/client_managers/poisson_sampling_manager.py b/fl4health/client_managers/poisson_sampling_manager.py index ec5c10bfb..d1263f5a0 100644 --- a/fl4health/client_managers/poisson_sampling_manager.py +++ b/fl4health/client_managers/poisson_sampling_manager.py @@ -1,5 +1,4 @@ from logging import WARNING -from typing import List, Optional import numpy as np from flwr.common.logger import log @@ -13,7 +12,7 @@ class PoissonSamplingClientManager(BaseFractionSamplingManager): """Overrides the Simple Client Manager to Provide Poisson Sampling for Clients rather than fixed without replacement sampling""" - def _poisson_sample(self, sampling_probability: float, available_cids: List[str]) -> List[str]: + def _poisson_sample(self, sampling_probability: float, available_cids: list[str]) -> list[str]: poisson_trials = np.random.binomial(1, sampling_probability, len(available_cids)) poisson_mask = poisson_trials.astype(dtype=bool) return list(np.array(available_cids)[poisson_mask]) @@ -21,9 +20,9 @@ def _poisson_sample(self, sampling_probability: float, available_cids: List[str] def sample_fraction( self, sample_fraction: float, - min_num_clients: Optional[int] = None, - criterion: Optional[Criterion] = None, - ) -> List[ClientProxy]: + min_num_clients: int | None = None, + criterion: Criterion | None = None, + ) -> list[ClientProxy]: """Poisson Sampling of Flower ClientProxy instances with a probability determine by sample_fraction.""" available_cids = self.wait_and_filter(min_num_clients, criterion) diff --git a/fl4health/clients/adaptive_drift_constraint_client.py b/fl4health/clients/adaptive_drift_constraint_client.py index 2a482cbb7..e5e50a669 100644 --- a/fl4health/clients/adaptive_drift_constraint_client.py +++ b/fl4health/clients/adaptive_drift_constraint_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, List, Optional, Sequence import torch from flwr.common.logger import log @@ -26,10 +26,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ This client serves as a base for FL methods implementing an auxiliary loss penalty with a weight coefficient @@ -46,7 +46,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -54,7 +54,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -69,7 +69,7 @@ def __init__( client_name=client_name, ) # These are the tensors that will be used to compute the penalty loss - self.drift_penalty_tensors: List[torch.Tensor] + self.drift_penalty_tensors: list[torch.Tensor] # Exchanger with packing to be able to exchange the weights and auxiliary information with the server for # adaptation self.parameter_exchanger: FullParameterExchangerWithPacking[float] @@ -187,7 +187,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: return FullParameterExchangerWithPacking(ParameterPackerAdaptiveConstraint()) - def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], config: Config) -> None: + def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None: """ Called after training with the number of local_steps performed over the FL round and the corresponding loss dictionary. We use this to store the training loss that we want to use to adapt the penalty weight parameter @@ -195,7 +195,7 @@ def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], conf Args: local_steps (int): The number of steps so far in the round in the local training. - loss_dict (Dict[str, float]): A dictionary of losses from local training. + loss_dict (dict[str, float]): A dictionary of losses from local training. config (Config): The config from the server """ assert "loss_for_adaptation" in loss_dict diff --git a/fl4health/clients/apfl_client.py b/fl4health/clients/apfl_client.py index 32338fac9..0ec4b8313 100644 --- a/fl4health/clients/apfl_client.py +++ b/fl4health/clients/apfl_client.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch from flwr.common.typing import Config @@ -22,10 +22,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ Client specifically implementing the APFL Algorithm: https://arxiv.org/abs/2003.13461 @@ -40,7 +40,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -48,7 +48,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -65,12 +65,12 @@ def __init__( self.model: ApflModule self.learning_rate: float - self.optimizers: Dict[str, torch.optim.Optimizer] + self.optimizers: dict[str, torch.optim.Optimizer] def is_start_of_local_training(self, step: int) -> bool: return step == 0 - def update_after_step(self, step: int, current_round: Optional[int] = None) -> None: + def update_after_step(self, step: int, current_round: int | None = None) -> None: """ Called after local train step on client. step is an integer that represents the local training step that was most recently completed. @@ -78,7 +78,7 @@ def update_after_step(self, step: int, current_round: Optional[int] = None) -> N if self.is_start_of_local_training(step) and self.model.adaptive_alpha: self.model.update_alpha() - def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[TrainingLosses, TorchPredType]: + def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[TrainingLosses, TorchPredType]: # Return preds value thats Dict of torch.Tensor containing personal, global and local predictions # Mechanics of training loop follow from original implementation @@ -121,18 +121,18 @@ def compute_loss_and_additional_losses( preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Computes the loss and any additional losses given predictions of the model and ground truth data. For APFL, the loss will be the personal loss and the additional losses are the global and local loss. Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. - features (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. + features (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target (torch.Tensor): Ground truth data to evaluate predictions against. Returns: - Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]; A tuple with: + tuple[torch.Tensor, dict[str, torch.Tensor]]; A tuple with: - The tensor for the personal loss - A dictionary of with `global_loss` and `local_loss` keys and their calculated values """ @@ -150,7 +150,7 @@ def set_optimizer(self, config: Config) -> None: assert isinstance(optimizers, dict) and set(("global", "local")) == set(optimizers.keys()) self.optimizers = optimizers - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: """ Returns a dictionary with global and local optimizers with string keys 'global' and 'local' respectively. """ diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index d69ae140b..e0f528f69 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any import torch import torch.nn as nn @@ -41,10 +41,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ Base FL Client with functionality to train, evaluate, log, report and checkpoint. @@ -58,7 +58,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -66,7 +66,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -101,7 +101,7 @@ def __init__( self.test_metric_manager = MetricManager(metrics=self.metrics, metric_manager_name="test") # Optional variable to store the weights that the client was initialized with during each round of training - self.initial_weights: Optional[NDArrays] = None + self.initial_weights: NDArrays | None = None self.total_steps: int = 0 # Need to track total_steps across rounds for WANDB reporting self.total_epochs: int = 0 # Will remain as 0 if training by steps @@ -112,15 +112,15 @@ def __init__( self.optimizers: dict[str, torch.optim.Optimizer] self.train_loader: DataLoader self.val_loader: DataLoader - self.test_loader: Optional[DataLoader] + self.test_loader: DataLoader | None self.num_train_samples: int self.num_val_samples: int - self.num_test_samples: Optional[int] = None - self.learning_rate: Optional[float] = None + self.num_test_samples: int | None + self.learning_rate: float | None # Config can contain max_num_validation_steps key, which determines an upper bound # for the validation steps taken. If not specified, no upper bound will be enforced. # By specifying this in the config we cannot guarantee the validation set is the same - # accross rounds for clients. + # across rounds for clients. self.max_num_validation_steps: int | None = None def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None: @@ -208,7 +208,7 @@ def shutdown(self) -> None: self.reports_manager.report({"shutdown": str(datetime.datetime.now())}) self.reports_manager.shutdown() - def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool, bool]: + def process_config(self, config: Config) -> tuple[int | None, int | None, int, bool, bool]: """ Method to ensure the required keys are present in config and extracts values to be returned. @@ -216,7 +216,7 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N config (Config): The config from the server. Returns: - Tuple[Union[int, None], Union[int, None], int, bool, bool]: Returns the local_epochs, local_steps, + tuple[int | None, int | None, int, bool, bool]: Returns the local_epochs, local_steps, current_server_round, evaluate_after_fit and pack_losses_with_val_metrics. Ensures only one of local_epochs and local_steps is defined in the config and sets the one that is not to None. @@ -248,7 +248,7 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N # Either local epochs or local steps is none based on what key is passed in the config return local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: """ Processes config, initializes client (if first round) and performs training based on the passed config. If per_round_checkpointer is not None, on initialization the client checks if a checkpointed client state @@ -259,7 +259,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict config (NDArrays): The config from the server. Returns: - Tuple[NDArrays, int, dict[str, Scalar]]: The parameters following the local training along with the + tuple[NDArrays, int, dict[str, Scalar]]: The parameters following the local training along with the number of samples in the local training dataset and the computed metrics throughout the fit. Raises: @@ -342,7 +342,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, dict metrics, ) - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: """ Evaluates the model on the validation set, and test set (if defined). @@ -351,7 +351,7 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di config (NDArrays): The config object from the server. Returns: - Tuple[float, int, dict[str, Scalar]]: A loss associated with the evaluation, the number of samples in the + tuple[float, int, dict[str, Scalar]]: A loss associated with the evaluation, the number of samples in the validation/test set and the metric_values associated with evaluation. """ if not self.initialized: @@ -414,8 +414,8 @@ def _should_evaluate_after_fit(self, evaluate_after_fit: bool) -> bool: def _log_header_str( self, - current_round: Optional[int] = None, - current_epoch: Optional[int] = None, + current_round: int | None = None, + current_epoch: int | None = None, logging_mode: LoggingMode = LoggingMode.TRAIN, ) -> None: """ @@ -423,9 +423,9 @@ def _log_header_str( epoch or at the beginning of the round if training by steps Args: - current_round (Optional[int], optional): The current FL round. (Ie current + current_round (int | None, optional): The current FL round. (Ie current server round). Defaults to None. - current_epoch (Optional[int], optional): The current epoch of local + current_epoch (int | None, optional): The current epoch of local training. Defaults to None. """ @@ -444,8 +444,8 @@ def _log_results( self, loss_dict: dict[str, float], metrics_dict: dict[str, Scalar], - current_round: Optional[int] = None, - current_epoch: Optional[int] = None, + current_round: int | None = None, + current_epoch: int | None = None, logging_mode: LoggingMode = LoggingMode.TRAIN, ) -> None: """ @@ -455,8 +455,8 @@ def _log_results( Args: loss_dict (dict[str, float]): A dictionary of losses to log. metrics_dict (dict[str, Scalar]): A dictionary of the metric to log. - current_round (Optional[int]): The current FL round (i.e., current server round). - current_epoch (Optional[int]): The current epoch of local training. + current_round (int | None): The current FL round (i.e., current server round). + current_epoch (int | None): The current epoch of local training. logging_mode (LoggingMode): The logging mode (Training, Validation, or Testing). """ _, client_logs = self.get_client_specific_logs(current_round, current_epoch, logging_mode) @@ -480,10 +480,10 @@ def _log_results( def get_client_specific_logs( self, - current_round: Optional[int], - current_epoch: Optional[int], + current_round: int | None, + current_epoch: int | None, logging_mode: LoggingMode, - ) -> Tuple[str, list[Tuple[LogLevel, str]]]: + ) -> tuple[str, list[tuple[LogLevel, str]]]: """ This function can be overridden to provide any client specific information to the basic client logging. For example, perhaps a client @@ -492,17 +492,17 @@ def get_client_specific_logs( of validation/testing. Args: - current_round (Optional[int]): The current FL round (i.e., current + current_round (int | None): The current FL round (i.e., current server round). - current_epoch (Optional[int]): The current epoch of local training. + current_epoch (int | None): The current epoch of local training. logging_mode (LoggingMode): The logging mode (Training, Validation, or Testing). Returns: - Optional[str]: A string to append to the header log string that + str | None: A string to append to the header log string that typically announces the current server round and current epoch at the beginning of each round or local epoch. - Optional[list[Tuple[LogLevel, str]]]]: A list of tuples where the + list[tuple[LogLevel, str]]] | None: A list of tuples where the first element is a LogLevel as defined in fl4health.utils. typing and the second element is a string message. Each item in the list will be logged at the end of each server round or epoch. @@ -542,7 +542,7 @@ def update_metric_manager( """ metric_manager.update(preds, target) - def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[TrainingLosses, TorchPredType]: + def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[TrainingLosses, TorchPredType]: """ Given a single batch of input and target data, generate predictions, compute loss, update parameters and optionally update metrics if they exist. (ie backprop on a single batch of data). @@ -553,7 +553,7 @@ def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[Tr target (TorchTargetType): The target corresponding to the input. Returns: - Tuple[TrainingLosses, TorchPredType]: The losses object from the train step along with + tuple[TrainingLosses, TorchPredType]: The losses object from the train step along with a dictionary of any predictions produced by the model. """ # Clear gradients from optimizer if they exist @@ -571,7 +571,7 @@ def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[Tr return losses, preds - def val_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[EvaluationLosses, TorchPredType]: + def val_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[EvaluationLosses, TorchPredType]: """ Given input and target, compute loss, update loss and metrics. Assumes self.model is in eval mode already. @@ -581,7 +581,7 @@ def val_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[Eval target (TorchTargetType): The target corresponding to the input. Returns: - Tuple[EvaluationLosses, TorchPredType]: The losses object from the val step along with + tuple[EvaluationLosses, TorchPredType]: The losses object from the val step along with a dictionary of the predictions produced by the model. """ @@ -596,17 +596,17 @@ def val_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[Eval def train_by_epochs( self, epochs: int, - current_round: Optional[int] = None, - ) -> Tuple[dict[str, float], dict[str, Scalar]]: + current_round: int | None = None, + ) -> tuple[dict[str, float], dict[str, Scalar]]: """ Train locally for the specified number of epochs. Args: epochs (int): The number of epochs for local training. - current_round (Optional[int], optional): The current FL round. + current_round (int | None, optional): The current FL round. Returns: - Tuple[dict[str, float], dict[str, Scalar]]: The loss and metrics dictionary from the local training. + tuple[dict[str, float], dict[str, Scalar]]: The loss and metrics dictionary from the local training. Loss is a dictionary of one or more losses that represent the different components of the loss. """ self.model.train() @@ -659,17 +659,17 @@ def train_by_epochs( def train_by_steps( self, steps: int, - current_round: Optional[int] = None, - ) -> Tuple[dict[str, float], dict[str, Scalar]]: + current_round: int | None = None, + ) -> tuple[dict[str, float], dict[str, Scalar]]: """ Train locally for the specified number of steps. Args: steps (int): The number of steps to train locally. - current_round (Optional[int], optional): The current FL round + current_round (int | None, optional): The current FL round Returns: - Tuple[dict[str, float], dict[str, Scalar]]: The loss and metrics dictionary from the local training. + tuple[dict[str, float], dict[str, Scalar]]: The loss and metrics dictionary from the local training. Loss is a dictionary of one or more losses that represent the different components of the loss. """ self.model.train() @@ -725,7 +725,7 @@ def _validate_or_test( metric_manager: MetricManager, logging_mode: LoggingMode = LoggingMode.VALIDATION, include_losses_in_metrics: bool = False, - ) -> Tuple[float, Dict[str, Scalar]]: + ) -> tuple[float, dict[str, Scalar]]: """ Evaluate the model on the given validation or test dataset. If max_num_validation_steps attribute is not None and in validation phase, steps are limited to the value of max_num_validation_steps. @@ -740,7 +740,7 @@ def _validate_or_test( dictionary. Defaults to False. Returns: - Tuple[float, dict[str, Scalar]]: The loss and a dictionary of metrics from evaluation. + tuple[float, dict[str, Scalar]]: The loss and a dictionary of metrics from evaluation. """ assert logging_mode in [ LoggingMode.VALIDATION, @@ -770,13 +770,13 @@ def _validate_or_test( return loss_dict["checkpoint"], metrics - def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_losses_in_metrics: bool = False) -> tuple[float, dict[str, Scalar]]: """ Validate the current model on the entire validation and potentially an entire test dataset if it has been defined. Returns: - Tuple[float, dict[str, Scalar]]: The validation loss and a dictionary of metrics + tuple[float, dict[str, Scalar]]: The validation loss and a dictionary of metrics from validation (and test if present). """ val_loss, val_metrics = self._validate_or_test( @@ -894,7 +894,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: """ return FullParameterExchanger() - def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureType]: + def predict(self, input: TorchInputType) -> tuple[TorchPredType, TorchFeatureType]: """ Computes the prediction(s), and potentially features, of the model(s) given the input. @@ -905,7 +905,7 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureTyp forward(). Returns: - Tuple[TorchPredType, TorchFeatureType]: A tuple in which the + tuple[TorchPredType, TorchFeatureType]: A tuple in which the first element contains a dictionary of predictions indexed by name and the second element contains intermediate activations indexed by name. By passing features, we can compute losses @@ -943,7 +943,7 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureTyp def compute_loss_and_additional_losses( self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType - ) -> Tuple[torch.Tensor, Optional[dict[str, torch.Tensor]]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]: """ Computes the loss and any additional losses given predictions of the model and ground truth data. @@ -953,7 +953,7 @@ def compute_loss_and_additional_losses( target (TorchTargetType): Ground truth data to evaluate predictions against. Returns: - Tuple[torch.Tensor, Union[dict[str, torch.Tensor], None]]; A tuple with: + tuple[torch.Tensor, dict[str, torch.Tensor] | None]; A tuple with: - The tensor for the loss - A dictionary of additional losses with their names and values, or None if there are no additional losses. @@ -1017,7 +1017,7 @@ def set_optimizer(self, config: Config) -> None: assert not isinstance(optimizer, dict) self.optimizers = {"global": optimizer} - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, ...]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, ...]: """ User defined method that returns a PyTorch Train DataLoader and a PyTorch Validation DataLoader @@ -1026,14 +1026,14 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, ...]: config (Config): The config from the server. Returns: - Tuple[DataLoader, ...]: Tuple of length 2. The client train and validation loader. + tuple[DataLoader, ...]: Tuple of length 2. The client train and validation loader. Raises: NotImplementedError: To be defined in child class. """ raise NotImplementedError - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: """ User defined method that returns a PyTorch Test DataLoader. By default, this function returns None, assuming that there is no test dataset to be used. @@ -1044,7 +1044,7 @@ def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: config (Config): The config from the server. Returns: - Optional[DataLoader]. The optional client test loader. Returns None. + DataLoader | None. The optional client test loader. Returns None. """ return None @@ -1079,7 +1079,7 @@ def get_criterion(self, config: Config) -> _Loss: """ raise NotImplementedError - def get_optimizer(self, config: Config) -> Union[Optimizer, dict[str, Optimizer]]: + def get_optimizer(self, config: Config) -> Optimizer | dict[str, Optimizer]: """ Method to be defined by user that returns the PyTorch optimizer used to train models locally Return value can be a single torch optimizer or a dictionary of string and torch optimizer. @@ -1090,7 +1090,7 @@ def get_optimizer(self, config: Config) -> Union[Optimizer, dict[str, Optimizer] config (Config): The config sent from the server. Returns: - Union[Optimizer, dict[str, Optimizer]]: An optimizer or dictionary of optimizers to + Optimizer | dict[str, Optimizer]: An optimizer or dictionary of optimizers to train model. Raises: @@ -1113,7 +1113,7 @@ def get_model(self, config: Config) -> nn.Module: """ raise NotImplementedError - def get_lr_scheduler(self, optimizer_key: str, config: Config) -> Union[None, _LRScheduler]: + def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler | None: """ Optional user defined method that returns learning rate scheduler to be used throughout training for the given optimizer. Defaults to None. @@ -1125,18 +1125,18 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> Union[None, _L config (Config): The config from the server. Returns: - Union[None, _LRScheduler]: Client learning rate schedulers. + _LRScheduler | None: Client learning rate schedulers. """ return None - def update_lr_schedulers(self, step: Union[int, None] = None, epoch: Union[int, None] = None) -> None: + def update_lr_schedulers(self, step: int | None = None, epoch: int | None = None) -> None: """ Updates any schedulers that exist. Can be overridden to customize update logic based on client state (ie self.total_steps). Args: - step (Union[int, None]): If using local_steps, current step of this round. Otherwise None. - epoch (Union[int, None]): If using local_epochs current epoch of this round. Otherwise None. + step (int | None): If using local_steps, current step of this round. Otherwise None. + epoch (int | None): If using local_epochs current epoch of this round. Otherwise None. """ assert (step is None) ^ (epoch is None) @@ -1171,18 +1171,18 @@ def update_after_train(self, local_steps: int, loss_dict: dict[str, float], conf """ pass - def update_before_step(self, step: int, current_round: Optional[int] = None) -> None: + def update_before_step(self, step: int, current_round: int | None = None) -> None: """ Hook method called before local train step. Args: step (int): The local training step that was most recently completed. Resets only at the end of the round. - current_round (Optional[int], optional): The current FL server round + current_round (int | None, optional): The current FL server round """ pass - def update_after_step(self, step: int, current_round: Optional[int] = None) -> None: + def update_after_step(self, step: int, current_round: int | None = None) -> None: """ Hook method called after local train step on client. step is an integer that represents the local training step that was most recently completed. For example, used by the APFL @@ -1192,7 +1192,7 @@ def update_after_step(self, step: int, current_round: Optional[int] = None) -> N Args: step (int): The step number in local training that was most recently completed. Resets only at the end of the round. - current_round (Optional[int], optional): The current FL server round + current_round (int | None, optional): The current FL server round """ pass diff --git a/fl4health/clients/clipping_client.py b/fl4health/clients/clipping_client.py index 857ab370a..f20801539 100644 --- a/fl4health/clients/clipping_client.py +++ b/fl4health/clients/clipping_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import torch from flwr.common import NDArrays @@ -27,10 +27,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ Client that clips updates being sent to the server where noise is added. Used to obtain Client Level @@ -43,7 +43,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -51,7 +51,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -66,15 +66,15 @@ def __init__( client_name=client_name, ) self.parameter_exchanger: FullParameterExchangerWithPacking[float] - self.clipping_bound: Optional[float] = None - self.adaptive_clipping: Optional[bool] = None + self.clipping_bound: float | None = None + self.adaptive_clipping: bool | None = None def calculate_parameters_norm(self, parameters: NDArrays) -> float: layer_inner_products = [pow(linalg.norm(layer_weights), 2) for layer_weights in parameters] # network Frobenius norm return pow(sum(layer_inner_products), 0.5) - def clip_parameters(self, parameters: NDArrays) -> Tuple[NDArrays, float]: + def clip_parameters(self, parameters: NDArrays) -> tuple[NDArrays, float]: assert self.clipping_bound is not None assert self.adaptive_clipping is not None # performs flat clipping (i.e. parameters * min(1, C/||parameters||_2)) @@ -89,7 +89,7 @@ def clip_parameters(self, parameters: NDArrays) -> Tuple[NDArrays, float]: # parameters and clipping bit return [layer_weights * clip_scalar for layer_weights in parameters], 0.0 - def compute_weight_update_and_clip(self, parameters: NDArrays) -> Tuple[NDArrays, float]: + def compute_weight_update_and_clip(self, parameters: NDArrays) -> tuple[NDArrays, float]: assert self.initial_weights is not None assert len(parameters) == len(self.initial_weights) weight_update: NDArrays = [ diff --git a/fl4health/clients/constrained_fenda_client.py b/fl4health/clients/constrained_fenda_client.py index fe4e3f51e..14c7d3b61 100644 --- a/fl4health/clients/constrained_fenda_client.py +++ b/fl4health/clients/constrained_fenda_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import WARNING from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch import torch.nn as nn @@ -26,11 +26,11 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, - loss_container: Optional[ConstrainedFendaLossContainer] = None, + client_name: str | None = None, + loss_container: ConstrainedFendaLossContainer | None = None, ) -> None: """ This class extends the functionality of FENDA training to include various kinds of constraints applied during @@ -43,7 +43,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -51,10 +51,10 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. - loss_container (Optional[ConstrainedFendaLossContainer], optional): Configuration that determines which + loss_container (ConstrainedFendaLossContainer | None, optional): Configuration that determines which losses will be applied during FENDA training. Defaults to None. """ @@ -82,9 +82,9 @@ def __init__( # Need to save previous local module, global module and aggregated global module at each communication round # to compute contrastive loss. - self.old_local_module: Optional[nn.Module] = None - self.old_global_module: Optional[nn.Module] = None - self.initial_global_module: Optional[nn.Module] = None + self.old_local_module: nn.Module | None = None + self.old_global_module: nn.Module | None = None + self.initial_global_module: nn.Module | None = None def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: assert isinstance(self.model, FendaModelWithFeatureState) @@ -102,7 +102,7 @@ def _flatten(self, features: torch.Tensor) -> torch.Tensor: """ return features.reshape(len(features), -1) - def _perfcl_keys_present(self, features: Dict[str, torch.Tensor]) -> bool: + def _perfcl_keys_present(self, features: dict[str, torch.Tensor]) -> bool: target_keys = { "old_local_features", "old_global_features", @@ -110,16 +110,16 @@ def _perfcl_keys_present(self, features: Dict[str, torch.Tensor]) -> bool: } return target_keys.issubset(features.keys()) - def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureType]: + def predict(self, input: TorchInputType) -> tuple[TorchPredType, TorchFeatureType]: """ Computes the prediction(s) and features of the model(s) given the input. Args: input (TorchInputType): Inputs to be fed into the model. TorchInputType is simply an alias - for the union of torch.Tensor and Dict[str, torch.Tensor]. + for the union of torch.Tensor and dict[str, torch.Tensor]. Returns: - Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple in which the first element + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. Specifically the features of the model, features of the global model and features of the old model are returned. All predictions included in dictionary will be used to compute metrics. @@ -143,7 +143,7 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureTyp return preds, features - def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], config: Config) -> None: + def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None: """ This function is called after client-side training concludes. If a contrastive or PerFCL loss function has been defined, it is used to save the local and global feature extraction weights/modules to be used in the @@ -151,7 +151,7 @@ def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], conf Args: local_steps (int): Number of steps performed during training - loss_dict (Dict[str, float]): Losses computed during training. + loss_dict (dict[str, float]): Losses computed during training. """ # Save the parameters of the old model assert isinstance(self.model, FendaModelWithFeatureState) @@ -185,19 +185,19 @@ def compute_loss_and_additional_losses( preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Computes the loss and any additional losses given predictions of the model and ground truth data. For FENDA, the loss is the total loss and the additional losses are the loss, total loss and, based on client attributes set from server config, cosine similarity loss, contrastive loss and perfcl losses. Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. - features (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. + features (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target (torch.Tensor): Ground truth data to evaluate predictions against. Returns: - Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]; A tuple with: + tuple[torch.Tensor, dict[str, torch.Tensor]]; A tuple with: - The tensor for the total loss - A dictionary with `loss`, `total_loss` and, based on client attributes set from server config, also `cos_sim_loss`, `contrastive_loss`, `contrastive_loss_minimize` and `contrastive_loss_minimize` @@ -252,9 +252,9 @@ def compute_evaluation_loss( client attributes set from server config. Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics. - features: (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + features: (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target: (torch.Tensor): Ground truth data to evaluate predictions against. Returns: diff --git a/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py b/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py index ed9de5c78..5acfb44d5 100644 --- a/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py +++ b/fl4health/clients/deep_mmd_clients/ditto_deep_mmd_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import ERROR from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch import torch.nn as nn @@ -26,14 +26,14 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, deep_mmd_loss_weight: float = 10.0, - feature_extraction_layers_with_size: Optional[Dict[str, int]] = None, + feature_extraction_layers_with_size: dict[str, int] | None = None, mmd_kernel_train_interval: int = 20, - num_accumulating_batches: Optional[int] = None, + num_accumulating_batches: int | None = None, ) -> None: """ This client implements the Deep MMD loss function in the Ditto framework. The Deep MMD loss is a measure of @@ -48,7 +48,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -56,11 +56,11 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. deep_mmd_loss_weight (float, optional): weight applied to the Deep MMD loss. Defaults to 10.0. - feature_extraction_layers_with_size (Optional[Dict[str, int]], optional): Dictionary of layers to extract + feature_extraction_layers_with_size (dict[str, int] | None, optional): Dictionary of layers to extract features from them and their respective feature size. Defaults to None. mmd_kernel_update_interval (int, optional): interval at which to train and update the Deep MMD kernel. If set to above 0, the kernel will be train based on whole distribution of latent features of data with @@ -91,7 +91,7 @@ def __init__( if feature_extraction_layers_with_size is None: feature_extraction_layers_with_size = {} self.flatten_feature_extraction_layers = {layer: True for layer in feature_extraction_layers_with_size.keys()} - self.deep_mmd_losses: Dict[str, DeepMmdLoss] = {} + self.deep_mmd_losses: dict[str, DeepMmdLoss] = {} # Save the random state to be restored after initializing the Deep MMD loss layers. random_state, numpy_state, torch_state = save_random_state() for layer, feature_size in feature_extraction_layers_with_size.items(): @@ -143,7 +143,7 @@ def _should_optimize_betas(self, step: int) -> bool: weighted_deep_mmd_loss = self.deep_mmd_loss_weight != 0 return step_at_interval and valid_components_present and weighted_deep_mmd_loss - def update_after_step(self, step: int, current_round: Optional[int] = None) -> None: + def update_after_step(self, step: int, current_round: int | None = None) -> None: if self.mmd_kernel_train_interval > 0 and self._should_optimize_betas(step): # Get the feature distribution of the local and initial global features with evaluation # mode @@ -161,7 +161,7 @@ def update_after_step(self, step: int, current_round: Optional[int] = None) -> N def update_buffers( self, local_model: torch.nn.Module, initial_global_model: torch.nn.Module - ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: """ Update the feature buffer of the local and global features. @@ -170,7 +170,7 @@ def update_buffers( initial_global_model (torch.nn.Module): Initial global model to extract features from. Returns: - Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple containing the extracted + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: A tuple containing the extracted features using the local and initial global models. """ @@ -224,18 +224,18 @@ def update_buffers( def predict( self, input: TorchInputType, - ) -> Tuple[TorchPredType, TorchFeatureType]: + ) -> tuple[TorchPredType, TorchFeatureType]: """ Computes the predictions for both the GLOBAL and LOCAL models and pack them into the prediction dictionary Args: input (TorchInputType): Inputs to be fed into the model. If input is - of type Dict[str, torch.Tensor], it is assumed that the keys of + of type dict[str, torch.Tensor], it is assumed that the keys of input match the names of the keyword arguments of self.model. forward(). Returns: - Tuple[TorchPredType, TorchFeatureType]: A tuple in which the + tuple[TorchPredType, TorchFeatureType]: A tuple in which the first element contains a dictionary of predictions indexed by name and the second element contains intermediate activations indexed by name. By passing features, we can compute all the @@ -262,7 +262,7 @@ def predict( return {"global": global_preds, "local": local_preds}, features - def _maybe_checkpoint(self, loss: float, metrics: Dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None: + def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None: # Hooks need to be removed before checkpointing the model self.local_feature_extractor.remove_hooks() super()._maybe_checkpoint(loss=loss, metrics=metrics, checkpoint_mode=checkpoint_mode) @@ -270,12 +270,12 @@ def _maybe_checkpoint(self, loss: float, metrics: Dict[str, Scalar], checkpoint_ # each time. self.local_feature_extractor._maybe_register_hooks() - def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_losses_in_metrics: bool = False) -> tuple[float, dict[str, Scalar]]: """ Validate the current model on the entire validation dataset. Returns: - Tuple[float, Dict[str, Scalar]]: The validation loss and a dictionary of metrics from validation. + tuple[float, dict[str, Scalar]]: The validation loss and a dictionary of metrics from validation. """ for layer in self.flatten_feature_extraction_layers.keys(): self.deep_mmd_losses[layer].training = False @@ -358,7 +358,7 @@ def compute_evaluation_loss( def compute_loss_and_additional_losses( self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Computes the loss and any additional losses given predictions of the model and ground truth data. @@ -368,7 +368,7 @@ def compute_loss_and_additional_losses( target (TorchTargetType): Ground truth data to evaluate predictions against. Returns: - Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple with: + tuple[torch.Tensor, dict[str, torch.Tensor]]: A tuple with: - The tensor for the loss - A dictionary of additional losses with their names and values, or None if there are no additional losses. diff --git a/fl4health/clients/ditto_client.py b/fl4health/clients/ditto_client.py index 6ed5ee8ff..183f0b3a3 100644 --- a/fl4health/clients/ditto_client.py +++ b/fl4health/clients/ditto_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch import torch.nn as nn @@ -25,10 +25,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ This client implements the Ditto algorithm from Ditto: Fair and Robust Federated Learning Through @@ -47,7 +47,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -55,7 +55,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -71,7 +71,7 @@ def __init__( ) self.global_model: nn.Module - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: """ Returns a dictionary with global and local optimizers with string keys 'global' and 'local' respectively. @@ -79,7 +79,7 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: config (Config): The config from the server. """ raise NotImplementedError( - "User Clients must define a function that returns a Dict[str, Optimizer] with keys 'global' and 'local' " + "User Clients must define a function that returns a dict[str, Optimizer] with keys 'global' and 'local' " "defining separate optimizers for the global and local models of Ditto." ) @@ -216,7 +216,7 @@ def update_before_train(self, current_server_round: int) -> None: super().update_before_train(current_server_round) - def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[TrainingLosses, TorchPredType]: + def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[TrainingLosses, TorchPredType]: """ Mechanics of training loop follow from original Ditto implementation: https://github.com/litian96/ditto As in the implementation there, steps of the global and local models are done in tandem and for the same @@ -224,11 +224,11 @@ def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[Tr Args: input (TorchInputType): input tensor to be run through both the global and local models. Here, - TorchInputType is simply an alias for the union of torch.Tensor and Dict[str, torch.Tensor]. + TorchInputType is simply an alias for the union of torch.Tensor and dict[str, torch.Tensor]. target (TorchTargetType): target tensor to be used to compute a loss given each models outputs. Returns: - Tuple[TrainingLosses, TorchPredType]: Returns relevant loss values from both the global and local + tuple[TrainingLosses, TorchPredType]: Returns relevant loss values from both the global and local model optimization steps. The prediction dictionary contains predictions indexed a "global" and "local" corresponding to predictions from the global and local Ditto models for metric evaluations. """ @@ -258,7 +258,7 @@ def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[Tr def predict( self, input: TorchInputType, - ) -> Tuple[TorchPredType, TorchFeatureType]: + ) -> tuple[TorchPredType, TorchFeatureType]: """ Computes the predictions for both the GLOBAL and LOCAL models and pack them into the prediction dictionary @@ -266,7 +266,7 @@ def predict( input (TorchInputType): Inputs to be fed into both models. Returns: - Tuple[TorchPredType, TorchFeatureType]: A tuple in which the first element + tuple[TorchPredType, TorchFeatureType]: A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. For Ditto, we only need the predictions, so the second dictionary is simply empty. @@ -295,7 +295,7 @@ def compute_loss_and_additional_losses( preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Computes the local model loss and the global Ditto model loss (stored in additional losses) for reporting and training of the global model @@ -305,7 +305,7 @@ def compute_loss_and_additional_losses( features (TorchFeatureType): Feature(s) of the model(s) indexed by name. target (TorchTargetType): Ground truth data to evaluate predictions against. Returns: - Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]; A tuple with: + tuple[torch.Tensor, dict[str, torch.Tensor]]; A tuple with: - The tensor for the model loss - A dictionary with `local_loss`, `global_loss` as additionally reported loss values. """ @@ -362,12 +362,12 @@ def compute_training_loss( return TrainingLosses(backward=loss + penalty_loss, additional_losses=additional_losses) - def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_losses_in_metrics: bool = False) -> tuple[float, dict[str, Scalar]]: """ Validate the current model on the entire validation dataset. Returns: - Tuple[float, Dict[str, Scalar]]: The validation loss and a dictionary of metrics from validation. + tuple[float, dict[str, Scalar]]: The validation loss and a dictionary of metrics from validation. """ # Set the global model to evaluate mode self.global_model.eval() diff --git a/fl4health/clients/ensemble_client.py b/fl4health/clients/ensemble_client.py index bc2e1df8a..5f6341158 100644 --- a/fl4health/clients/ensemble_client.py +++ b/fl4health/clients/ensemble_client.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch from flwr.common.typing import Config @@ -21,10 +21,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ This client enables the training of ensemble models in a federated manner. @@ -36,7 +36,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -44,7 +44,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -90,22 +90,22 @@ def set_optimizer(self, config: Config) -> None: assert isinstance(optimizers, dict) self.optimizers = optimizers - def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[TrainingLosses, TorchPredType]: + def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[TrainingLosses, TorchPredType]: """ Given a single batch of input and target data, generate predictions (both individual models and ensemble prediction), compute loss, update parameters and - optionally update metrics if they exist. (ie backprop on a single batch of data). + optionally update metrics if they exist. (ie backpropagation on a single batch of data). Assumes self.model is in train mode already. Differs from parent method in that, there are multiple losses that we have to do backward passes on and multiple optimizers to update parameters each train step. Args: input (TorchInputType): The input to be fed into the model. TorchInputType is simply an alias for the union of torch.Tensor and - Dict[str, torch.Tensor]. + dict[str, torch.Tensor]. target (torch.Tensor): The target corresponding to the input. Returns: - Tuple[TrainingLosses, Dict[str, torch.Tensor]]: The losses object from the train step along with + tuple[TrainingLosses, dict[str, torch.Tensor]]: The losses object from the train step along with a dictionary of any predictions produced by the model. """ assert isinstance(input, torch.Tensor) @@ -136,9 +136,9 @@ def compute_training_loss( Since the ensemble client has more than one model, there are multiple backward losses that exist. Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. Anything stored + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. Anything stored in preds will be used to compute metrics. - features: (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + features: (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target: (torch.Tensor): Ground truth data to evaluate predictions against. Returns: @@ -163,9 +163,9 @@ def compute_evaluation_loss( Since the ensemble client has more than one model, there are multiple backward losses that exist. Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. Anything stored + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. Anything stored in preds will be used to compute metrics. - features: (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + features: (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target: (torch.Tensor): Ground truth data to evaluate predictions against. Returns: @@ -179,7 +179,7 @@ def compute_evaluation_loss( checkpoint_loss = loss_dict["ensemble-pred"] return EvaluationLosses(checkpoint=checkpoint_loss) - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: """ Method to be defined by user that returns dictionary of optimizers with keys corresponding to the keys of the models in EnsembleModel that the optimizer applies too. @@ -188,7 +188,7 @@ def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: config (Config): The config sent from the server. Returns: - Dict[str, Optimizer]: An optimizer or dictionary of optimizers to + dict[str, Optimizer]: An optimizer or dictionary of optimizers to train model. Raises: diff --git a/fl4health/clients/evaluate_client.py b/fl4health/clients/evaluate_client.py index 08954fc08..619d91e79 100644 --- a/fl4health/clients/evaluate_client.py +++ b/fl4health/clients/evaluate_client.py @@ -1,7 +1,7 @@ import datetime +from collections.abc import Sequence from logging import INFO, WARNING from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch import torch.nn as nn @@ -27,9 +27,9 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - model_checkpoint_path: Optional[Path] = None, + model_checkpoint_path: Path | None = None, reporters: Sequence[BaseReporter] | None = None, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ This client implements an evaluation only flow. That is, there is no expectation of parameter exchange with @@ -44,10 +44,10 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - model_checkpoint_path (Optional[Path], optional): _description_. Defaults to None. + model_checkpoint_path (Path | None, optional): _description_. Defaults to None. reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client should send data to. Defaults to None. - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Defaults to None. """ # EvaluateClient does not call BasicClient constructor and sets attributes @@ -74,13 +74,13 @@ def __init__( # if they exist, to be evaluated on the client's dataset. self.data_loader: DataLoader self.criterion: _Loss - self.local_model: Optional[nn.Module] = None - self.global_model: Optional[nn.Module] = None + self.local_model: nn.Module | None = None + self.global_model: nn.Module | None = None - def get_parameters(self, config: Dict[str, Scalar]) -> NDArrays: + def get_parameters(self, config: dict[str, Scalar]) -> NDArrays: raise ValueError("Get Parameters is not implemented for an Evaluation-Only Client") - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: raise ValueError("Fit is not implemented for an Evaluation-Only Client") def setup_client(self, config: Config) -> None: @@ -106,7 +106,7 @@ def setup_client(self, config: Config) -> None: def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None: assert not fitting_round - # Sets the global model parameters transfered from the server using a parameter exchanger to coordinate how + # Sets the global model parameters transferred from the server using a parameter exchanger to coordinate how # parameters are set if len(parameters) > 0: # If a non-empty set of parameters are passed, then they are inserted into a global model to be evaluated. @@ -118,7 +118,7 @@ def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bo # be initialized with trained weights. self.global_model = None - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -151,7 +151,7 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di ) def _handle_logging( # type: ignore - self, losses: EvaluationLosses, metrics_dict: Dict[str, Scalar], is_global: bool + self, losses: EvaluationLosses, metrics_dict: dict[str, Scalar], is_global: bool ) -> None: metric_string = "\t".join([f"{key}: {str(val)}" for key, val in metrics_dict.items()]) loss_string = "\t".join([f"{key}: {str(val)}" for key, val in losses.as_dict().items()]) @@ -168,7 +168,7 @@ def validate_on_model( metric_meter: MetricManager, loss_meter: LossMeter, is_global: bool, - ) -> Tuple[EvaluationLosses, Dict[str, Scalar]]: + ) -> tuple[EvaluationLosses, dict[str, Scalar]]: model.eval() metric_meter.clear() loss_meter.clear() @@ -188,12 +188,12 @@ def validate_on_model( self._handle_logging(losses, metrics, is_global) return losses, metrics - def validate(self, include_loss_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: - local_loss: Optional[EvaluationLosses] = None - local_metrics: Optional[Dict[str, Scalar]] = None + def validate(self, include_loss_in_metrics: bool = False) -> tuple[float, dict[str, Scalar]]: + local_loss: EvaluationLosses | None = None + local_metrics: dict[str, Scalar] | None = None - global_loss: Optional[EvaluationLosses] = None - global_metrics: Optional[Dict[str, Scalar]] = None + global_loss: EvaluationLosses | None = None + global_metrics: dict[str, Scalar] | None = None if self.local_model: log(INFO, "Performing evaluation on local model") @@ -225,9 +225,9 @@ def validate(self, include_loss_in_metrics: bool = False) -> Tuple[float, Dict[s @staticmethod def merge_metrics( - global_metrics: Optional[Dict[str, Scalar]], - local_metrics: Optional[Dict[str, Scalar]], - ) -> Dict[str, Scalar]: + global_metrics: dict[str, Scalar] | None, + local_metrics: dict[str, Scalar] | None, + ) -> dict[str, Scalar]: # Merge metrics if necessary if global_metrics: metrics = global_metrics @@ -256,20 +256,20 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: """ return FullParameterExchanger() - def get_data_loader(self, config: Config) -> Tuple[DataLoader]: + def get_data_loader(self, config: Config) -> tuple[DataLoader]: """ User defined method that returns a PyTorch DataLoader for validation """ raise NotImplementedError - def initialize_global_model(self, config: Config) -> Optional[nn.Module]: + def initialize_global_model(self, config: Config) -> nn.Module | None: """ User defined method that to initializes a global model to potentially be hydrated by parameters sent by the server, by default, no global model is assumed to exist unless specified by the user """ return None - def get_local_model(self, config: Config) -> Optional[nn.Module]: + def get_local_model(self, config: Config) -> nn.Module | None: """ Functionality for initializing a model from a local checkpoint. This can be overridden for custom behavior """ diff --git a/fl4health/clients/fed_pca_client.py b/fl4health/clients/fed_pca_client.py index 773fb7966..2fe42e59a 100644 --- a/fl4health/clients/fed_pca_client.py +++ b/fl4health/clients/fed_pca_client.py @@ -2,7 +2,6 @@ import string from logging import INFO from pathlib import Path -from typing import Dict, Tuple import torch from flwr.client.numpy_client import NumPyClient @@ -58,14 +57,14 @@ def get_parameters(self, config: Config) -> NDArrays: def set_parameters(self, parameters: NDArrays, config: Config) -> None: """ - Sets the merged principal components transfered from the server. + Sets the merged principal components transferred from the server. Since federated PCA only runs for one round, the principal components obtained here are in fact the final result, so they are saved locally by each client for downstream tasks. """ self.parameter_exchanger.pull_parameters(parameters, self.model, config) self.save_model() - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: """ User defined method that returns a PyTorch Train DataLoader and a PyTorch Validation DataLoader @@ -99,7 +98,7 @@ def setup_client(self, config: Config) -> None: def get_data_tensor(self, data_loader: DataLoader) -> Tensor: raise NotImplementedError - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: """Perform PCA using the locally held dataset.""" if not self.initialized: self.setup_client(config) @@ -110,19 +109,19 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict cumulative_explained_variance = self.model.compute_cumulative_explained_variance() explained_variance_ratios = self.model.compute_explained_variance_ratios() - metrics: Dict[str, Scalar] = { + metrics: dict[str, Scalar] = { "cumulative_explained_variance": cumulative_explained_variance, "top_explained_variance_ratio": explained_variance_ratios[0].item(), } return (self.get_parameters(config), self.num_train_samples, metrics) - def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: dict[str, Scalar]) -> tuple[float, int, dict[str, Scalar]]: """ Evaluate merged principal components on the local validation set. Args: parameters (NDArrays): Server-merged principal components. - config (Dict[str, Scalar]): Config file. + config (dict[str, Scalar]): Config file. """ if not self.initialized: self.setup_client(config) @@ -136,7 +135,7 @@ def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]) -> Tuple[flo ) reconstruction_loss = self.model.compute_reconstruction_error(val_data_tensor_prepared, num_components_eval) projection_variance = self.model.compute_projection_variance(val_data_tensor_prepared, num_components_eval) - metrics: Dict[str, Scalar] = {"projection_variance": projection_variance} + metrics: dict[str, Scalar] = {"projection_variance": projection_variance} return (reconstruction_loss, self.num_val_samples, metrics) def save_model(self) -> None: diff --git a/fl4health/clients/fedpm_client.py b/fl4health/clients/fedpm_client.py index b0c2deffe..7e0dc42ce 100644 --- a/fl4health/clients/fedpm_client.py +++ b/fl4health/clients/fedpm_client.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from pathlib import Path -from typing import Optional, Sequence import torch from flwr.common.typing import Config @@ -22,10 +22,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ Client implementing the FedPM algorithm (https://arxiv.org/pdf/2209.15328). FedPM is a recent sparse, @@ -42,7 +42,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -50,7 +50,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ diff --git a/fl4health/clients/fedrep_client.py b/fl4health/clients/fedrep_client.py index 6ae9f1681..8c8fdfb35 100644 --- a/fl4health/clients/fedrep_client.py +++ b/fl4health/clients/fedrep_client.py @@ -1,8 +1,8 @@ import datetime +from collections.abc import Sequence from enum import Enum from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch from flwr.common.logger import log @@ -21,7 +21,7 @@ from fl4health.utils.metrics import Metric from fl4health.utils.typing import TorchInputType, TorchPredType, TorchTargetType -EpochsAndStepsTuple = Tuple[Optional[int], Optional[int], Optional[int], Optional[int]] +EpochsAndStepsTuple = tuple[int | None, int | None, int | None, int | None] class FedRepTrainMode(Enum): @@ -36,10 +36,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ Client implementing the training of FedRep (https://arxiv.org/abs/2303.05206). @@ -55,7 +55,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -63,7 +63,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -106,7 +106,7 @@ def _prepare_train_head(self) -> None: self.model.freeze_base_module() def _prefix_loss_and_metrics_dictionaries( - self, prefix: str, loss_dict: Dict[str, float], metrics_dict: Dict[str, Scalar] + self, prefix: str, loss_dict: dict[str, float], metrics_dict: dict[str, Scalar] ) -> None: """ This method is used to added the provided prefix string to the keys of both the loss_dict and the metrics_dict @@ -115,8 +115,8 @@ def _prefix_loss_and_metrics_dictionaries( Args: prefix (str): Prefix to be attached to all keys of the provided dictionaries. - loss_dict (Dict[str, float]): Dictionary of loss values obtained during training. - metrics (Dict[str, Scalar]): Dictionary of metrics values measured during training + loss_dict (dict[str, float]): Dictionary of loss values obtained during training. + metrics (dict[str, Scalar]): Dictionary of metrics values measured during training """ for loss_key in list(loss_dict): loss_dict[f"{prefix}_{loss_key}"] = loss_dict.pop(loss_key) @@ -172,11 +172,11 @@ def _extract_epochs_or_steps_specified(self, config: Config) -> EpochsAndStepsTu else: raise ValueError( "Either configuration keys not properly present or a mix of steps and epochs based training was " - "specified and is not admissable. Keys should be one of {local_head_epochs, local_rep_epochs} or " + "specified and is not admissible. Keys should be one of {local_head_epochs, local_rep_epochs} or " "{local_head_steps, local_rep_steps}" ) - def process_fed_rep_config(self, config: Config) -> Tuple[EpochsAndStepsTuple, int, bool]: + def process_fed_rep_config(self, config: Config) -> tuple[EpochsAndStepsTuple, int, bool]: """ Method to ensure the required keys are present in config and extracts values to be returned. We override this functionality from the BasicClient, because FedRep has slightly different requirements. That is, one needs @@ -186,7 +186,7 @@ def process_fed_rep_config(self, config: Config) -> Tuple[EpochsAndStepsTuple, i config (Config): The config from the server. Returns: - Tuple[Union[int, None], Union[int, None], int, bool]: Returns the local_epochs, local_steps, + tuple[int | None, int | None, int | None, int | None, int, bool]: Returns the local_epochs, local_steps, current_server_round and evaluate_after_fit. Ensures only one of local_epochs and local_steps is defined in the config and sets the one that is not to None. @@ -205,7 +205,7 @@ def process_fed_rep_config(self, config: Config) -> Tuple[EpochsAndStepsTuple, i # Either local epochs or local steps is none based on what key is passed in the config return steps_or_epochs_tuple, current_server_round, evaluate_after_fit - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: """ Returns a dictionary with global and local optimizers with string keys 'representation' and 'head' respectively. @@ -231,7 +231,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: assert isinstance(self.model, SequentiallySplitExchangeBaseModel) return FixedLayerExchanger(self.model.layers_to_exchange()) - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: """ Processes config, initializes client (if first round) and performs training based on the passed config. For FedRep, this coordinates calling the right training functions based on the passed steps. We need to @@ -243,7 +243,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict config (NDArrays): The config from the server. Returns: - Tuple[NDArrays, int, Dict[str, Scalar]]: The parameters following the local training along with the + tuple[NDArrays, int, dict[str, Scalar]]: The parameters following the local training along with the number of samples in the local training dataset and the computed metrics throughout the fit. Raises: @@ -305,17 +305,17 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict ) def train_fedrep_by_epochs( - self, head_epochs: int, rep_epochs: int, current_round: Optional[int] = None - ) -> Tuple[Dict[str, float], Dict[str, Scalar]]: + self, head_epochs: int, rep_epochs: int, current_round: int | None = None + ) -> tuple[dict[str, float], dict[str, Scalar]]: """ Train locally for the specified number of epochs. Args: epochs (int): The number of epochs for local training. - current_round (Optional[int]): The current FL round. + current_round (int | None): The current FL round. Returns: - Tuple[Dict[str, float], Dict[str, Scalar]]: The loss and metrics dictionary from the local training. + tuple[dict[str, float], dict[str, Scalar]]: The loss and metrics dictionary from the local training. Loss is a dictionary of one or more losses that represent the different components of the loss. """ # First we train the head module for head_epochs with the representations frozen in place @@ -344,8 +344,8 @@ def train_fedrep_by_epochs( return loss_dict_head, metrics_dict_head def train_fedrep_by_steps( - self, head_steps: int, rep_steps: int, current_round: Optional[int] = None - ) -> Tuple[Dict[str, float], Dict[str, Scalar]]: + self, head_steps: int, rep_steps: int, current_round: int | None = None + ) -> tuple[dict[str, float], dict[str, Scalar]]: """ Train locally for the specified number of steps. @@ -353,7 +353,7 @@ def train_fedrep_by_steps( steps (int): The number of steps to train locally. Returns: - Tuple[Dict[str, float], Dict[str, Scalar]]: The loss and metrics dictionary from the local training. + tuple[dict[str, float], dict[str, Scalar]]: The loss and metrics dictionary from the local training. Loss is a dictionary of one or more losses that represent the different components of the loss. """ assert isinstance(self.model, FedRepModel) @@ -382,7 +382,7 @@ def train_fedrep_by_steps( metrics_dict_head.update(metrics_dict_rep) return loss_dict_head, metrics_dict_head - def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[TrainingLosses, TorchPredType]: + def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[TrainingLosses, TorchPredType]: """ Mechanics of training loop follow the FedRep paper: https://arxiv.org/pdf/2102.07078.pdf In order to reuse the train_step functionality, we switch between the appropriate optimizers depending on the @@ -390,11 +390,11 @@ def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[Tr Args: input (TorchInputType): input tensor to be run through the model. Here, TorchInputType is simply an alias - for the union of torch.Tensor and Dict[str, torch.Tensor]. + for the union of torch.Tensor and dict[str, torch.Tensor]. target (torch.Tensor): target tensor to be used to compute a loss given the model's outputs. Returns: - Tuple[TrainingLosses, Dict[str, torch.Tensor]]: The losses object from the train step along with + tuple[TrainingLosses, dict[str, torch.Tensor]]: The losses object from the train step along with a dictionary of any predictions produced by the model. """ diff --git a/fl4health/clients/fenda_client.py b/fl4health/clients/fenda_client.py index 8abaeb0a9..bbe9334ac 100644 --- a/fl4health/clients/fenda_client.py +++ b/fl4health/clients/fenda_client.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from pathlib import Path -from typing import Optional, Sequence import torch from flwr.common.typing import Config @@ -21,10 +21,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ This client is used to perform client-side training associated with the FENDA method described in @@ -40,7 +40,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -48,7 +48,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ diff --git a/fl4health/clients/fenda_ditto_client.py b/fl4health/clients/fenda_ditto_client.py index 5df746ca9..76dfe3844 100644 --- a/fl4health/clients/fenda_ditto_client.py +++ b/fl4health/clients/fenda_ditto_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import torch from flwr.common.logger import log @@ -25,10 +25,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, freeze_global_feature_extractor: bool = False, ) -> None: """ @@ -65,7 +65,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -73,7 +73,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. freeze_global_feature_extractor (bool, optional): Determines whether we freeze the FENDA global feature @@ -239,7 +239,7 @@ def update_before_train(self, current_server_round: int) -> None: def predict( self, input: TorchInputType, - ) -> Tuple[TorchPredType, TorchFeatureType]: + ) -> tuple[TorchPredType, TorchFeatureType]: """ Computes the predictions for both the GLOBAL and LOCAL models and pack them into the prediction dictionary @@ -247,7 +247,7 @@ def predict( input (TorchInputType): Inputs to be fed into both models. Returns: - Tuple[TorchPredType, TorchFeatureType]: A tuple in which the first element + tuple[TorchPredType, TorchFeatureType]: A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. For Ditto+FENDA, we only need the predictions, so the second dictionary is simply empty. @@ -290,9 +290,9 @@ def compute_training_loss( optimize the global model is stored in the additional losses dictionary under "global_loss". Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. All predictions included in the dictionary will be used to compute metrics. - features (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + features (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target (torch.Tensor): Ground truth data to evaluate predictions against. Returns: diff --git a/fl4health/clients/flash_client.py b/fl4health/clients/flash_client.py index ef4a496c4..4c5541574 100644 --- a/fl4health/clients/flash_client.py +++ b/fl4health/clients/flash_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple, Union import torch from flwr.common.logger import log @@ -22,10 +22,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ This client is used to perform client-side training associated with the Flash method described in @@ -40,7 +40,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -48,7 +48,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -63,9 +63,9 @@ def __init__( client_name=client_name, ) # gamma: Threshold for early stopping based on the change in validation loss. - self.gamma: Optional[float] = None + self.gamma: float | None = None - def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, None], int, bool, bool]: + def process_config(self, config: Config) -> tuple[int | None, int | None, int, bool, bool]: local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics = ( super().process_config(config) ) @@ -77,8 +77,8 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N return local_epochs, local_steps, current_server_round, evaluate_after_fit, pack_losses_with_val_metrics def train_by_epochs( - self, epochs: int, current_round: Optional[int] = None - ) -> Tuple[Dict[str, float], Dict[str, Scalar]]: + self, epochs: int, current_round: int | None = None + ) -> tuple[dict[str, float], dict[str, Scalar]]: self.model.train() local_step = 0 previous_loss = float("inf") diff --git a/fl4health/clients/instance_level_dp_client.py b/fl4health/clients/instance_level_dp_client.py index f903e66eb..4ab932b91 100644 --- a/fl4health/clients/instance_level_dp_client.py +++ b/fl4health/clients/instance_level_dp_client.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from pathlib import Path -from typing import Optional, Sequence import torch from flwr.common.typing import Config @@ -21,10 +21,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ Client for Instance/Record level Differentially Private Federated Averaging @@ -36,7 +36,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -44,7 +44,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ diff --git a/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py b/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py index 9fa724c91..574dc653b 100644 --- a/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py +++ b/fl4health/clients/mkmmd_clients/ditto_mkmmd_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import ERROR, INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch import torch.nn as nn @@ -25,15 +25,15 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, mkmmd_loss_weight: float = 10.0, - feature_extraction_layers: Optional[Sequence[str]] = None, + feature_extraction_layers: Sequence[str] | None = None, feature_l2_norm_weight: float = 0.0, beta_global_update_interval: int = 20, - num_accumulating_batches: Optional[int] = None, + num_accumulating_batches: int | None = None, ) -> None: """ This client implements the MK-MMD loss function in the Ditto framework. The MK-MMD loss is a measure of the @@ -47,7 +47,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -55,11 +55,11 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. mkmmd_loss_weight (float, optional): weight applied to the MK-MMD loss. Defaults to 10.0. - feature_extraction_layers (Optional[Sequence[str]], optional): List of layers from which to extract + feature_extraction_layers (Sequence[str] | None, optional): List of layers from which to extract and flatten features. Defaults to None. feature_l2_norm_weight (float, optional): weight applied to the L2 norm of the features. Defaults to 0.0. @@ -104,7 +104,7 @@ def __init__( self.flatten_feature_extraction_layers = {layer: True for layer in feature_extraction_layers} else: self.flatten_feature_extraction_layers = {} - self.mkmmd_losses: Dict[str, MkMmdLoss] = {} + self.mkmmd_losses: dict[str, MkMmdLoss] = {} for layer in self.flatten_feature_extraction_layers.keys(): self.mkmmd_losses[layer] = MkMmdLoss( device=self.device, minimize_type_two_error=True, normalize_features=True, layer_name=layer @@ -144,7 +144,7 @@ def _should_optimize_betas(self, step: int) -> bool: weighted_mkmmd_loss = self.mkmmd_loss_weight != 0 return step_at_interval and valid_components_present and weighted_mkmmd_loss - def update_after_step(self, step: int, current_round: Optional[int] = None) -> None: + def update_after_step(self, step: int, current_round: int | None = None) -> None: if self.beta_global_update_interval > 0 and self._should_optimize_betas(step): # Get the feature distribution of the local and initial global features with evaluation # mode @@ -160,7 +160,7 @@ def update_after_step(self, step: int, current_round: Optional[int] = None) -> N def update_buffers( self, local_model: torch.nn.Module, initial_global_model: torch.nn.Module - ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: """ Update the feature buffer of the local and global features. @@ -169,7 +169,7 @@ def update_buffers( initial_global_model (torch.nn.Module): Initial global model to extract features from. Returns: - Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple containing the extracted + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: A tuple containing the extracted features using the local and initial global models. """ @@ -223,18 +223,18 @@ def update_buffers( def predict( self, input: TorchInputType, - ) -> Tuple[TorchPredType, TorchFeatureType]: + ) -> tuple[TorchPredType, TorchFeatureType]: """ Computes the predictions for both the GLOBAL and LOCAL models and pack them into the prediction dictionary Args: input (TorchInputType): Inputs to be fed into the model. If input is - of type Dict[str, torch.Tensor], it is assumed that the keys of + of type dict[str, torch.Tensor], it is assumed that the keys of input match the names of the keyword arguments of self.model. forward(). Returns: - Tuple[TorchPredType, TorchFeatureType]: A tuple in which the + tuple[TorchPredType, TorchFeatureType]: A tuple in which the first element contains a dictionary of predictions indexed by name and the second element contains intermediate activations indexed by name. By passing features, we can compute all the @@ -261,7 +261,7 @@ def predict( return {"global": global_preds, "local": local_preds}, features - def _maybe_checkpoint(self, loss: float, metrics: Dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None: + def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None: # Hooks need to be removed before checkpointing the model self.local_feature_extractor.remove_hooks() super()._maybe_checkpoint(loss=loss, metrics=metrics, checkpoint_mode=checkpoint_mode) @@ -323,7 +323,7 @@ def compute_training_loss( def compute_loss_and_additional_losses( self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Computes the loss and any additional losses given predictions of the model and ground truth data. @@ -333,7 +333,7 @@ def compute_loss_and_additional_losses( target (TorchTargetType): Ground truth data to evaluate predictions against. Returns: - Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple with: + tuple[torch.Tensor, dict[str, torch.Tensor]]: A tuple with: - The tensor for the loss - A dictionary of additional losses with their names and values, or None if there are no additional losses. diff --git a/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py b/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py index 642f6752c..8a1babb0a 100644 --- a/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py +++ b/fl4health/clients/mkmmd_clients/mr_mtl_mkmmd_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import ERROR, INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch from flwr.common.logger import log @@ -23,15 +23,15 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, mkmmd_loss_weight: float = 10.0, - feature_extraction_layers: Optional[Sequence[str]] = None, + feature_extraction_layers: Sequence[str] | None = None, feature_l2_norm_weight: float = 0.0, beta_global_update_interval: int = 20, - num_accumulating_batches: Optional[int] = None, + num_accumulating_batches: int | None = None, ) -> None: """ This client implements the MK-MMD loss function in the MR-MTL framework. The MK-MMD loss is a measure of the @@ -46,7 +46,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -54,11 +54,11 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. mkmmd_loss_weight (float, optional): weight applied to the MK-MMD loss. Defaults to 10.0. - feature_extraction_layers (Optional[Sequence[str]], optional): List of layers from which to extract + feature_extraction_layers (Sequence[str] | None, optional): List of layers from which to extract and flatten features. Defaults to None. feature_l2_norm_weight (float, optional): weight applied to the L2 norm of the features. Defaults to 0.0. @@ -103,7 +103,7 @@ def __init__( self.flatten_feature_extraction_layers = {layer: True for layer in feature_extraction_layers} else: self.flatten_feature_extraction_layers = {} - self.mkmmd_losses: Dict[str, MkMmdLoss] = {} + self.mkmmd_losses: dict[str, MkMmdLoss] = {} for layer in self.flatten_feature_extraction_layers.keys(): self.mkmmd_losses[layer] = MkMmdLoss( device=self.device, minimize_type_two_error=True, normalize_features=True, layer_name=layer @@ -138,7 +138,7 @@ def _should_optimize_betas(self, step: int) -> bool: weighted_mkmmd_loss = self.mkmmd_loss_weight != 0 return step_at_interval and valid_components_present and weighted_mkmmd_loss - def update_after_step(self, step: int, current_round: Optional[int] = None) -> None: + def update_after_step(self, step: int, current_round: int | None = None) -> None: if self.beta_global_update_interval > 0 and self._should_optimize_betas(step): # Get the feature distribution of the local and initial global features with evaluation # mode @@ -154,7 +154,7 @@ def update_after_step(self, step: int, current_round: Optional[int] = None) -> N def update_buffers( self, local_model: torch.nn.Module, initial_global_model: torch.nn.Module - ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: """ Update the feature buffer of the local and global features. @@ -163,7 +163,7 @@ def update_buffers( initial_global_model (torch.nn.Module): Initial global model to extract features from. Returns: - Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple containing the extracted + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: A tuple containing the extracted features using the local and initial global models. """ @@ -217,18 +217,18 @@ def update_buffers( def predict( self, input: TorchInputType, - ) -> Tuple[TorchPredType, TorchFeatureType]: + ) -> tuple[TorchPredType, TorchFeatureType]: """ Computes the predictions for both models and pack them into the prediction dictionary Args: input (TorchInputType): Inputs to be fed into the model. If input is - of type Dict[str, torch.Tensor], it is assumed that the keys of + of type dict[str, torch.Tensor], it is assumed that the keys of input match the names of the keyword arguments of self.model. forward(). Returns: - Tuple[TorchPredType, TorchFeatureType]: A tuple in which the + tuple[TorchPredType, TorchFeatureType]: A tuple in which the first element contains a dictionary of predictions indexed by name and the second element contains intermediate activations indexed by name. By passing features, we can compute all the @@ -252,7 +252,7 @@ def predict( return {"prediction": preds}, features - def _maybe_checkpoint(self, loss: float, metrics: Dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None: + def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None: # Hooks need to be removed before checkpointing the model self.local_feature_extractor.remove_hooks() super()._maybe_checkpoint(loss=loss, metrics=metrics, checkpoint_mode=checkpoint_mode) @@ -314,7 +314,7 @@ def compute_training_loss( def compute_loss_and_additional_losses( self, preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Computes the loss and any additional losses given predictions of the model and ground truth data. @@ -324,7 +324,7 @@ def compute_loss_and_additional_losses( target (TorchTargetType): Ground truth data to evaluate predictions against. Returns: - Tuple[torch.Tensor, Dict[str, torch.Tensor]]: A tuple with: + tuple[torch.Tensor, dict[str, torch.Tensor]]: A tuple with: - The tensor for the loss - A dictionary of additional losses with their names and values, or None if there are no additional losses. diff --git a/fl4health/clients/model_merge_client.py b/fl4health/clients/model_merge_client.py index 6bed66108..7a31539f1 100644 --- a/fl4health/clients/model_merge_client.py +++ b/fl4health/clients/model_merge_client.py @@ -1,7 +1,7 @@ import datetime from abc import abstractmethod +from collections.abc import Sequence from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -26,8 +26,8 @@ def __init__( model_path: Path, metrics: Sequence[Metric], device: torch.device, - reporters: Optional[Sequence[BaseReporter]] = None, - client_name: Optional[str] = None, + reporters: Sequence[BaseReporter] | None = None, + client_name: str | None = None, ) -> None: """ ModelMergeClient to support functionality to simply perform model merging across client @@ -112,7 +112,7 @@ def set_parameters(self, parameters: NDArrays, config: Config) -> None: assert self.initialized self.parameter_exchanger.pull_parameters(parameters, self.model) - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: """ Initializes client, validates local client model on local test data and returns parameters, test dataset length and test metrics. Importantly, parameters from Server, which is empty, @@ -126,7 +126,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict config (NDArrays): The config from the server. Returns: - Tuple[NDArrays, int, Dict[str, Scalar]]: The local model parameters along with the + tuple[NDArrays, int, dict[str, Scalar]]: The local model parameters along with the number of samples in the local test dataset and the computed metrics of the local model on the local test dataset. @@ -149,9 +149,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict return self.get_parameters(config), self.num_test_samples, val_metrics - def _move_data_to_device( - self, data: Union[TorchInputType, TorchTargetType] - ) -> Union[TorchTargetType, TorchInputType]: + def _move_data_to_device(self, data: TorchInputType | TorchTargetType) -> TorchTargetType | TorchInputType: """ Moving data to self.device where data is intended to be either input to the model or the targets that the model is trying to achieve @@ -165,7 +163,7 @@ def _move_data_to_device( TorchInputType or TorchTargetType Returns: - Union[TorchTargetType, TorchInputType]: The data argument except now it's been moved to self.device + TorchTargetType | TorchInputType: The data argument except now it's been moved to self.device """ # Currently we expect both inputs and targets to be either tensors # or dictionaries of tensors @@ -175,18 +173,18 @@ def _move_data_to_device( return {key: value.to(self.device) for key, value in data.items()} else: raise TypeError( - "data must be of type torch.Tensor or Dict[str, torch.Tensor]. \ + "data must be of type torch.Tensor or dict[str, torch.Tensor]. \ If definition of TorchInputType or TorchTargetType has \ changed this method might need to be updated or split into \ two" ) - def validate(self) -> Dict[str, Scalar]: + def validate(self) -> dict[str, Scalar]: """ Validate the model on the test dataset. Returns: - Tuple[float, Dict[str, Scalar]]: The loss and a dictionary of metrics + tuple[float, dict[str, Scalar]]: The loss and a dictionary of metrics from test set. """ self.model.eval() @@ -200,7 +198,7 @@ def validate(self) -> Dict[str, Scalar]: return self.test_metric_manager.compute() - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: """ Evaluate the provided parameters using the locally held dataset. @@ -209,7 +207,7 @@ def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Di config (Config): Configuration object from the server. Returns: - Tuple[float, int, Dict[str, Scalar]: The float represents the + tuple[float, int, dict[str, Scalar]: The float represents the loss which is assumed to be 0 for the ModelMergeClient. The int represents the number of examples in the local test dataset and the dictionary is the computed metrics on the test set. diff --git a/fl4health/clients/moon_client.py b/fl4health/clients/moon_client.py index 7acb00b02..aeb58a7e0 100644 --- a/fl4health/clients/moon_client.py +++ b/fl4health/clients/moon_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import WARNING from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch from flwr.common.logger import log @@ -23,10 +23,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, temperature: float = 0.5, contrastive_weight: float = 1.0, len_old_models_buffer: int = 1, @@ -43,7 +43,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -51,7 +51,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. temperature (float, optional): Temperature used in the calculation of the contrastive loss. @@ -80,9 +80,9 @@ def __init__( # Saving previous local models and a global model at each communication round to compute contrastive loss self.len_old_models_buffer = len_old_models_buffer self.old_models_list: list[torch.nn.Module] = [] - self.global_model: Optional[torch.nn.Module] = None + self.global_model: torch.nn.Module | None = None - def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureType]: + def predict(self, input: TorchInputType) -> tuple[TorchPredType, TorchFeatureType]: """ Computes the prediction(s) and features of the model(s) given the input. This function also produces the necessary features from the global_model (aggregated model from server) and old_models (previous client-side @@ -90,11 +90,11 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureTyp Args: input (TorchInputType): Inputs to be fed into the model. TorchInputType is simply an alias - for the union of torch.Tensor and Dict[str, torch.Tensor]. Here, the MOON models require input to + for the union of torch.Tensor and dict[str, torch.Tensor]. Here, the MOON models require input to simply be of type torch.Tensor Returns: - Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple in which the first element + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. Specifically the features of the model, features of the global model and features of the old model are returned. All predictions included in dictionary will be used to compute metrics. @@ -117,14 +117,14 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureTyp features.update({"global_features": global_model_features["features"]}) return preds, features - def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], config: Config) -> None: + def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None: """ This function is called immediately after client-side training has completed. This function saves the final trained model to the list of old models to be used in subsequent server rounds Args: local_steps (int): Number of local steps performed during training - loss_dict (Dict[str, float]): Loss dictionary associated with training. + loss_dict (dict[str, float]): Loss dictionary associated with training. config (Config): The config from the server """ assert isinstance(self.model, SequentiallySplitModel) @@ -157,19 +157,19 @@ def compute_loss_and_additional_losses( preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Computes the loss and any additional losses given predictions of the model and ground truth data. For MOON, the loss is the total loss (criterion and weighted contrastive loss) and the additional losses are the loss, (unweighted) contrastive loss, and total loss. Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. - features (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. + features (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target (torch.Tensor): Ground truth data to evaluate predictions against. Returns: - Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]; A tuple with: + tuple[torch.Tensor, dict[str, torch.Tensor]]; A tuple with: - The tensor for the total loss - A dictionary with `loss`, `contrastive_loss` and `total_loss` keys and their calculated values. """ @@ -202,9 +202,9 @@ def compute_training_loss( base loss plus a model contrastive loss. Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics. - features: (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + features: (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target: (torch.Tensor): Ground truth data to evaluate predictions against. Returns: @@ -234,9 +234,9 @@ def compute_evaluation_loss( base loss plus a model contrastive loss. Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics. - features: (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + features: (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target: (torch.Tensor): Ground truth data to evaluate predictions against. Returns: diff --git a/fl4health/clients/mr_mtl_client.py b/fl4health/clients/mr_mtl_client.py index cd671771c..0fdbfa372 100644 --- a/fl4health/clients/mr_mtl_client.py +++ b/fl4health/clients/mr_mtl_client.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, List, Optional, Sequence, Tuple import torch import torch.nn as nn @@ -22,10 +22,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ This client implements the MR-MTL algorithm from MR-MTL: On Privacy and Personalization in Cross-Silo @@ -46,7 +46,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -54,7 +54,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -71,7 +71,7 @@ def __init__( # NOTE: The initial global model is used to house the aggregate weight updates at the beginning of a round, # because in MR-MTL, the local models are not updated with these aggregates. self.initial_global_model: nn.Module - self.initial_global_tensors: List[torch.Tensor] + self.initial_global_tensors: list[torch.Tensor] def setup_client(self, config: Config) -> None: """ @@ -155,12 +155,12 @@ def compute_training_loss( # Use the rest of the training loss computation from the AdaptiveDriftConstraintClient parent return super().compute_training_loss(preds, features, target) - def validate(self, include_losses_in_metrics: bool = False) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, include_losses_in_metrics: bool = False) -> tuple[float, dict[str, Scalar]]: """ Validate the current model on the entire validation dataset. Returns: - Tuple[float, Dict[str, Scalar]]: The validation loss and a dictionary of metrics from validation. + tuple[float, dict[str, Scalar]]: The validation loss and a dictionary of metrics from validation. """ # ensure that the initial global model is in eval mode assert not self.initial_global_model.training diff --git a/fl4health/clients/nnunet_client.py b/fl4health/clients/nnunet_client.py index 89d44da6f..3bbc1bd6f 100644 --- a/fl4health/clients/nnunet_client.py +++ b/fl4health/clients/nnunet_client.py @@ -4,11 +4,12 @@ import pickle import time import warnings +from collections.abc import Sequence from contextlib import redirect_stdout from logging import DEBUG, ERROR, INFO, WARNING from os.path import exists, join from pathlib import Path -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import Any import numpy as np import torch @@ -64,22 +65,22 @@ def __init__( self, device: torch.device, dataset_id: int, - fold: Union[int, str], - data_identifier: Optional[str] = None, - plans_identifier: Optional[str] = None, + fold: int | str, + data_identifier: str | None = None, + plans_identifier: str | None = None, compile: bool = True, always_preprocess: bool = False, max_grad_norm: float = 12, - n_dataload_processes: Optional[int] = None, + n_dataload_processes: int | None = None, verbose: bool = True, - metrics: Optional[Sequence[Metric]] = None, + metrics: Sequence[Metric] | None = None, progress_bar: bool = False, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, - client_name: Optional[str] = None, - nnunet_trainer_class: Type[nnUNetTrainer] = nnUNetTrainer, - nnunet_trainer_class_kwargs: Optional[dict[str, Any]] = {}, + client_name: str | None = None, + nnunet_trainer_class: type[nnUNetTrainer] = nnUNetTrainer, + nnunet_trainer_class_kwargs: dict[str, Any] | None = {}, ) -> None: """ A client for training nnunet models. Requires the nnunet environment variables to be set. Also requires the @@ -91,14 +92,14 @@ def __init__( device (torch.device): Device indicator for where to send the model, batches, labels etc. Often 'cpu' or 'cuda' or 'mps' dataset_id (int): The nnunet dataset id for the local client dataset to use for training and validation. - fold (Union[int, str]): Which fold of the local client dataset to use for validation. nnunet defaults to + fold (int | str): Which fold of the local client dataset to use for validation. nnunet defaults to 5 folds (0 to 4). Can also be set to 'all' to use all the data for both training and validation. - data_identifier (Optional[str], optional): The nnunet data identifier prefix to use. The final data + data_identifier (str | None, optional): The nnunet data identifier prefix to use. The final data identifier will be {data_identifier}_config where 'config' is the nnunet config (eg. 2d, 3d_fullres, etc.). If preprocessed data already exists can be used to specify which preprocessed data to use. The default data_identifier prefix is the plans name used during training (see the plans_identifier argument). - plans_identifier (Optional[str], optional): Specify what to save the client's local copy of the plans file + plans_identifier (str | None, optional): Specify what to save the client's local copy of the plans file as. The client modifies the source plans json file sent from the server and makes a local copy. If left as default None, the plans identifier will be set as 'FL_Dataset000_plansname' where 000 is the dataset_id and plansname is the 'plans_name' value of the source plans file. @@ -111,7 +112,7 @@ def __init__( the client. max_grad_norm (float, optional): The maximum gradient norm to use for gradient clipping. Defaults to 12 which is the nnunetv2 2.5.1 default. - n_dataload_processes (Optional[int], optional): The number of processes to spawn for each nnunet + n_dataload_processes (int | None, optional): The number of processes to spawn for each nnunet dataloader. If left as None we use the nnunetv2 version 2.5.1 defaults for each config verbose (bool, optional): If True the client will log some extra INFO logs. Defaults to False unless the log level is DEBUG or lower. @@ -121,13 +122,13 @@ def __init__( to False loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the client should send data to. - nnunet_trainer_class (Type[nnUNetTrainer]): A nnUNetTrainer constructor. Useful for passing custom + nnunet_trainer_class (type[nnUNetTrainer]): A nnUNetTrainer constructor. Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. Must match the nnunet_trainer_class passed to the NnunetServer. nnunet_trainer_class_kwargs (dict[str, Any]): Additional kwargs to pass to nnunet_trainer_class. Defaults @@ -189,7 +190,7 @@ def __init__( log(INFO, "Switching pytorch model jit compile to OFF") os.environ["nnUNet_compile"] = str("false") - def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[TrainingLosses, TorchPredType]: + def train_step(self, input: TorchInputType, target: TorchTargetType) -> tuple[TrainingLosses, TorchPredType]: """ Given a single batch of input and target data, generate predictions, compute loss, update parameters and optionally update metrics if they exist. (ie backprop on a single batch of data). @@ -242,7 +243,7 @@ def train_step(self, input: TorchInputType, target: TorchTargetType) -> Tuple[Tr return losses, preds @use_default_signal_handlers # Dataloaders use multiprocessing - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: """ Gets the nnunet dataloaders and wraps them in another class to make them pytorch iterators @@ -251,7 +252,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: config (Config): The config file from the server Returns: - Tuple[DataLoader, DataLoader]: A tuple of length two. The client + tuple[DataLoader, DataLoader]: A tuple of length two. The client train and validation dataloaders as pytorch dataloaders """ start_time = time.time() @@ -350,7 +351,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler: max_steps=total_steps, ) - def create_plans(self, config: Config) -> Dict[str, Any]: + def create_plans(self, config: Config) -> dict[str, Any]: """ Modifies the provided plans file to work with the local client dataset @@ -359,7 +360,7 @@ def create_plans(self, config: Config) -> Dict[str, Any]: 'nnunet_plans' key with a pickled dictionary as the value Returns: - Dict[str, Any]: The modified nnunet plans for the client + dict[str, Any]: The modified nnunet plans for the client """ # Get the nnunet plans specified by the server plans = pickle.loads(narrow_dict_type(config, "nnunet_plans", bytes)) @@ -549,7 +550,7 @@ def setup_client(self, config: Config) -> None: # We have to call parent method after setting up nnunet trainer super().setup_client(config) - def predict(self, input: TorchInputType) -> Tuple[TorchPredType, Dict[str, torch.Tensor]]: + def predict(self, input: TorchInputType) -> tuple[TorchPredType, dict[str, torch.Tensor]]: """ Generate model outputs. Overridden because nnunets output lists when deep supervision is on so we have to reformat the output into dicts @@ -559,7 +560,7 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, Dict[str, torch input (TorchInputType): The model inputs Returns: - Tuple[TorchPredType, Dict[str, torch.Tensor]]: A tuple in which the + tuple[TorchPredType, dict[str, torch.Tensor]]: A tuple in which the first element model outputs indexed by name. The second element is unused by this subclass and therefore is always an empty dict """ @@ -589,9 +590,9 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, Dict[str, torch def compute_loss_and_additional_losses( self, preds: TorchPredType, - features: Dict[str, torch.Tensor], + features: dict[str, torch.Tensor], target: TorchTargetType, - ) -> Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor] | None]: """ Checks the pred and target types and computes the loss. If device type is cuda, loss computed in mixed precision. @@ -599,13 +600,13 @@ def compute_loss_and_additional_losses( Args: preds (TorchPredType): Dictionary of model output tensors indexed by name - features (Dict[str, torch.Tensor]): Not used by this subclass + features (dict[str, torch.Tensor]): Not used by this subclass target (TorchTargetType): The targets to evaluate the predictions with. If multiple prediction tensors are given, target must be a dictionary with the same number of tensors Returns: - Tuple[torch.Tensor, Optional[Dict[str, torch.Tensor]]]: A tuple + tuple[torch.Tensor, dict[str, torch.Tensor] | None]: A tuple where the first element is the loss and the second element is an optional additional loss """ @@ -635,7 +636,7 @@ def compute_loss_and_additional_losses( return loss - def mask_data(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def mask_data(self, pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Masks the pred and target tensors according to nnunet ignore_label. The number of classes in the input tensors should be at least 3 @@ -707,7 +708,7 @@ def update_metric_manager( else: m_target = list(target.values())[0] else: - raise TypeError("Was expecting target to be type Dict[str, torch.Tensor] or torch.Tensor") + raise TypeError("Was expecting target to be type dict[str, torch.Tensor] or torch.Tensor") # Check if target is one hot encoded. Prediction always is for nnunet # Add channel dimension if there isn't one @@ -745,10 +746,10 @@ def empty_cache(self) -> None: def get_client_specific_logs( self, - current_round: Optional[int], - current_epoch: Optional[int], + current_round: int | None, + current_epoch: int | None, logging_mode: LoggingMode, - ) -> Tuple[str, List[Tuple[LogLevel, str]]]: + ) -> tuple[str, list[tuple[LogLevel, str]]]: if logging_mode == LoggingMode.TRAIN: lr = float(self.optimizers["global"].param_groups[0]["lr"]) if current_epoch is None: @@ -759,11 +760,11 @@ def get_client_specific_logs( else: return "", [] - def get_client_specific_reports(self) -> Dict[str, Any]: + def get_client_specific_reports(self) -> dict[str, Any]: return {"learning_rate": float(self.optimizers["global"].param_groups[0]["lr"])} @use_default_signal_handlers # Experiment planner spawns a process I think - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """ Return properties (sample counts and nnunet plans) of client. @@ -775,7 +776,7 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: config (Config): The config from the server Returns: - Dict[str, Scalar]: A dictionary containing the train and + dict[str, Scalar]: A dictionary containing the train and validation sample counts as well as the serialized nnunet plans """ # Check if nnunet plans have already been initialized @@ -818,14 +819,14 @@ def get_properties(self, config: Config) -> Dict[str, Scalar]: properties["enable_deep_supervision"] = self.nnunet_trainer.enable_deep_supervision return properties - def shutdown_dataloader(self, dataloader: Optional[DataLoader], dl_name: Optional[str] = None) -> None: + def shutdown_dataloader(self, dataloader: DataLoader | None, dl_name: str | None = None) -> None: """ The nnunet dataloader/augmenter uses multiprocessing under the hood, so the shutdown method terminates the child processes gracefully Args: dataloader (DataLoader): The dataloader to shutdown - dl_name (Optional[str]): A string that identifies the dataloader + dl_name (str | None): A string that identifies the dataloader to shutdown. Used for logging purposes. Defaults to None """ if dataloader is not None and isinstance(dataloader, nnUNetDataLoaderWrapper): @@ -865,6 +866,6 @@ def update_before_train(self, current_server_round: int) -> None: def transform_gradients(self, losses: TrainingLosses) -> None: """ Apply the gradient clipping performed by the default nnunet trainer. This is - the default behaviour for nnunet 2.5.1 + the default behavior for nnunet 2.5.1 """ nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) diff --git a/fl4health/clients/partial_weight_exchange_client.py b/fl4health/clients/partial_weight_exchange_client.py index 1cdd45de7..830f59ea8 100644 --- a/fl4health/clients/partial_weight_exchange_client.py +++ b/fl4health/clients/partial_weight_exchange_client.py @@ -1,7 +1,7 @@ import copy +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence import torch import torch.nn as nn @@ -25,10 +25,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, store_initial_model: bool = False, ) -> None: """ @@ -44,7 +44,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -52,7 +52,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. store_initial_model (bool): Indicates whether the client should store a copy of the model weights @@ -71,7 +71,7 @@ def __init__( client_name=client_name, ) # Initial model parameters to be used in selecting parameters to be exchanged during training. - self.initial_model: Optional[nn.Module] + self.initial_model: nn.Module | None # Parameter exchanger to be used in server-client exchange of dynamic layers. self.parameter_exchanger: PartialParameterExchanger self.store_initial_model = store_initial_model @@ -138,7 +138,7 @@ def get_parameters(self, config: Config) -> NDArrays: def set_parameters(self, parameters: NDArrays, config: Config, fitting_round: bool) -> None: """ - Sets the local model parameters transfered from the server using a parameter exchanger to coordinate how + Sets the local model parameters transferred from the server using a parameter exchanger to coordinate how parameters are set. In the first fitting round, we assume the full model is being diff --git a/fl4health/clients/perfcl_client.py b/fl4health/clients/perfcl_client.py index 695e67e38..ec1a17e78 100644 --- a/fl4health/clients/perfcl_client.py +++ b/fl4health/clients/perfcl_client.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch from flwr.common.typing import Config @@ -24,10 +24,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, global_feature_loss_temperature: float = 0.5, local_feature_loss_temperature: float = 0.5, global_feature_contrastive_loss_weight: float = 1.0, @@ -47,7 +47,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -55,7 +55,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. global_feature_loss_temperature (float, optional): Temperature to be used in the contrastive loss @@ -86,9 +86,9 @@ def __init__( # In order to compute the PerFCL losses, we need to save final local module and global modules from the # previous iteration of client-side training and initial global module passed to the client after server-side # aggregation at each communication round - self.old_local_module: Optional[torch.nn.Module] = None - self.old_global_module: Optional[torch.nn.Module] = None - self.initial_global_module: Optional[torch.nn.Module] = None + self.old_local_module: torch.nn.Module | None = None + self.old_global_module: torch.nn.Module | None = None + self.initial_global_module: torch.nn.Module | None = None def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: """ @@ -134,16 +134,16 @@ def _all_contrastive_loss_modules_defined(self) -> bool: and self.initial_global_module is not None ) - def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureType]: + def predict(self, input: TorchInputType) -> tuple[TorchPredType, TorchFeatureType]: """ Computes the prediction(s) and features of the model(s) given the input. Args: input (TorchInputType): Inputs to be fed into the model. TorchInputType is simply an alias - for the union of torch.Tensor and Dict[str, torch.Tensor]. + for the union of torch.Tensor and dict[str, torch.Tensor]. Returns: - Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: A tuple in which the first element + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: A tuple in which the first element contains predictions indexed by name and the second element contains intermediate activations index by name. Specifically the features of the model, features of the global model and features of the old model are returned. All predictions included in dictionary will be used to compute metrics. @@ -165,14 +165,14 @@ def predict(self, input: TorchInputType) -> Tuple[TorchPredType, TorchFeatureTyp return preds, features - def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], config: Config) -> None: + def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None: """ This function is called after client-side training concludes. In this case, it is used to save the local and global feature extraction weights/modules to be used in the next round of client-side training. Args: local_steps (int): Number of steps performed during training - loss_dict (Dict[str, float]): Losses computed during training. + loss_dict (dict[str, float]): Losses computed during training. config (Config): The config from the server """ assert isinstance(self.model, PerFclModel) @@ -204,19 +204,19 @@ def compute_loss_and_additional_losses( preds: TorchPredType, features: TorchFeatureType, target: TorchTargetType, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Computes the loss and any additional losses given predictions of the model and ground truth data. For PerFCL, the total loss is the standard criterion loss provided by the user and the PerFCL contrastive losses aimed at manipulating the local and global feature extractor latent spaces. Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. - features (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. + features (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target (torch.Tensor): Ground truth data to evaluate predictions against. Returns: - Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]; A tuple with: + tuple[torch.Tensor, dict[str, torch.Tensor]]; A tuple with: - The tensor for the total loss - A dictionary with `loss`, `total_loss`, `global_feature_contrastive_loss`, and `local_feature_contrastive_loss` representing the various and relevant pieces of the loss @@ -263,9 +263,9 @@ def compute_evaluation_loss( additional loss components associated with the PerFCL loss function. Args: - preds (Dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. + preds (dict[str, torch.Tensor]): Prediction(s) of the model(s) indexed by name. All predictions included in dictionary will be used to compute metrics. - features: (Dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. + features: (dict[str, torch.Tensor]): Feature(s) of the model(s) indexed by name. target: (torch.Tensor): Ground truth data to evaluate predictions against. Returns: diff --git a/fl4health/clients/scaffold_client.py b/fl4health/clients/scaffold_client.py index ee373563f..79dc962dd 100644 --- a/fl4health/clients/scaffold_client.py +++ b/fl4health/clients/scaffold_client.py @@ -1,7 +1,7 @@ import copy +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import torch from flwr.common.logger import log @@ -19,7 +19,7 @@ from fl4health.utils.losses import LossMeterType, TrainingLosses from fl4health.utils.metrics import Metric -ScaffoldTrainStepOutput = Tuple[torch.Tensor, torch.Tensor] +ScaffoldTrainStepOutput = tuple[torch.Tensor, torch.Tensor] class ScaffoldClient(BasicClient): @@ -29,10 +29,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ Federated Learning Client for Scaffold strategy. @@ -46,7 +46,7 @@ def __init__( 'cuda' loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -54,7 +54,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -69,13 +69,13 @@ def __init__( client_name=client_name, ) self.learning_rate: float # eta_l in paper - self.client_control_variates: Optional[NDArrays] = None # c_i in paper - self.client_control_variates_updates: Optional[NDArrays] = None # delta_c_i in paper - self.server_control_variates: Optional[NDArrays] = None # c in paper + self.client_control_variates: NDArrays | None = None # c_i in paper + self.client_control_variates_updates: NDArrays | None = None # delta_c_i in paper + self.server_control_variates: NDArrays | None = None # c in paper # Scaffold require vanilla SGD as optimizer, will assert during setup_client - self.optimizers: Dict[str, torch.optim.Optimizer] + self.optimizers: dict[str, torch.optim.Optimizer] - self.server_model_weights: Optional[NDArrays] = None # x in paper + self.server_model_weights: NDArrays | None = None # x in paper self.parameter_exchanger: FullParameterExchangerWithPacking[NDArrays] def get_parameters(self, config: Config) -> NDArrays: @@ -242,7 +242,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: parameter_exchanger = FullParameterExchangerWithPacking(ParameterPackerWithControlVariates(model_size)) return parameter_exchanger - def update_after_train(self, local_steps: int, loss_dict: Dict[str, float], config: Config) -> None: + def update_after_train(self, local_steps: int, loss_dict: dict[str, float], config: Config) -> None: """ Called after training with the number of local_steps performed over the FL round and the corresponding loss dictionary. @@ -279,10 +279,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: ScaffoldClient.__init__( self, diff --git a/fl4health/clients/tabular_data_client.py b/fl4health/clients/tabular_data_client.py index c74617414..747ff0c52 100644 --- a/fl4health/clients/tabular_data_client.py +++ b/fl4health/clients/tabular_data_client.py @@ -1,11 +1,11 @@ +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, List, Sequence, Union import pandas as pd import torch from flwr.common.logger import log -from flwr.common.typing import Config, NDArray, Optional, Scalar +from flwr.common.typing import Config, NDArray, Scalar from sklearn.pipeline import Pipeline from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule @@ -26,12 +26,12 @@ def __init__( metrics: Sequence[Metric], device: torch.device, id_column: str, - targets: Union[str, List[str]], + targets: str | list[str], loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: """ Client to facilitate federated feature space alignment, specifically for tabular data, and then perform @@ -44,11 +44,11 @@ def __init__( 'cuda' id_column (str): ID column. This is required for tabular encoding in cyclops, which we leverage. It should be unique per row, but need not necessarily be a meaningful identifier (i.e. could be row number) - targets (Union[str, List[str]]): The target column or columns name. This allows for multiple targets to + targets (str | list[str]): The target column or columns name. This allows for multiple targets to be specified if desired. loss_meter_type (LossMeterType, optional): Type of meter used to track and compute the losses over each batch. Defaults to LossMeterType.AVERAGE. - checkpoint_and_state_module (Optional[ClientCheckpointAndStateModule], optional): A module meant to handle + checkpoint_and_state_module (ClientCheckpointAndStateModule | None, optional): A module meant to handle both checkpointing and state saving. The module, and its underlying model and state checkpointing components will determine when and how to do checkpointing during client-side training. No checkpointing (state or model) is done if not provided. Defaults to None. @@ -56,7 +56,7 @@ def __init__( should send data to. Defaults to None. progress_bar (bool, optional): Whether or not to display a progress bar during client training and validation. Uses tqdm. Defaults to False - client_name (Optional[str], optional): An optional client name that uniquely identifies a client. + client_name (str | None, optional): An optional client name that uniquely identifies a client. If not passed, a hash is randomly generated. Client state will use this as part of its state file name. Defaults to None. """ @@ -80,7 +80,7 @@ def __init__( # The aligned data and targets, which are used to construct dataloaders. self.aligned_features: NDArray self.aligned_targets: NDArray - self.feature_specific_pipelines: Dict[str, Pipeline] = {} + self.feature_specific_pipelines: dict[str, Pipeline] = {} def setup_client(self, config: Config) -> None: """ @@ -143,7 +143,7 @@ def get_data_frame(self, config: Config) -> pd.DataFrame: """ raise NotImplementedError - def get_properties(self, config: Config) -> Dict[str, Scalar]: + def get_properties(self, config: Config) -> dict[str, Scalar]: """ Return properties of client to be sent to the server. Depending on whether the server has communicated the information diff --git a/fl4health/datasets/skin_cancer/load_data.py b/fl4health/datasets/skin_cancer/load_data.py index e4e358521..86027256d 100644 --- a/fl4health/datasets/skin_cancer/load_data.py +++ b/fl4health/datasets/skin_cancer/load_data.py @@ -6,10 +6,11 @@ import json import random +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from logging import INFO from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any import torch import torchvision.transforms as transforms @@ -22,16 +23,16 @@ from fl4health.utils.sampler import LabelBasedSampler -def load_image(item: Dict[str, Any], transform: Optional[Callable]) -> Tuple[torch.Tensor, int]: +def load_image(item: dict[str, Any], transform: Callable | None) -> tuple[torch.Tensor, int]: """ Load and transform an image from a given item dictionary. Args: - item (Dict[str, Any]): A dictionary containing image path and labels. - transform (Optional[Callable]): Transformation function to apply to the images. + item (dict[str, Any]): A dictionary containing image path and labels. + transform (Callable | None): Transformation function to apply to the images. Returns: - Tuple[torch.Tensor, int]: A tuple containing the transformed image tensor and the target label. + tuple[torch.Tensor, int]: A tuple containing the transformed image tensor and the target label. """ image_path = item["img_path"] image = Image.open(image_path).convert("RGB") @@ -46,14 +47,14 @@ def load_image(item: Dict[str, Any], transform: Optional[Callable]) -> Tuple[tor def construct_skin_cancer_tensor_dataset( - data: List[Dict[str, Any]], transform: Optional[Callable] = None, num_workers: int = 8 + data: list[dict[str, Any]], transform: Callable | None = None, num_workers: int = 8 ) -> TensorDataset: """ Construct a TensorDataset for skin cancer data. Args: - data (List[Dict[str, Any]]): List of dictionaries containing image paths and labels. - transform (Optional[Callable]): Transformation function to apply to the images. Defaults to None. + data (list[dict[str, Any]]): List of dictionaries containing image paths and labels. + transform (Callable | None): Transformation function to apply to the images. Defaults to None. num_workers (int): Number of workers for parallel processing. Defaults to 8. Returns: @@ -73,14 +74,14 @@ def load_skin_cancer_data( data_dir: Path, dataset_name: str, batch_size: int, - split_percents: Tuple[float, float, float] = (0.7, 0.15, 0.15), - sampler: Optional[LabelBasedSampler] = None, - train_transform: Union[None, Callable] = None, - val_transform: Union[None, Callable] = None, - test_transform: Union[None, Callable] = None, - dataset_converter: Optional[DatasetConverter] = None, - seed: Optional[int] = None, -) -> Tuple[DataLoader, DataLoader, DataLoader, Dict[str, int]]: + split_percents: tuple[float, float, float] = (0.7, 0.15, 0.15), + sampler: LabelBasedSampler | None = None, + train_transform: Callable | None = None, + val_transform: Callable | None = None, + test_transform: Callable | None = None, + dataset_converter: DatasetConverter | None = None, + seed: int | None = None, +) -> tuple[DataLoader, DataLoader, DataLoader, dict[str, int]]: """ Load skin cancer dataset (training, validation, and test set). @@ -88,16 +89,16 @@ def load_skin_cancer_data( data_dir (Path): Directory containing the dataset files. dataset_name (str): Name of the dataset to load. batch_size (int): Batch size for the DataLoader. - split_percents (Tuple[float, float, float]): Percentages for splitting the data into train, val, and test sets. - sampler (Optional[LabelBasedSampler]): Sampler for the dataset. Defaults to None. - train_transform (Union[None, Callable]): Transformations to apply to the training data. Defaults to None. - val_transform (Union[None, Callable]): Transformations to apply to the validation data. Defaults to None. - test_transform (Union[None, Callable]): Transformations to apply to the test data. Defaults to None. - dataset_converter (Optional[DatasetConverter]): Converter to apply to the dataset. Defaults to None. - seed (Optional[int]): Random seed for shuffling data. Defaults to None. + split_percents (tuple[float, float, float]): Percentages for splitting the data into train, val, and test sets. + sampler (LabelBasedSampler | None): Sampler for the dataset. Defaults to None. + train_transform (Callable | None): Transformations to apply to the training data. Defaults to None. + val_transform (Callable | None): Transformations to apply to the validation data. Defaults to None. + test_transform (Callable | None): Transformations to apply to the test data. Defaults to None. + dataset_converter (DatasetConverter | None): Converter to apply to the dataset. Defaults to None. + seed (int | None): Random seed for shuffling data. Defaults to None. Returns: - Tuple[DataLoader, DataLoader, DataLoader, Dict[str, int]]: DataLoaders for the training, validation, + tuple[DataLoader, DataLoader, DataLoader, dict[str, int]]: DataLoaders for the training, validation, and test sets, and a dictionary with the number of examples in each set. """ if sum(split_percents) != 1.0: diff --git a/fl4health/datasets/skin_cancer/preprocess_skin.py b/fl4health/datasets/skin_cancer/preprocess_skin.py index 144bb4795..e718248f8 100644 --- a/fl4health/datasets/skin_cancer/preprocess_skin.py +++ b/fl4health/datasets/skin_cancer/preprocess_skin.py @@ -9,12 +9,13 @@ import json import os -from typing import Any, Callable, Dict, List +from collections.abc import Callable +from typing import Any import pandas as pd -def save_to_json(data: Dict[str, Any], path: str) -> None: +def save_to_json(data: dict[str, Any], path: str) -> None: """Saves a dictionary to a JSON file. Args: @@ -31,8 +32,8 @@ def process_client_data( data_path: str, image_path_func: Callable[[pd.Series], str], label_map_func: Callable[[pd.Series], str], - original_columns: List[str], - official_columns: List[str], + original_columns: list[str], + official_columns: list[str], ) -> None: """Processes and saves the client-specific dataset. @@ -45,7 +46,7 @@ def process_client_data( original_columns: The list of original columns for the dataset. official_columns: The list of official columns for the dataset. """ - preprocessed_data: Dict[str, Any] = { + preprocessed_data: dict[str, Any] = { "columns": official_columns, "original_columns": original_columns, "data": [], @@ -70,7 +71,7 @@ def process_client_data( save_to_json(preprocessed_data, os.path.join(data_path, f"{client_name}.json")) -def preprocess_isic_2019(data_path: str, official_columns: List[str]) -> None: +def preprocess_isic_2019(data_path: str, official_columns: list[str]) -> None: """Preprocesses the ISIC 2019 dataset. Args: @@ -90,7 +91,7 @@ def preprocess_isic_2019(data_path: str, official_columns: List[str]) -> None: Isic_2019_data_path = os.path.join(data_path, "ISIC_2019", "ISIC_2019_Training_Input") Barcelona_df = pd.read_csv(os.path.join(Isic_2019_path, "ISIC_2019_core.csv")) Barcelona_new = Barcelona_df[["image"] + official_columns + ["UNK"]] - preprocessed_data: Dict[str, Any] = { + preprocessed_data: dict[str, Any] = { "columns": official_columns, "original_columns": official_columns, "data": [], @@ -146,7 +147,7 @@ def ham_label_map_func(row: pd.Series) -> str: return Ham_labelmap[row["dx"]] -def preprocess_ham10000(data_path: str, official_columns: List[str]) -> None: +def preprocess_ham10000(data_path: str, official_columns: list[str]) -> None: """Preprocesses the HAM10000 dataset. Args: @@ -216,7 +217,7 @@ def pad_label_map_func(row: pd.Series) -> str: return Pad_ufes_20_labelmap[row["diagnostic"]] -def preprocess_pad_ufes_20(data_path: str, official_columns: List[str]) -> None: +def preprocess_pad_ufes_20(data_path: str, official_columns: list[str]) -> None: """Preprocesses the PAD-UFES-20 dataset. Args: @@ -286,7 +287,7 @@ def derm7pt_label_map_func(row: pd.Series) -> str: return Derm7pt_labelmap[row["diagnosis"]] -def preprocess_derm7pt(data_path: str, official_columns: List[str]) -> None: +def preprocess_derm7pt(data_path: str, official_columns: list[str]) -> None: """Preprocesses the Derm7pt dataset. Args: diff --git a/fl4health/feature_alignment/constants.py b/fl4health/feature_alignment/constants.py index 9fff609d7..e62e36995 100644 --- a/fl4health/feature_alignment/constants.py +++ b/fl4health/feature_alignment/constants.py @@ -1,8 +1,6 @@ -from typing import Union - from sklearn.feature_extraction.text import CountVectorizer, HashingVectorizer, TfidfTransformer, TfidfVectorizer -TextFeatureTransformer = Union[CountVectorizer, TfidfTransformer, TfidfVectorizer, HashingVectorizer] +TextFeatureTransformer = CountVectorizer | TfidfTransformer | TfidfVectorizer | HashingVectorizer # constants used in config for communication between # the server and clients. diff --git a/fl4health/feature_alignment/string_columns_transformer.py b/fl4health/feature_alignment/string_columns_transformer.py index 74bc3065f..a61b89972 100644 --- a/fl4health/feature_alignment/string_columns_transformer.py +++ b/fl4health/feature_alignment/string_columns_transformer.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - import pandas as pd from sklearn.base import BaseEstimator, TransformerMixin @@ -17,7 +15,7 @@ class TextMulticolumnTransformer(BaseEstimator, TransformerMixin): def __init__(self, transformer: TextFeatureTransformer): self.transformer = transformer - def fit(self, X: pd.DataFrame, y: Optional[pd.DataFrame] = None) -> TextMulticolumnTransformer: + def fit(self, X: pd.DataFrame, y: pd.DataFrame | None = None) -> TextMulticolumnTransformer: joined_X = X.apply(lambda x: " ".join(x), axis=1) self.transformer.fit(joined_X) return self @@ -36,7 +34,7 @@ class TextColumnTransformer(BaseEstimator, TransformerMixin): def __init__(self, transformer: TextFeatureTransformer): self.transformer = transformer - def fit(self, X: pd.DataFrame, y: Optional[pd.DataFrame] = None) -> TextColumnTransformer: + def fit(self, X: pd.DataFrame, y: pd.DataFrame | None = None) -> TextColumnTransformer: assert isinstance(X, pd.DataFrame) and X.shape[1] == 1 self.transformer.fit(X[X.columns[0]]) return self diff --git a/fl4health/feature_alignment/tab_features_info_encoder.py b/fl4health/feature_alignment/tab_features_info_encoder.py index 04a98aae4..d8ac09f20 100644 --- a/fl4health/feature_alignment/tab_features_info_encoder.py +++ b/fl4health/feature_alignment/tab_features_info_encoder.py @@ -1,7 +1,6 @@ from __future__ import annotations import json -from typing import Dict, List, Optional, Union import pandas as pd from cyclops.data.df.feature import TabularFeatures @@ -18,37 +17,37 @@ class TabularFeaturesInfoEncoder: alignment on tabular datasets. Args: - tabular_features (List[TabularFeature]): List of all tabular features. - tabular_targets (List[TabularFeature]): List of all targets. + tabular_features (list[TabularFeature]): List of all tabular features. + tabular_targets (list[TabularFeature]): List of all targets. (Note: targets are not included in tabular_features) """ - def __init__(self, tabular_features: List[TabularFeature], tabular_targets: List[TabularFeature]) -> None: + def __init__(self, tabular_features: list[TabularFeature], tabular_targets: list[TabularFeature]) -> None: self.tabular_features = sorted(tabular_features, key=TabularFeature.get_feature_name) self.tabular_targets = sorted(tabular_targets, key=TabularFeature.get_feature_name) - def get_tabular_features(self) -> List[TabularFeature]: + def get_tabular_features(self) -> list[TabularFeature]: return self.tabular_features - def get_tabular_targets(self) -> List[TabularFeature]: + def get_tabular_targets(self) -> list[TabularFeature]: return self.tabular_targets - def get_feature_columns(self) -> List[str]: + def get_feature_columns(self) -> list[str]: return sorted([feature.get_feature_name() for feature in self.tabular_features]) - def get_target_columns(self) -> List[str]: + def get_target_columns(self) -> list[str]: return sorted([target.get_feature_name() for target in self.tabular_targets]) - def features_by_type(self, tabular_type: TabularType) -> List[TabularFeature]: + def features_by_type(self, tabular_type: TabularType) -> list[TabularFeature]: return sorted( [feature for feature in self.tabular_features if feature.get_feature_type() == tabular_type], key=TabularFeature.get_feature_name, ) - def type_to_features(self) -> Dict[TabularType, List[TabularFeature]]: + def type_to_features(self) -> dict[TabularType, list[TabularFeature]]: return {tabular_type: self.features_by_type(tabular_type) for tabular_type in TabularType} - def get_categories_list(self) -> List[MetaData]: + def get_categories_list(self) -> list[MetaData]: return [cat_feature.get_metadata() for cat_feature in self.features_by_type(TabularType.ORDINAL)] def get_target_dimension(self) -> int: @@ -63,7 +62,7 @@ def _construct_tab_feature( df: pd.DataFrame, feature_name: str, feature_type: TabularType, - fill_values: Optional[Dict[str, Scalar]], + fill_values: dict[str, Scalar] | None, ) -> TabularFeature: if fill_values is None or feature_name not in fill_values: fill_value = TabularType.get_default_fill_value(feature_type) @@ -87,8 +86,8 @@ def _construct_tab_feature( def encoder_from_dataframe( df: pd.DataFrame, id_column: str, - target_columns: Union[str, List[str]], - fill_values: Optional[Dict[str, Scalar]] = None, + target_columns: str | list[str], + fill_values: dict[str, Scalar] | None = None, ) -> TabularFeaturesInfoEncoder: features_list = sorted(df.columns.values.tolist()) features_list.remove(id_column) diff --git a/fl4health/feature_alignment/tab_features_preprocessor.py b/fl4health/feature_alignment/tab_features_preprocessor.py index 9abeb3047..c0449c28f 100644 --- a/fl4health/feature_alignment/tab_features_preprocessor.py +++ b/fl4health/feature_alignment/tab_features_preprocessor.py @@ -1,5 +1,4 @@ from logging import WARNING -from typing import Dict, List, Tuple import pandas as pd from flwr.common.logger import log @@ -35,8 +34,8 @@ class TabularFeaturesPreprocessor: """ def __init__(self, tab_feature_encoder: TabularFeaturesInfoEncoder) -> None: - self.features_to_pipelines: Dict[str, Pipeline] = {} - self.targets_to_pipelines: Dict[str, Pipeline] = {} + self.features_to_pipelines: dict[str, Pipeline] = {} + self.targets_to_pipelines: dict[str, Pipeline] = {} self.tabular_features = tab_feature_encoder.get_tabular_features() self.tabular_targets = tab_feature_encoder.get_tabular_targets() @@ -77,13 +76,13 @@ def get_default_string_pipeline(self, vocabulary: MetaData) -> Pipeline: return Pipeline(steps=[("vectorizer", TextColumnTransformer(TfidfVectorizer(vocabulary=vocabulary)))]) def initialize_default_pipelines( - self, tabular_features: List[TabularFeature], one_hot: bool - ) -> Dict[str, Pipeline]: + self, tabular_features: list[TabularFeature], one_hot: bool + ) -> dict[str, Pipeline]: """ Initialize a default Pipeline for every data column in tabular_features. Args: - tabular_features (List[TabularFeature]): list of tabular + tabular_features (list[TabularFeature]): list of tabular features in the data columns. """ columns_to_pipelines = {} @@ -106,7 +105,7 @@ def initialize_default_pipelines( columns_to_pipelines[feature_name] = feature_pipeline return columns_to_pipelines - def return_column_transformer(self, pipelines: Dict[str, Pipeline]) -> ColumnTransformer: + def return_column_transformer(self, pipelines: dict[str, Pipeline]) -> ColumnTransformer: transformers = [ (f"{feature_name}_pipeline", pipelines[feature_name], [feature_name]) for feature_name in sorted(pipelines.keys()) @@ -129,7 +128,7 @@ def set_feature_pipeline(self, feature_name: str, pipeline: Pipeline) -> None: else: log(WARNING, f"{feature_name} is neither a feature nor target and the provided pipeline will be ignored.") - def preprocess_features(self, df: pd.DataFrame) -> Tuple[NDArray, NDArray]: + def preprocess_features(self, df: pd.DataFrame) -> tuple[NDArray, NDArray]: # If the dataframe has an entire column missing, we need to fill it with some default value first. df_filled = self.fill_in_missing_columns(df) # After filling in missing columns, apply the feature alignment transform. diff --git a/fl4health/feature_alignment/tabular_feature.py b/fl4health/feature_alignment/tabular_feature.py index 797b02cd5..d1e3167ff 100644 --- a/fl4health/feature_alignment/tabular_feature.py +++ b/fl4health/feature_alignment/tabular_feature.py @@ -1,13 +1,12 @@ from __future__ import annotations import json -from typing import Optional, Union -from flwr.common.typing import Dict, List, Scalar +from flwr.common.typing import Scalar from fl4health.feature_alignment.tabular_type import TabularType -MetaData = Union[Dict[str, int], List[Scalar]] +MetaData = dict[str, int] | list[Scalar] class TabularFeature: @@ -15,8 +14,8 @@ def __init__( self, feature_name: str, feature_type: TabularType, - fill_value: Optional[Scalar], - metadata: Optional[MetaData] = None, + fill_value: Scalar | None, + metadata: MetaData | None = None, ) -> None: """ Information that represents a tabular feature. @@ -24,7 +23,7 @@ def __init__( Args: feature_name (str): name of the feature. feature_type (TabularType): data type of the feature. - fill_value (Optional[Scalar]): the default fill value for this feature when it is missing in a dataframe. + fill_value (Scalar | None): the default fill value for this feature when it is missing in a dataframe. metadata (MetaData, optional): metadata associated with this feature. For example, if the feature is categorical, then metadata would be all the categories. Defaults to None. """ diff --git a/fl4health/feature_alignment/tabular_type.py b/fl4health/feature_alignment/tabular_type.py index 4a87233b4..b2004f602 100644 --- a/fl4health/feature_alignment/tabular_type.py +++ b/fl4health/feature_alignment/tabular_type.py @@ -1,5 +1,6 @@ +from __future__ import annotations + from enum import Enum -from typing import Union from flwr.common.typing import Scalar @@ -11,7 +12,7 @@ class TabularType(str, Enum): STRING = "string" @staticmethod - def get_default_fill_value(tabular_type: Union["TabularType", str]) -> Scalar: + def get_default_fill_value(tabular_type: TabularType | str) -> Scalar: if tabular_type is TabularType.NUMERIC: return 0.0 elif tabular_type is TabularType.BINARY: diff --git a/fl4health/losses/deep_mmd_loss.py b/fl4health/losses/deep_mmd_loss.py index ca8b79255..6f2c96418 100644 --- a/fl4health/losses/deep_mmd_loss.py +++ b/fl4health/losses/deep_mmd_loss.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple - import numpy as np import torch @@ -103,7 +101,7 @@ def __init__( # Set the model to training mode if required to train the Deep Kernel self.training = False - def pairwise_distiance_squared(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: + def pairwise_distance_squared(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: """ Compute the paired distance between x and y. @@ -126,7 +124,7 @@ def h1_mean_var_gram( k_y: torch.Tensor, k_xy: torch.Tensor, is_var_computed: bool, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute value of MMD and std of MMD using kernel matrix. @@ -137,7 +135,7 @@ def h1_mean_var_gram( is_var_computed (bool): Whether to compute the variance of the MMD. Returns: - Tuple[torch.Tensor, Optional[torch.Tensor]]: The value of MMD and the variance of MMD + tuple[torch.Tensor, torch.Tensor | None]: The value of MMD and the variance of MMD if required to compute. """ nx = k_x.shape[0] @@ -176,7 +174,7 @@ def MMDu( epsilon: torch.Tensor, is_smooth: bool = True, is_var_computed: bool = True, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Compute value of deep-kernel MMD and std of deep-kernel MMD using merged data. @@ -193,21 +191,21 @@ def MMDu( Defaults to True. Returns: - Tuple[torch.Tensor, Optional[torch.Tensor]]: The value of MMD and the variance of MMD + tuple[torch.Tensor, torch.Tensor | None]: The value of MMD and the variance of MMD if required to compute. """ x = features[0:len_s, :] # fetch the sample 1 (features of deep networks) y = features[len_s:, :] # fetch the sample 2 (features of deep networks) - distance_xx = self.pairwise_distiance_squared(x, x) - distance_yy = self.pairwise_distiance_squared(y, y) - distance_xy = self.pairwise_distiance_squared(x, y) + distance_xx = self.pairwise_distance_squared(x, x) + distance_yy = self.pairwise_distance_squared(y, y) + distance_xy = self.pairwise_distance_squared(x, y) if is_smooth: x_original = features_org[0:len_s, :] # fetch the original sample 1 y_original = features_org[len_s:, :] # fetch the original sample 2 - distance_xx_original = self.pairwise_distiance_squared(x_original, x_original) - distance_yy_original = self.pairwise_distiance_squared(y_original, y_original) - distance_xy_original = self.pairwise_distiance_squared(x_original, y_original) + distance_xx_original = self.pairwise_distance_squared(x_original, x_original) + distance_yy_original = self.pairwise_distance_squared(y_original, y_original) + distance_xy_original = self.pairwise_distance_squared(x_original, y_original) kernel_x = (1 - epsilon) * torch.exp( -((distance_xx / sigma_phi) ** self.gaussian_degree) - distance_xx_original / sigma_q @@ -224,7 +222,7 @@ def MMDu( kernel_y = torch.exp(-distance_yy / sigma_phi) kernel_xy = torch.exp(-distance_xy / sigma_phi) - # kernel_x reprsents k_w(x_i, x_j), kernel_y represents k_w(y_i, y_j), kernel_xy represents + # kernel_x represents k_w(x_i, x_j), kernel_y represents k_w(y_i, y_j), kernel_xy represents # k_w(x_i, y_j) for all i, j in the sample X and sample Y defined in Equation (1) of the paper return self.h1_mean_var_gram(kernel_x, kernel_y, kernel_xy, is_var_computed) diff --git a/fl4health/losses/fenda_loss_config.py b/fl4health/losses/fenda_loss_config.py index 0150d93b9..b862dffaf 100644 --- a/fl4health/losses/fenda_loss_config.py +++ b/fl4health/losses/fenda_loss_config.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple - import torch from fl4health.losses.contrastive_loss import MoonContrastiveLoss @@ -36,9 +34,9 @@ def __init__(self, device: torch.device, contrastive_loss_weight: float, tempera class ConstrainedFendaLossContainer: def __init__( self, - perfcl_loss_config: Optional[PerFclLossContainer], - cosine_similarity_loss_config: Optional[CosineSimilarityLossContainer], - contrastive_loss_config: Optional[MoonContrastiveLossContainer], + perfcl_loss_config: PerFclLossContainer | None, + cosine_similarity_loss_config: CosineSimilarityLossContainer | None, + contrastive_loss_config: MoonContrastiveLossContainer | None, ) -> None: self.perfcl_loss_config = perfcl_loss_config self.cos_sim_loss_config = cosine_similarity_loss_config @@ -76,7 +74,7 @@ def compute_perfcl_loss( global_features: torch.Tensor, old_global_features: torch.Tensor, initial_global_features: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: assert self.perfcl_loss_config is not None global_feature_contrastive_loss, local_feature_contrastive_loss = self.perfcl_loss_config.perfcl_loss_function( local_features, old_local_features, global_features, old_global_features, initial_global_features diff --git a/fl4health/losses/mkmmd_loss.py b/fl4health/losses/mkmmd_loss.py index da9df2a1f..5049d0a9a 100644 --- a/fl4health/losses/mkmmd_loss.py +++ b/fl4health/losses/mkmmd_loss.py @@ -1,5 +1,4 @@ from logging import INFO -from typing import Optional import torch from flwr.common.logger import log @@ -10,11 +9,11 @@ class MkMmdLoss(torch.nn.Module): def __init__( self, device: torch.device, - gammas: Optional[torch.Tensor] = None, - betas: Optional[torch.Tensor] = None, + gammas: torch.Tensor | None = None, + betas: torch.Tensor | None = None, minimize_type_two_error: bool = True, normalize_features: bool = False, - layer_name: Optional[str] = None, + layer_name: str | None = None, perform_linear_approximation: bool = False, ) -> None: """ @@ -23,19 +22,19 @@ def __init__( Args: device (torch.device): Device onto which tensors should be moved - gammas (Optional[torch.Tensor], optional): These are known as the length-scales of the RBF functions used + gammas (torch.Tensor | None, optional): These are known as the length-scales of the RBF functions used to compute the Mk-MMD distances. The length of this list defines the number of kernels used in the norm measurement. If none, a default of 19 kernels is used. Defaults to None. - betas (Optional[torch.Tensor], optional): These are the linear coefficients used on the basis of kernels + betas (torch.Tensor | None, optional): These are the linear coefficients used on the basis of kernels to compute the Mk-MMD measure. If not provided, a unit-length, random default is constructed. These can be optimized using the functions of this class. Defaults to None. - minimize_type_two_error (Optional[bool], optional): Whether we're aiming to minimize the type II error in + minimize_type_two_error (bool | None, optional): Whether we're aiming to minimize the type II error in optimizing the betas or maximize it. The first coincides with trying to minimize feature distance. The second coincides with trying to maximize their feature distance. Defaults to True. - normalize_features (Optional[bool], optional): Whether to normalize the features to have unit length before + normalize_features (bool | None, optional): Whether to normalize the features to have unit length before computing the MK-MMD and optimizing betas. Defaults to False. - layer_name (Optional[str], optional): The name of the layer to extract features from. Defaults to None. - perform_linear_approximation (Optional[bool], optional): Whether to use linear approximations for the + layer_name (str | None, optional): The name of the layer to extract features from. Defaults to None. + perform_linear_approximation (bool | None, optional): Whether to use linear approximations for the estimates of the mean and covariance of the kernel values. Experimentally, we have found that the linear approximations largely hinder the statistical power of Mk-MMD. Defaults to False """ diff --git a/fl4health/losses/perfcl_loss.py b/fl4health/losses/perfcl_loss.py index fe7daea10..706781d15 100644 --- a/fl4health/losses/perfcl_loss.py +++ b/fl4health/losses/perfcl_loss.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn @@ -24,7 +22,7 @@ def forward( global_features: torch.Tensor, old_global_features: torch.Tensor, initial_global_features: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ PerFCL loss implemented based on https://www.sciencedirect.com/science/article/pii/S0031320323002078. This paper introduced two contrastive loss functions: @@ -51,7 +49,7 @@ def forward( model at the start of client-side training. This feature extractor is the AGGREGATED weights across clients. Shape (batch_size, n_features) Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple containing the two components of the PerFCL loss function to + tuple[torch.Tensor, torch.Tensor]: Tuple containing the two components of the PerFCL loss function to be weighted and summed. The first tensor corresponds to the global feature loss, the second is associated with the local feature loss. """ diff --git a/fl4health/losses/weight_drift_loss.py b/fl4health/losses/weight_drift_loss.py index f18acf0ec..3f008202b 100644 --- a/fl4health/losses/weight_drift_loss.py +++ b/fl4health/losses/weight_drift_loss.py @@ -1,5 +1,3 @@ -from typing import List - import torch import torch.nn as nn @@ -15,7 +13,7 @@ def __init__( def _compute_weight_difference_inner_product(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return torch.pow(torch.linalg.norm(x - y), 2.0) - def forward(self, target_model: nn.Module, constraint_tensors: List[torch.Tensor], weight: float) -> torch.Tensor: + def forward(self, target_model: nn.Module, constraint_tensors: list[torch.Tensor], weight: float) -> torch.Tensor: # move model and tensors to device if needed target_model = target_model.to(self.device) constraint_tensors = [constraint_tensor.to(self.device) for constraint_tensor in constraint_tensors] @@ -24,7 +22,7 @@ def forward(self, target_model: nn.Module, constraint_tensors: List[torch.Tensor assert len(constraint_tensors) == len(model_weights) assert len(model_weights) > 0 - layer_inner_products: List[torch.Tensor] = [ + layer_inner_products: list[torch.Tensor] = [ self._compute_weight_difference_inner_product(constraint_layer_weights, model_layer_weights) for constraint_layer_weights, model_layer_weights in zip(constraint_tensors, model_weights) ] diff --git a/fl4health/model_bases/apfl_base.py b/fl4health/model_bases/apfl_base.py index 69daf90c6..b34e4624e 100644 --- a/fl4health/model_bases/apfl_base.py +++ b/fl4health/model_bases/apfl_base.py @@ -1,5 +1,4 @@ import copy -from typing import Dict, List import torch import torch.nn as nn @@ -29,7 +28,7 @@ def global_forward(self, input: torch.Tensor) -> torch.Tensor: def local_forward(self, input: torch.Tensor) -> torch.Tensor: return self.local_model(input) - def forward(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: + def forward(self, input: torch.Tensor) -> dict[str, torch.Tensor]: # Forward return dictionary because APFL has multiple different prediction types global_logits = self.global_forward(input) local_logits = self.local_forward(input) @@ -70,8 +69,8 @@ def update_alpha(self) -> None: alpha = max(min(alpha, 1), 0) self.alpha = alpha - def layers_to_exchange(self) -> List[str]: - layers_to_exchange: List[str] = [ + def layers_to_exchange(self) -> list[str]: + layers_to_exchange: list[str] = [ layer for layer in self.state_dict().keys() if layer.startswith("global_model.") ] return layers_to_exchange diff --git a/fl4health/model_bases/autoencoders_base.py b/fl4health/model_bases/autoencoders_base.py index b96b0d053..a958637cb 100644 --- a/fl4health/model_bases/autoencoders_base.py +++ b/fl4health/model_bases/autoencoders_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Optional, Tuple +from collections.abc import Callable import torch import torch.nn as nn @@ -58,7 +58,7 @@ def __init__( """ super().__init__(encoder, decoder) - def encode(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def encode(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: mu, logvar = self.encoder(input) return mu, logvar @@ -88,28 +88,26 @@ def __init__( self, encoder: nn.Module, decoder: nn.Module, - unpack_input_condition: Optional[Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]] = None, + unpack_input_condition: Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]] | None = None, ) -> None: """Conditional Variational Auto-Encoder model. Args: encoder (nn.Module): The encoder used to map input to latent space. decoder (nn.Module): The decoder used to reconstruct the input using a vector in latent space. - unpack_input_condition (Optional[Callable], optional): For unpacking the input and condition tensors. + unpack_input_condition (Callable | None, optional): For unpacking the input and condition tensors. """ super().__init__(encoder, decoder) self.unpack_input_condition = unpack_input_condition - def encode( - self, input: torch.Tensor, condition: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: + def encode(self, input: torch.Tensor, condition: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]: # User can decide how to use the condition in the encoder, # ex: using the condition in the middle layers of encoder. mu, logvar = self.encoder(input, condition) return mu, logvar - def decode(self, latent_vector: torch.Tensor, condition: Optional[torch.Tensor] = None) -> torch.Tensor: + def decode(self, latent_vector: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor: # User can decide how to use the condition in the decoder, # ex: using the condition in the middle layers of decoder, or not using it at all. output = self.decoder(latent_vector, condition) diff --git a/fl4health/model_bases/ensemble_base.py b/fl4health/model_bases/ensemble_base.py index beeb1132c..f8ee79fc2 100644 --- a/fl4health/model_bases/ensemble_base.py +++ b/fl4health/model_bases/ensemble_base.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Dict, List, Optional import torch import torch.nn as nn @@ -13,16 +12,16 @@ class EnsembleAggregationMode(Enum): class EnsembleModel(nn.Module): def __init__( self, - ensemble_models: Dict[str, nn.Module], - aggregation_mode: Optional[EnsembleAggregationMode] = EnsembleAggregationMode.AVERAGE, + ensemble_models: dict[str, nn.Module], + aggregation_mode: EnsembleAggregationMode | None = EnsembleAggregationMode.AVERAGE, ) -> None: """ Class that acts a wrapper to an ensemble of models to be trained in federated manner with support for both voting and averaging prediction of individual models. Args: - ensemble_models (Dict[str, nn.Module]): A dictionary of models that make up the ensemble. - aggregation_mode (Optional[EnsembleAggregationMode]): The mode in which to aggregate the + ensemble_models (dict[str, nn.Module]): A dictionary of models that make up the ensemble. + aggregation_mode (EnsembleAggregationMode | None): The mode in which to aggregate the predictions of individual models. """ super().__init__() @@ -30,7 +29,7 @@ def __init__( self.ensemble_models = nn.ModuleDict(ensemble_models) self.aggregation_mode = aggregation_mode - def forward(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: + def forward(self, input: torch.Tensor) -> dict[str, torch.Tensor]: """ Produce the predictions of the ensemble models given input data. @@ -38,7 +37,7 @@ def forward(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: input (torch.Tensor): A batch of input data. Returns: - Dict[str, torch.Tensor]: A dictionary of predictions of the individual ensemble models + dict[str, torch.Tensor]: A dictionary of predictions of the individual ensemble models as well as prediction of the ensemble as a whole. """ preds = {} @@ -56,14 +55,14 @@ def forward(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: return preds - def ensemble_vote(self, preds_list: List[torch.Tensor]) -> torch.Tensor: + def ensemble_vote(self, preds_list: list[torch.Tensor]) -> torch.Tensor: """ Produces the aggregated prediction of the ensemble via voting. Expects predictions to be in a format where the 0 axis represents the sample index and the -1 axis represents the class dimension. Args: - preds_list (List[torch.Tensor]): A list of predictions of the models in the ensemble. + preds_list (list[torch.Tensor]): A list of predictions of the models in the ensemble. Returns: torch.Tensor: The vote prediction of the ensemble. @@ -92,12 +91,12 @@ def ensemble_vote(self, preds_list: List[torch.Tensor]) -> torch.Tensor: return vote_preds - def ensemble_average(self, preds_list: List[torch.Tensor]) -> torch.Tensor: + def ensemble_average(self, preds_list: list[torch.Tensor]) -> torch.Tensor: """ Produces the aggregated prediction of the ensemble via averaging. Args: - preds_list (List[torch.Tensor]): A list of predictions of the models in the ensemble. + preds_list (list[torch.Tensor]): A list of predictions of the models in the ensemble. Returns: torch.Tensor: The average prediction of the ensemble. diff --git a/fl4health/model_bases/feature_extractor_buffer.py b/fl4health/model_bases/feature_extractor_buffer.py index 8285b5f6b..29928d21f 100644 --- a/fl4health/model_bases/feature_extractor_buffer.py +++ b/fl4health/model_bases/feature_extractor_buffer.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from logging import INFO -from typing import Callable, Dict, List import torch import torch.nn as nn @@ -8,7 +8,7 @@ class FeatureExtractorBuffer: - def __init__(self, model: nn.Module, flatten_feature_extraction_layers: Dict[str, bool]) -> None: + def __init__(self, model: nn.Module, flatten_feature_extraction_layers: dict[str, bool]) -> None: """ This class is used to extract features from the intermediate layers of a neural network model and store them in a buffer. The features are extracted using additional hooks that are registered to the model. The extracted @@ -17,24 +17,24 @@ def __init__(self, model: nn.Module, flatten_feature_extraction_layers: Dict[str Args: model (nn.Module): The neural network model. - flatten_feature_extraction_layers (Dict[str, bool]): Dictionary of layers to extract features from them and + flatten_feature_extraction_layers (dict[str, bool]): Dictionary of layers to extract features from them and whether to flatten them. Keys are the layer names that are extracted from the named_modules and values are boolean. Attributes: model (nn.Module): The neural network model. - flatten_feature_extraction_layers (Dict[str, bool]): A dictionary specifying whether to flatten the feature + flatten_feature_extraction_layers (dict[str, bool]): A dictionary specifying whether to flatten the feature extraction layers. - fhooks (List[RemovableHandle]): A list to store the handles for removing hooks. + fhooks (list[RemovableHandle]): A list to store the handles for removing hooks. accumulate_features (bool): A flag indicating whether to accumulate features. - extracted_features_buffers (Dict[str, List[torch.Tensor]]): A dictionary to store the extracted features + extracted_features_buffers (dict[str, list[torch.Tensor]]): A dictionary to store the extracted features for each layer. """ self.model = model self.flatten_feature_extraction_layers = flatten_feature_extraction_layers - self.fhooks: List[RemovableHandle] = [] + self.fhooks: list[RemovableHandle] = [] self.accumulate_features: bool = False - self.extracted_features_buffers: Dict[str, List[torch.Tensor]] = { + self.extracted_features_buffers: dict[str, list[torch.Tensor]] = { layer: [] for layer in flatten_feature_extraction_layers.keys() } @@ -63,14 +63,14 @@ def clear_buffers(self) -> None: """ self.extracted_features_buffers = {layer: [] for layer in self.flatten_feature_extraction_layers.keys()} - def get_hierarchical_attr(self, module: nn.Module, layer_hierarchy: List[str]) -> nn.Module: + def get_hierarchical_attr(self, module: nn.Module, layer_hierarchy: list[str]) -> nn.Module: """ Traverse the hierarchical attributes of the module to get the desired attribute. Hooks should be registered to specific layers of the model, not to nn.Sequential or nn.ModuleList. Args: module (nn.Module): The nn.Module object to traverse. - layer_hierarchy (List[str]): The hierarchical list of name of desired layer. + layer_hierarchy (list[str]): The hierarchical list of name of desired layer. Returns: nn.Module: The desired layer of the model. @@ -80,14 +80,14 @@ def get_hierarchical_attr(self, module: nn.Module, layer_hierarchy: List[str]) - else: return self.get_hierarchical_attr(getattr(module, layer_hierarchy[0]), layer_hierarchy[1:]) - def find_last_common_prefix(self, prefix: str, layers_name: List[str]) -> str: + def find_last_common_prefix(self, prefix: str, layers_name: list[str]) -> str: """ Check the model's list of named modules to filter any layer that starts with the given prefix and return the last one. Args: prefix (str): The prefix of the layer name for registering the hook. - layers_name (List[str]): The list of named modules of the model. The assumption is that list of + layers_name (list[str]): The list of named modules of the model. The assumption is that list of named modules is sorted in the order of the model's forward pass with depth-first traversal. This will allow the user to specify the generic name of the layer instead of the full hierarchical name. @@ -112,9 +112,9 @@ def _maybe_register_hooks(self) -> None: # Find the last specific layer under a given generic name specific_layer = self.find_last_common_prefix(layer, named_layers) # Split the specific layer name by '.' to get the hierarchical attribute - layer_hierarchicy_list = specific_layer.split(".") + layer_hierarchy_list = specific_layer.split(".") self.fhooks.append( - self.get_hierarchical_attr(self.model, layer_hierarchicy_list).register_forward_hook( + self.get_hierarchical_attr(self.model, layer_hierarchy_list).register_forward_hook( self.forward_hook(layer) ) ) @@ -166,12 +166,12 @@ def flatten(self, features: torch.Tensor) -> torch.Tensor: return features.reshape(len(features), -1) - def get_extracted_features(self) -> Dict[str, torch.Tensor]: + def get_extracted_features(self) -> dict[str, torch.Tensor]: """ Returns a dictionary of extracted features. Returns: - features (Dict[str, torch.Tensor]): A dictionary where the keys are the layer names and the values are + features (dict[str, torch.Tensor]): A dictionary where the keys are the layer names and the values are the extracted features as torch Tensors. """ features = {} diff --git a/fl4health/model_bases/fedsimclr_base.py b/fl4health/model_bases/fedsimclr_base.py index 8f93bdeec..4ea3429df 100644 --- a/fl4health/model_bases/fedsimclr_base.py +++ b/fl4health/model_bases/fedsimclr_base.py @@ -1,7 +1,6 @@ from __future__ import annotations from pathlib import Path -from typing import Optional import torch import torch.nn as nn @@ -12,7 +11,7 @@ def __init__( self, encoder: nn.Module, projection_head: nn.Module = nn.Identity(), - prediction_head: Optional[nn.Module] = None, + prediction_head: nn.Module | None = None, pretrain: bool = True, ) -> None: """ @@ -26,7 +25,7 @@ def __init__( projection_head (nn.Module): Projection Head that maps output of encoder to final representation used in contrastive loss for pretraining stage. Defaults to identity transformation. - prediction_head (Optional[nn.Module]): Prediction head that maps + prediction_head (nn.Module | None): Prediction head that maps output of encoder to prediction in the finetuning stage. Defaults to None. pretrain (bool): Determines whether or not to use the projection_head diff --git a/fl4health/model_bases/fenda_base.py b/fl4health/model_bases/fenda_base.py index d75f767f2..a32ff1007 100644 --- a/fl4health/model_bases/fenda_base.py +++ b/fl4health/model_bases/fenda_base.py @@ -1,5 +1,3 @@ -from typing import Dict, List, Tuple - import torch import torch.nn as nn @@ -25,7 +23,7 @@ def __init__(self, local_module: nn.Module, global_module: nn.Module, model_head self, first_feature_extractor=local_module, second_feature_extractor=global_module, model_head=model_head ) - def layers_to_exchange(self) -> List[str]: + def layers_to_exchange(self) -> list[str]: return [ layer_name for layer_name in self.state_dict().keys() if layer_name.startswith("second_feature_extractor.") ] @@ -58,7 +56,7 @@ def __init__( super().__init__(local_module=local_module, global_module=global_module, model_head=model_head) self.flatten_features = flatten_features - def forward(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + def forward(self, input: torch.Tensor) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: # input is expected to be of shape (batch_size, *) local_output = self.first_feature_extractor.forward(input) global_output = self.second_feature_extractor.forward(input) diff --git a/fl4health/model_bases/masked_layers/masked_conv.py b/fl4health/model_bases/masked_layers/masked_conv.py index 0d9bc1b8c..f997479ba 100644 --- a/fl4health/model_bases/masked_layers/masked_conv.py +++ b/fl4health/model_bases/masked_layers/masked_conv.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import List, Optional, Union - import torch import torch.nn as nn import torch.nn.functional as F @@ -20,13 +18,13 @@ def __init__( out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, - padding: Union[str, _size_1_t] = 0, + padding: str | _size_1_t = 0, dilation: _size_1_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked Conv1d layers. @@ -105,7 +103,7 @@ def from_pretrained(cls, conv_module: nn.Conv1d) -> MaskedConv1d: """ has_bias = conv_module.bias is not None # we create new variables below to make mypy happy since kernel_size has - # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] + # type int | tuple[int] and kernel_size_ has type tuple[int] kernel_size_ = _single(conv_module.kernel_size) stride_ = _single(conv_module.stride) padding_ = conv_module.padding if isinstance(conv_module.padding, str) else _single(conv_module.padding) @@ -137,13 +135,13 @@ def __init__( out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, - padding: Union[str, _size_2_t] = 0, + padding: str | _size_2_t = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked Conv2d layers. @@ -251,13 +249,13 @@ def __init__( out_channels: int, kernel_size: _size_3_t, stride: _size_3_t = 1, - padding: Union[str, _size_3_t] = 0, + padding: str | _size_3_t = 0, dilation: _size_3_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked Conv2d layers. @@ -369,8 +367,8 @@ def __init__( bias: bool = True, dilation: _size_1_t = 1, padding_mode: str = "zeros", - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked ConvTranspose1d layers. For more information on transposed convolution, @@ -433,7 +431,7 @@ def __init__( else: self.register_parameter("bias_scores", None) - def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: # Note: the same check is already present in super().__init__ if self.padding_mode != "zeros": raise ValueError("Only `zeros` padding mode is supported for ConvTranspose1d") @@ -441,7 +439,7 @@ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Ten # (The type ignore below is just used to resolve some small typing issue.) # One cannot replace List by Tuple or Sequence in "_output_padding" - # because TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`. + # because TorchScript does not support `Sequence[T]` or `tuple[T, ...]`. output_padding = self._output_padding( input, output_size, @@ -473,7 +471,7 @@ def from_pretrained(cls, conv_module: nn.ConvTranspose1d) -> MaskedConvTranspose """ has_bias = conv_module.bias is not None # we create new variables below to make mypy happy since kernel_size has - # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] + # type int | tuple[int] and kernel_size_ has type tuple[int] kernel_size_ = _single(conv_module.kernel_size) stride_ = _single(conv_module.stride) padding_ = _single(conv_module.padding) @@ -513,8 +511,8 @@ def __init__( bias: bool = True, dilation: _size_2_t = 1, padding_mode: str = "zeros", - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked ConvTranspose2d layers. For more information on transposed convolution, @@ -576,7 +574,7 @@ def __init__( else: self.register_parameter("bias_scores", None) - def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: # Note: the same check is already present in super().__init__ if self.padding_mode != "zeros": raise ValueError("Only `zeros` padding mode is supported for ConvTranspose1d") @@ -613,7 +611,7 @@ def from_pretrained(cls, conv_module: nn.ConvTranspose2d) -> MaskedConvTranspose """ has_bias = conv_module.bias is not None # we create new variables below to make mypy happy since kernel_size has - # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] + # type int | tuple[int] and kernel_size_ has type tuple[int] kernel_size_ = _pair(conv_module.kernel_size) stride_ = _pair(conv_module.stride) padding_ = _pair(conv_module.padding) @@ -653,8 +651,8 @@ def __init__( bias: bool = True, dilation: _size_3_t = 1, padding_mode: str = "zeros", - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked ConvTranspose3d layers. For more information on transposed convolution, @@ -716,7 +714,7 @@ def __init__( else: self.register_parameter("bias_scores", None) - def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor: + def forward(self, input: Tensor, output_size: list[int] | None = None) -> Tensor: # Note: the same check is already present in super().__init__ if self.padding_mode != "zeros": raise ValueError("Only `zeros` padding mode is supported for ConvTranspose1d") @@ -753,7 +751,7 @@ def from_pretrained(cls, conv_module: nn.ConvTranspose3d) -> MaskedConvTranspose """ has_bias = conv_module.bias is not None # we create new variables below to make mypy happy since kernel_size has - # type Union[int, Tuple[int]] and kernel_size_ has type Tuple[int] + # type int | tuple[int] and kernel_size_ has type tuple[int] kernel_size_ = _triple(conv_module.kernel_size) stride_ = _triple(conv_module.stride) padding_ = _triple(conv_module.padding) diff --git a/fl4health/model_bases/masked_layers/masked_linear.py b/fl4health/model_bases/masked_layers/masked_linear.py index 38023f862..473c24f72 100644 --- a/fl4health/model_bases/masked_layers/masked_linear.py +++ b/fl4health/model_bases/masked_layers/masked_linear.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn import torch.nn.functional as F @@ -15,8 +13,8 @@ def __init__( in_features: int, out_features: int, bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> None: """ Implementation of masked linear layers. diff --git a/fl4health/model_bases/masked_layers/masked_normalization_layers.py b/fl4health/model_bases/masked_layers/masked_normalization_layers.py index bf5b37dd7..89bd61a0a 100644 --- a/fl4health/model_bases/masked_layers/masked_normalization_layers.py +++ b/fl4health/model_bases/masked_layers/masked_normalization_layers.py @@ -1,5 +1,3 @@ -from typing import List, Optional, Union - import torch import torch.nn as nn import torch.nn.functional as F @@ -9,7 +7,7 @@ from fl4health.utils.functions import bernoulli_sample -TorchShape = Union[int, List[int], torch.Size] +TorchShape = int | list[int] | torch.Size class MaskedLayerNorm(nn.LayerNorm): @@ -19,8 +17,8 @@ def __init__( eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> None: """ Implementation of the masked Layer Normalization module. When elementwise_affine is True, @@ -135,8 +133,8 @@ def __init__( momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, ) -> None: """ Base class for masked batch normalization modules of various dimensions. When affine is True, diff --git a/fl4health/model_bases/moon_base.py b/fl4health/model_bases/moon_base.py index ad6db9078..8d09e886c 100644 --- a/fl4health/model_bases/moon_base.py +++ b/fl4health/model_bases/moon_base.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple - import torch import torch.nn as nn @@ -8,7 +6,7 @@ class MoonModel(SequentiallySplitModel): def __init__( - self, base_module: nn.Module, head_module: nn.Module, projection_module: Optional[nn.Module] = None + self, base_module: nn.Module, head_module: nn.Module, projection_module: nn.Module | None = None ) -> None: """ A MOON Model is a specific type of sequentially split model, where one may specify an optional projection @@ -19,7 +17,7 @@ def __init__( Args: base_module (nn.Module): Feature extractor component of the model head_module (nn.Module): Classification (or other type) of head used by the model - projection_module (Optional[nn.Module], optional): An optional module for manipulating the features before + projection_module (nn.Module | None, optional): An optional module for manipulating the features before they are passed to the head_module. Defaults to None. """ @@ -28,7 +26,7 @@ def __init__( super().__init__(base_module, head_module, flatten_features=True) self.projection_module = projection_module - def sequential_forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def sequential_forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Overriding the sequential forward of the SequentiallySplitModel parent to allow for the injection of a projection module into the forward pass. The remainder of the functionality stays the same. That is, @@ -38,7 +36,7 @@ def sequential_forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.T input (torch.Tensor): Input to the model forward pass. Expected to be of shape (batch_size, *) Returns: - Tuple[torch.Tensor, torch.Tensor]: Returns the predictions and features tensor from the sequential forward + tuple[torch.Tensor, torch.Tensor]: Returns the predictions and features tensor from the sequential forward """ x = self.base_module.forward(input) # A projection module is optionally specified for MOON models. If no module is provided, it is simply skipped diff --git a/fl4health/model_bases/parallel_split_models.py b/fl4health/model_bases/parallel_split_models.py index 0dd222eba..1ec98b09b 100644 --- a/fl4health/model_bases/parallel_split_models.py +++ b/fl4health/model_bases/parallel_split_models.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Dict, Tuple import torch import torch.nn as nn @@ -65,7 +64,7 @@ def __init__( self.second_feature_extractor = second_feature_extractor self.model_head = model_head - def forward(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + def forward(self, input: torch.Tensor) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: first_output = self.first_feature_extractor.forward(input) second_output = self.second_feature_extractor.forward(input) preds = {"prediction": self.model_head.forward(first_output, second_output)} diff --git a/fl4health/model_bases/partial_layer_exchange_model.py b/fl4health/model_bases/partial_layer_exchange_model.py index 1e74937c2..e908b0652 100644 --- a/fl4health/model_bases/partial_layer_exchange_model.py +++ b/fl4health/model_bases/partial_layer_exchange_model.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod -from typing import List import torch.nn as nn class PartialLayerExchangeModel(nn.Module, ABC): @abstractmethod - def layers_to_exchange(self) -> List[str]: + def layers_to_exchange(self) -> list[str]: raise NotImplementedError diff --git a/fl4health/model_bases/pca.py b/fl4health/model_bases/pca.py index 2844a77c9..9ed69bcfc 100644 --- a/fl4health/model_bases/pca.py +++ b/fl4health/model_bases/pca.py @@ -1,5 +1,4 @@ from logging import INFO, WARNING -from typing import Optional, Tuple import torch import torch.nn as nn @@ -59,7 +58,7 @@ def __init__(self, low_rank: bool = False, full_svd: bool = False, rank_estimati self.singular_values: Parameter self.data_mean: Tensor - def forward(self, X: Tensor, center_data: bool) -> Tuple[Tensor, Tensor]: + def forward(self, X: Tensor, center_data: bool) -> tuple[Tensor, Tensor]: """ Perform PCA on the data matrix X by computing its SVD. @@ -71,7 +70,7 @@ def forward(self, X: Tensor, center_data: bool) -> Tuple[Tensor, Tensor]: will be thrown if it is not. Returns: - Tuple[Tensor, Tensor]: The principal components (i.e., right singular vectors) + tuple[Tensor, Tensor]: The principal components (i.e., right singular vectors) and their corresponding singular values. Note: the algorithm assumes that the rows of X are the data points (after reshaping as needed). @@ -143,13 +142,13 @@ def prepare_data_forward(self, X: Tensor, center_data: bool) -> Tensor: assert torch.allclose(torch.zeros(data_mean.size()), data_mean, atol=1e-6) return X - def project_lower_dim(self, X: Tensor, k: Optional[int] = None, center_data: bool = False) -> Tensor: + def project_lower_dim(self, X: Tensor, k: int | None = None, center_data: bool = False) -> Tensor: """ Project input data X onto the top k principal components. Args: X (Tensor): Input data matrix whose rows are the data points. - k (Optional[int], optional): The number of principal components + k (int | None, optional): The number of principal components onto which projection is done. If k is None, then all principal components will be used in the projection. Defaults to None. center_data (bool): If true, then the *training* data mean (learned in the forward pass) @@ -196,7 +195,7 @@ def project_back(self, X_lower_dim: Tensor, add_mean: bool = False) -> Tensor: else: return torch.matmul(X_lower_dim_prime, self.principal_components[:, :k].T) - def compute_reconstruction_error(self, X: Tensor, k: Optional[int], center_data: bool = False) -> float: + def compute_reconstruction_error(self, X: Tensor, k: int | None, center_data: bool = False) -> float: """ Compute the reconstruction error of X under PCA reconstruction. @@ -207,7 +206,7 @@ def compute_reconstruction_error(self, X: Tensor, k: Optional[int], center_data: Args: X (Tensor): Input data tensor whose rows represent data points. - k (Optional[int]): The number of principal components onto which + k (int | None): The number of principal components onto which projection is applied. center_data (bool): Indicates whether to subtract data mean prior to projecting the data into a lower-dimensional subspace, and whether to add @@ -224,7 +223,7 @@ def compute_reconstruction_error(self, X: Tensor, k: Optional[int], center_data: reconstruction = self.project_back(X_lower_dim, add_mean=center_data) return (torch.linalg.norm(reconstruction - X) ** 2).item() / N - def compute_projection_variance(self, X: Tensor, k: Optional[int], center_data: bool = False) -> float: + def compute_projection_variance(self, X: Tensor, k: int | None, center_data: bool = False) -> float: """ Compute the variance of the data matrix X after projection via PCA. @@ -234,7 +233,7 @@ def compute_projection_variance(self, X: Tensor, k: Optional[int], center_data: Args: X (Tensor): input data tensor whose rows represent data points. - k (Optional[int]): the number of principal components onto which + k (int | None): the number of principal components onto which projection is applied. Returns: float: variance after projection as defined above. diff --git a/fl4health/model_bases/perfcl_base.py b/fl4health/model_bases/perfcl_base.py index 6b4be4d48..df6860056 100644 --- a/fl4health/model_bases/perfcl_base.py +++ b/fl4health/model_bases/perfcl_base.py @@ -1,5 +1,3 @@ -from typing import Dict, List, Tuple - import torch import torch.nn as nn @@ -26,12 +24,12 @@ def __init__(self, local_module: nn.Module, global_module: nn.Module, model_head self, first_feature_extractor=local_module, second_feature_extractor=global_module, model_head=model_head ) - def layers_to_exchange(self) -> List[str]: + def layers_to_exchange(self) -> list[str]: return [ layer_name for layer_name in self.state_dict().keys() if layer_name.startswith("second_feature_extractor.") ] - def forward(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + def forward(self, input: torch.Tensor) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: # input is expected to be of shape (batch_size, *) local_output = self.first_feature_extractor.forward(input) global_output = self.second_feature_extractor.forward(input) diff --git a/fl4health/model_bases/sequential_split_models.py b/fl4health/model_bases/sequential_split_models.py index 175048e01..7a856c95b 100644 --- a/fl4health/model_bases/sequential_split_models.py +++ b/fl4health/model_bases/sequential_split_models.py @@ -1,5 +1,3 @@ -from typing import Dict, List, Tuple - import torch import torch.nn as nn @@ -39,7 +37,7 @@ def _flatten_features(self, features: torch.Tensor) -> torch.Tensor: """ return features.reshape(len(features), -1) - def sequential_forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def sequential_forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Run a forward pass using the sequentially split modules base_module -> head_module. @@ -47,13 +45,13 @@ def sequential_forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.T input (torch.Tensor): Input to the model forward pass. Expected to be of shape (batch_size, *) Returns: - Tuple[torch.Tensor, torch.Tensor]: Returns the predictions and features tensor from the sequential forward + tuple[torch.Tensor, torch.Tensor]: Returns the predictions and features tensor from the sequential forward """ features = self.base_module.forward(input) predictions = self.head_module.forward(features) return predictions, features - def forward(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + def forward(self, input: torch.Tensor) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: """ Run a forward pass using the sequentially split modules base_module -> head_module. Features from the base_module are stored either in their original shapes are flattened to be of shape (batch_size, -1) depending @@ -63,7 +61,7 @@ def forward(self, input: torch.Tensor) -> Tuple[Dict[str, torch.Tensor], Dict[st input (torch.Tensor): Input to the model forward pass. Expected to be of shape (batch_size, *) Returns: - Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: Dictionaries of predictions and features + tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: Dictionaries of predictions and features """ predictions, features = self.sequential_forward(input) predictions_dict = {"prediction": predictions} @@ -82,13 +80,13 @@ class SequentiallySplitExchangeBaseModel(SequentiallySplitModel, PartialLayerExc those belonging to the base_module. """ - def layers_to_exchange(self) -> List[str]: + def layers_to_exchange(self) -> list[str]: """ Names of the layers of the model to be exchanged with the server. For these models, we only exchange layers associated with the base_model. Returns: - List[str]: The names of the layers to be exchanged with the server. This is used by the FixedLayerExchanger + list[str]: The names of the layers to be exchanged with the server. This is used by the FixedLayerExchanger class """ return [layer_name for layer_name in self.state_dict().keys() if layer_name.startswith("base_module.")] diff --git a/fl4health/parameter_exchange/fedpm_exchanger.py b/fl4health/parameter_exchange/fedpm_exchanger.py index 577b143cd..6dc704bc6 100644 --- a/fl4health/parameter_exchange/fedpm_exchanger.py +++ b/fl4health/parameter_exchange/fedpm_exchanger.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn from flwr.common.typing import Config, NDArrays @@ -13,7 +11,7 @@ class FedPmExchanger(DynamicLayerExchanger): def __init__(self) -> None: super().__init__(select_scores_and_sample_masks) - def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Optional[Config] = None) -> None: + def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Config | None = None) -> None: current_state = model.state_dict() layer_params, layer_names = self.unpack_parameters(parameters) for layer_name, layer_param in zip(layer_names, layer_params): diff --git a/fl4health/parameter_exchange/full_exchanger.py b/fl4health/parameter_exchange/full_exchanger.py index 0dbfb587c..16799028b 100644 --- a/fl4health/parameter_exchange/full_exchanger.py +++ b/fl4health/parameter_exchange/full_exchanger.py @@ -1,5 +1,4 @@ from collections import OrderedDict -from typing import Optional import torch import torch.nn as nn @@ -10,13 +9,13 @@ class FullParameterExchanger(ParameterExchanger): def push_parameters( - self, model: nn.Module, initial_model: Optional[nn.Module] = None, config: Optional[Config] = None + self, model: nn.Module, initial_model: nn.Module | None = None, config: Config | None = None ) -> NDArrays: # Sending all of parameters ordered by state_dict keys # NOTE: Order matters, because it is relied upon by pull_parameters below return [val.cpu().numpy() for _, val in model.state_dict().items()] - def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Optional[Config] = None) -> None: + def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Config | None = None) -> None: # Assumes all model parameters are contained in parameters # The state_dict is reconstituted because parameters is simply a list of bytes params_dict = zip(model.state_dict().keys(), parameters) diff --git a/fl4health/parameter_exchange/layer_exchanger.py b/fl4health/parameter_exchange/layer_exchanger.py index 46c3c9598..f7a344632 100644 --- a/fl4health/parameter_exchange/layer_exchanger.py +++ b/fl4health/parameter_exchange/layer_exchanger.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Set, Tuple, Type, TypeVar +from collections.abc import Set +from typing import TypeVar import torch import torch.nn as nn @@ -13,7 +14,7 @@ class FixedLayerExchanger(ParameterExchanger): - def __init__(self, layers_to_transfer: List[str]) -> None: + def __init__(self, layers_to_transfer: list[str]) -> None: self.layers_to_transfer = layers_to_transfer def apply_layer_filter(self, model: nn.Module) -> NDArrays: @@ -22,11 +23,11 @@ def apply_layer_filter(self, model: nn.Module) -> NDArrays: return [model_state_dict[layer_to_transfer].cpu().numpy() for layer_to_transfer in self.layers_to_transfer] def push_parameters( - self, model: nn.Module, initial_model: Optional[nn.Module] = None, config: Optional[Config] = None + self, model: nn.Module, initial_model: nn.Module | None = None, config: Config | None = None ) -> NDArrays: return self.apply_layer_filter(model) - def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Optional[Config] = None) -> None: + def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Config | None = None) -> None: current_state = model.state_dict() # update the correct layers to new parameters for layer_name, layer_parameters in zip(self.layers_to_transfer, parameters): @@ -40,7 +41,7 @@ class LayerExchangerWithExclusions(ParameterExchanger): is provided with the model in order to extract the proper layers to be exchanged based on the exclusion criteria """ - def __init__(self, model: nn.Module, module_exclusions: Set[Type[TorchModule]]) -> None: + def __init__(self, model: nn.Module, module_exclusions: Set[type[TorchModule]]) -> None: # module_exclusion is a set of nn.Module types that should NOT be exchanged with the server. # {nn.BatchNorm1d} self.module_exclusions = module_exclusions @@ -55,9 +56,9 @@ def __init__(self, model: nn.Module, module_exclusions: Set[Type[TorchModule]]) } # Needs to be an ordered collection to facilitate exchange consistency between server and client # NOTE: Layers here refers to a collection of parameters in the state dictionary - self.layers_to_transfer: List[str] = self.get_layers_to_transfer(model) + self.layers_to_transfer: list[str] = self.get_layers_to_transfer(model) - def should_module_be_excluded(self, module: Type[TorchModule]) -> bool: + def should_module_be_excluded(self, module: type[TorchModule]) -> bool: return type(module) in self.module_exclusions def should_layer_be_excluded(self, layer_name: str) -> bool: @@ -68,7 +69,7 @@ def should_layer_be_excluded(self, layer_name: str) -> bool: # We filter out any parameters prefixed with the name of an excluded module, as stored in modules_to_filter return any([layer_name.startswith(module_to_filter) for module_to_filter in self.modules_to_filter]) - def get_layers_to_transfer(self, model: nn.Module) -> List[str]: + def get_layers_to_transfer(self, model: nn.Module) -> list[str]: # We store the state dictionary keys that do not correspond to excluded modules as held in modules_to_filter return [name for name in model.state_dict().keys() if not self.should_layer_be_excluded(name)] @@ -80,11 +81,11 @@ def apply_layer_filter(self, model: nn.Module) -> NDArrays: return [model_state_dict[layer_to_transfer].cpu().numpy() for layer_to_transfer in self.layers_to_transfer] def push_parameters( - self, model: nn.Module, initial_model: Optional[nn.Module] = None, config: Optional[Config] = None + self, model: nn.Module, initial_model: nn.Module | None = None, config: Config | None = None ) -> NDArrays: return self.apply_layer_filter(model) - def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Optional[Config] = None) -> None: + def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Config | None = None) -> None: current_state = model.state_dict() # update the correct layers to new parameters. Assumes order of parameters is the same as in push_parameters for layer_name, layer_parameters in zip(self.layers_to_transfer, parameters): @@ -92,7 +93,7 @@ def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Option model.load_state_dict(current_state, strict=True) -class DynamicLayerExchanger(PartialParameterExchanger[List[str]]): +class DynamicLayerExchanger(PartialParameterExchanger[list[str]]): def __init__( self, layer_selection_function: LayerSelectionFunction, @@ -113,17 +114,17 @@ def __init__( self.parameter_packer = ParameterPackerWithLayerNames() def select_parameters( - self, model: nn.Module, initial_model: Optional[nn.Module] = None - ) -> Tuple[NDArrays, List[str]]: + self, model: nn.Module, initial_model: nn.Module | None = None + ) -> tuple[NDArrays, list[str]]: return self.layer_selection_function(model, initial_model) def push_parameters( - self, model: nn.Module, initial_model: Optional[nn.Module] = None, config: Optional[Config] = None + self, model: nn.Module, initial_model: nn.Module | None = None, config: Config | None = None ) -> NDArrays: layers_to_transfer, layer_names = self.select_parameters(model, initial_model) return self.pack_parameters(layers_to_transfer, layer_names) - def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Optional[Config] = None) -> None: + def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Config | None = None) -> None: current_state = model.state_dict() # update the correct layers to new parameters layer_params, layer_names = self.unpack_parameters(parameters) diff --git a/fl4health/parameter_exchange/packing_exchanger.py b/fl4health/parameter_exchange/packing_exchanger.py index a0d363de0..9a4c5bad7 100644 --- a/fl4health/parameter_exchange/packing_exchanger.py +++ b/fl4health/parameter_exchange/packing_exchanger.py @@ -1,4 +1,4 @@ -from typing import Generic, Tuple, TypeVar +from typing import Generic, TypeVar from flwr.common.typing import NDArrays @@ -16,5 +16,5 @@ def __init__(self, parameter_packer: ParameterPacker[T]) -> None: def pack_parameters(self, model_weights: NDArrays, additional_parameters: T) -> NDArrays: return self.parameter_packer.pack_parameters(model_weights, additional_parameters) - def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, T]: + def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, T]: return self.parameter_packer.unpack_parameters(packed_parameters) diff --git a/fl4health/parameter_exchange/parameter_exchanger_base.py b/fl4health/parameter_exchange/parameter_exchanger_base.py index 8a16db996..fa4c22b86 100644 --- a/fl4health/parameter_exchange/parameter_exchanger_base.py +++ b/fl4health/parameter_exchange/parameter_exchanger_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, TypeVar +from typing import TypeVar import torch.nn as nn from flwr.common.typing import Config, NDArrays @@ -8,12 +8,12 @@ class ParameterExchanger(ABC): @abstractmethod def push_parameters( - self, model: nn.Module, initial_model: Optional[nn.Module] = None, config: Optional[Config] = None + self, model: nn.Module, initial_model: nn.Module | None = None, config: Config | None = None ) -> NDArrays: raise NotImplementedError @abstractmethod - def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Optional[Config] = None) -> None: + def pull_parameters(self, parameters: NDArrays, model: nn.Module, config: Config | None = None) -> None: raise NotImplementedError diff --git a/fl4health/parameter_exchange/parameter_packer.py b/fl4health/parameter_exchange/parameter_packer.py index 14692920c..65b6906a3 100644 --- a/fl4health/parameter_exchange/parameter_packer.py +++ b/fl4health/parameter_exchange/parameter_packer.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import Generic, Tuple, TypeVar +from typing import Generic, TypeVar import numpy as np import torch -from flwr.common.typing import List, NDArray, NDArrays +from flwr.common.typing import NDArray, NDArrays from torch import Tensor T = TypeVar("T") @@ -15,7 +15,7 @@ def pack_parameters(self, model_weights: NDArrays, additional_parameters: T) -> raise NotImplementedError @abstractmethod - def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, T]: + def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, T]: raise NotImplementedError @@ -29,7 +29,7 @@ def __init__(self, size_of_model_params: int) -> None: def pack_parameters(self, model_weights: NDArrays, additional_parameters: NDArrays) -> NDArrays: return model_weights + additional_parameters - def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, NDArrays]: + def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, NDArrays]: return packed_parameters[: self.size_of_model_params], packed_parameters[self.size_of_model_params :] @@ -37,7 +37,7 @@ class ParameterPackerWithClippingBit(ParameterPacker[float]): def pack_parameters(self, model_weights: NDArrays, additional_parameters: float) -> NDArrays: return model_weights + [np.array(additional_parameters)] - def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, float]: + def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, float]: # The last entry in the parameters list is assumed to be a clipping bound (even if we're evaluating) split_size = len(packed_parameters) - 1 model_parameters = packed_parameters[:split_size] @@ -49,7 +49,7 @@ class ParameterPackerAdaptiveConstraint(ParameterPacker[float]): def pack_parameters(self, model_weights: NDArrays, extra_adaptive_variable: float) -> NDArrays: return model_weights + [np.array(extra_adaptive_variable)] - def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, float]: + def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, float]: # The last entry is an extra packed adaptive constraint variable (information to allow for adaptation) split_size = len(packed_parameters) - 1 model_parameters = packed_parameters[:split_size] @@ -60,11 +60,11 @@ def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, floa return model_parameters, extra_adaptive_variable -class ParameterPackerWithLayerNames(ParameterPacker[List[str]]): - def pack_parameters(self, model_weights: NDArrays, weights_names: List[str]) -> NDArrays: +class ParameterPackerWithLayerNames(ParameterPacker[list[str]]): + def pack_parameters(self, model_weights: NDArrays, weights_names: list[str]) -> NDArrays: return model_weights + [np.array(weights_names)] - def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, List[str]]: + def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, list[str]]: """ Assumption: packed_parameters is a list containing model parameters followed by an NDArray that contains the corresponding names of those parameters. @@ -75,7 +75,7 @@ def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, List return model_parameters, param_names -class SparseCooParameterPacker(ParameterPacker[Tuple[NDArrays, NDArrays, List[str]]]): +class SparseCooParameterPacker(ParameterPacker[tuple[NDArrays, NDArrays, list[str]]]): """ This parameter packer is responsible for selecting an arbitrary set of parameters and then representing them in the sparse COO tensor format, which requires knowing @@ -90,12 +90,12 @@ class SparseCooParameterPacker(ParameterPacker[Tuple[NDArrays, NDArrays, List[st """ def pack_parameters( - self, model_parameters: NDArrays, additional_parameters: Tuple[NDArrays, NDArrays, List[str]] + self, model_parameters: NDArrays, additional_parameters: tuple[NDArrays, NDArrays, list[str]] ) -> NDArrays: parameter_indices, tensor_shapes, tensor_names = additional_parameters return model_parameters + parameter_indices + tensor_shapes + [np.array(tensor_names)] - def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, Tuple[NDArrays, NDArrays, List[str]]]: + def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, tuple[NDArrays, NDArrays, list[str]]]: # The names of the tensors is wrapped in a list, which is then transformed into an NDArrays of length 1 # before packing. assert len(packed_parameters) % 3 == 1 @@ -107,7 +107,7 @@ def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, Tupl return model_parameters, (parameter_indices, tensor_shapes, tensor_names) @staticmethod - def extract_coo_info_from_dense(x: Tensor) -> Tuple[NDArray, NDArray, NDArray]: + def extract_coo_info_from_dense(x: Tensor) -> tuple[NDArray, NDArray, NDArray]: """ Take a dense tensor x and extract the information required (namely, its nonzero values, their indices within the tensor, and the shape of x) @@ -119,7 +119,7 @@ def extract_coo_info_from_dense(x: Tensor) -> Tuple[NDArray, NDArray, NDArray]: x (Tensor): Input dense tensor. Returns: - Tuple[NDArray, NDArray, NDArray]: The nonzero values of x, + tuple[NDArray, NDArray, NDArray]: The nonzero values of x, the indices of those values within x, and the shape of x. """ selected_parameters = x[torch.nonzero(x, as_tuple=True)].cpu().numpy() diff --git a/fl4health/parameter_exchange/parameter_selection_criteria.py b/fl4health/parameter_exchange/parameter_selection_criteria.py index 4a3ff4551..b8f4c551c 100644 --- a/fl4health/parameter_exchange/parameter_selection_criteria.py +++ b/fl4health/parameter_exchange/parameter_selection_criteria.py @@ -1,6 +1,5 @@ import math from functools import partial -from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -69,7 +68,7 @@ def select_layers_by_threshold( select_drift_more: bool, model: nn.Module, initial_model: nn.Module, -) -> Tuple[NDArrays, List[str]]: +) -> tuple[NDArrays, list[str]]: """ Return those layers of model that deviate (in l2 norm) away from corresponding layers of self.initial_model by at least (or at most) self.threshold. @@ -98,7 +97,7 @@ def select_layers_by_percentage( select_drift_more: bool, model: nn.Module, initial_model: nn.Module, -) -> Tuple[NDArrays, List[str]]: +) -> tuple[NDArrays, list[str]]: names_to_norm_drift = {} initial_model_states = initial_model.state_dict() model_states = model.state_dict() @@ -119,21 +118,21 @@ def select_layers_by_percentage( # Score generating functions used for selecting arbitrary sets of weights. # The ones implemented here are those that demonstrated good performance in the super-mask paper. # Link to this paper: https://arxiv.org/abs/1905.01067 -def largest_final_magnitude_scores(model: nn.Module, initial_model: Optional[nn.Module]) -> Dict[str, Tensor]: +def largest_final_magnitude_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: names_to_scores = {} for tensor_name, tensor_values in model.state_dict().items(): names_to_scores[tensor_name] = torch.abs(tensor_values) return names_to_scores -def smallest_final_magnitude_scores(model: nn.Module, initial_model: Optional[nn.Module]) -> Dict[str, Tensor]: +def smallest_final_magnitude_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: names_to_scores = {} for tensor_name, tensor_values in model.state_dict().items(): names_to_scores[tensor_name] = (-1) * torch.abs(tensor_values) return names_to_scores -def largest_magnitude_change_scores(model: nn.Module, initial_model: Optional[nn.Module]) -> Dict[str, Tensor]: +def largest_magnitude_change_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: assert initial_model is not None names_to_scores = {} current_model_states = model.state_dict() @@ -144,7 +143,7 @@ def largest_magnitude_change_scores(model: nn.Module, initial_model: Optional[nn return names_to_scores -def smallest_magnitude_change_scores(model: nn.Module, initial_model: Optional[nn.Module]) -> Dict[str, Tensor]: +def smallest_magnitude_change_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: assert initial_model is not None names_to_scores = {} current_model_states = model.state_dict() @@ -155,7 +154,7 @@ def smallest_magnitude_change_scores(model: nn.Module, initial_model: Optional[n return names_to_scores -def largest_increase_in_magnitude_scores(model: nn.Module, initial_model: Optional[nn.Module]) -> Dict[str, Tensor]: +def largest_increase_in_magnitude_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: assert initial_model is not None names_to_scores = {} current_model_states = model.state_dict() @@ -166,7 +165,7 @@ def largest_increase_in_magnitude_scores(model: nn.Module, initial_model: Option return names_to_scores -def smallest_increase_in_magnitude_scores(model: nn.Module, initial_model: Optional[nn.Module]) -> Dict[str, Tensor]: +def smallest_increase_in_magnitude_scores(model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: assert initial_model is not None names_to_scores = {} current_model_states = model.state_dict() @@ -186,8 +185,8 @@ def _sample_masks(score_tensor: Tensor) -> NDArray: def _process_masked_module( - module: nn.Module, model_state_dict: Dict[str, Tensor], module_name: Optional[str] = None -) -> Tuple[NDArrays, List[str]]: + module: nn.Module, model_state_dict: dict[str, Tensor], module_name: str | None = None +) -> tuple[NDArrays, list[str]]: """ Perform Bernoulli sampling using the weight and bias scores of a masked module. @@ -196,8 +195,8 @@ def _process_masked_module( "module" can either be a submodule of the model trained in FedPM, or it can a standalone module itself. In the latter case, the argument "model_state_dict" should be the same as "module.state_dict()". In either case, it is assumed that module is a masked module. - model_state_dict (Dict[str, Tensor]): the state dictionary of the model trained in FedPM. - module_name (Optional[str]): the name of module if module is a submodule of the model trained in FedPM. + model_state_dict (dict[str, Tensor]): the state dictionary of the model trained in FedPM. + module_name (str | None): the name of module if module is a submodule of the model trained in FedPM. This is used to access the weight and bias score tensors in model_state_dict. Defaults to None. """ masks_to_exchange = [] @@ -222,7 +221,7 @@ def _process_masked_module( return masks_to_exchange, score_tensor_names -def select_scores_and_sample_masks(model: nn.Module, initial_model: Optional[nn.Module]) -> Tuple[NDArrays, List[str]]: +def select_scores_and_sample_masks(model: nn.Module, initial_model: nn.Module | None) -> tuple[NDArrays, list[str]]: """ Selection function that first selects the "weight_scores" and "bias_scores" parameters for the masked layers, and then samples binary masks based on those scores to send to the server. diff --git a/fl4health/parameter_exchange/partial_parameter_exchanger.py b/fl4health/parameter_exchange/partial_parameter_exchanger.py index 94a598420..82792c8be 100644 --- a/fl4health/parameter_exchange/partial_parameter_exchanger.py +++ b/fl4health/parameter_exchange/partial_parameter_exchanger.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Generic, Optional, Tuple, TypeVar +from typing import Generic, TypeVar import torch.nn as nn from flwr.common.typing import NDArrays @@ -18,13 +18,13 @@ def __init__(self, parameter_packer: ParameterPacker[T]) -> None: def pack_parameters(self, model_weights: NDArrays, additional_parameters: T) -> NDArrays: return self.parameter_packer.pack_parameters(model_weights, additional_parameters) - def unpack_parameters(self, packed_parameters: NDArrays) -> Tuple[NDArrays, T]: + def unpack_parameters(self, packed_parameters: NDArrays) -> tuple[NDArrays, T]: return self.parameter_packer.unpack_parameters(packed_parameters) @abstractmethod def select_parameters( self, model: nn.Module, - initial_model: Optional[nn.Module] = None, - ) -> Tuple[NDArrays, T]: + initial_model: nn.Module | None = None, + ) -> tuple[NDArrays, T]: raise NotImplementedError diff --git a/fl4health/parameter_exchange/sparse_coo_parameter_exchanger.py b/fl4health/parameter_exchange/sparse_coo_parameter_exchanger.py index a3b27afb1..0506c23c3 100644 --- a/fl4health/parameter_exchange/sparse_coo_parameter_exchanger.py +++ b/fl4health/parameter_exchange/sparse_coo_parameter_exchanger.py @@ -1,6 +1,6 @@ import math +from collections.abc import Callable from logging import INFO, WARNING -from typing import Callable, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -12,10 +12,10 @@ from fl4health.parameter_exchange.parameter_packer import SparseCooParameterPacker from fl4health.parameter_exchange.partial_parameter_exchanger import PartialParameterExchanger -ScoreGenFunction = Callable[[nn.Module, Optional[nn.Module]], Dict[str, Tensor]] +ScoreGenFunction = Callable[[nn.Module, nn.Module | None], dict[str, Tensor]] -class SparseCooParameterExchanger(PartialParameterExchanger[Tuple[NDArrays, NDArrays, List[str]]]): +class SparseCooParameterExchanger(PartialParameterExchanger[tuple[NDArrays, NDArrays, list[str]]]): def __init__(self, sparsity_level: float, score_gen_function: ScoreGenFunction) -> None: """ Parameter exchanger for sparse tensors. @@ -42,7 +42,7 @@ def __init__(self, sparsity_level: float, score_gen_function: ScoreGenFunction) self.parameter_packer: SparseCooParameterPacker = SparseCooParameterPacker() self.score_gen_function = score_gen_function - def generate_parameter_scores(self, model: nn.Module, initial_model: Optional[nn.Module]) -> Dict[str, Tensor]: + def generate_parameter_scores(self, model: nn.Module, initial_model: nn.Module | None) -> dict[str, Tensor]: """Calling the score generating function to produce parameter scores.""" return self.score_gen_function(model, initial_model) @@ -56,8 +56,8 @@ def _check_unique_score(self, param_scores: Tensor) -> None: ) def select_parameters( - self, model: nn.Module, initial_model: Optional[nn.Module] = None - ) -> Tuple[NDArrays, Tuple[NDArrays, NDArrays, List[str]]]: + self, model: nn.Module, initial_model: nn.Module | None = None + ) -> tuple[NDArrays, tuple[NDArrays, NDArrays, list[str]]]: """ Select model parameters according to the sparsity level and pack them into the sparse COO format to be exchanged. @@ -82,7 +82,7 @@ def select_parameters( initial_model (nn.Module): Initial model. Returns: - Tuple[NDArrays, Tuple[NDArrays, NDArrays, List[str]]]: the selected parameters + tuple[NDArrays, tuple[NDArrays, NDArrays, list[str]]]: the selected parameters and other information, as detailed above. """ all_parameter_scores = self.generate_parameter_scores(model, initial_model) @@ -126,7 +126,7 @@ def select_parameters( return (selected_parameters_all_tensors, (selected_indices_all_tensors, tensor_shapes, tensor_names)) def push_parameters( - self, model: nn.Module, initial_model: Optional[nn.Module] = None, config: Optional[Config] = None + self, model: nn.Module, initial_model: nn.Module | None = None, config: Config | None = None ) -> NDArrays: selected_parameters, additional_parameters = self.select_parameters(model, initial_model) return self.pack_parameters( @@ -134,7 +134,7 @@ def push_parameters( additional_parameters=additional_parameters, ) - def pull_parameters(self, parameters: NDArrays, model: Module, config: Optional[Config] = None) -> None: + def pull_parameters(self, parameters: NDArrays, model: Module, config: Config | None = None) -> None: selected_parameters, additional_info = self.parameter_packer.unpack_parameters(parameters) indices, shapes, names = additional_info current_state = model.state_dict() diff --git a/fl4health/preprocessing/autoencoders/loss.py b/fl4health/preprocessing/autoencoders/loss.py index f054fe406..e74c0a929 100644 --- a/fl4health/preprocessing/autoencoders/loss.py +++ b/fl4health/preprocessing/autoencoders/loss.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch from torch.nn.modules.loss import _Loss @@ -39,14 +37,14 @@ def standard_normal_kl_divergence_loss(self, mu: torch.Tensor, logvar: torch.Ten kl_divergence_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return kl_divergence_loss - def unpack_model_output(self, preds: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def unpack_model_output(self, preds: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Unpacks the model output tensor. Args: preds (torch.Tensor): Model predictions. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Unpacked output containing predictions, mu, and logvar. + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Unpacked output containing predictions, mu, and logvar. """ # This methods assumes "preds" are batch first, and preds are 2D dimensional (already flattened). assert ( diff --git a/fl4health/preprocessing/warmed_up_module.py b/fl4health/preprocessing/warmed_up_module.py index 7c366aa1d..4b33254da 100644 --- a/fl4health/preprocessing/warmed_up_module.py +++ b/fl4health/preprocessing/warmed_up_module.py @@ -2,7 +2,6 @@ import os from logging import INFO, WARNING from pathlib import Path -from typing import Optional import torch from flwr.common.logger import log @@ -13,18 +12,18 @@ class WarmedUpModule: def __init__( self, - pretrained_model: Optional[torch.nn.Module] = None, - pretrained_model_path: Optional[Path] = None, - weights_mapping_path: Optional[Path] = None, + pretrained_model: torch.nn.Module | None = None, + pretrained_model_path: Path | None = None, + weights_mapping_path: Path | None = None, ) -> None: """Initialize the WarmedUpModule with the pretrained model states and weights mapping dict. Args: - pretrained_model (Optional[torch.nn.Module]): Pretrained model. + pretrained_model (torch.nn.Module | None): Pretrained model. This is mutually exclusive with pretrained_model_path. - pretrained_model_path (Optional[Path]): Path of the pretrained model. + pretrained_model_path (Path | None): Path of the pretrained model. This is mutually exclusive with pretrained_model. - weights_mapping_dir (Optional[str], optional): Path of to json file of the weights mapping dict. + weights_mapping_dir (str | None, optional): Path of to json file of the weights mapping dict. If models are not exactly the same, a weights mapping dict is needed to map the weights of the pretrained model to the target model. """ @@ -54,7 +53,7 @@ def __init__( log(INFO, "Weights mapping dict is not provided. Matching states directly, based on target model's keys.") self.weights_mapping_dict = None - def get_matching_component(self, key: str) -> Optional[str]: + def get_matching_component(self, key: str) -> str | None: """Get the matching component of the key from the weights mapping dictionary. Since the provided mapping can contain partial names of the keys, this function is used to split the key of the target model and match it with the partial key in the mapping, returning the complete name of the key in the pretrained model. @@ -67,7 +66,7 @@ def get_matching_component(self, key: str) -> Optional[str]: key (str): Key to be matched in pretrained model. Returns: - Optional[str]: If no weights mapping dict is provided, returns the key. Otherwise, if the key is in the + str | None: If no weights mapping dict is provided, returns the key. Otherwise, if the key is in the weights mapping dict, returns the matching component of the key. Otherwise, returns None. """ diff --git a/fl4health/privacy/fl_accountants.py b/fl4health/privacy/fl_accountants.py index 16994322d..0117e8fb3 100644 --- a/fl4health/privacy/fl_accountants.py +++ b/fl4health/privacy/fl_accountants.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from math import ceil -from typing import List, Optional, Union from fl4health.privacy.moments_accountant import ( FixedSamplingWithoutReplacement, @@ -22,9 +21,9 @@ def __init__( client_sampling_rate: float, noise_multiplier: float, epochs_per_round: int, - client_batch_sizes: List[int], - client_dataset_sizes: List[int], - moment_orders: Optional[List[float]] = None, + client_batch_sizes: list[int], + client_dataset_sizes: list[int], + moment_orders: list[float] | None = None, ) -> None: """ client_sampling_rate: probability that each client will be included in a round @@ -47,10 +46,10 @@ def __init__( self.accountant = MomentsAccountant(moment_orders) - def _calculate_batch_ratios(self, client_batch_sizes: List[int], client_dataset_sizes: List[int]) -> List[float]: + def _calculate_batch_ratios(self, client_batch_sizes: list[int], client_dataset_sizes: list[int]) -> list[float]: return [batch / dataset for batch, dataset in zip(client_batch_sizes, client_dataset_sizes)] - def _calculate_num_batches(self, client_batch_sizes: List[int], client_dataset_sizes: List[int]) -> List[int]: + def _calculate_num_batches(self, client_batch_sizes: list[int], client_dataset_sizes: list[int]) -> list[int]: return [ceil(dataset / batch) for batch, dataset in zip(client_batch_sizes, client_dataset_sizes)] def get_epsilon(self, server_updates: int, delta: float) -> float: @@ -75,21 +74,19 @@ def get_delta(self, server_updates: int, epsilon: float) -> float: class ClientLevelAccountant(ABC): - def __init__( - self, noise_multiplier: Union[float, List[float]], moment_orders: Optional[List[float]] = None - ) -> None: + def __init__(self, noise_multiplier: float | list[float], moment_orders: list[float] | None = None) -> None: self.noise_multiplier = noise_multiplier self.accountant = MomentsAccountant(moment_orders) @abstractmethod - def get_epsilon(self, server_updates: Union[int, List[int]], delta: float) -> float: + def get_epsilon(self, server_updates: int | list[int], delta: float) -> float: pass @abstractmethod - def get_delta(self, server_updates: Union[int, List[int]], epsilon: float) -> float: + def get_delta(self, server_updates: int | list[int], epsilon: float) -> float: pass - def _validate_server_updates(self, server_updates: Union[int, List[int]]) -> None: + def _validate_server_updates(self, server_updates: int | list[int]) -> None: if isinstance(server_updates, list): assert isinstance(self.noise_multiplier, list) assert len(server_updates) == len(self.noise_multiplier) @@ -104,9 +101,9 @@ class FlClientLevelAccountantPoissonSampling(ClientLevelAccountant): def __init__( self, - client_sampling_rate: Union[float, List[float]], - noise_multiplier: Union[float, List[float]], - moment_orders: Optional[List[float]] = None, + client_sampling_rate: float | list[float], + noise_multiplier: float | list[float], + moment_orders: list[float] | None = None, ) -> None: """ client_sampling_rate: probability that each client will be included in a round @@ -115,19 +112,19 @@ def __init__( parameters """ super().__init__(noise_multiplier, moment_orders) - self.sampling_strategy: Union[SamplingStrategy, List[PoissonSampling]] + self.sampling_strategy: SamplingStrategy | list[PoissonSampling] if isinstance(client_sampling_rate, list): self.sampling_strategy = [PoissonSampling(q) for q in client_sampling_rate] else: self.sampling_strategy = PoissonSampling(client_sampling_rate) - def get_epsilon(self, server_updates: Union[int, List[int]], delta: float) -> float: + def get_epsilon(self, server_updates: int | list[int], delta: float) -> float: """server_updates: number of central server updates performed""" self._validate_server_updates(server_updates) return self.accountant.get_epsilon(self.sampling_strategy, self.noise_multiplier, server_updates, delta) - def get_delta(self, server_updates: Union[int, List[int]], epsilon: float) -> float: + def get_delta(self, server_updates: int | list[int], epsilon: float) -> float: """server_updates: number of central server updates performed""" self._validate_server_updates(server_updates) return self.accountant.get_delta(self.sampling_strategy, self.noise_multiplier, server_updates, epsilon) @@ -142,9 +139,9 @@ class FlClientLevelAccountantFixedSamplingNoReplacement(ClientLevelAccountant): def __init__( self, n_total_clients: int, - n_clients_sampled: Union[int, List[int]], - noise_multiplier: Union[float, List[float]], - moment_orders: Optional[List[float]] = None, + n_clients_sampled: int | list[int], + noise_multiplier: float | list[float], + moment_orders: list[float] | None = None, ) -> None: """ n_total_clients: total number of clients to be sampled from @@ -154,7 +151,7 @@ def __init__( parameters """ super().__init__(noise_multiplier, moment_orders) - self.sampling_strategy: Union[SamplingStrategy, List[FixedSamplingWithoutReplacement]] + self.sampling_strategy: SamplingStrategy | list[FixedSamplingWithoutReplacement] if isinstance(n_clients_sampled, list): self.sampling_strategy = [ @@ -163,12 +160,12 @@ def __init__( else: self.sampling_strategy = FixedSamplingWithoutReplacement(n_total_clients, n_clients_sampled) - def get_epsilon(self, server_updates: Union[int, List[int]], delta: float) -> float: + def get_epsilon(self, server_updates: int | list[int], delta: float) -> float: """server_updates: number of central server updates performed""" self._validate_server_updates(server_updates) return self.accountant.get_epsilon(self.sampling_strategy, self.noise_multiplier, server_updates, delta) - def get_delta(self, server_updates: Union[int, List[int]], epsilon: float) -> float: + def get_delta(self, server_updates: int | list[int], epsilon: float) -> float: """server_updates: number of central server updates performed""" self._validate_server_updates(server_updates) return self.accountant.get_delta(self.sampling_strategy, self.noise_multiplier, server_updates, epsilon) diff --git a/fl4health/privacy/moments_accountant.py b/fl4health/privacy/moments_accountant.py index 86a03f7e9..393afcff2 100644 --- a/fl4health/privacy/moments_accountant.py +++ b/fl4health/privacy/moments_accountant.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Optional, Sequence, Union +from collections.abc import Sequence from dp_accounting import ( DpEvent, @@ -41,7 +41,7 @@ def get_dp_event(self, noise_event: DpEvent) -> DpEvent: class MomentsAccountant: - def __init__(self, moment_orders: Optional[List[float]] = None) -> None: + def __init__(self, moment_orders: list[float] | None = None) -> None: """Moment orders are equivalent to lambda from Deep Learning with Differential Privacy (Abadi et. al. 2016). They form the set of moments to estimate the infimum of Theorem 2 part 2. The default values were taken from the tensorflow federated DP tutorial notebook: @@ -57,7 +57,7 @@ def __init__(self, moment_orders: Optional[List[float]] = None) -> None: self.moment_orders = moment_orders else: low_orders = [1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 3.0, 3.5, 4.0, 4.5] - medium_orders: List[float] = list(range(5, 64)) + medium_orders: list[float] = list(range(5, 64)) high_orders = [128.0, 256.0, 512.0] self.moment_orders = low_orders + medium_orders + high_orders @@ -74,8 +74,8 @@ def _construct_dp_events( def _construct_dp_events_trajectory( self, sampling_strategies: Sequence[SamplingStrategy], - noise_multipliers: List[float], - updates_list: List[int], + noise_multipliers: list[float], + updates_list: list[int], ) -> DpEvent: # Given a list of parameters this assumes that the DP operations were performed in sequence event_builder = DpEventBuilder() @@ -85,9 +85,9 @@ def _construct_dp_events_trajectory( def _construct_rdp_accountant( self, - sampling_strategies: Union[SamplingStrategy, Sequence[SamplingStrategy]], - noise_multipliers: Union[float, List[float]], - updates: Union[int, List[int]], + sampling_strategies: SamplingStrategy | Sequence[SamplingStrategy], + noise_multipliers: float | list[float], + updates: int | list[int], ) -> RdpAccountant: if isinstance(sampling_strategies, SamplingStrategy): sampling_strategies = [sampling_strategies] @@ -106,9 +106,9 @@ def _construct_rdp_accountant( def _validate_accountant_input( self, - sampling_strategies: Union[SamplingStrategy, Sequence[SamplingStrategy]], - noise_multiplier: Union[float, List[float]], - updates: Union[int, List[int]], + sampling_strategies: SamplingStrategy | Sequence[SamplingStrategy], + noise_multiplier: float | list[float], + updates: int | list[int], ) -> None: all_lists = all( [ @@ -128,9 +128,9 @@ def _validate_accountant_input( def get_epsilon( self, - sampling_strategies: Union[SamplingStrategy, Sequence[SamplingStrategy]], - noise_multiplier: Union[float, List[float]], - updates: Union[int, List[int]], + sampling_strategies: SamplingStrategy | Sequence[SamplingStrategy], + noise_multiplier: float | list[float], + updates: int | list[int], delta: float, ) -> float: """ @@ -164,9 +164,9 @@ def get_epsilon( def get_delta( self, - sampling_strategies: Union[SamplingStrategy, Sequence[SamplingStrategy]], - noise_multiplier: Union[float, List[float]], - updates: Union[int, List[int]], + sampling_strategies: SamplingStrategy | Sequence[SamplingStrategy], + noise_multiplier: float | list[float], + updates: int | list[int], epsilon: float, ) -> float: """ diff --git a/fl4health/reporting/base_reporter.py b/fl4health/reporting/base_reporter.py index 040021130..016e67738 100644 --- a/fl4health/reporting/base_reporter.py +++ b/fl4health/reporting/base_reporter.py @@ -1,6 +1,6 @@ """Base Class for Reporters. -Super simple for now but keeping it in a seperate file in case we add more base methods. +Super simple for now but keeping it in a separate file in case we add more base methods. """ from typing import Any diff --git a/fl4health/servers/adaptive_constraint_servers/ditto_server.py b/fl4health/servers/adaptive_constraint_servers/ditto_server.py index 27a03bf20..759f86c5d 100644 --- a/fl4health/servers/adaptive_constraint_servers/ditto_server.py +++ b/fl4health/servers/adaptive_constraint_servers/ditto_server.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Sequence +from collections.abc import Callable, Sequence from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager @@ -17,7 +17,7 @@ def __init__( strategy: FedAvgWithAdaptiveConstraint, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: AdaptiveConstraintServerCheckpointAndStateModule | None = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, ) -> None: @@ -44,7 +44,7 @@ def __init__( module is provided, no checkpointing or state preservation will happen. Defaults to None. NOTE: For Ditto, the model shared with the server is the GLOBAL MODEL, which isn't the target of FL training for this algorithm. However, one may still want to save this model for other purposes. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. diff --git a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py index 9acfc4b55..7001b80c2 100644 --- a/fl4health/servers/adaptive_constraint_servers/fedprox_server.py +++ b/fl4health/servers/adaptive_constraint_servers/fedprox_server.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Sequence +from collections.abc import Callable, Sequence from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager @@ -17,7 +17,7 @@ def __init__( strategy: FedAvgWithAdaptiveConstraint, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: AdaptiveConstraintServerCheckpointAndStateModule | None = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, ) -> None: @@ -43,7 +43,7 @@ def __init__( saving model artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. diff --git a/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py b/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py index 2f4aba07b..ed9d56d1f 100644 --- a/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py +++ b/fl4health/servers/adaptive_constraint_servers/mrmtl_server.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Sequence +from collections.abc import Callable, Sequence from flwr.common.typing import Config, Scalar from flwr.server.client_manager import ClientManager @@ -17,7 +17,7 @@ def __init__( strategy: FedAvgWithAdaptiveConstraint, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: AdaptiveConstraintServerCheckpointAndStateModule | None = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, ) -> None: @@ -44,7 +44,7 @@ def __init__( module is provided, no checkpointing or state preservation will happen. Defaults to None. NOTE: For MR-MTL, the server model is an aggregation of the personal models, which isn't the target of FL training for this algorithm. However, one may still want to save this model for other purposes. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. diff --git a/fl4health/servers/base_server.py b/fl4health/servers/base_server.py index 53b929648..237a106c6 100644 --- a/fl4health/servers/base_server.py +++ b/fl4health/servers/base_server.py @@ -1,6 +1,6 @@ import datetime +from collections.abc import Callable, Sequence from logging import DEBUG, ERROR, INFO, WARNING -from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import torch.nn as nn from flwr.common import EvaluateRes, Parameters @@ -29,10 +29,10 @@ def __init__( self, client_manager: ClientManager, fl_config: Config, - strategy: Optional[Strategy] = None, + strategy: Strategy | None = None, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, ) -> None: @@ -46,7 +46,7 @@ def __init__( In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. - strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle. + strategy (Strategy | None, optional): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. If None the strategy is FedAvg as set by the flwr Server. Defaults to None. reporters (Sequence[BaseReporter] | None, optional): sequence of FL4Health reporters which the server @@ -56,7 +56,7 @@ def __init__( artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. @@ -89,7 +89,7 @@ def __init__( self.reports_manager.initialize(id=self.server_name) self._log_fl_config() - def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: + def update_before_fit(self, num_rounds: int, timeout: float | None) -> None: """ Hook method to allow the server to do some work before starting the fit process. In the base server, it is a no-op function, but it can be overridden in child classes for custom functionality. For example, the @@ -98,7 +98,7 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: Args: num_rounds (int): The number of server rounds of FL to be performed - timeout (Optional[float], optional): The server's timeout parameter. Useful if one is requesting + timeout (float | None, optional): The server's timeout parameter. Useful if one is requesting information from a client. Defaults to None, which indicates indefinite timeout. """ pass @@ -118,7 +118,7 @@ def report_centralized_eval(self, history: History, num_rounds: int) -> None: round_metrics.update({metric: vals[round][1]}) self.reports_manager.report({"eval_round_metrics_centralized": round_metrics}, round + 1) - def fit_with_per_round_checkpointing(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit_with_per_round_checkpointing(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """ Runs federated learning for a number of rounds. Heavily based on the fit method from the base server provided by flower (flwr.server.server.Server) except that it is resilient to preemptions. @@ -127,10 +127,10 @@ def fit_with_per_round_checkpointing(self, num_rounds: int, timeout: Optional[fl Args: num_rounds (int): The number of rounds to perform federated learning. - timeout (Optional[float]): The timeout for clients to return results in a given FL round. + timeout (float | None): The timeout for clients to return results in a given FL round. Returns: - Tuple[History, float]: The first element of the tuple is a history object containing the losses and + tuple[History, float]: The first element of the tuple is a history object containing the losses and metrics computed during training and validation. The second element of the tuple is the elapsed time in seconds. """ @@ -207,7 +207,7 @@ def fit_with_per_round_checkpointing(self, num_rounds: int, timeout: Optional[fl log(INFO, "FL finished in %s", str(elapsed_time)) return self.history, elapsed_time.total_seconds() - def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """ Run federated learning for a number of rounds. This function also allows the server to perform some operations prior to fitting starting. This is useful, for example, if you need to communicate with the clients to @@ -215,11 +215,11 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float Args: num_rounds (int): Number of server rounds to run. - timeout (Optional[float]): The amount of time in seconds that the server will wait for results from the + timeout (float | None): The amount of time in seconds that the server will wait for results from the clients selected to participate in federated training. Returns: - Tuple[History, float]: The first element of the tuple is a history object containing the full set of + tuple[History, float]: The first element of the tuple is a history object containing the full set of FL training results, including things like aggregated loss and metrics. Tuple also contains the elapsed time in seconds for the round. """ @@ -255,8 +255,8 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float def fit_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]]: + timeout: float | None, + ) -> tuple[Parameters | None, dict[str, Scalar], FitResultsAndFailures] | None: """ This function is called at each round of federated training. The flow is generally the same as a flower server, where clients are sampled and client side training is requested from the clients that are chosen. @@ -264,11 +264,11 @@ def fit_round( Args: server_round (int): Current round number of the FL training. Begins at 1 - timeout (Optional[float]): Time that the server should wait (in seconds) for responses from the clients. + timeout (float | None): Time that the server should wait (in seconds) for responses from the clients. Defaults to None, which indicates indefinite timeout. Returns: - Optional[Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]]: The results of training + tuple[Parameters | None, dict[str, Scalar], FitResultsAndFailures] | None: The results of training on the client sit. The first set of parameters are the AGGREGATED parameters from the strategy. The second is a dictionary of AGGREGATED metrics. The third component holds the individual (non-aggregated) parameters, loss, and metrics for successful and unsuccessful client-side training. @@ -304,17 +304,17 @@ def shutdown(self) -> None: self.reports_manager.report({"shutdown": str(datetime.datetime.now())}) self.reports_manager.shutdown() - def poll_clients_for_sample_counts(self, timeout: Optional[float]) -> List[int]: + def poll_clients_for_sample_counts(self, timeout: float | None) -> list[int]: """ Poll clients for sample counts from their training set, if you want to use this functionality your strategy needs to inherit from the StrategyWithPolling ABC and implement a configure_poll function. Args: - timeout (Optional[float]): Timeout for how long the server will wait for clients to report counts. If none + timeout (float | None): Timeout for how long the server will wait for clients to report counts. If none then the server waits indefinitely. Returns: - List[int]: The number of training samples held by each client in the pool of available clients. + list[int]: The number of training samples held by each client in the pool of available clients. """ # Poll clients for sample counts, if you want to use this functionality your strategy needs to inherit from # the StrategyWithPolling ABC and implement a configure_poll function @@ -327,7 +327,7 @@ def poll_clients_for_sample_counts(self, timeout: Optional[float]) -> List[int]: timeout=timeout, ) - sample_counts: List[int] = [ + sample_counts: list[int] = [ int(get_properties_res.properties["num_train_samples"]) for (_, get_properties_res) in results ] log(INFO, f"Polling complete: Retrieved {len(sample_counts)} sample counts") @@ -337,8 +337,8 @@ def poll_clients_for_sample_counts(self, timeout: Optional[float]) -> List[int]: def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # By default the checkpointing works off of the aggregated evaluation loss from each of the clients # NOTE: parameter aggregation occurs **before** evaluation, so the parameters held by the server have been # updated prior to this function being called. @@ -426,7 +426,7 @@ def _load_server_state(self) -> bool: self.parameters = get_all_model_parameters(server_state["model"]) return True - def _terminate_after_unacceptable_failures(self, timeout: Optional[float]) -> None: + def _terminate_after_unacceptable_failures(self, timeout: float | None) -> None: assert not self.accept_failures # First we shutdown all clients involved in the FL training/evaluation if they can be. self.disconnect_all_clients(timeout=timeout) @@ -460,7 +460,7 @@ def _log_client_failures(self, failures: FitFailures | EvaluateFailures) -> None def _maybe_checkpoint( self, loss_aggregated: float, - metrics_aggregated: Dict[str, Scalar], + metrics_aggregated: dict[str, Scalar], server_round: int, ) -> None: """ @@ -469,12 +469,12 @@ def _maybe_checkpoint( Args: loss_aggregated (float): aggregated loss value that can be used to determine whether to checkpoint - metrics_aggregated (Dict[str, Scalar]): aggregated metrics from each of the clients for checkpointing + metrics_aggregated (dict[str, Scalar]): aggregated metrics from each of the clients for checkpointing server_round (int): What round of federated training we're on. This is just for logging purposes. """ self.checkpoint_and_state_module.maybe_checkpoint(self.parameters, loss_aggregated, metrics_aggregated) - def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) -> Parameters: + def _get_initial_parameters(self, server_round: int, timeout: float | None) -> Parameters: """ Get initial parameters from one of the available clients. This function is the same as the parent function in the flower server class except that we make use of the on_parameter_initialization_config_fn to provide @@ -482,10 +482,10 @@ def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) - NOTE: The default behavior of flower servers is to simply send over a blank config, but this is insufficient for certain uses, where the client requires additional information from the server. This is needed, for example - in nnUnet based Servers. An issue has been logged with flower: https://github.com/adap/flower/issues/3770 + in nnUnet-based Servers. An issue has been logged with flower: https://github.com/adap/flower/issues/3770 """ # Server-side parameter initialization - parameters: Optional[Parameters] = self.strategy.initialize_parameters(client_manager=self._client_manager) + parameters: Parameters | None = self.strategy.initialize_parameters(client_manager=self._client_manager) if parameters is not None: log(INFO, "Using initial global parameters provided by strategy") return parameters @@ -509,8 +509,8 @@ def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) - return get_parameters_res.parameters def _unpack_metrics( - self, results: List[Tuple[ClientProxy, EvaluateRes]] - ) -> Tuple[List[Tuple[ClientProxy, EvaluateRes]], List[Tuple[ClientProxy, EvaluateRes]]]: + self, results: list[tuple[ClientProxy, EvaluateRes]] + ) -> tuple[list[tuple[ClientProxy, EvaluateRes]], list[tuple[ClientProxy, EvaluateRes]]]: val_results = [] test_results = [] @@ -539,23 +539,23 @@ def _unpack_metrics( def _handle_result_aggregation( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[tuple[ClientProxy, EvaluateRes] | BaseException], + ) -> tuple[float | None, dict[str, Scalar]]: val_results, test_results = self._unpack_metrics(results) # Aggregate the validation results - val_aggregated_result: Tuple[ - Optional[float], - Dict[str, Scalar], + val_aggregated_result: tuple[ + float | None, + dict[str, Scalar], ] = self.strategy.aggregate_evaluate(server_round, val_results, failures) val_loss_aggregated, val_metrics_aggregated = val_aggregated_result # Aggregate the test results if they are present if len(test_results) > 0: - test_aggregated_result: Tuple[ - Optional[float], - Dict[str, Scalar], + test_aggregated_result: tuple[ + float | None, + dict[str, Scalar], ] = self.strategy.aggregate_evaluate(server_round, test_results, failures) test_loss_aggregated, test_metrics_aggregated = test_aggregated_result @@ -569,8 +569,8 @@ def _handle_result_aggregation( def _evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: """Validate current global model on a number of clients.""" # Get clients and their respective instructions from strategy client_instructions = self.strategy.configure_evaluate( diff --git a/fl4health/servers/client_level_dp_fed_avg_server.py b/fl4health/servers/client_level_dp_fed_avg_server.py index c080fbbd7..39f7167f0 100644 --- a/fl4health/servers/client_level_dp_fed_avg_server.py +++ b/fl4health/servers/client_level_dp_fed_avg_server.py @@ -1,7 +1,6 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from logging import INFO from math import ceil -from typing import Callable, Dict, List, Optional, Tuple from flwr.common.logger import log from flwr.common.typing import Config, Scalar @@ -31,8 +30,8 @@ def __init__( num_server_rounds: int, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: ClippingBitServerCheckpointAndStateModule | None = None, - delta: Optional[int] = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + delta: int | None = None, + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, ) -> None: @@ -57,9 +56,9 @@ def __init__( artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. - delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to + delta (float | None, optional): The delta value for epsilon-delta DP accounting. If None it defaults to being 1/total_samples in the FL run. Defaults to None. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. @@ -89,17 +88,17 @@ def __init__( self.num_server_rounds = num_server_rounds self.delta = delta - def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """ Run federated averaging for a number of rounds. Args: num_rounds (int): Number of server rounds to run. - timeout (Optional[float]): The amount of time in seconds that the server will wait for results from the + timeout (float | None): The amount of time in seconds that the server will wait for results from the clients selected to participate in federated training. Returns: - Tuple[History, float]: The first element of the tuple is a history object containing the full set of + tuple[History, float]: The first element of the tuple is a history object containing the full set of FL training results, including things like aggregated loss and metrics. Tuple also contains the elapsed time in seconds for the round. """ @@ -116,12 +115,12 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float return super().fit(num_rounds=num_rounds, timeout=timeout) - def setup_privacy_accountant(self, sample_counts: List[int]) -> None: + def setup_privacy_accountant(self, sample_counts: list[int]) -> None: """ Sets up FL Accountant and computes privacy loss based on class attributes and retrieved sample counts. Args: - sample_counts (List[int]): These should be the total number of training examples fetched from all clients + sample_counts (list[int]): These should be the total number of training examples fetched from all clients during the sample polling process. """ assert isinstance(self.strategy, ClientLevelDPFedAvgM) diff --git a/fl4health/servers/evaluate_server.py b/fl4health/servers/evaluate_server.py index 00ea4a3fa..c0cae1541 100644 --- a/fl4health/servers/evaluate_server.py +++ b/fl4health/servers/evaluate_server.py @@ -2,7 +2,6 @@ from collections.abc import Sequence from logging import INFO, WARNING from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union import torch from flwr.common import EvaluateIns, EvaluateRes, MetricsAggregationFn, Parameters, Scalar @@ -23,9 +22,9 @@ def __init__( self, client_manager: ClientManager, fraction_evaluate: float, - model_checkpoint_path: Optional[Path] = None, - evaluate_config: Optional[Dict[str, Scalar]] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + model_checkpoint_path: Path | None = None, + evaluate_config: dict[str, Scalar] | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, accept_failures: bool = True, min_available_clients: int = 1, reporters: Sequence[BaseReporter] | None = None, @@ -35,11 +34,11 @@ def __init__( client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if they are to be sampled at all. fraction_evaluate (float): Fraction of clients used during evaluation. - model_checkpoint_path (Optional[Path], optional): Server side model checkpoint path to load global model + model_checkpoint_path (Path | None, optional): Server side model checkpoint path to load global model from. Defaults to None. - evaluate_config (Optional[Dict[str, Scalar]], optional): Configuration dictionary to configure evaluation + evaluate_config (dict[str, Scalar] | None, optional): Configuration dictionary to configure evaluation on clients. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 1. @@ -78,18 +77,18 @@ def load_model_checkpoint_to_parameters(self) -> Parameters: log(INFO, "Model loaded and state converted to parameters") return parameters - def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """ In order to head off training and only run eval, we have to override the fit function as this is essentially the entry point for federated learning from the app. Args: num_rounds (int): Not used. - timeout (Optional[float]): Timeout in seconds that the server should wait for the clients to respond. + timeout (float | None): Timeout in seconds that the server should wait for the clients to respond. If none, then it will wait for the minimum number to respond indefinitely. Returns: - Tuple[History, float]: The first element of the tuple is a History object containing the aggregated + tuple[History, float]: The first element of the tuple is a History object containing the aggregated metrics returned from the clients. Tuple also contains elapsed time in seconds for round. """ history = History() @@ -134,17 +133,17 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float def federated_evaluate( self, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: """ Validate current global model on a number of clients. Args: - timeout (Optional[float]): Timeout in seconds that the server should wait for the clients to response. + timeout (float | None): Timeout in seconds that the server should wait for the clients to response. If none, then it will wait for the minimum number to respond indefinitely. Returns: - Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: The first value is the + tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: The first value is the loss, which is ignored since we pack loss from the global and local models into the metrics dictionary The second is the aggregated metrics passed from the clients, the third is the set of raw results and failure objects returned by the clients. @@ -177,21 +176,21 @@ def federated_evaluate( # Aggregate the evaluation results, note that we assume that the losses have been packed and aggregated with # the metrics. A dummy loss is returned by each of the clients. We therefore return none for the aggregated # loss - aggregated_result: Tuple[ - Optional[float], - Dict[str, Scalar], + aggregated_result: tuple[ + float | None, + dict[str, Scalar], ] = self.aggregate_evaluate(results, failures) _, metrics_aggregated = aggregated_result return None, metrics_aggregated, (results, failures) - def configure_evaluate(self) -> List[Tuple[ClientProxy, EvaluateIns]]: + def configure_evaluate(self) -> list[tuple[ClientProxy, EvaluateIns]]: """ Configure the next round of evaluation. This handles the two different was that a set of clients might be sampled. Returns: - List[Tuple[ClientProxy, EvaluateIns]]: List of configuration instructions for the clients selected by the + list[tuple[ClientProxy, EvaluateIns]]: List of configuration instructions for the clients selected by the client manager for evaluation. These configuration objects are sent to the clients to customize evaluation. """ @@ -218,22 +217,22 @@ def configure_evaluate(self) -> List[Tuple[ClientProxy, EvaluateIns]]: def aggregate_evaluate( self, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[tuple[ClientProxy, EvaluateRes] | BaseException], + ) -> tuple[float | None, dict[str, Scalar]]: """ Aggregate evaluation results using the evaluate_metrics_aggregation_fn provided. Note that a dummy loss is returned as we assume that it was packed into the metrics dictionary for this functionality. Args: - results (List[Tuple[ClientProxy, EvaluateRes]]): List of results objects that have the metrics returned + results (list[tuple[ClientProxy, EvaluateRes]]): List of results objects that have the metrics returned from each client, if successful, along with the number of samples used in the evaluation. - failures (List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]]): Failures reported by the clients + failures (list[tuple[ClientProxy, EvaluateRes] | BaseException]): Failures reported by the clients along with the client id, the results that we passed, if any, and the associated exception if one was raised. Returns: - Tuple[Optional[float], Dict[str, Scalar]]: A dummy float for the "loss" (these are packed with the metrics) + tuple[float | None, dict[str, Scalar]]: A dummy float for the "loss" (these are packed with the metrics) and the aggregated metrics dictionary. """ if not results: diff --git a/fl4health/servers/fedpm_server.py b/fl4health/servers/fedpm_server.py index 23c5a5d4a..8476129d5 100644 --- a/fl4health/servers/fedpm_server.py +++ b/fl4health/servers/fedpm_server.py @@ -1,5 +1,4 @@ -from collections.abc import Sequence -from typing import Callable, Dict, Optional, Tuple +from collections.abc import Callable, Sequence from flwr.common import Parameters from flwr.common.typing import Config, Scalar @@ -20,7 +19,7 @@ def __init__( strategy: FedPm, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: LayerNamesServerCheckpointAndStateModule | None = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, reset_frequency: int = 1, @@ -42,7 +41,7 @@ def __init__( artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. @@ -74,8 +73,8 @@ def __init__( def fit_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]]: + timeout: float | None, + ) -> tuple[Parameters | None, dict[str, Scalar], FitResultsAndFailures] | None: assert isinstance(self.strategy, FedPm) # If self.reset_frequency == x, then the beta priors are reset every x fitting rounds. # Note that (server_round + 1) % self.reset_frequency == 0 is to ensure that the priors diff --git a/fl4health/servers/instance_level_dp_server.py b/fl4health/servers/instance_level_dp_server.py index f01c446ba..b7773d37c 100644 --- a/fl4health/servers/instance_level_dp_server.py +++ b/fl4health/servers/instance_level_dp_server.py @@ -1,7 +1,6 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from logging import INFO from math import ceil -from typing import Callable, Dict, List, Optional, Tuple from flwr.common.logger import log from flwr.common.typing import Config, Scalar @@ -26,12 +25,12 @@ def __init__( batch_size: int, num_server_rounds: int, strategy: BasicFedAvg, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, checkpoint_and_state_module: OpacusServerCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, - delta: Optional[float] = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + delta: float | None = None, + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, ) -> None: @@ -53,10 +52,10 @@ def __init__( strategy (BasicFedAvg): The aggregation strategy to be used by the server to handle client updates and other information potentially sent by the participating clients. this must be an OpacusBasicFedAvg strategy to ensure proper treatment of the model in the Opacus framework - local_epochs (Optional[int], optional): Number of local epochs to be performed on the client-side. This is + local_epochs (int | None, optional): Number of local epochs to be performed on the client-side. This is used in privacy accounting. One of local_epochs or local_steps should be defined, but not both. Defaults to None. - local_steps (Optional[int], optional): Number of local steps to be performed on the client-side. This is + local_steps (int | None, optional): Number of local steps to be performed on the client-side. This is used in privacy accounting. One of local_epochs or local_steps should be defined, but not both. Defaults to None. checkpoint_and_state_module (OpacusServerCheckpointAndStateModule | None, optional): This module is used @@ -66,9 +65,9 @@ def __init__( module is provided, no checkpointing or state preservation will happen. Defaults to None. reporters (Sequence[BaseReporter] | None, optional): A sequence of FL4Health reporters which the client should send data to. - delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to + delta (float | None, optional): The delta value for epsilon-delta DP accounting. If None it defaults to being 1/total_samples in the FL run. Defaults to None. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. @@ -107,17 +106,17 @@ def __init__( self.num_server_rounds = num_server_rounds self.delta = delta - def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """ Run federated averaging for a number of rounds. Args: num_rounds (int): Number of server rounds to run. - timeout (Optional[float]): The amount of time in seconds that the server will wait for results from the + timeout (float | None): The amount of time in seconds that the server will wait for results from the clients selected to participate in federated training. Returns: - Tuple[History, float]: The first element of the tuple is a history object containing the full + tuple[History, float]: The first element of the tuple is a history object containing the full set of FL training results, including things like aggregated loss and metrics. Tuple also includes elapsed time in seconds for round. """ @@ -128,12 +127,12 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float return super().fit(num_rounds=num_rounds, timeout=timeout) - def setup_privacy_accountant(self, sample_counts: List[int]) -> None: + def setup_privacy_accountant(self, sample_counts: list[int]) -> None: """ Sets up FL Accountant and computes privacy loss based on class attributes and retrieved sample counts. Args: - sample_counts (List[int]): These should be the total number of training examples fetched from all clients + sample_counts (list[int]): These should be the total number of training examples fetched from all clients during the sample polling process. """ # Ensures that we're using a fraction sampler of the diff --git a/fl4health/servers/model_merge_server.py b/fl4health/servers/model_merge_server.py index e12452bfb..ed90eff0e 100644 --- a/fl4health/servers/model_merge_server.py +++ b/fl4health/servers/model_merge_server.py @@ -1,7 +1,7 @@ import datetime import timeit +from collections.abc import Sequence from logging import INFO, WARNING -from typing import Dict, Optional, Sequence, Tuple import torch.nn as nn from flwr.common.logger import log @@ -25,12 +25,12 @@ def __init__( self, *, client_manager: ClientManager, - strategy: Optional[Strategy] = None, - server_model: Optional[nn.Module] = None, - checkpointer: Optional[LatestTorchModuleCheckpointer] = None, - parameter_exchanger: Optional[ParameterExchanger] = None, + strategy: Strategy | None = None, + server_model: nn.Module | None = None, + checkpointer: LatestTorchModuleCheckpointer | None = None, + parameter_exchanger: ParameterExchanger | None = None, reporters: Sequence[BaseReporter] | None = None, - server_name: Optional[str] = None, + server_name: str | None = None, ) -> None: """ ModelMergeServer provides functionality to fetch client weights, perform a simple average, @@ -38,21 +38,21 @@ def __init__( Args: client_manager (ClientManager): Determines the mechanism by which clients are sampled by the server, if they are to be sampled at all. - strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle + strategy (Strategy | None, optional): The aggregation strategy to be used by the server to handle client updates sent by the participating clients. Must be ModelMergeStrategy. - checkpointer (Optional[LatestTorchCheckpointer], optional): To be provided if the server should perform + checkpointer (LatestTorchCheckpointer | None, optional): To be provided if the server should perform server side checkpointing on the merged model. If none, then no server-side checkpointing is performed. Defaults to None. - server_model (Optional[nn.Module]): Optional model to be hydrated with parameters from model merge if doing + server_model (nn.Module | None): Optional model to be hydrated with parameters from model merge if doing server side checkpointing. Must only be provided if checkpointer is also provided. Defaults to None. - parameter_exchanger (Optional[ExchangerType], optional): A parameter exchanger used to facilitate + parameter_exchanger (ExchangerType | None, optional): A parameter exchanger used to facilitate server-side model checkpointing if a checkpointer has been defined. If not provided then checkpointing will not be done unless the _hydrate_model_for_checkpointing function is overridden. Because the server only sees numpy arrays, the parameter exchanger is used to insert the numpy arrays into a provided model. Defaults to None. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the server should send data to before and after each round. - server_name (Optional[str]): An optional string name to uniquely identify server. + server_name (str | None): An optional string name to uniquely identify server. """ assert isinstance(strategy, ModelMergeStrategy) assert (server_model is None and checkpointer is None and parameter_exchanger is None) or ( @@ -69,7 +69,7 @@ def __init__( self.reports_manager = ReportsManager(reporters) self.reports_manager.initialize(id=self.server_name) - def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """ Performs a fit round in which the local client weights are evaluated on their test set, uploaded to the server and averaged, then redistributed to clients for evaluation. @@ -77,11 +77,11 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float Args: num_rounds (int): Not used. - timeout (Optional[float]): Timeout in seconds that the server should wait for the clients to respond. + timeout (float | None): Timeout in seconds that the server should wait for the clients to respond. If none, then it will wait for the minimum number to respond indefinitely. Returns: - Tuple[History, float]: The first element of the tuple is a History object containing the aggregated + tuple[History, float]: The first element of the tuple is a History object containing the aggregated metrics returned from the clients. Tuple also contains elapsed time in seconds for round. """ self.reports_manager.report({"host_type": "server", "fit_start": datetime.datetime.now()}) @@ -162,7 +162,7 @@ def _hydrate_model_for_checkpointing(self) -> nn.Module: return self.server_model def _maybe_checkpoint( - self, loss_aggregated: float, metrics_aggregated: Dict[str, Scalar], server_round: int + self, loss_aggregated: float, metrics_aggregated: dict[str, Scalar], server_round: int ) -> None: """ Method to checkpoint merged model on server side if the checkpointer, server_model and @@ -170,7 +170,7 @@ def _maybe_checkpoint( Args: loss_aggregated (float): Not used. - metrics_aggregated (Dict[str, Scalar]): Not used. + metrics_aggregated (dict[str, Scalar]): Not used. server_round (int): Not used. """ if self.checkpointer and self.server_model and self.parameter_exchanger: diff --git a/fl4health/servers/nnunet_server.py b/fl4health/servers/nnunet_server.py index 018a2d691..378b36e06 100644 --- a/fl4health/servers/nnunet_server.py +++ b/fl4health/servers/nnunet_server.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Callable, Sequence from logging import INFO -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any import torch.nn as nn from flwr.common import Parameters @@ -25,9 +25,9 @@ from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer from nnunetv2.utilities.plans_handling.plans_handler import PlansManager -FIT_CFG_FN = Callable[[int, Parameters, ClientManager], list[Tuple[ClientProxy, FitIns]]] -EVAL_CFG_FN = Callable[[int, Parameters, ClientManager], list[Tuple[ClientProxy, EvaluateIns]]] -CFG_FN = Union[FIT_CFG_FN, EVAL_CFG_FN] +FIT_CFG_FN = Callable[[int, Parameters, ClientManager], list[tuple[ClientProxy, FitIns]]] +EVAL_CFG_FN = Callable[[int, Parameters, ClientManager], list[tuple[ClientProxy, EvaluateIns]]] +CFG_FN = FIT_CFG_FN | EVAL_CFG_FN def add_items_to_config_fn(fn: CFG_FN, items: Config) -> CFG_FN: @@ -60,13 +60,13 @@ def __init__( self, client_manager: ClientManager, fl_config: Config, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]], + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]], strategy: Strategy | None = None, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: NnUnetServerCheckpointAndStateModule | None = None, server_name: str | None = None, accept_failures: bool = True, - nnunet_trainer_class: Type[nnUNetTrainer] = nnUNetTrainer, + nnunet_trainer_class: type[nnUNetTrainer] = nnUNetTrainer, ) -> None: """ A Basic FlServer with added functionality to ask a client to initialize the global nnunet plans if one was not @@ -79,7 +79,7 @@ def __init__( In most cases it should be the "source of truth" for how FL training/evaluation should proceed. For example, the config used to produce the on_fit_config_fn and on_evaluate_config_fn for the strategy. NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]]): Function used to configure how one + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]]): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. For NnunetServers this is a required function to provide the additional information necessary to a client for parameter initialization @@ -100,7 +100,7 @@ def __init__( accept_failures (bool, optional): Determines whether the server should accept failures during training or evaluation from clients or not. If set to False, this will cause the server to shutdown all clients and throw an exception. Defaults to True. - nnunet_trainer_class (Type[nnUNetTrainer]): nnUNetTrainer class. + nnunet_trainer_class (type[nnUNetTrainer]): nnUNetTrainer class. Useful for passing custom nnUNetTrainer. Defaults to the standard nnUNetTrainer class. Must match the nnunet_trainer_class passed to the NnunetClient. """ @@ -151,7 +151,7 @@ def initialize_server_model(self) -> None: self.checkpoint_and_state_module.model = model - def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: + def update_before_fit(self, num_rounds: int, timeout: float | None) -> None: """ Hook method to allow the server to do some additional initialization prior to fitting. NunetServer uses this method to sample a client for properties which are required to initialize the server. @@ -166,7 +166,7 @@ def update_before_fit(self, num_rounds: int, timeout: Optional[float]) -> None: Args: num_rounds (int): The number of server rounds of FL to be performed - timeout (Optional[float], optional): The server's timeout parameter. Useful if one is requesting + timeout (float | None, optional): The server's timeout parameter. Useful if one is requesting information from a client. Defaults to None, which indicates indefinite timeout. """ diff --git a/fl4health/servers/polling.py b/fl4health/servers/polling.py index 51527bde5..22462e026 100644 --- a/fl4health/servers/polling.py +++ b/fl4health/servers/polling.py @@ -1,19 +1,18 @@ import concurrent.futures -from typing import List, Optional, Tuple, Union from flwr.common.typing import Code, GetPropertiesIns, GetPropertiesRes from flwr.server.client_proxy import ClientProxy -PollResultsAndFailures = Tuple[ - List[Tuple[ClientProxy, GetPropertiesRes]], - List[Union[Tuple[ClientProxy, GetPropertiesRes], BaseException]], +PollResultsAndFailures = tuple[ + list[tuple[ClientProxy, GetPropertiesRes]], + list[tuple[ClientProxy, GetPropertiesRes] | BaseException], ] def _handle_finished_future_after_poll( future: concurrent.futures.Future, - results: List[Tuple[ClientProxy, GetPropertiesRes]], - failures: List[Union[Tuple[ClientProxy, GetPropertiesRes], BaseException]], + results: list[tuple[ClientProxy, GetPropertiesRes]], + failures: list[tuple[ClientProxy, GetPropertiesRes] | BaseException], ) -> None: """ Convert finished future into either a result or a failure for polling. @@ -21,8 +20,8 @@ def _handle_finished_future_after_poll( Args: future (concurrent.futures.Future): The future returned by a client executing polling. It is either added to results if there are no exceptions or failures if there are any. - results (List[Tuple[ClientProxy, GetPropertiesRes]]): Set of good results from clients that have accumulated. - failures (List[Union[Tuple[ClientProxy, GetPropertiesRes], BaseException]]): The set of failing results that + results (list[tuple[ClientProxy, GetPropertiesRes]]): Set of good results from clients that have accumulated. + failures (list[tuple[ClientProxy, GetPropertiesRes] | BaseException]): The set of failing results that have accumulated for the polling. """ @@ -33,7 +32,7 @@ def _handle_finished_future_after_poll( return # Successfully received a result from a client - result: Tuple[ClientProxy, GetPropertiesRes] = future.result() + result: tuple[ClientProxy, GetPropertiesRes] = future.result() _, res = result # Check result status code @@ -45,7 +44,7 @@ def _handle_finished_future_after_poll( failures.append(result) -def poll_client(client: ClientProxy, ins: GetPropertiesIns) -> Tuple[ClientProxy, GetPropertiesRes]: +def poll_client(client: ClientProxy, ins: GetPropertiesIns) -> tuple[ClientProxy, GetPropertiesRes]: """ Get Properties of client. This is run for each client to extract the properties from the target client. @@ -55,27 +54,27 @@ def poll_client(client: ClientProxy, ins: GetPropertiesIns) -> Tuple[ClientProxy properties. Returns: - Tuple[ClientProxy, GetPropertiesRes]: Returns the resulting properties from the client response. + tuple[ClientProxy, GetPropertiesRes]: Returns the resulting properties from the client response. """ property_res: GetPropertiesRes = client.get_properties(ins=ins, timeout=None, group_id=None) return client, property_res def poll_clients( - client_instructions: List[Tuple[ClientProxy, GetPropertiesIns]], - max_workers: Optional[int], - timeout: Optional[float], + client_instructions: list[tuple[ClientProxy, GetPropertiesIns]], + max_workers: int | None, + timeout: float | None, ) -> PollResultsAndFailures: """ Poll clients concurrently on all selected clients. Args: - client_instructions (List[Tuple[ClientProxy, GetPropertiesIns]]): This is the set of instructions for the + client_instructions (list[tuple[ClientProxy, GetPropertiesIns]]): This is the set of instructions for the polling to be passed to each client. Each client is represented by a single ClientProxy in the list. - max_workers (Optional[int]): This is the maximum number of concurrent workers to be used by the server to + max_workers (int | None): This is the maximum number of concurrent workers to be used by the server to poll the clients. This should be set if pooling an extremely large number, if none a maximum of 32 workers are used. - timeout (Optional[float]): How long the executor should wait to receive a response before moving on. + timeout (float | None): How long the executor should wait to receive a response before moving on. Returns: PollResultsAndFailures: Object holding the results and failures associate with the concurrent polling. @@ -92,8 +91,8 @@ def poll_clients( ) # Gather results - results: List[Tuple[ClientProxy, GetPropertiesRes]] = [] - failures: List[Union[Tuple[ClientProxy, GetPropertiesRes], BaseException]] = [] + results: list[tuple[ClientProxy, GetPropertiesRes]] = [] + failures: list[tuple[ClientProxy, GetPropertiesRes] | BaseException] = [] for future in finished_fs: _handle_finished_future_after_poll(future=future, results=results, failures=failures) diff --git a/fl4health/servers/scaffold_server.py b/fl4health/servers/scaffold_server.py index 111924fd3..027588113 100644 --- a/fl4health/servers/scaffold_server.py +++ b/fl4health/servers/scaffold_server.py @@ -1,6 +1,5 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from logging import DEBUG, ERROR, INFO -from typing import Callable, Dict, Optional, Tuple from flwr.common import Parameters, ndarrays_to_parameters, parameters_to_ndarrays from flwr.common.logger import log @@ -27,7 +26,7 @@ def __init__( strategy: Scaffold, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: ScaffoldServerCheckpointAndStateModule | None = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, warm_start: bool = False, @@ -53,7 +52,7 @@ def __init__( artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. @@ -85,7 +84,7 @@ def __init__( ) self.warm_start = warm_start - def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) -> Parameters: + def _get_initial_parameters(self, server_round: int, timeout: float | None) -> Parameters: """ Overrides the _get_initial_parameters in the flwr server base class to strap on the possibility of a warm_start for SCAFFOLD. Initializes parameters (models weights and control variates) of the server. @@ -95,7 +94,7 @@ def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) - Args: server_round (int): The current server round. - timeout (Optional[float]): If the server strategy object does not have a server-side initial parameters + timeout (float | None): If the server strategy object does not have a server-side initial parameters function defined, then one of the clients is polled and their model parameters are returned in order to initialize the models of all clients. Timeout defines how long to wait for a response. @@ -160,19 +159,19 @@ def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) - return initial_parameters - def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """ Run the SCAFFOLD FL algorithm for a fixed number of rounds. This overrides the base server fit class just to ensure that the provided strategy is a Scaffold strategy object before proceeding. Args: num_rounds (int): Number of rounds of FL to perform (i.e. server rounds). - timeout (Optional[float]): Timeout associated with queries to the clients in seconds. The server waits for + timeout (float | None): Timeout associated with queries to the clients in seconds. The server waits for timeout seconds before moving on without any unresponsive clients. If None, there is no timeout and the server waits for the minimum number of clients to be available set in the strategy. Returns: - Tuple[History, float]: The first element of the tuple is a history object containing the full set of + tuple[History, float]: The first element of the tuple is a history object containing the full set of FL training results, including things like aggregated loss and metrics. Tuple also includes elapsed time in seconds for round. """ @@ -189,13 +188,13 @@ def __init__( batch_size: int, num_server_rounds: int, strategy: OpacusScaffold, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, - delta: Optional[float] = None, + local_epochs: int | None = None, + local_steps: int | None = None, + delta: float | None = None, checkpoint_and_state_module: DpScaffoldServerCheckpointAndStateModule | None = None, warm_start: bool = False, reporters: Sequence[BaseReporter] | None = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, ) -> None: @@ -214,10 +213,10 @@ def __init__( DP-SGD. batch_size (int): The batch size to be used in training on the client-side. Used in privacy accounting. num_server_rounds (int): The number of server rounds to be done in FL training. Used in privacy accounting - local_epochs (Optional[int], optional): Number of local epochs to be performed on the client-side. This is + local_epochs (int | None, optional): Number of local epochs to be performed on the client-side. This is used in privacy accounting. One of local_epochs or local_steps should be defined, but not both. Defaults to None. - local_steps (Optional[int], optional): Number of local steps to be performed on the client-side. This is + local_steps (int | None, optional): Number of local steps to be performed on the client-side. This is used in privacy accounting. One of local_epochs or local_steps should be defined, but not both. Defaults to None. strategy (Scaffold): The aggregation strategy to be used by the server to handle client updates and @@ -232,11 +231,11 @@ def __init__( gradients. The clients will perform a training pass (without updating the weights) in order to provide a "warm" estimate of the SCAFFOLD control variates. If false, variates are initialized to 0. Defaults to False. - delta (Optional[float], optional): The delta value for epsilon-delta DP accounting. If None it defaults to + delta (float | None, optional): The delta value for epsilon-delta DP accounting. If None it defaults to being 1/total_samples in the FL run. Defaults to None. reporters (Sequence[BaseReporter], optional): A sequence of FL4Health reporters which the client should send data to. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. @@ -276,18 +275,18 @@ def __init__( accept_failures=accept_failures, ) - def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """ Run DP Scaffold algorithm for the specified number of rounds. Args: num_rounds (int): Number of rounds of FL to perform (i.e. server rounds). - timeout (Optional[float]): Timeout associated with queries to the clients in seconds. The server waits for + timeout (float | None): Timeout associated with queries to the clients in seconds. The server waits for timeout seconds before moving on without any unresponsive clients. If None, there is no timeout and the server waits for the minimum number of clients to be available set in the strategy. Returns: - Tuple[History, float]: First element of tuple is history object containing the full set of FL + tuple[History, float]: First element of tuple is history object containing the full set of FL training results, including aggregated loss and metrics. Tuple also includes the elapsed time in seconds for round. """ diff --git a/fl4health/servers/tabular_feature_alignment_server.py b/fl4health/servers/tabular_feature_alignment_server.py index 3a689d84e..1a7366900 100644 --- a/fl4health/servers/tabular_feature_alignment_server.py +++ b/fl4health/servers/tabular_feature_alignment_server.py @@ -1,7 +1,7 @@ import random +from collections.abc import Callable, Sequence from functools import partial from logging import DEBUG, INFO, WARNING -from typing import Callable, Dict, Optional, Sequence, Tuple from flwr.common import Parameters from flwr.common.logger import log @@ -31,10 +31,10 @@ def __init__( config: Config, initialize_parameters: Callable[..., Parameters], strategy: BasicFedAvg, - tabular_features_source_of_truth: Optional[TabularFeaturesInfoEncoder] = None, + tabular_features_source_of_truth: TabularFeaturesInfoEncoder | None = None, reporters: Sequence[BaseReporter] | None = None, checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, - on_init_parameters_config_fn: Callable[[int], Dict[str, Scalar]] | None = None, + on_init_parameters_config_fn: Callable[[int], dict[str, Scalar]] | None = None, server_name: str | None = None, accept_failures: bool = True, ) -> None: @@ -47,16 +47,16 @@ def __init__( config (Config): This should be the configuration that was used to setup the federated alignment. In most cases it should be the "source of truth" for how FL alignment should proceed. NOTE: This config is DISTINCT from the Flwr server config, which is extremely minimal. - strategy (Optional[Strategy], optional): The aggregation strategy to be used by the server to handle. + strategy (Strategy | None, optional): The aggregation strategy to be used by the server to handle. client updates and other information potentially sent by the participating clients. If None the strategy is FedAvg as set by the flwr Server. - wandb_reporter (Optional[ServerWandBReporter], optional): To be provided if the server is to log + wandb_reporter (ServerWandBReporter | None, optional): To be provided if the server is to log information and results to a Weights and Biases account. If None is provided, no logging occurs. Defaults to None. - checkpointer (Optional[TorchCheckpointer], optional): To be provided if the server should perform + checkpointer (TorchCheckpointer | None, optional): To be provided if the server should perform server side checkpointing based on some criteria. If none, then no server-side checkpointing is performed. Defaults to None. - tab_features_source_of_truth (Optional[TabularFeaturesInfoEncoder]): The information that is required + tab_features_source_of_truth (TabularFeaturesInfoEncoder | None): The information that is required for aligning client features. If it is not specified, then the server will randomly poll a client and gather this information from its data source. reporters (Sequence[BaseReporter] | None, optional): sequence of FL4Health reporters which the server @@ -66,7 +66,7 @@ def __init__( artifacts to be used or evaluated after training. The latter is used to preserve training state (including models) such that if FL training is interrupted, the process may be restarted. If no module is provided, no checkpointing or state preservation will happen. Defaults to None. - on_init_parameters_config_fn (Callable[[int], Dict[str, Scalar]] | None, optional): Function used to + on_init_parameters_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure how one asks a client to provide parameters from which to initialize all other clients by providing a Config dictionary. If this is none, then a blank config is sent with the parameter request (which is default behavior for flower servers). Defaults to None. @@ -99,7 +99,7 @@ def __init__( self.tab_features_info = tabular_features_source_of_truth self.initialize_parameters = initialize_parameters self.source_info_gathered = False - self.dimension_info: Dict[str, int] = {} + self.dimension_info: dict[str, int] = {} # ensure that self.strategy has type BasicFedAvg so its on_fit_config_fn can be specified. assert isinstance(self.strategy, BasicFedAvg), "This server is only compatible with BasicFedAvg at this time" self.strategy.on_fit_config_fn = partial(fit_config, self.fl_config, self.source_info_gathered) @@ -108,13 +108,13 @@ def _set_dimension_info(self, input_dimension: int, output_dimension: int) -> No self.dimension_info[INPUT_DIMENSION] = input_dimension self.dimension_info[OUTPUT_DIMENSION] = output_dimension - def _get_initial_parameters(self, server_round: int, timeout: Optional[float]) -> Parameters: + def _get_initial_parameters(self, server_round: int, timeout: float | None) -> Parameters: assert INPUT_DIMENSION in self.dimension_info and OUTPUT_DIMENSION in self.dimension_info input_dimension = self.dimension_info[INPUT_DIMENSION] output_dimension = self.dimension_info[OUTPUT_DIMENSION] return self.initialize_parameters(input_dimension, output_dimension) - def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]: + def fit(self, num_rounds: int, timeout: float | None) -> tuple[History, float]: """Run federated averaging for a number of rounds.""" assert isinstance(self.strategy, BasicFedAvg) @@ -154,7 +154,7 @@ def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float # are aligned and global model is initialized. return super().fit(num_rounds=num_rounds, timeout=timeout) - def poll_clients_for_feature_info(self, timeout: Optional[float]) -> str: + def poll_clients_for_feature_info(self, timeout: float | None) -> str: log(INFO, "Feature information source unspecified. Polling clients for feature information.") assert isinstance(self.strategy, BasicFedAvg) client_instructions = self.strategy.configure_poll(server_round=1, client_manager=self._client_manager) @@ -169,7 +169,7 @@ def poll_clients_for_feature_info(self, timeout: Optional[float]) -> str: feature_info = str(get_properties_res.properties[FEATURE_INFO]) return feature_info - def poll_clients_for_dimension_info(self, timeout: Optional[float]) -> Tuple[int, int]: + def poll_clients_for_dimension_info(self, timeout: float | None) -> tuple[int, int]: log(INFO, "Waiting for Clients to align features and then polling for dimension information.") assert isinstance(self.strategy, BasicFedAvg) client_instructions = self.strategy.configure_poll(server_round=1, client_manager=self._client_manager) diff --git a/fl4health/strategies/aggregate_utils.py b/fl4health/strategies/aggregate_utils.py index ecef5eaeb..48a3aa7a9 100644 --- a/fl4health/strategies/aggregate_utils.py +++ b/fl4health/strategies/aggregate_utils.py @@ -1,17 +1,16 @@ from functools import reduce -from typing import List, Tuple import numpy as np from flwr.common import NDArrays from flwr.server.strategy.aggregate import aggregate, weighted_loss_avg -def aggregate_results(results: List[Tuple[NDArrays, int]], weighted: bool = True) -> NDArrays: +def aggregate_results(results: list[tuple[NDArrays, int]], weighted: bool = True) -> NDArrays: """ Compute weighted or unweighted average. Args: - results (List[Tuple[NDArrays, int]]): This is a set of NDArrays (list of numpy arrays) and the number of + results (list[tuple[NDArrays, int]]): This is a set of NDArrays (list of numpy arrays) and the number of relevant samples from each client (training or validation samples where appropriate). These are to be aggregated together in a weighted or unweighted average. The NDArrays most often represent model states. weighted (bool, optional): Whether or not the aggregation is a weighted average (by the sample counts @@ -33,12 +32,12 @@ def aggregate_results(results: List[Tuple[NDArrays, int]], weighted: bool = True return [reduce(np.add, layer_updates) for layer_updates in zip(*weighted_weights)] -def aggregate_losses(results: List[Tuple[int, float]], weighted: bool = True) -> float: +def aggregate_losses(results: list[tuple[int, float]], weighted: bool = True) -> float: """ Aggregate evaluation results obtained from multiple clients. Args: - results (List[Tuple[int, float]]): A list of sample counts and loss values (in that order). The sample counts + results (list[tuple[int, float]]): A list of sample counts and loss values (in that order). The sample counts from each client (training or validation samples where appropriate) are used if weighted averaging is requested. weighted (bool, optional): Whether or not the aggregation is a weighted average (by the sample counts diff --git a/fl4health/strategies/basic_fedavg.py b/fl4health/strategies/basic_fedavg.py index 27f09efee..7317fe34d 100644 --- a/fl4health/strategies/basic_fedavg.py +++ b/fl4health/strategies/basic_fedavg.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from logging import INFO, WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union from flwr.common import ( EvaluateIns, @@ -38,18 +38,15 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, - initial_parameters: Optional[Parameters] = None, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + initial_parameters: Parameters | None = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, weighted_aggregation: bool = True, weighted_eval_losses: bool = True, ) -> None: @@ -72,20 +69,18 @@ def __init__( min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - initial_parameters (Optional[Parameters], optional): Initial global model parameters. Defaults to None. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + initial_parameters (Parameters | None, optional): Initial global model parameters. Defaults to None. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. weighted_aggregation (bool, optional): Determines whether parameter aggregation is a linearly weighted average or a uniform average. FedAvg default is weighted average by client dataset counts. @@ -113,7 +108,7 @@ def __init__( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """ This function configures a sample of clients for a training round. It handles the case where the client manager has a sample fraction vs. a sample function (to allow for more flexible sampling). @@ -127,7 +122,7 @@ def configure_fit( client_manager (ClientManager): The manager used to sample from the available clients. Returns: - List[Tuple[ClientProxy, FitIns]]: List of sampled client identifiers and the configuration/parameters to + list[tuple[ClientProxy, FitIns]]: List of sampled client identifiers and the configuration/parameters to be sent to each client (packaged as FitIns). """ @@ -150,7 +145,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """ This function configures a sample of clients for a evaluation round. It handles the case where the client manager has a sample fraction vs. a sample function (to allow for more flexible sampling). @@ -164,7 +159,7 @@ def configure_evaluate( client_manager (ClientManager): The manager used to sample from the available clients. Returns: - List[Tuple[ClientProxy, EvaluateIns]]: List of sampled client identifiers and the configuration/parameters + list[tuple[ClientProxy, EvaluateIns]]: List of sampled client identifiers and the configuration/parameters to be sent to each client (packaged as EvaluateIns). """ @@ -192,7 +187,7 @@ def configure_evaluate( def configure_poll( self, server_round: int, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, GetPropertiesIns]]: + ) -> list[tuple[ClientProxy, GetPropertiesIns]]: """ This function configures everything required to request properties from ALL of the clients. The client manger, regardless of type, is instructed to grab all available clients to perform the polling process. @@ -202,7 +197,7 @@ def configure_poll( client_manager (ClientManager): The manager used to sample all available clients. Returns: - List[Tuple[ClientProxy, GetPropertiesIns]]: List of sampled client identifiers and the configuration + list[tuple[ClientProxy, GetPropertiesIns]]: List of sampled client identifiers and the configuration to be sent to each client (packaged as GetPropertiesIns). """ config = {} @@ -225,22 +220,22 @@ def configure_poll( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """ Aggregate the results from the federated fit round. This is done with either weighted or unweighted FedAvg, depending on the settings used for the strategy. Args: server_round (int): Indicates the server round we're currently on. - results (List[Tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training + results (list[tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training that need to be aggregated on the server-side. - failures (List[Union[Tuple[ClientProxy, FitRes], BaseException]]): These are the results and exceptions + failures (list[tuple[ClientProxy, FitRes] | BaseException]): These are the results and exceptions from clients that experienced an issue during training, such as timeouts or exceptions. Returns: - Tuple[Optional[Parameters], Dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. + tuple[Parameters | None, dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. """ if not results: return None, {} @@ -273,21 +268,21 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[tuple[ClientProxy, EvaluateRes] | BaseException], + ) -> tuple[float | None, dict[str, Scalar]]: """ Aggregate the metrics and losses returned from the clients as a result of the evaluation round. Args: - results (List[Tuple[ClientProxy, EvaluateRes]]): The client identifiers and the results of their local + results (list[tuple[ClientProxy, EvaluateRes]]): The client identifiers and the results of their local evaluation that need to be aggregated on the server-side. These results are loss values and the metrics dictionary. - failures (List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]]): These are the results and + failures (list[tuple[ClientProxy, EvaluateRes] | BaseException]): These are the results and exceptions from clients that experienced an issue during evaluation, such as timeouts or exceptions. Returns: - Tuple[Optional[float], Dict[str, Scalar]]: Aggregated loss values and the aggregated metrics. The metrics + tuple[float | None, dict[str, Scalar]]: Aggregated loss values and the aggregated metrics. The metrics are aggregated according to evaluate_metrics_aggregation_fn. """ if not results: @@ -325,17 +320,14 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, weighted_aggregation: bool = True, weighted_eval_losses: bool = True, ) -> None: @@ -358,19 +350,17 @@ def __init__( min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. weighted_aggregation (bool, optional): Determines whether parameter aggregation is a linearly weighted average or a uniform average. FedAvg default is weighted average by client dataset counts. diff --git a/fl4health/strategies/client_dp_fedavgm.py b/fl4health/strategies/client_dp_fedavgm.py index af2cecf02..5c3428865 100644 --- a/fl4health/strategies/client_dp_fedavgm.py +++ b/fl4health/strategies/client_dp_fedavgm.py @@ -1,6 +1,6 @@ import math +from collections.abc import Callable from logging import INFO, WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np from flwr.common import ( @@ -38,21 +38,18 @@ def __init__( fraction_fit: float = 1.0, fraction_evaluate: float = 1.0, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, - initial_parameters: Optional[Parameters] = None, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + initial_parameters: Parameters | None = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, weighted_aggregation: bool = False, weighted_eval_losses: bool = True, - per_client_example_cap: Optional[float] = None, + per_client_example_cap: float | None = None, adaptive_clipping: bool = False, server_learning_rate: float = 1.0, clipping_learning_rate: float = 1.0, @@ -77,28 +74,26 @@ def __init__( fraction_evaluate (float, optional): Fraction of clients used during validation. Defaults to 1.0. min_available_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - initial_parameters (Optional[Parameters], optional): Initial global model parameters. This strategy assumes + initial_parameters (Parameters | None, optional): Initial global model parameters. This strategy assumes that the initial parameters is not None. So they need to be set in spite of the optional tag. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. weighted_aggregation (bool, optional): Determines whether the FedAvg update is weighted by client dataset size or unweighted. Defaults to False. weighted_eval_losses (bool, optional): Determines whether losses during evaluation are linearly weighted averages or a uniform average. FedAvg default is weighted average of the losses by client dataset counts. Defaults to True. - per_client_example_cap (Optional[float], optional): The maximum number samples per client. hat{w} in + per_client_example_cap (float | None, optional): The maximum number samples per client. hat{w} in https://arxiv.org/pdf/1710.06963.pdf. Defaults to None. adaptive_clipping (bool, optional): If enabled, the model expects the last entry of the parameter list to be a binary value indicating whether or not the batch gradient was clipped. Defaults to False. @@ -152,8 +147,8 @@ def __init__( # Weighted averaging requires list of sample counts # to compute client weights. Set by server after polling clients. - self.sample_counts: Optional[List[int]] = None - self.m_t: Optional[NDArrays] = None + self.sample_counts: list[int] | None = None + self.m_t: NDArrays | None = None def __repr__(self) -> str: rep = f"ClientLevelDPFedAvgM(accept_failures={self.accept_failures})" @@ -181,19 +176,19 @@ def modify_noise_multiplier(self) -> float: return pow(sqrt_argument, -0.5) def split_model_weights_and_clipping_bits( - self, results: List[Tuple[ClientProxy, FitRes]] - ) -> Tuple[List[Tuple[NDArrays, int]], NDArrays]: + self, results: list[tuple[ClientProxy, FitRes]] + ) -> tuple[list[tuple[NDArrays, int]], NDArrays]: """ Given results from an FL round of training, this function splits the result into sets of (weights, training counts) and clipping bits. The split is required because the clipping bits are packed with the weights in order to communicate them back to the server. The parameter packer facilitates this splitting. Args: - results (List[Tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training + results (list[tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training that need to be aggregated on the server-side. In this strategy, the clients pack the weights to be aggregated along with a clipping bit calculated during training. Returns: - Tuple[List[Tuple[NDArrays, int]], NDArrays]: The first tuple is the set of (weights, training counts) per + tuple[list[tuple[NDArrays, int]], NDArrays]: The first tuple is the set of (weights, training counts) per client. The second is a set of clipping bits, one for each client. """ # Sorting the results by elements and sample counts. This is primarily to reduce numerical fluctuations in @@ -203,7 +198,7 @@ def split_model_weights_and_clipping_bits( (weights, sample_counts) for _, weights, sample_counts in decode_and_pseudo_sort_results(results) ] - weights_and_counts: List[Tuple[NDArrays, int]] = [] + weights_and_counts: list[tuple[NDArrays, int]] = [] clipping_bits: NDArrays = [] for weights, sample_count in decoded_and_sorted_results: updated_weights, clipping_bit = self.parameter_packer.unpack_parameters(weights) @@ -274,9 +269,9 @@ def update_clipping_bound(self, clipping_bits: NDArrays) -> None: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """ Aggregate fit using averaging of weights (can be unweighted or weighted) and inject noise and optionally perform adaptive clipping updates. @@ -286,14 +281,14 @@ def aggregate_fit( Args: server_round (int): Indicates the server round we're currently on. - results (List[Tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training + results (list[tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training that need to be aggregated on the server-side. In this strategy, the clients pack the weights to be aggregated along with a clipping bit calculated during their local training cycle. - failures (List[Union[Tuple[ClientProxy, FitRes], BaseException]]): These are the results and exceptions + failures (list[tuple[ClientProxy, FitRes] | BaseException]): These are the results and exceptions from clients that experienced an issue during training, such as timeouts or exceptions. Returns: - Tuple[Optional[Parameters], Dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. + tuple[Parameters | None, dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. For this strategy, the server also packs a clipping bound to be sent to the clients. This is sent even if adaptive clipping is turned off and the value simply remains constant. """ @@ -364,7 +359,7 @@ def aggregate_fit( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """ This function configures a sample of clients for a training round. Due to the privacy accounting, this strategy requires that the sampling manager be of type BaseFractionSamplingManager. @@ -380,7 +375,7 @@ def configure_fit( be BaseFractionSamplingManager, which has a sample_fraction function built in. Returns: - List[Tuple[ClientProxy, FitIns]]: List of sampled client identifiers and the configuration/parameters to + list[tuple[ClientProxy, FitIns]]: List of sampled client identifiers and the configuration/parameters to be sent to each client (packaged as FitIns). """ # This strategy requires the client manager to be of type at least BaseFractionSamplingManager @@ -399,7 +394,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """ This function configures a sample of clients for an eval round. Due to the privacy accounting, this strategy requires that the sampling manager be of type BaseFractionSamplingManager. @@ -415,7 +410,7 @@ def configure_evaluate( be BaseFractionSamplingManager, which has a sample_fraction function built in. Returns: - List[Tuple[ClientProxy, EvaluateIns]]: List of sampled client identifiers and the configuration/parameters + list[tuple[ClientProxy, EvaluateIns]]: List of sampled client identifiers and the configuration/parameters to be sent to each client (packaged as EvaluateIns) """ diff --git a/fl4health/strategies/fedavg_dynamic_layer.py b/fl4health/strategies/fedavg_dynamic_layer.py index 9d2fd8cfe..c2bbddc9d 100644 --- a/fl4health/strategies/fedavg_dynamic_layer.py +++ b/fl4health/strategies/fedavg_dynamic_layer.py @@ -1,7 +1,7 @@ from collections import defaultdict +from collections.abc import Callable from functools import reduce from logging import WARNING -from typing import Callable, DefaultDict, Dict, List, Optional, Tuple, Union import numpy as np from flwr.common import MetricsAggregationFn, NDArray, NDArrays, Parameters, ndarrays_to_parameters @@ -23,18 +23,15 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, - initial_parameters: Optional[Parameters] = None, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + initial_parameters: Parameters | None = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, weighted_aggregation: bool = True, weighted_eval_losses: bool = True, ) -> None: @@ -48,20 +45,18 @@ def __init__( min_fit_clients (int, optional): Minimum number of clients used during fitting. Defaults to 2. min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. min_available_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - initial_parameters (Optional[Parameters], optional): Initial global model parameters. Defaults to None. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + initial_parameters (Parameters | None, optional): Initial global model parameters. Defaults to None. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. weighted_aggregation (bool, optional): Determines whether parameter aggregation is a linearly weighted average or a uniform average. FedAvg default is weighted average by client dataset counts. @@ -91,9 +86,9 @@ def __init__( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """ Aggregate the results from the federated fit round. The aggregation requires some special treatment, as the participating clients are allowed to exchange an arbitrary set of weights. So before aggregation takes place @@ -101,14 +96,14 @@ def aggregate_fit( Args: server_round (int): Indicates the server round we're currently on. - results (List[Tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training + results (list[tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training that need to be aggregated on the server-side. In this scheme, the clients pack the layer weights into the results object along with the weight values to allow for alignment during aggregation. - failures (List[Union[Tuple[ClientProxy, FitRes], BaseException]]): These are the results and exceptions + failures (list[tuple[ClientProxy, FitRes] | BaseException]): These are the results and exceptions from clients that experienced an issue during training, such as timeouts or exceptions. Returns: - Tuple[Optional[Parameters], Dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. + tuple[Parameters | None, dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. For dynamic layer exchange we also pack in the names of all of the layers that were aggregated in this phase to allow client's to insert the values into the proper areas of their models. """ @@ -148,19 +143,19 @@ def aggregate_fit( return ndarrays_to_parameters(parameters), metrics_aggregated - def aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, NDArray]: + def aggregate(self, results: list[tuple[NDArrays, int]]) -> dict[str, NDArray]: """ Aggregate the different layers across clients that have contributed to a layer. This aggregation may be weighted or unweighted. The called functions handle layer alignment. Args: - results (List[Tuple[NDArrays, int]]): The weight results from each client's local training that need to be + results (list[tuple[NDArrays, int]]): The weight results from each client's local training that need to be aggregated on the server-side and the number of training samples held on each client. In this scheme, the clients pack the layer weights into the results object along with the weight values to allow for alignment during aggregation. Returns: - Dict[str, NDArray]: A dictionary mapping the name of the layer that was aggregated to the aggregated + dict[str, NDArray]: A dictionary mapping the name of the layer that was aggregated to the aggregated weights. """ if self.weighted_aggregation: @@ -168,24 +163,24 @@ def aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, NDArray]: else: return self.unweighted_aggregate(results) - def weighted_aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, NDArray]: + def weighted_aggregate(self, results: list[tuple[NDArrays, int]]) -> dict[str, NDArray]: """ Results consists of the layer weights (and their names) sent by clients who participated in this round of training. Since each client can send an arbitrary subset of layers, the aggregate performs weighted averaging for each layer separately. Args: - results (List[Tuple[NDArrays, int]]): The weight results from each client's local training that need to be + results (list[tuple[NDArrays, int]]): The weight results from each client's local training that need to be aggregated on the server-side and the number of training samples held on each client. In this scheme, the clients pack the layer weights into the results object along with the weight values to allow for alignment during aggregation. Returns: - Dict[str, NDArray]: A dictionary mapping the name of the layer that was aggregated to the aggregated + dict[str, NDArray]: A dictionary mapping the name of the layer that was aggregated to the aggregated weights. """ - names_to_layers: DefaultDict[str, List[NDArray]] = defaultdict(list) - total_num_examples: DefaultDict[str, int] = defaultdict(int) + names_to_layers: defaultdict[str, list[NDArray]] = defaultdict(list) + total_num_examples: defaultdict[str, int] = defaultdict(int) for packed_layers, num_examples in results: layers, names = self.parameter_packer.unpack_parameters(packed_layers) @@ -200,24 +195,24 @@ def weighted_aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, N return name_to_layers_aggregated - def unweighted_aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, NDArray]: + def unweighted_aggregate(self, results: list[tuple[NDArrays, int]]) -> dict[str, NDArray]: """ Results consists of the layer weights (and their names) sent by clients who participated in this round of training. Since each client can send an arbitrary subset of layers, the aggregate performs uniform averaging for each layer separately. Args: - results (List[Tuple[NDArrays, int]]): The weight results from each client's local training that need to be + results (list[tuple[NDArrays, int]]): The weight results from each client's local training that need to be aggregated on the server-side and the number of training samples held on each client. In this scheme, the clients pack the layer weights into the results object along with the weight values to allow for alignment during aggregation. Returns: - Dict[str, NDArray]: A dictionary mapping the name of the layer that was aggregated to the aggregated + dict[str, NDArray]: A dictionary mapping the name of the layer that was aggregated to the aggregated weights. """ - names_to_layers: DefaultDict[str, List[NDArray]] = defaultdict(list) - total_num_clients: DefaultDict[str, int] = defaultdict(int) + names_to_layers: defaultdict[str, list[NDArray]] = defaultdict(list) + total_num_clients: defaultdict[str, int] = defaultdict(int) for packed_layers, _ in results: layers, names = self.parameter_packer.unpack_parameters(packed_layers) diff --git a/fl4health/strategies/fedavg_sparse_coo_tensor.py b/fl4health/strategies/fedavg_sparse_coo_tensor.py index 1c7703d7f..ccbbc11a2 100644 --- a/fl4health/strategies/fedavg_sparse_coo_tensor.py +++ b/fl4health/strategies/fedavg_sparse_coo_tensor.py @@ -1,7 +1,7 @@ from collections import defaultdict +from collections.abc import Callable from functools import reduce from logging import WARNING -from typing import Callable, DefaultDict, Dict, List, Optional, Tuple, Union import torch from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters @@ -24,18 +24,15 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, - initial_parameters: Optional[Parameters] = None, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + initial_parameters: Parameters | None = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, weighted_aggregation: bool = True, weighted_eval_losses: bool = True, ) -> None: @@ -61,20 +58,18 @@ def __init__( min_fit_clients (int, optional): _description_. Defaults to 2. min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. min_available_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - initial_parameters (Optional[Parameters], optional): Initial global model parameters. Defaults to None. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + initial_parameters (Parameters | None, optional): Initial global model parameters. Defaults to None. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. weighted_aggregation (bool, optional): Determines whether parameter aggregation is a linearly weighted average or a uniform average. FedAvg default is weighted average by client dataset counts. @@ -104,9 +99,9 @@ def __init__( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """ Aggregate the results from the federated fit round. The aggregation requires some special treatment, as the participating clients are allowed to exchange an arbitrary set of parameters. So before aggregation takes place @@ -121,14 +116,14 @@ def aggregate_fit( Args: server_round (int): Indicates the server round we're currently on. - results (List[Tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training + results (list[tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training that need to be aggregated on the server-side. In this scheme, the clients pack the tensor names into the results object along with the weight values to allow for alignment during aggregation. - failures (List[Union[Tuple[ClientProxy, FitRes], BaseException]]): These are the results and exceptions + failures (list[tuple[ClientProxy, FitRes] | BaseException]): These are the results and exceptions from clients that experienced an issue during training, such as timeouts or exceptions. Returns: - Tuple[Optional[Parameters], Dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. + tuple[Parameters | None, dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. For sparse tensor exchange we also pack in the names of all of the tensors that were aggregated in this phase to allow clients to insert the values into the proper areas of their models. """ @@ -178,14 +173,14 @@ def aggregate_fit( return ndarrays_to_parameters(packed_parameters), metrics_aggregated - def aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, Tensor]: + def aggregate(self, results: list[tuple[NDArrays, int]]) -> dict[str, Tensor]: """ Aggregate the different tensors across clients that have contributed to a certain tensor. This aggregation may be weighted or unweighted. The called functions handle tensor alignment. Args: - results (List[Tuple[NDArrays, int]]): The weight results from each client's local training + results (list[tuple[NDArrays, int]]): The weight results from each client's local training that need to be aggregated on the server-side and the number of training samples held on each client. @@ -193,7 +188,7 @@ def aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, Tensor]: the weight values to allow for alignment during aggregation. Returns: - Dict[str, Tensor]: A dictionary mapping the name of the tensor that was aggregated to the aggregated + dict[str, Tensor]: A dictionary mapping the name of the tensor that was aggregated to the aggregated weights. """ if self.weighted_aggregation: @@ -201,7 +196,7 @@ def aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, Tensor]: else: return self.unweighted_aggregate(results) - def weighted_aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, Tensor]: + def weighted_aggregate(self, results: list[tuple[NDArrays, int]]) -> dict[str, Tensor]: """ "results" consist of four parts: the exchanged (nonzero) parameter values, their coordinates within the tensor to which they belong, @@ -222,18 +217,18 @@ def weighted_aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, T Note: this method performs weighted averaging. Args: - results (List[Tuple[NDArrays, int]]): The weight results from each client's local training that need to be + results (list[tuple[NDArrays, int]]): The weight results from each client's local training that need to be aggregated on the server-side and the number of training samples held on each client. The weight results consist of four parts, as detailed above. In this scheme, the clients pack the layer names into the results object along with the weight values to allow for alignment during aggregation. Returns: - Dict[str, Tensor]: A dictionary mapping the name of the tensor that was aggregated to the aggregated + dict[str, Tensor]: A dictionary mapping the name of the tensor that was aggregated to the aggregated weights. """ - names_to_dense_tensors: DefaultDict[str, List[Tensor]] = defaultdict(list) - total_num_examples: DefaultDict[str, int] = defaultdict(int) + names_to_dense_tensors: defaultdict[str, list[Tensor]] = defaultdict(list) + total_num_examples: defaultdict[str, int] = defaultdict(int) for packed_parameters, num_examples in results: nonzero_parameter_values, additional_info = self.parameter_packer.unpack_parameters(packed_parameters) @@ -263,7 +258,7 @@ def weighted_aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, T return names_to_tensors_aggregated - def unweighted_aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, Tensor]: + def unweighted_aggregate(self, results: list[tuple[NDArrays, int]]) -> dict[str, Tensor]: """ "results" consist of four parts: the exchanged (nonzero) parameter values, their coordinates within the tensor to which they belong, @@ -283,18 +278,18 @@ def unweighted_aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, Note: this method performs uniform averaging. Args: - results (List[Tuple[NDArrays, int]]): The weight results from each client's local training that need to be + results (list[tuple[NDArrays, int]]): The weight results from each client's local training that need to be aggregated on the server-side and the number of training samples held on each client. The weight results consist of four parts, as detailed above. In this scheme, the clients pack the layer names into the results object along with the weight values to allow for alignment during aggregation. Returns: - Dict[str, Tensor]: A dictionary mapping the name of the tensor that was aggregated to the aggregated + dict[str, Tensor]: A dictionary mapping the name of the tensor that was aggregated to the aggregated weights. """ - names_to_dense_tensors: DefaultDict[str, List[Tensor]] = defaultdict(list) - total_num_clients: DefaultDict[str, int] = defaultdict(int) + names_to_dense_tensors: defaultdict[str, list[Tensor]] = defaultdict(list) + total_num_clients: defaultdict[str, int] = defaultdict(int) for packed_parameters, _ in results: nonzero_parameter_values, additional_info = self.parameter_packer.unpack_parameters(packed_parameters) diff --git a/fl4health/strategies/fedavg_with_adaptive_constraint.py b/fl4health/strategies/fedavg_with_adaptive_constraint.py index d60c346b9..2c4d7593c 100644 --- a/fl4health/strategies/fedavg_with_adaptive_constraint.py +++ b/fl4health/strategies/fedavg_with_adaptive_constraint.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from logging import INFO, WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters, parameters_to_ndarrays @@ -22,18 +22,15 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, initial_parameters: Parameters, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, initial_loss_weight: float = 1.0, adapt_loss_weight: bool = False, loss_weight_delta: float = 0.1, @@ -62,20 +59,18 @@ def __init__( min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. initial_parameters (Parameters): Initial global model parameters. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. initial_loss_weight (float): Initial loss weight (mu in FedProx). If adaptivity is false, then this is the constant weight used for all clients. @@ -132,23 +127,23 @@ def __init__( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """ Aggregate the results from the federated fit round and, if applicable, determine whether the constraint weight should be updated based on the aggregated loss seen on the clients. Args: server_round (int): Indicates the server round we're currently on. - results (List[Tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training + results (list[tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training that need to be aggregated on the server-side. For adaptive constraints, the clients pack the weights to be aggregated along with the training loss seen during their local training cycle. - failures (List[Union[Tuple[ClientProxy, FitRes], BaseException]]): These are the results and exceptions + failures (list[tuple[ClientProxy, FitRes] | BaseException]): These are the results and exceptions from clients that experienced an issue during training, such as timeouts or exceptions. Returns: - Tuple[Optional[Parameters], Dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. + tuple[Parameters | None, dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. For adaptive constraints, the server also packs a constraint weight to be sent to the clients. This is sent even if adaptive constraint weights are turned off and the value simply remains constant. """ @@ -166,8 +161,8 @@ def aggregate_fit( ] # Convert results with packed params of model weights and training loss - weights_and_counts: List[Tuple[NDArrays, int]] = [] - train_losses_and_counts: List[Tuple[int, float]] = [] + weights_and_counts: list[tuple[NDArrays, int]] = [] + train_losses_and_counts: list[tuple[int, float]] = [] for weights, sample_count in decoded_and_sorted_results: updated_weights, train_loss = self.parameter_packer.unpack_parameters(weights) weights_and_counts.append((updated_weights, sample_count)) diff --git a/fl4health/strategies/feddg_ga.py b/fl4health/strategies/feddg_ga.py index 108758707..b88eb2aee 100644 --- a/fl4health/strategies/feddg_ga.py +++ b/fl4health/strategies/feddg_ga.py @@ -1,6 +1,6 @@ +from collections.abc import Callable from enum import Enum from logging import INFO, WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np from flwr.common import EvaluateIns, MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters @@ -57,8 +57,8 @@ class FairnessMetric: def __init__( self, metric_type: FairnessMetricType, - metric_name: Optional[str] = None, - signal: Optional[float] = None, + metric_name: str | None = None, + signal: float | None = None, ): """ Instantiates a fairness metric with a type and optional metric name and @@ -95,19 +95,16 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, - initial_parameters: Optional[Parameters] = None, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - fairness_metric: Optional[FairnessMetric] = None, + initial_parameters: Parameters | None = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, + fairness_metric: FairnessMetric | None = None, adjustment_weight_step_size: float = 0.2, ): """ @@ -124,23 +121,19 @@ def __init__( Minimum number of clients used during validation. Defaults to 2. min_available_clients : int, optional Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : - Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]]] - ] + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + on_fit_config_fn : Callable[[int], dict[str, Scalar]], optional Function used to configure training. Must be specified for this strategy. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + on_evaluate_config_fn : Callable[[int], dict[str, Scalar]], optional Function used to configure validation. Must be specified for this strategy. Defaults to None. accept_failures : bool, optional Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Parameters, optional Initial global model parameters. - fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] + fit_metrics_aggregation_fn : MetricsAggregationFn | None Metrics aggregation function, optional. - evaluate_metrics_aggregation_fn : Optional[MetricsAggregationFn] + evaluate_metrics_aggregation_fn : MetricsAggregationFn | None Metrics aggregation function, optional. fairness_metric : FairnessMetric, optional. The metric to evaluate the local model of each client against the global model in order to @@ -184,18 +177,18 @@ def __init__( log(INFO, f"FedDG-GA Strategy initialized with weight_step_size of {self.adjustment_weight_step_size}") log(INFO, f"FedDG-GA Strategy initialized with FairnessMetric {self.fairness_metric}") - self.train_metrics: Dict[str, Dict[str, Scalar]] = {} - self.evaluation_metrics: Dict[str, Dict[str, Scalar]] = {} - self.num_rounds: Optional[int] = None - self.initial_adjustment_weight: Optional[float] = None - self.adjustment_weights: Dict[str, float] = {} + self.train_metrics: dict[str, dict[str, Scalar]] = {} + self.evaluation_metrics: dict[str, dict[str, Scalar]] = {} + self.num_rounds: int | None = None + self.initial_adjustment_weight: float | None = None + self.adjustment_weights: dict[str, float] = {} def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager, - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """ Configure the next round of training. @@ -210,7 +203,7 @@ def configure_fit( connected clients. It must be an instance of FixedSamplingClientManager. Returns: - (List[Tuple[ClientProxy, FitIns]]) the input for the clients' fit function. + (list[tuple[ClientProxy, FitIns]]) the input for the clients' fit function. """ assert isinstance( client_manager, FixedSamplingClientManager @@ -246,7 +239,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: assert isinstance( client_manager, FixedSamplingClientManager ), f"Client manager is not of type FixedSamplingClientManager: {type(client_manager)}" @@ -263,9 +256,9 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """ Aggregate fit results by weighing them against the adjustment weights and then summing them. @@ -273,11 +266,11 @@ def aggregate_fit( Args: server_round: (int) the current server round. - results: (List[Tuple[ClientProxy, FitRes]]) The clients' fit results. - failures: (List[Union[Tuple[ClientProxy, FitRes], BaseException]]) the clients' fit failures. + results: (list[tuple[ClientProxy, FitRes]]) The clients' fit results. + failures: (list[tuple[ClientProxy, FitRes] | BaseException]) the clients' fit failures. Returns: - (Tuple[Optional[Parameters], Dict[str, Scalar]]) A tuple containing the aggregated parameters + (tuple[Parameters | None, dict[str, Scalar]]) A tuple containing the aggregated parameters and the aggregated fit metrics. """ if not results: @@ -305,9 +298,9 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[tuple[ClientProxy, EvaluateRes] | BaseException], + ) -> tuple[float | None, dict[str, Scalar]]: """ Aggregate evaluation losses using weighted average. @@ -316,11 +309,11 @@ def aggregate_evaluate( Args: server_round: (int) the current server round. - results: (List[Tuple[ClientProxy, FitRes]]) The clients' evaluate results. - failures: (List[Union[Tuple[ClientProxy, FitRes], BaseException]]) the clients' evaluate failures. + results: (list[tuple[ClientProxy, FitRes]]) The clients' evaluate results. + failures: (list[tuple[ClientProxy, FitRes] | BaseException]) the clients' evaluate failures. Returns: - (Tuple[Optional[float], Dict[str, Scalar]]) A tuple containing the aggregated evaluation loss + (tuple[float | None, dict[str, Scalar]]) A tuple containing the aggregated evaluation loss and the aggregated evaluation metrics. """ @@ -340,12 +333,12 @@ def aggregate_evaluate( return loss_aggregated, metrics_aggregated - def weight_and_aggregate_results(self, results: List[Tuple[ClientProxy, FitRes]]) -> NDArrays: + def weight_and_aggregate_results(self, results: list[tuple[ClientProxy, FitRes]]) -> NDArrays: """ Aggregate results by weighing them against the adjustment weights and then summing them. Args: - results: (List[Tuple[ClientProxy, FitRes]]) The clients' fit results. + results: (list[tuple[ClientProxy, FitRes]]) The clients' fit results. Returns: (NDArrays) the weighted and aggregated results. @@ -363,7 +356,7 @@ def weight_and_aggregate_results(self, results: List[Tuple[ClientProxy, FitRes]] # reducing numerical fluctuation. decoded_and_sorted_results = decode_and_pseudo_sort_results(results) - aggregated_results: Optional[NDArrays] = None + aggregated_results: NDArrays | None = None for client_proxy, weights, _ in decoded_and_sorted_results: cid = client_proxy.cid @@ -390,14 +383,14 @@ def weight_and_aggregate_results(self, results: List[Tuple[ClientProxy, FitRes]] assert aggregated_results is not None return aggregated_results - def update_weights_by_ga(self, server_round: int, cids: List[str]) -> None: + def update_weights_by_ga(self, server_round: int, cids: list[str]) -> None: """ Update the self.adjustment_weights dictionary by calculating the new weights based on the current server round, fit and evaluation metrics. Args: server_round: (int) the current server round. - cids: (List[str]) the list of client ids that participated in this round. + cids: (list[str]) the list of client ids that participated in this round. """ generalization_gaps = [] # calculating local vs global metric difference (generalization gaps) diff --git a/fl4health/strategies/feddg_ga_with_adaptive_constraint.py b/fl4health/strategies/feddg_ga_with_adaptive_constraint.py index 8ab069492..75b236bc4 100644 --- a/fl4health/strategies/feddg_ga_with_adaptive_constraint.py +++ b/fl4health/strategies/feddg_ga_with_adaptive_constraint.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from logging import INFO, WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np from flwr.common import MetricsAggregationFn, NDArrays, Parameters, ndarrays_to_parameters, parameters_to_ndarrays @@ -19,24 +19,21 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, initial_parameters: Parameters, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, initial_loss_weight: float = 1.0, adapt_loss_weight: bool = False, loss_weight_delta: float = 0.1, loss_weight_patience: int = 5, weighted_train_losses: bool = False, - fairness_metric: Optional[FairnessMetric] = None, + fairness_metric: FairnessMetric | None = None, adjustment_weight_step_size: float = 0.2, ): """ @@ -50,21 +47,17 @@ def __init__( min_fit_clients (int, optional): Minimum number of clients used during training. Defaults to 2. min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : - Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]]] - ] + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for validation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): Function used to configure + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): Function used to configure + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure validation. Defaults to None initial_parameters (Parameters): Initial global model parameters. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function, Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. initial_loss_weight (float, optional): Initial penalty loss weight (mu in FedProx). If adaptivity is false, then this is the constant weight used for all clients. Defaults to 1.0. @@ -79,7 +72,7 @@ def __init__( weighted_train_losses (bool, optional): Determines whether the training losses from the clients should be aggregated using a weighted or unweighted average. These aggregated losses are used to adjust the proximal weight in the adaptive setting. Defaults to False. - fairness_metric (Optional[FairnessMetric], optional): he metric to evaluate the local model of each + fairness_metric (FairnessMetric | None, optional): he metric to evaluate the local model of each client against the global model in order to determine their adjustment weight for aggregation. Can be set to any default metric in FairnessMetricType or set to use a custom metric. Optional, default is FairnessMetric(FairnessMetricType.LOSS) when specified as None. @@ -122,9 +115,9 @@ def __init__( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """ Aggregate fit results by weighing them against the adjustment weights and then summing them. @@ -135,11 +128,11 @@ def aggregate_fit( Args: server_round: (int) the current server round. - results: (List[Tuple[ClientProxy, FitRes]]) The clients' fit results. - failures: (List[Union[Tuple[ClientProxy, FitRes], BaseException]]) the clients' fit failures. + results: (list[tuple[ClientProxy, FitRes]]) The clients' fit results. + failures: (list[tuple[ClientProxy, FitRes] | BaseException]) the clients' fit failures. Returns: - (Tuple[Optional[Parameters], Dict[str, Scalar]]) A tuple containing the aggregated parameters + (tuple[Parameters | None, dict[str, Scalar]]) A tuple containing the aggregated parameters and the aggregated fit metrics. For adaptive constraints, the server also packs a constraint weight to be sent to the clients. This is sent even if adaptive constraint weights are turned off and the value simply remains constant. @@ -175,7 +168,7 @@ def aggregate_fit( parameters = self.parameter_packer.pack_parameters(weights_aggregated, self.loss_weight) return ndarrays_to_parameters(parameters), metrics_aggregated - def _unpack_weights_and_losses(self, results: List[Tuple[ClientProxy, FitRes]]) -> List[Tuple[int, float]]: + def _unpack_weights_and_losses(self, results: list[tuple[ClientProxy, FitRes]]) -> list[tuple[int, float]]: """ This function takes results returned from a fit round from each of the participating clients and unpacks the information into the appropriate objects. The parameters contained in the FitRes object are unpacked to @@ -185,13 +178,13 @@ def _unpack_weights_and_losses(self, results: List[Tuple[ClientProxy, FitRes]]) NOTE: The results that are passed to this function are MODIFIED IN-PLACE Args: - results (List[Tuple[ClientProxy, FitRes]]): The results produced in a fitting round by each of the clients + results (list[tuple[ClientProxy, FitRes]]): The results produced in a fitting round by each of the clients these the FitRes object contains both model weights and training losses which need to be processed. Returns: - List[Tuple[int, float]]: A list of the training losses produced by client training + list[tuple[int, float]]: A list of the training losses produced by client training """ - train_losses_and_counts: List[Tuple[int, float]] = [] + train_losses_and_counts: list[tuple[int, float]] = [] for _, fit_res in results: sample_count = fit_res.num_examples updated_weights, train_loss = self.parameter_packer.unpack_parameters( diff --git a/fl4health/strategies/fedpca.py b/fl4health/strategies/fedpca.py index 6d9546783..0bdc67fc5 100644 --- a/fl4health/strategies/fedpca.py +++ b/fl4health/strategies/fedpca.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from logging import INFO, WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np from flwr.common import MetricsAggregationFn, NDArray, NDArrays, Parameters, ndarrays_to_parameters @@ -20,18 +20,15 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, - initial_parameters: Optional[Parameters] = None, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + initial_parameters: Parameters | None = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, weighted_aggregation: bool = True, weighted_eval_losses: bool = True, svd_merging: bool = True, @@ -45,20 +42,18 @@ def __init__( fraction_fit (float, optional): Fraction of clients used during training. Defaults to 1.0. Defaults to 1.0. fraction_evaluate (float, optional): Fraction of clients used during validation. Defaults to 1.0. min_available_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - initial_parameters (Optional[Parameters], optional): Initial global model parameters. Defaults to None. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + initial_parameters (Parameters | None, optional): Initial global model parameters. Defaults to None. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. weighted_aggregation (bool, optional): Determines whether parameter aggregation is a linearly weighted average or a uniform average. FedAvg default is weighted average by client dataset counts. @@ -92,22 +87,22 @@ def __init__( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """ Aggregate client parameters. In this case, merge all clients' local principal components. Args: server_round (int): Indicates the server round we're currently on. - results (List[Tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training + results (list[tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local training that need to be aggregated on the server-side. In this scheme, the clients pack the layer weights into the results object along with the weight values to allow for alignment during aggregation. - failures (List[Union[Tuple[ClientProxy, FitRes], BaseException]]): These are the results and exceptions + failures (list[tuple[ClientProxy, FitRes] | BaseException]): These are the results and exceptions from clients that experienced an issue during training, such as timeouts or exceptions. Returns: - Tuple[Optional[Parameters], Dict[str, Scalar]]: The aggregated parameters and the metrics dictionary. + tuple[Parameters | None, dict[str, Scalar]]: The aggregated parameters and the metrics dictionary. In this case, the parameters are the new singular vectors and their corresponding singular values. """ if not results: @@ -153,7 +148,7 @@ def aggregate_fit( def merge_subspaces_svd( self, client_singular_vectors: NDArrays, client_singular_values: NDArrays - ) -> Tuple[NDArray, NDArray]: + ) -> tuple[NDArray, NDArray]: """ Produce the principal components for all the data distributed across clients by merging the principal components belonging to each local dataset. @@ -191,7 +186,7 @@ def merge_subspaces_svd( client_singular_values (NDArrays): Singular values corresponding to local PCs. Returns: - Tuple[NDArray, NDArray]: merged PCs and corresponding singular values. + tuple[NDArray, NDArray]: merged PCs and corresponding singular values. Note: This method assumes that the *columns* of U_i's are the local principal components. @@ -213,7 +208,7 @@ def merge_subspaces_svd( def merge_subspaces_qr( self, client_singular_vectors: NDArrays, client_singular_values: NDArrays - ) -> Tuple[NDArray, NDArray]: + ) -> tuple[NDArray, NDArray]: """ Produce the principal components (PCs) for all the data distributed across clients by merging the PCs belonging to each local dataset. @@ -244,7 +239,7 @@ def merge_subspaces_qr( client_singular_values (NDArrays): Singular values corresponding to local PCs. Returns: - Tuple[NDArray, NDArray]: merged PCs and corresponding singular values. + tuple[NDArray, NDArray]: merged PCs and corresponding singular values. Note: Similar to merge_subspaces_svd, this method assumes that the *columns* of U_i's are @@ -261,8 +256,8 @@ def merge_subspaces_qr( return self.merge_two_subspaces_qr((U, np.diag(S)), (U_last, np.diag(S_last))) def merge_two_subspaces_qr( - self, subspace1: Tuple[NDArray, NDArray], subspace2: Tuple[NDArray, NDArray] - ) -> Tuple[NDArray, NDArray]: + self, subspace1: tuple[NDArray, NDArray], subspace2: tuple[NDArray, NDArray] + ) -> tuple[NDArray, NDArray]: U1, S1 = subspace1 U2, S2 = subspace2 diff --git a/fl4health/strategies/fedpm.py b/fl4health/strategies/fedpm.py index 29402e8c4..4511b3db6 100644 --- a/fl4health/strategies/fedpm.py +++ b/fl4health/strategies/fedpm.py @@ -1,6 +1,6 @@ from collections import defaultdict +from collections.abc import Callable from functools import reduce -from typing import Callable, DefaultDict, Dict, List, Optional, Tuple import numpy as np from flwr.common import MetricsAggregationFn, NDArray, NDArrays, Parameters @@ -18,18 +18,15 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, - initial_parameters: Optional[Parameters] = None, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + initial_parameters: Parameters | None = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, weighted_eval_losses: bool = True, bayesian_aggregation: bool = True, ) -> None: @@ -47,20 +44,18 @@ def __init__( min_fit_clients (int, optional): Minimum number of clients used during fitting. Defaults to 2. min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. min_available_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - initial_parameters (Optional[Parameters], optional): Initial global model parameters. Defaults to None. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + initial_parameters (Parameters | None, optional): Initial global model parameters. Defaults to None. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. weighted_aggregation (bool, optional): Determines whether parameter aggregation is a linearly weighted average or a uniform average. FedAvg default is weighted average by client dataset counts. @@ -87,16 +82,16 @@ def __init__( weighted_eval_losses=weighted_eval_losses, ) # Parameters for Beta distribution. - self.beta_parameters: Dict[str, Tuple[NDArray, NDArray]] = {} + self.beta_parameters: dict[str, tuple[NDArray, NDArray]] = {} self.bayesian_aggregation = bayesian_aggregation - def aggregate(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, NDArray]: + def aggregate(self, results: list[tuple[NDArrays, int]]) -> dict[str, NDArray]: if not self.bayesian_aggregation: return super().aggregate(results) else: return self.aggregate_bayesian(results) - def aggregate_bayesian(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, NDArray]: + def aggregate_bayesian(self, results: list[tuple[NDArrays, int]]) -> dict[str, NDArray]: """ Perform posterior update to the Beta distribution parameters based on the binary masks sent by the clients. @@ -124,8 +119,8 @@ def aggregate_bayesian(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, N In the beginning, alpha and beta are initialized to arrays of all ones. """ - names_to_layers: DefaultDict[str, List[NDArray]] = defaultdict(list) - total_num_clients: DefaultDict[str, int] = defaultdict(int) + names_to_layers: defaultdict[str, list[NDArray]] = defaultdict(list) + total_num_clients: defaultdict[str, int] = defaultdict(int) # unpack the parameters and initialize the beta parameters to be all ones if they have not already # been initialized. @@ -139,7 +134,7 @@ def aggregate_bayesian(self, results: List[Tuple[NDArrays, int]]) -> Dict[str, N beta = np.ones_like(layer) self.beta_parameters[name] = (alpha, beta) - aggregation_result: Dict[str, NDArray] = {} + aggregation_result: dict[str, NDArray] = {} # posterior update of the beta parameters and using them # to compute the final result. diff --git a/fl4health/strategies/flash.py b/fl4health/strategies/flash.py index 97fbe8dae..c6f426efd 100644 --- a/fl4health/strategies/flash.py +++ b/fl4health/strategies/flash.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Optional, Tuple, Union +from collections.abc import Callable import numpy as np from flwr.common import ( @@ -24,18 +24,15 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, initial_parameters: Parameters, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, eta: float = 1e-1, eta_l: float = 1e-1, beta_1: float = 0.9, @@ -59,20 +56,20 @@ def __init__( Minimum number of clients used during validation. Defaults to 2. min_available_clients : int, optional Minimum number of total clients in the system. Defaults to 2. - evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]]]] + evaluate_fn : Callable[[int, NDArrays, dict[str, Scalar] | None, + tuple[float, dict[str, Scalar]]]] | None Optional function used for validation. Defaults to None. - on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional + on_fit_config_fn : Callable[[int], dict[str, Scalar]], optional Function used to configure training. Defaults to None. - on_evaluate_config_fn : Callable[[int], Dict[str, Scalar]], optional + on_evaluate_config_fn : Callable[[int], dict[str, Scalar]], optional Function used to configure validation. Defaults to None. accept_failures : bool, optional Whether or not accept rounds containing failures. Defaults to True. initial_parameters : Parameters Initial global model parameters. - fit_metrics_aggregation_fn : Optional[MetricsAggregationFn] + fit_metrics_aggregation_fn : MetricsAggregationFn | None Metrics aggregation function, optional. - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None Metrics aggregation function, optional. eta : float, optional Server-side learning rate. Defaults to 1e-1. @@ -145,9 +142,9 @@ def _update_parameters(self, delta_t: NDArrays) -> None: def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """Aggregate fit results using the Flash method.""" fedavg_parameters_aggregated, metrics_aggregated = super().aggregate_fit( diff --git a/fl4health/strategies/model_merge_strategy.py b/fl4health/strategies/model_merge_strategy.py index 2fdeba842..cd5c8207c 100644 --- a/fl4health/strategies/model_merge_strategy.py +++ b/fl4health/strategies/model_merge_strategy.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union from flwr.common import ( EvaluateIns, @@ -33,17 +33,18 @@ def __init__( min_fit_clients: int = 2, min_evaluate_clients: int = 2, min_available_clients: int = 2, - evaluate_fn: Optional[ + evaluate_fn: ( Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], + [int, NDArrays, dict[str, Scalar]], + tuple[float, dict[str, Scalar]] | None, ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, weighted_aggregation: bool = True ) -> None: """ @@ -61,19 +62,17 @@ def __init__( min_evaluate_clients (int, optional): Minimum number of clients used during validation. Defaults to 2. min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional): Whether or not accept rounds containing failures. Defaults to True. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. counts. Defaults to True. weighted_aggregation (bool, optional): Determines whether parameter aggregation is a linearly weighted @@ -95,7 +94,7 @@ def __init__( def configure_fit( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """ Sample and configure clients for a fit round. @@ -108,7 +107,7 @@ def configure_fit( client_manager (ClientManager): The manager used to sample from the available clients. Returns: - List[Tuple[ClientProxy, FitIns]]: List of sampled client identifiers and the configuration/parameters to + list[tuple[ClientProxy, FitIns]]: List of sampled client identifiers and the configuration/parameters to be sent to each client (packaged as FitIns). """ config = {} @@ -129,7 +128,7 @@ def configure_fit( def configure_evaluate( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, EvaluateIns]]: + ) -> list[tuple[ClientProxy, EvaluateIns]]: """ Sample and configure clients for a evaluation round. @@ -140,7 +139,7 @@ def configure_evaluate( client_manager (ClientManager): The manager used to sample from the available clients. Returns: - List[Tuple[ClientProxy, EvaluateIns]]: List of sampled client identifiers and the configuration/parameters + list[tuple[ClientProxy, EvaluateIns]]: List of sampled client identifiers and the configuration/parameters to be sent to each client (packaged as EvaluateIns). """ # Do not configure federated evaluation if fraction eval is 0. @@ -167,21 +166,21 @@ def configure_evaluate( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """ Performs model merging by taking an unweighted average of client weights and metrics. Args: server_round (int): Indicates the server round we're currently on. Only one round for ModelMergeStrategy. - results (List[Tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local fit + results (list[tuple[ClientProxy, FitRes]]): The client identifiers and the results of their local fit that need to be aggregated on the server-side. - failures (List[Union[Tuple[ClientProxy, FitRes], BaseException]]): These are the results and exceptions + failures (list[tuple[ClientProxy, FitRes] | BaseException]): These are the results and exceptions from clients that experienced an issue during fit, such as timeouts or exceptions. Returns: - Tuple[Optional[Parameters], Dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. + tuple[Parameters | None, dict[str, Scalar]]: The aggregated model weights and the metrics dictionary. """ if not results: return None, {} @@ -214,22 +213,22 @@ def aggregate_fit( def aggregate_evaluate( self, server_round: int, - results: List[Tuple[ClientProxy, EvaluateRes]], - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]], - ) -> Tuple[Optional[float], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, EvaluateRes]], + failures: list[tuple[ClientProxy, EvaluateRes] | BaseException], + ) -> tuple[float | None, dict[str, Scalar]]: """ Aggregate the metrics returned from the clients as a result of the evaluation round. ModelMergeStrategy assumes only metrics will be computed on client and loss is set to None. Args: - results (List[Tuple[ClientProxy, EvaluateRes]]): The client identifiers and the results of their local + results (list[tuple[ClientProxy, EvaluateRes]]): The client identifiers and the results of their local evaluation that need to be aggregated on the server-side. These results are loss values (None in this case) and the metrics dictionary. - failures (List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]]): These are the results and + failures (list[tuple[ClientProxy, EvaluateRes] | BaseException]): These are the results and exceptions from clients that experienced an issue during evaluation, such as timeouts or exceptions. Returns: - Tuple[Optional[float], Dict[str, Scalar]]: Aggregated loss values and the aggregated metrics. The metrics + tuple[float | None, dict[str, Scalar]]: Aggregated loss values and the aggregated metrics. The metrics are aggregated according to evaluate_metrics_aggregation_fn. """ if not results: @@ -248,7 +247,7 @@ def aggregate_evaluate( return None, metrics_aggregated - def evaluate(self, server_round: int, parameters: Parameters) -> Optional[Tuple[float, Dict[str, Scalar]]]: + def evaluate(self, server_round: int, parameters: Parameters) -> tuple[float, dict[str, Scalar]] | None: """ Evaluate the model parameters after the merging has occurred. This function can be used to perform centralized (i.e., server-side) evaluation of model parameters. @@ -258,7 +257,7 @@ def evaluate(self, server_round: int, parameters: Parameters) -> Optional[Tuple[ parameters: Parameters The current model parameters after merging has occurred. Returns: - Optional[Tuple[float, Dict[str, Scalar]]]: A Tuple containing loss and a + tuple[float, dict[str, Scalar]] | None: A Tuple containing loss and a dictionary containing task-specific metrics (e.g., accuracy). """ if self.evaluate_fn is None: @@ -271,7 +270,7 @@ def evaluate(self, server_round: int, parameters: Parameters) -> Optional[Tuple[ loss, metrics = eval_res return loss, metrics - def initialize_parameters(self, client_manager: ClientManager) -> Optional[Parameters]: + def initialize_parameters(self, client_manager: ClientManager) -> Parameters | None: """ Required definition of parent class. ModelMergeStrategy does not support server side initialization. Parameters are always set to None diff --git a/fl4health/strategies/noisy_aggregate.py b/fl4health/strategies/noisy_aggregate.py index 852a92043..5871a5a57 100644 --- a/fl4health/strategies/noisy_aggregate.py +++ b/fl4health/strategies/noisy_aggregate.py @@ -1,5 +1,4 @@ from functools import reduce -from typing import List, Tuple import numpy as np from flwr.common import NDArray, NDArrays @@ -22,13 +21,13 @@ def add_noise_to_array(layer: NDArray, noise_std_dev: float, denominator: int) - return (1.0 / denominator) * (layer + layer_noise) -def add_noise_to_ndarrays(client_model_updates: List[NDArrays], sigma: float, n_clients: int) -> NDArrays: +def add_noise_to_ndarrays(client_model_updates: list[NDArrays], sigma: float, n_clients: int) -> NDArrays: """ This function adds centered gaussian noise (with standard deviation sigma) to the uniform average of the list of the numpy arrays provided. Args: - client_model_updates (List[NDArrays]): List of lists of numpy arrays. Each member of the list represents a + client_model_updates (list[NDArrays]): List of lists of numpy arrays. Each member of the list represents a set of numpy arrays, each of which should be averaged element-wise with the corresponding array from the other lists. These will have centered gaussian noise added. sigma (float): The standard deviation of the centered gaussian noise to be added to each element. @@ -46,13 +45,13 @@ def add_noise_to_ndarrays(client_model_updates: List[NDArrays], sigma: float, n_ def gaussian_noisy_unweighted_aggregate( - results: List[Tuple[NDArrays, int]], noise_multiplier: float, clipping_bound: float + results: list[tuple[NDArrays, int]], noise_multiplier: float, clipping_bound: float ) -> NDArrays: """ Compute unweighted average of weights. Apply gaussian noise to the sum of these weights prior to normalizing. Args: - results (List[Tuple[NDArrays, int]]): List of tuples containing the model updates and the number of samples + results (list[tuple[NDArrays, int]]): List of tuples containing the model updates and the number of samples for each client. noise_multiplier (float): The multiplier on the clipping bound to determine the std of noise applied to weight updates. @@ -70,7 +69,7 @@ def gaussian_noisy_unweighted_aggregate( def gaussian_noisy_weighted_aggregate( - results: List[Tuple[NDArrays, int]], + results: list[tuple[NDArrays, int]], noise_multiplier: float, clipping_bound: float, fraction_fit: float, @@ -84,7 +83,7 @@ def gaussian_noisy_weighted_aggregate( Args: - results (List[Tuple[NDArrays, int]]): List of tuples containing the model updates and the number of samples + results (list[tuple[NDArrays, int]]): List of tuples containing the model updates and the number of samples for each client. noise_multiplier (float): The multiplier on the clipping bound to determine the std of noise applied to weight updates. @@ -97,8 +96,8 @@ def gaussian_noisy_weighted_aggregate( NDArrays: Noised model update for a given round. """ n_clients = len(results) - client_model_updates: List[NDArrays] = [] - client_n_points: List[int] = [] + client_model_updates: list[NDArrays] = [] + client_n_points: list[int] = [] for weights, n_points in results: client_model_updates.append(weights) client_n_points.append(n_points) diff --git a/fl4health/strategies/scaffold.py b/fl4health/strategies/scaffold.py index 8075c3539..bff2accd2 100644 --- a/fl4health/strategies/scaffold.py +++ b/fl4health/strategies/scaffold.py @@ -1,6 +1,6 @@ +from collections.abc import Callable from functools import reduce from logging import WARNING -from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch.nn as nn @@ -32,22 +32,19 @@ def __init__( fraction_fit: float = 1.0, fraction_evaluate: float = 1.0, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, initial_parameters: Parameters, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, weighted_eval_losses: bool = True, learning_rate: float = 1.0, - initial_control_variates: Optional[Parameters] = None, - model: Optional[nn.Module] = None, + initial_control_variates: Parameters | None = None, + model: nn.Module | None = None, ) -> None: """ Scaffold Federated Learning strategy. Implementation based on https://arxiv.org/pdf/1910.06378.pdf @@ -57,29 +54,27 @@ def __init__( fraction_evaluate (float, optional): Fraction of clients used during validation. Defaults to 1.0. min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional):Whether or not accept rounds containing failures. Defaults to True. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. weighted_eval_losses (bool, optional): Determines whether losses during evaluation are linearly weighted averages or a uniform average. FedAvg default is weighted average of the losses by client dataset counts. Defaults to True. learning_rate (float, optional): Learning rate for server side optimization. Defaults to 1.0. - initial_control_variates (Optional[Parameters], optional): These are the initial set of control variates + initial_control_variates (Parameters | None, optional): These are the initial set of control variates to use for the scaffold strategy both on the server and client sides. It is optional, but if it is not provided, the strategy must receive a model that reflects the architecture to be used on the clients. Defaults to None. - model (Optional[nn.Module], optional): If provided and initial_control_variates is not, this is used to + model (nn.Module | None, optional): If provided and initial_control_variates is not, this is used to set the server control variates and the initial control variates on the client side to all zeros. If initial_control_variates are provided, they take precedence. Defaults to None. """ @@ -107,7 +102,7 @@ def __init__( self.parameter_packer = ParameterPackerWithControlVariates(len(self.server_model_weights)) def initialize_control_variates( - self, initial_control_variates: Optional[Parameters], model: Optional[nn.Module] + self, initial_control_variates: Parameters | None, model: nn.Module | None ) -> Parameters: """ This is a helper function for the SCAFFOLD strategy init function to initialize the server_control_variates. @@ -115,11 +110,11 @@ def initialize_control_variates( architecture. Args: - initial_control_variates (Optional[Parameters]): These are the initial set of control variates + initial_control_variates (Parameters | None): These are the initial set of control variates to use for the scaffold strategy both on the server and client sides. It is optional, but if it is not provided, the strategy must receive a model that reflects the architecture to be used on the clients. Defaults to None. - model (Optional[nn.Module]): If provided and initial_control_variates is not, this is used to + model (nn.Module | None): If provided and initial_control_variates is not, this is used to set the server control variates and the initial control variates on the client side to all zeros. If initial_control_variates are provided, they take precedence. Defaults to None. @@ -150,9 +145,9 @@ def initialize_control_variates( def aggregate_fit( self, server_round: int, - results: List[Tuple[ClientProxy, FitRes]], - failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]], - ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]: + results: list[tuple[ClientProxy, FitRes]], + failures: list[tuple[ClientProxy, FitRes] | BaseException], + ) -> tuple[Parameters | None, dict[str, Scalar]]: """ Performs server-side aggregation of model weights and control variates associated with the SCAFFOLD method Both model weights and control variates are aggregated through UNWEIGHTED averaging consistent with the paper. @@ -163,14 +158,14 @@ def aggregate_fit( Args: server_round (int): What round of FL we're on (from servers perspective). - results (List[Tuple[ClientProxy, FitRes]]): These are the "successful" training run results. By default + results (list[tuple[ClientProxy, FitRes]]): These are the "successful" training run results. By default these results are the only ones used in aggregation, even if some of the failed clients have partial results (in the failures list). - failures (List[Union[Tuple[ClientProxy, FitRes], BaseException]]): This is the list of clients that + failures (list[tuple[ClientProxy, FitRes] | BaseException]): This is the list of clients that "failed" during the training phase for one reason or another, including timeouts and exceptions. Returns: - Tuple[Optional[Parameters], Dict[str, Scalar]]: The aggregated weighted and metrics dictionary. The + tuple[Parameters | None, dict[str, Scalar]]: The aggregated weighted and metrics dictionary. The parameters are optional and will be none in the even that there are no successful clients or there were failures and they are not accepted. """ @@ -246,13 +241,13 @@ def compute_updated_parameters( return updated_parameters - def aggregate(self, params: List[NDArrays]) -> NDArrays: + def aggregate(self, params: list[NDArrays]) -> NDArrays: """ Simple unweighted average to aggregate params, consistent with SCAFFOLD paper. This is "element-wise" averaging. Args: - params (List[NDArrays]): numpy arrays whose entries are to be averaged together. + params (list[NDArrays]): numpy arrays whose entries are to be averaged together. Returns: NDArrays: element-wise average over the list of numpy arrays. @@ -266,7 +261,7 @@ def aggregate(self, params: List[NDArrays]) -> NDArrays: def configure_fit_all( self, server_round: int, parameters: Parameters, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, FitIns]]: + ) -> list[tuple[ClientProxy, FitIns]]: """ This function configures ALL clients for a training round. That is, it forces the client manager to grab all of the available clients to participate in the training round. By default, the manager will at least wait for @@ -284,7 +279,7 @@ def configure_fit_all( be BaseFractionSamplingManager, which has a "sample all" function built in. Returns: - List[Tuple[ClientProxy, FitIns]]: List of sampled client identifiers and the configuration/parameters to + list[tuple[ClientProxy, FitIns]]: List of sampled client identifiers and the configuration/parameters to be sent to each client (packaged as FitIns). """ @@ -360,17 +355,14 @@ def __init__( fraction_fit: float = 1.0, fraction_evaluate: float = 1.0, min_available_clients: int = 2, - evaluate_fn: Optional[ - Callable[ - [int, NDArrays, Dict[str, Scalar]], - Optional[Tuple[float, Dict[str, Scalar]]], - ] - ] = None, - on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, - on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None, + evaluate_fn: ( + Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None + ) = None, + on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None, + on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None, accept_failures: bool = True, - fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, - evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None, + fit_metrics_aggregation_fn: MetricsAggregationFn | None = None, + evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None, weighted_eval_losses: bool = True, learning_rate: float = 1.0, ) -> None: @@ -391,19 +383,17 @@ def __init__( fraction_evaluate (float, optional): Fraction of clients used during validation. Defaults to 1.0. min_available_clients (int, optional): Minimum number of total clients in the system. Defaults to 2. - evaluate_fn (Optional[ - Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]] - ]): + evaluate_fn (Callable[[int, NDArrays, dict[str, Scalar]], tuple[float, dict[str, Scalar]] | None] | None): Optional function used for central server-side evaluation. Defaults to None. - on_fit_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_fit_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure training by providing a configuration dictionary. Defaults to None. - on_evaluate_config_fn (Optional[Callable[[int], Dict[str, Scalar]]], optional): + on_evaluate_config_fn (Callable[[int], dict[str, Scalar]] | None, optional): Function used to configure client-side validation by providing a Config dictionary. Defaults to None. accept_failures (bool, optional):Whether or not accept rounds containing failures. Defaults to True. - fit_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + fit_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. - evaluate_metrics_aggregation_fn (Optional[MetricsAggregationFn], optional): Metrics aggregation function. + evaluate_metrics_aggregation_fn (MetricsAggregationFn | None, optional): Metrics aggregation function. Defaults to None. weighted_eval_losses (bool, optional): Determines whether losses during evaluation are linearly weighted averages or a uniform average. FedAvg default is weighted average of the losses by client dataset diff --git a/fl4health/strategies/strategy_with_poll.py b/fl4health/strategies/strategy_with_poll.py index 0fc97e4a2..b1b5c8891 100644 --- a/fl4health/strategies/strategy_with_poll.py +++ b/fl4health/strategies/strategy_with_poll.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import List, Tuple from flwr.common import GetPropertiesIns from flwr.server.client_manager import ClientManager @@ -15,5 +14,5 @@ class StrategyWithPolling(ABC): @abstractmethod def configure_poll( self, server_round: int, client_manager: ClientManager - ) -> List[Tuple[ClientProxy, GetPropertiesIns]]: + ) -> list[tuple[ClientProxy, GetPropertiesIns]]: pass diff --git a/fl4health/utils/client.py b/fl4health/utils/client.py index 3272415ff..f6b0cd0ea 100644 --- a/fl4health/utils/client.py +++ b/fl4health/utils/client.py @@ -1,8 +1,9 @@ import copy import os +from collections.abc import Iterable from inspect import currentframe, getframeinfo from logging import INFO, LogRecord -from typing import Any, Dict, Iterable, TypeVar +from typing import Any, TypeVar import torch import torch.nn as nn @@ -19,7 +20,7 @@ def fold_loss_dict_into_metrics( - metrics: Dict[str, Scalar], loss_dict: Dict[str, float], logging_mode: LoggingMode + metrics: dict[str, Scalar], loss_dict: dict[str, float], logging_mode: LoggingMode ) -> None: # Prefixing the loss value keys with the mode from which they are generated if logging_mode is LoggingMode.VALIDATION: @@ -61,7 +62,7 @@ def move_data_to_device(data: T, device: torch.device) -> T: return {key: value.to(device) for key, value in data.items()} else: raise TypeError( - "data must be of type torch.Tensor or Dict[str, torch.Tensor]. If definition of TorchInputType or " + "data must be of type torch.Tensor or dict[str, torch.Tensor]. If definition of TorchInputType or " "TorchTargetType has changed this method might need to be updated or split into two." ) @@ -73,12 +74,12 @@ def check_if_batch_is_empty_and_verify_input(input: TorchInputType) -> bool: NOTE: This function assumes the input is BATCH FIRST Args: - input (TorchInputType): Input batch. input can be of type torch.Tensor or Dict[str, torch.Tensor], and in the + input (TorchInputType): Input batch. input can be of type torch.Tensor or dict[str, torch.Tensor], and in the latter case, the batch is considered to be empty if all tensors in the dictionary have length zero. Raises: - TypeError: Raised if input is not of type torch.Tensor or Dict[str, torch.Tensor]. - ValueError: Raised if input has type Dict[str, torch.Tensor] and not all tensors within the dictionary have + TypeError: Raised if input is not of type torch.Tensor or dict[str, torch.Tensor]. + ValueError: Raised if input has type dict[str, torch.Tensor] and not all tensors within the dictionary have the same size. Returns: @@ -95,7 +96,7 @@ def check_if_batch_is_empty_and_verify_input(input: TorchInputType) -> bool: else: return first_val_len == 0 else: - raise TypeError("Input must be of type torch.Tensor or Dict[str, torch.Tensor].") + raise TypeError("Input must be of type torch.Tensor or dict[str, torch.Tensor].") def clone_and_freeze_model(model: nn.Module) -> nn.Module: diff --git a/fl4health/utils/config.py b/fl4health/utils/config.py index d52e0fac2..eb70b2d69 100644 --- a/fl4health/utils/config.py +++ b/fl4health/utils/config.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, Optional, Type, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar import yaml @@ -14,7 +15,7 @@ class InvalidConfigError(ValueError): pass -def load_config(config_path: str) -> Dict[str, Any]: +def load_config(config_path: str) -> dict[str, Any]: """Load Configuration Dictionary""" with open(config_path, "r") as f: @@ -25,7 +26,7 @@ def load_config(config_path: str) -> Dict[str, Any]: return config -def check_config(config: Dict[str, Any]) -> None: +def check_config(config: dict[str, Any]) -> None: """Check if Configuration Dictionary is valid""" # Check for presence of required keys @@ -44,14 +45,14 @@ def check_config(config: Dict[str, Any]) -> None: raise InvalidConfigError(f"{key} must be greater than 0") -def narrow_dict_type(dictionary: Dict[str, Any], key: str, narrow_type_to: Type[T]) -> T: +def narrow_dict_type(dictionary: dict[str, Any], key: str, narrow_type_to: type[T]) -> T: """ Checks if a key exists in dictionary and if so, verify it is of type narrow_type_to. Args: - dictionary (Dict[str, Any]): A dictionary with string keys. + dictionary (dict[str, Any]): A dictionary with string keys. key (str): The key to check dictionary for. - narrow_type_to (Type[T]): The expected type of dictionary[key] + narrow_type_to (type[T]): The expected type of dictionary[key] Returns: T: The type-checked value at dictionary[key] @@ -75,8 +76,8 @@ def narrow_dict_type_and_set_attribute( dictionary: dict, dictionary_key: str, attribute_name: str, - narrow_type_to: Type[T], - func: Optional[Callable[[Any], Any]] = None, + narrow_type_to: type[T], + func: Callable[[Any], Any] | None = None, ) -> None: """ Checks a key exists in dictionary, verify its type and sets the corresponding attribute. @@ -86,9 +87,9 @@ def narrow_dict_type_and_set_attribute( Args: self (object): The object to set attribute to dictionary[dictionary_key]. - dictionary (Dict[str, Any]): A dictionary with string keys. + dictionary (dict[str, Any]): A dictionary with string keys. dictionary_key (str): The key to check dictionary for. - narrow_type_to (Type[T]): The expected type of dictionary[key]. + narrow_type_to (type[T]): The expected type of dictionary[key]. Raises: ValueError: If dictionary[key] is not of type narrow_type_to or diff --git a/fl4health/utils/data_generation.py b/fl4health/utils/data_generation.py index 47224e560..67a75b0a4 100644 --- a/fl4health/utils/data_generation.py +++ b/fl4health/utils/data_generation.py @@ -1,6 +1,5 @@ import math from abc import ABC, abstractmethod -from typing import List, Tuple import torch import torch.nn.functional as F @@ -84,12 +83,12 @@ def map_inputs_to_outputs(self, x: torch.Tensor, W: torch.Tensor, b: torch.Tenso samples = torch.multinomial(distributions, 1) return F.one_hot(samples, num_classes=self.output_dim).squeeze() - def generate(self) -> List[TensorDataset]: + def generate(self) -> list[TensorDataset]: """ Based on the class parameters, generate a list of synthetic TensorDatasets, one for each client. Returns: - List[TensorDataset]: Synthetic datasets for each client. + list[TensorDataset]: Synthetic datasets for each client. """ client_tensors = self.generate_client_tensors() assert ( @@ -99,14 +98,14 @@ def generate(self) -> List[TensorDataset]: return client_datasets @abstractmethod - def generate_client_tensors(self) -> List[Tuple[torch.Tensor, torch.Tensor]]: + def generate_client_tensors(self) -> list[tuple[torch.Tensor, torch.Tensor]]: """ Method to be implemented determining how to generate the tensors in the subclasses. Each of the subclasses uses the affine mapping, but the parameters for how that affine mapping is setup are different and determined in this function. Returns: - List[Tuple[torch.Tensor, torch.Tensor]]: input and output tensors for each of the clients. + list[tuple[torch.Tensor, torch.Tensor]]: input and output tensors for each of the clients. """ pass @@ -165,7 +164,7 @@ def __init__( def get_input_output_tensors( self, mu: float, v: torch.Tensor, sigma: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """ This function takes values for the center of elements in the affine transformation elements (mu), the centers feature each of the input feature dimensions (v), and the covariance of those features (sigma) and produces @@ -180,7 +179,7 @@ def get_input_output_tensors( diagonal matrix as well. Returns: - Tuple[torch.Tensor, torch.Tensor]: X and Y for the clients synthetic dataset. Shape of X is + tuple[torch.Tensor, torch.Tensor]: X and Y for the clients synthetic dataset. Shape of X is n_samples x input dimension. Shape of Y is n_samples x output_dim and is one-hot encoded """ @@ -193,7 +192,7 @@ def get_input_output_tensors( return x, self.map_inputs_to_outputs(x, W, b) - def generate_client_tensors(self) -> List[Tuple[torch.Tensor, torch.Tensor]]: + def generate_client_tensors(self) -> list[tuple[torch.Tensor, torch.Tensor]]: """ For the Non-IID synthetic generator, this function uses the values of alpha and beta to sample the parameters that will be used to generate the synthetic datasets on each client. For each client, beta is used to sample @@ -202,9 +201,9 @@ def generate_client_tensors(self) -> List[Tuple[torch.Tensor, torch.Tensor]]: the larger the variance in these values, implying higher probability of heterogeneity. Returns: - List[Tuple[torch.Tensor, torch.Tensor]]: Set of input and output tensors for each client. + list[tuple[torch.Tensor, torch.Tensor]]: Set of input and output tensors for each client. """ - tensors_per_client: List[Tuple[torch.Tensor, torch.Tensor]] = [] + tensors_per_client: list[tuple[torch.Tensor, torch.Tensor]] = [] for _ in range(self.num_clients): B = torch.normal(0.0, self.beta, (1,)) # v_k in the FedProx paper @@ -267,13 +266,13 @@ def __init__( loc=torch.zeros(self.input_dim), covariance_matrix=self.input_covariance ) - def get_input_output_tensors(self) -> Tuple[torch.Tensor, torch.Tensor]: + def get_input_output_tensors(self) -> tuple[torch.Tensor, torch.Tensor]: """ As described in the original FedProx paper (Appendix C.1), the features are all sampled from a centered multidimensional normal distribution with diagonal covariance matrix shared across clients. Returns: - Tuple[torch.Tensor, torch.Tensor]: X and Y for the clients synthetic dataset. Shape of X is + tuple[torch.Tensor, torch.Tensor]: X and Y for the clients synthetic dataset. Shape of X is n_samples x input dimension. Shape of Y is n_samples x output_dim and is one-hot encoded """ # size of x should be samples_per_client x input_dim @@ -282,15 +281,15 @@ def get_input_output_tensors(self) -> Tuple[torch.Tensor, torch.Tensor]: return x, self.map_inputs_to_outputs(x, self.W, self.b) - def generate_client_tensors(self) -> List[Tuple[torch.Tensor, torch.Tensor]]: + def generate_client_tensors(self) -> list[tuple[torch.Tensor, torch.Tensor]]: """ For IID generation, this function is simple, as we need not sample any parameters per client for use in generation, as these are all shared across clients. Returns: - List[Tuple[torch.Tensor, torch.Tensor]]: Set of input and output tensors for each client. + list[tuple[torch.Tensor, torch.Tensor]]: Set of input and output tensors for each client. """ - tensors_per_client: List[Tuple[torch.Tensor, torch.Tensor]] = [] + tensors_per_client: list[tuple[torch.Tensor, torch.Tensor]] = [] for _ in range(self.num_clients): client_X, client_Y = self.get_input_output_tensors() tensors_per_client.append((client_X, client_Y)) diff --git a/fl4health/utils/dataset.py b/fl4health/utils/dataset.py index 81c17406d..256795013 100644 --- a/fl4health/utils/dataset.py +++ b/fl4health/utils/dataset.py @@ -1,13 +1,14 @@ import copy from abc import ABC, abstractmethod -from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast +from collections.abc import Callable +from typing import TypeVar, cast import torch from torch.utils.data import Dataset class BaseDataset(ABC, Dataset): - def __init__(self, transform: Optional[Callable], target_transform: Optional[Callable]) -> None: + def __init__(self, transform: Callable | None, target_transform: Callable | None) -> None: self.transform = transform self.target_transform = target_transform @@ -26,7 +27,7 @@ def update_target_transform(self, g: Callable) -> None: self.target_transform = g @abstractmethod - def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError @abstractmethod @@ -38,15 +39,15 @@ class TensorDataset(BaseDataset): def __init__( self, data: torch.Tensor, - targets: Optional[torch.Tensor] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + targets: torch.Tensor | None = None, + transform: Callable | None = None, + target_transform: Callable | None = None, ) -> None: super().__init__(transform, target_transform) self.data = data self.targets = targets - def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: assert self.targets is not None data, target = self.data[index], self.targets[index] @@ -67,15 +68,15 @@ class SslTensorDataset(TensorDataset): def __init__( self, data: torch.Tensor, - targets: Optional[torch.Tensor] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + targets: torch.Tensor | None = None, + transform: Callable | None = None, + target_transform: Callable | None = None, ) -> None: assert targets is not None, "SslTensorDataset targets must be None" super().__init__(data, targets, transform, target_transform) - def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: data = self.data[index] assert self.target_transform is not None, "Target transform cannot be None." @@ -92,21 +93,21 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: class DictionaryDataset(Dataset): - def __init__(self, data: Dict[str, List[torch.Tensor]], targets: torch.Tensor) -> None: + def __init__(self, data: dict[str, list[torch.Tensor]], targets: torch.Tensor) -> None: """ A torch dataset that supports a dictionary of input data rather than just a torch.Tensor. This kind of dataset is useful when dealing with non-trivial inputs to a model. For example, a language model may require token ids AND attention masks. This dataset supports that functionality. Args: - data (Dict[str, List[torch.Tensor]]): A set of data for model training/input in the form of a dictionary + data (dict[str, list[torch.Tensor]]): A set of data for model training/input in the form of a dictionary of tensors. targets (torch.Tensor): Target tensor. """ self.data = data self.targets = targets - def __getitem__(self, index: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + def __getitem__(self, index: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]: return {key: val[index] for key, val in self.data.items()}, self.targets[index] def __len__(self) -> int: @@ -131,7 +132,7 @@ def __init__( self.data = data self.targets = targets - def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: assert self.targets is not None data, target = self.data[index], self.targets[index] @@ -141,7 +142,7 @@ def __len__(self) -> int: return len(self.data) -D = TypeVar("D", bound=Union[TensorDataset, DictionaryDataset]) +D = TypeVar("D", bound=TensorDataset | DictionaryDataset) def select_by_indices(dataset: D, selected_indices: torch.Tensor) -> D: @@ -169,7 +170,7 @@ def select_by_indices(dataset: D, selected_indices: torch.Tensor) -> D: return cast(D, modified_dataset) elif isinstance(dataset, DictionaryDataset): new_targets = dataset.targets[selected_indices] - new_data: Dict[str, List[torch.Tensor]] = {} + new_data: dict[str, list[torch.Tensor]] = {} for key, val in dataset.data.items(): # Since val is a list of tensors, we can't directly index into it # using selected_indices. diff --git a/fl4health/utils/dataset_converter.py b/fl4health/utils/dataset_converter.py index ded37e269..8fb8d5560 100644 --- a/fl4health/utils/dataset_converter.py +++ b/fl4health/utils/dataset_converter.py @@ -1,5 +1,5 @@ +from collections.abc import Callable from functools import partial -from typing import Callable, Optional, Tuple, Union import torch @@ -9,8 +9,8 @@ class DatasetConverter(TensorDataset): def __init__( self, - converter_function: Callable[[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]], - dataset: Union[None, TensorDataset], + converter_function: Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]], + dataset: TensorDataset | None, ) -> None: """ Dataset converter classes are designed to re-format any dataset for a given training task, @@ -23,7 +23,7 @@ def __init__( self.converter_function = converter_function self.dataset = dataset - def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: # Overriding this function from BaseDataset allows the converter to be compatible with the data transformers. # converter_function is applied after the transformers. assert self.dataset is not None, "Error: no dataset is set, use convert_dataset(your_dataset: TensorDataset)" @@ -47,10 +47,10 @@ def convert_dataset(self, dataset: TensorDataset) -> TensorDataset: class AutoEncoderDatasetConverter(DatasetConverter): def __init__( self, - condition: Union[None, str, torch.Tensor] = None, + condition: str | torch.Tensor | None = None, do_one_hot_encoding: bool = False, - custom_converter_function: Optional[Callable] = None, - condition_vector_size: Optional[int] = None, + custom_converter_function: Callable | None = None, + condition_vector_size: int | None = None, ) -> None: """ A dataset converter specific to formatting supervised data such as MNIST for @@ -60,10 +60,10 @@ def __init__( other converter functions can be added or passed to support other conditions. Args: - condition (Union[None, str, torch.Tensor]): Could be a fixed tensor used for all the data samples, + condition (str | torch.Tensor | None): Could be a fixed tensor used for all the data samples, None for non-conditional models, or a name(str) passed for other custom conversions like 'label'. do_one_hot_encoding (bool, optional): Should converter perform one hot encoding on the condition or not. - custom_converter_function (Optional[Callable], optional): User can define a new converter function. + custom_converter_function (Callable | None, optional): User can define a new converter function. """ self.condition = condition if isinstance(self.condition, torch.Tensor): @@ -135,15 +135,15 @@ def _setup_converter_function(self) -> Callable: return converter_function def _only_replace_target_with_data( - self, data: torch.Tensor, target: Union[None, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, data: torch.Tensor, target: torch.Tensor | None + ) -> tuple[torch.Tensor, torch.Tensor]: """The data converter function used for simple autoencoders or variational autoencoders.""" # Target in self-supervised training for autoencoder is the data. return data, data def _cat_input_condition( - self, data: torch.Tensor, target: Union[None, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + self, data: torch.Tensor, target: torch.Tensor | None + ) -> tuple[torch.Tensor, torch.Tensor]: """The data converter function used for conditional autoencoders. This converter is used when we have a torch tensor as condition for all the data samples. """ @@ -152,7 +152,7 @@ def _cat_input_condition( assert isinstance(self.condition, torch.Tensor), "Error: condition should be a torch tensor" return torch.cat([data.view(-1), self.condition]), data - def _cat_input_label(self, data: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def _cat_input_label(self, data: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """The data converter function used for conditional autoencoders. This converter is used when we want to condition each data sample on its label. """ @@ -163,7 +163,7 @@ def _cat_input_label(self, data: torch.Tensor, target: torch.Tensor) -> Tuple[to # We can flatten the data since self.data_shape is already saved. return torch.cat([data.view(-1), target]), data - def get_unpacking_function(self) -> Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + def get_unpacking_function(self) -> Callable[[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: condition_vector_size = self.get_condition_vector_size() return partial( AutoEncoderDatasetConverter.unpack_input_condition, @@ -174,7 +174,7 @@ def get_unpacking_function(self) -> Callable[[torch.Tensor], Tuple[torch.Tensor, @staticmethod def unpack_input_condition( packed_data: torch.Tensor, cond_vec_size: int, data_shape: torch.Size - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: """Unpacks model inputs (data and condition) from a tensor used in the training loop regardless of the converter function used to pack them. Unpacking relies on the size of the condition vector, and the original data shape which is saved before the packing process. @@ -183,7 +183,7 @@ def unpack_input_condition( packed_data (torch.Tensor): Data tensor used in the training loop as the input to the model. Returns: - Tuple[torch.Tensor, torch.Tensor]: Data in its original shape, and the condition vector + tuple[torch.Tensor, torch.Tensor]: Data in its original shape, and the condition vector to be fed into the model. """ # We assume data is "batch first". diff --git a/fl4health/utils/functions.py b/fl4health/utils/functions.py index 460aa949e..d52b70a20 100644 --- a/fl4health/utils/functions.py +++ b/fl4health/utils/functions.py @@ -1,4 +1,4 @@ -from typing import Any, List, Tuple +from typing import Any import numpy as np import torch @@ -26,7 +26,7 @@ def forward(bernoulli_probs: torch.Tensor) -> torch.Tensor: # type: ignore @staticmethod # inputs is a Tuple of all of the inputs passed to forward. # output is the output of the forward(). - def setup_context(ctx: Any, inputs: Tuple[torch.Tensor], output: torch.Tensor) -> None: + def setup_context(ctx: Any, inputs: tuple[torch.Tensor], output: torch.Tensor) -> None: assert len(inputs) == 1 (bernoulli_probs,) = inputs ctx.save_for_backward(bernoulli_probs) @@ -61,16 +61,16 @@ def select_zeroeth_element(array: np.ndarray) -> float: return array[indices] -def pseudo_sort_scoring_function(client_result: Tuple[ClientProxy, NDArrays, int]) -> float: +def pseudo_sort_scoring_function(client_result: tuple[ClientProxy, NDArrays, int]) -> float: """ - This function provides the "score" that is used to sort a list of Tuple[ClientProxy, NDArrays, int]. We select + This function provides the "score" that is used to sort a list of tuple[ClientProxy, NDArrays, int]. We select the zeroeth (index 0 across all dimensions) element from each of the arrays in the NDArrays list, sum them, and add the integer (client sample counts) to the sum to come up with a score for sorting. Note that the underlying numpy arrays in NDArrays may not all be of numerical type. So we limit to selecting elements from arrays of floats. Args: - client_result (Tuple[ClientProxy, NDArrays, int]]): Elements to use to determine the score. + client_result (tuple[ClientProxy, NDArrays, int]]): Elements to use to determine the score. Returns: float: Sum of a the zeroeth elements of each array in the NDArrays and the int of the tuple @@ -83,8 +83,8 @@ def pseudo_sort_scoring_function(client_result: Tuple[ClientProxy, NDArrays, int def decode_and_pseudo_sort_results( - results: List[Tuple[ClientProxy, FitRes]] -) -> List[Tuple[ClientProxy, NDArrays, int]]: + results: list[tuple[ClientProxy, FitRes]] +) -> list[tuple[ClientProxy, NDArrays, int]]: """ This function is used to convert the results of client training into NDArrays and to apply a pseudo sort based on the zeroeth elements in the weights and the sample counts. As long as the numpy seed has been set on the @@ -96,10 +96,10 @@ def decode_and_pseudo_sort_results( and are, therefore, not pinnable without a ton of work. Args: - results (List[Tuple[ClientProxy, FitRes]]): Results from a federated training round. + results (list[tuple[ClientProxy, FitRes]]): Results from a federated training round. Returns: - List[Tuple[ClientProxy, NDArrays, int]]: The ordered set of weights as NDarrays and the corresponding + list[tuple[ClientProxy, NDArrays, int]]: The ordered set of weights as NDarrays and the corresponding number of examples """ ndarrays_results = [ diff --git a/fl4health/utils/load_data.py b/fl4health/utils/load_data.py index d4936cc21..e125bd51d 100644 --- a/fl4health/utils/load_data.py +++ b/fl4health/utils/load_data.py @@ -1,8 +1,8 @@ import random import warnings +from collections.abc import Callable from logging import INFO from pathlib import Path -from typing import Callable, Dict, Optional, Tuple import numpy as np import torch @@ -28,8 +28,8 @@ def __call__(self, tensor: torch.Tensor) -> np.ndarray: def split_data_and_targets( - data: torch.Tensor, targets: torch.Tensor, validation_proportion: float = 0.2, hash_key: Optional[int] = None -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + data: torch.Tensor, targets: torch.Tensor, validation_proportion: float = 0.2, hash_key: int | None = None +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: total_size = data.shape[0] train_size = int(total_size * (1 - validation_proportion)) @@ -42,7 +42,7 @@ def split_data_and_targets( return train_data, train_targets, val_data, val_targets -def get_mnist_data_and_target_tensors(data_dir: Path, train: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def get_mnist_data_and_target_tensors(data_dir: Path, train: bool) -> tuple[torch.Tensor, torch.Tensor]: mnist_dataset = MNIST(data_dir, train=train, download=True) data = torch.Tensor(mnist_dataset.data) targets = torch.Tensor(mnist_dataset.targets).long() @@ -51,11 +51,11 @@ def get_mnist_data_and_target_tensors(data_dir: Path, train: bool) -> Tuple[torc def get_train_and_val_mnist_datasets( data_dir: Path, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + transform: Callable | None = None, + target_transform: Callable | None = None, validation_proportion: float = 0.2, - hash_key: Optional[int] = None, -) -> Tuple[TensorDataset, TensorDataset]: + hash_key: int | None = None, +) -> tuple[TensorDataset, TensorDataset]: data, targets = get_mnist_data_and_target_tensors(data_dir, True) train_data, train_targets, val_data, val_targets = split_data_and_targets( @@ -70,13 +70,13 @@ def get_train_and_val_mnist_datasets( def load_mnist_data( data_dir: Path, batch_size: int, - sampler: Optional[LabelBasedSampler] = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - dataset_converter: Optional[DatasetConverter] = None, + sampler: LabelBasedSampler | None = None, + transform: Callable | None = None, + target_transform: Callable | None = None, + dataset_converter: DatasetConverter | None = None, validation_proportion: float = 0.2, - hash_key: Optional[int] = None, -) -> Tuple[DataLoader, DataLoader, Dict[str, int]]: + hash_key: int | None = None, +) -> tuple[DataLoader, DataLoader, dict[str, int]]: """ Load MNIST Dataset (training and validation set). @@ -84,18 +84,18 @@ def load_mnist_data( data_dir (Path): The path to the MNIST dataset locally. Dataset is downloaded to this location if it does not already exist. batch_size (int): The batch size to use for the train and validation dataloader. - sampler (Optional[LabelBasedSampler]): Optional sampler to subsample dataset based on labels. - transform (Optional[Callable]): Optional transform to be applied to input samples. - target_transform (Optional[Callable]): Optional transform to be applied to targets. - dataset_converter (Optional[DatasetConverter]): Optional dataset converter used to convert + sampler (LabelBasedSampler | None): Optional sampler to subsample dataset based on labels. + transform (Callable | None): Optional transform to be applied to input samples. + target_transform (Callable | None): Optional transform to be applied to targets. + dataset_converter (DatasetConverter | None): Optional dataset converter used to convert the input and/or target of train and validation dataset. validation_proportion (float): A float between 0 and 1 specifying the proportion of samples to allocate to the validation dataset. Defaults to 0.2. - hash_key (Optional[int]): Optional hash key to create a reproducible split for train and validation + hash_key (int | None): Optional hash key to create a reproducible split for train and validation datasets. Returns: - Tuple[DataLoader, DataLoader, Dict[str, int]]: The train data loader, validation data loader + tuple[DataLoader, DataLoader, dict[str, int]]: The train data loader, validation data loader and a dictionary with the sample counts of datasets underpinning the respective data loaders. """ log(INFO, f"Data directory: {str(data_dir)}") @@ -130,9 +130,9 @@ def load_mnist_data( def load_mnist_test_data( data_dir: Path, batch_size: int, - sampler: Optional[LabelBasedSampler] = None, - transform: Optional[Callable] = None, -) -> Tuple[DataLoader, Dict[str, int]]: + sampler: LabelBasedSampler | None = None, + transform: Callable | None = None, +) -> tuple[DataLoader, dict[str, int]]: """ Load MNIST Test Dataset. @@ -140,11 +140,11 @@ def load_mnist_test_data( data_dir (Path): The path to the MNIST dataset locally. Dataset is downloaded to this location if it does not already exist. batch_size (int): The batch size to use for the test dataloader. - sampler (Optional[LabelBasedSampler]): Optional sampler to subsample dataset based on labels. - transform (Optional[Callable]): Optional transform to be applied to input samples. + sampler (LabelBasedSampler | None): Optional sampler to subsample dataset based on labels. + transform (Callable | None): Optional transform to be applied to input samples. Returns: - Tuple[DataLoader, Dict[str, int]]: The test data loader and a dictionary containing the sample count + tuple[DataLoader, dict[str, int]]: The test data loader and a dictionary containing the sample count of the test dataset. """ log(INFO, f"Data directory: {str(data_dir)}") @@ -169,7 +169,7 @@ def load_mnist_test_data( return evaluation_loader, num_examples -def get_cifar10_data_and_target_tensors(data_dir: Path, train: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def get_cifar10_data_and_target_tensors(data_dir: Path, train: bool) -> tuple[torch.Tensor, torch.Tensor]: cifar_dataset = CIFAR10(data_dir, train=train, download=True) data = torch.Tensor(cifar_dataset.data) targets = torch.Tensor(cifar_dataset.targets).long() @@ -178,11 +178,11 @@ def get_cifar10_data_and_target_tensors(data_dir: Path, train: bool) -> Tuple[to def get_train_and_val_cifar10_datasets( data_dir: Path, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + transform: Callable | None = None, + target_transform: Callable | None = None, validation_proportion: float = 0.2, - hash_key: Optional[int] = None, -) -> Tuple[TensorDataset, TensorDataset]: + hash_key: int | None = None, +) -> tuple[TensorDataset, TensorDataset]: data, targets = get_cifar10_data_and_target_tensors(data_dir, True) train_data, train_targets, val_data, val_targets = split_data_and_targets( @@ -198,10 +198,10 @@ def get_train_and_val_cifar10_datasets( def load_cifar10_data( data_dir: Path, batch_size: int, - sampler: Optional[LabelBasedSampler] = None, + sampler: LabelBasedSampler | None = None, validation_proportion: float = 0.2, - hash_key: Optional[int] = None, -) -> Tuple[DataLoader, DataLoader, Dict[str, int]]: + hash_key: int | None = None, +) -> tuple[DataLoader, DataLoader, dict[str, int]]: """ Load CIFAR10 Dataset (training and validation set). @@ -209,14 +209,14 @@ def load_cifar10_data( data_dir (Path): The path to the CIFAR10 dataset locally. Dataset is downloaded to this location if it does not already exist. batch_size (int): The batch size to use for the train and validation dataloader. - sampler (Optional[LabelBasedSampler]): Optional sampler to subsample dataset based on labels. + sampler (LabelBasedSampler | None): Optional sampler to subsample dataset based on labels. validation_proportion (float): A float between 0 and 1 specifying the proportion of samples to allocate to the validation dataset. Defaults to 0.2. - hash_key (Optional[int]): Optional hash key to create a reproducible split for train and validation + hash_key (int | None): Optional hash key to create a reproducible split for train and validation datasets. Returns: - Tuple[DataLoader, DataLoader, Dict[str, int]]: The train data loader, validation data loader + tuple[DataLoader, DataLoader, dict[str, int]]: The train data loader, validation data loader and a dictionary with the sample counts of datasets underpinning the respective data loaders. """ log(INFO, f"Data directory: {str(data_dir)}") @@ -246,8 +246,8 @@ def load_cifar10_data( def load_cifar10_test_data( - data_dir: Path, batch_size: int, sampler: Optional[LabelBasedSampler] = None -) -> Tuple[DataLoader, Dict[str, int]]: + data_dir: Path, batch_size: int, sampler: LabelBasedSampler | None = None +) -> tuple[DataLoader, dict[str, int]]: """ Load CIFAR10 Test Dataset. @@ -255,10 +255,10 @@ def load_cifar10_test_data( data_dir (Path): The path to the CIFAR10 dataset locally. Dataset is downloaded to this location if it does not already exist. batch_size (int): The batch size to use for the test dataloader. - sampler (Optional[LabelBasedSampler]): Optional sampler to subsample dataset based on labels. + sampler (LabelBasedSampler | None): Optional sampler to subsample dataset based on labels. Returns: - Tuple[DataLoader, Dict[str, int]]: The test data loader and a dictionary containing the sample count + tuple[DataLoader, dict[str, int]]: The test data loader and a dictionary containing the sample count of the test dataset. """ log(INFO, f"Data directory: {str(data_dir)}") diff --git a/fl4health/utils/losses.py b/fl4health/utils/losses.py index ac65a7b64..172f999eb 100644 --- a/fl4health/utils/losses.py +++ b/fl4health/utils/losses.py @@ -2,29 +2,29 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Dict, Generic, List, Optional, TypeVar, Union +from typing import Generic, TypeVar import torch class Losses(ABC): - def __init__(self, additional_losses: Optional[Dict[str, torch.Tensor]] = None) -> None: + def __init__(self, additional_losses: dict[str, torch.Tensor] | None = None) -> None: """ An abstract class to store the losses Args: - additional_losses (Optional[Dict[str, torch.Tensor]]): Optional dictionary of additional losses. + additional_losses (dict[str, torch.Tensor] | None): Optional dictionary of additional losses. """ self.additional_losses = additional_losses if additional_losses else {} - def as_dict(self) -> Dict[str, float]: + def as_dict(self) -> dict[str, float]: """ Produces a dictionary representation of the object with all of the losses. Returns: - Dict[str, float]: A dictionary with the additional losses if they exist. + dict[str, float]: A dictionary with the additional losses if they exist. """ - loss_dict: Dict[str, float] = {} + loss_dict: dict[str, float] = {} if self.additional_losses is not None: for key, val in self.additional_losses.items(): @@ -48,24 +48,24 @@ def aggregate(loss_meter: LossMeter) -> Losses: class EvaluationLosses(Losses): - def __init__(self, checkpoint: torch.Tensor, additional_losses: Optional[Dict[str, torch.Tensor]] = None) -> None: + def __init__(self, checkpoint: torch.Tensor, additional_losses: dict[str, torch.Tensor] | None = None) -> None: """ A class to store the checkpoint and additional_losses of a model along with a method to return a dictionary representation. Args: checkpoint (torch.Tensor): The loss used to checkpoint model (if checkpointing is enabled). - additional_losses (Optional[Dict[str, torch.Tensor]]): Optional dictionary of additional losses. + additional_losses (dict[str, torch.Tensor] | None): Optional dictionary of additional losses. """ super().__init__(additional_losses) self.checkpoint = checkpoint - def as_dict(self) -> Dict[str, float]: + def as_dict(self) -> dict[str, float]: """ Produces a dictionary representation of the object with all of the losses. Returns: - Dict[str, float]: A dictionary with the checkpoint loss, plus each one of the keys in + dict[str, float]: A dictionary with the checkpoint loss, plus each one of the keys in additional losses if they exist. """ loss_dict = super().as_dict() @@ -98,28 +98,28 @@ def aggregate(loss_meter: LossMeter[EvaluationLosses]) -> EvaluationLosses: class TrainingLosses(Losses): def __init__( self, - backward: Union[torch.Tensor, Dict[str, torch.Tensor]], - additional_losses: Optional[Dict[str, torch.Tensor]] = None, + backward: torch.Tensor | dict[str, torch.Tensor], + additional_losses: dict[str, torch.Tensor] | None = None, ) -> None: """ A class to store the backward and additional_losses of a model along with a method to return a dictionary representation. Args: - backward (Union[torch.Tensor, Dict[str, torch.Tensor]]): The backward loss or + backward (torch.Tensor | dict[str, torch.Tensor]): The backward loss or losses to optimize. In the normal case, backward is a Tensor corresponding to the loss of a model. In the case of an ensemble_model, backward is dictionary of losses. - additional_losses (Optional[Dict[str, torch.Tensor]]): Optional dictionary of additional losses. + additional_losses (dict[str, torch.Tensor] | None): Optional dictionary of additional losses. """ super().__init__(additional_losses) self.backward = backward if isinstance(backward, dict) else {"backward": backward} - def as_dict(self) -> Dict[str, float]: + def as_dict(self) -> dict[str, float]: """ Produces a dictionary representation of the object with all of the losses. Returns: - Dict[str, float]: A dictionary where each key represents one of the backward losses, + dict[str, float]: A dictionary where each key represents one of the backward losses, plus additional losses if they exist. """ loss_dict = super().as_dict() @@ -176,7 +176,7 @@ def __init__(self, loss_meter_type: LossMeterType, losses_type: type[LossesType] of the subclasses of Losses """ - self.losses_list: List[LossesType] = [] + self.losses_list: list[LossesType] = [] self.loss_meter_type = loss_meter_type self.losses_type = losses_type @@ -207,25 +207,25 @@ def compute(self) -> LossesType: @staticmethod def aggregate_losses_dict( - loss_list: List[Dict[str, torch.Tensor]], + loss_list: list[dict[str, torch.Tensor]], loss_meter_type: LossMeterType, - ) -> Dict[str, torch.Tensor]: + ) -> dict[str, torch.Tensor]: """ Aggregates a list of losses dictionaries into a single dictionary according to the loss meter aggregation type Args: - loss_list (List[Dict[str, torch.Tensor]]): A list of loss dictionaries + loss_list (list[dict[str, torch.Tensor]]): A list of loss dictionaries loss_meter_type (LossMeterType): The type of the loss meter to perform the aggregation Returns: - Dict[str, torch.Tensor]: A single dictionary with the aggregated losses according to the given loss + dict[str, torch.Tensor]: A single dictionary with the aggregated losses according to the given loss meter type """ # We don't know the keys of the dict (backward or additional losses) beforehand. We don't obtain them # from the first entry because losses can have different keys. We get list of all the keys from # all the losses. loss_keys = set(key for loss_dict_ in loss_list for key in loss_dict_.keys()) - loss_dict: Dict[str, torch.Tensor] = {} + loss_dict: dict[str, torch.Tensor] = {} for key in loss_keys: if loss_meter_type == LossMeterType.AVERAGE: loss = torch.mean(torch.FloatTensor([loss[key] for loss in loss_list if key in loss])) diff --git a/fl4health/utils/metric_aggregation.py b/fl4health/utils/metric_aggregation.py index 448e02754..ae9b1ba5e 100644 --- a/fl4health/utils/metric_aggregation.py +++ b/fl4health/utils/metric_aggregation.py @@ -1,24 +1,23 @@ from collections import defaultdict -from typing import DefaultDict, List, Tuple from flwr.common.typing import Metrics def uniform_metric_aggregation( - all_client_metrics: List[Tuple[int, Metrics]], -) -> Tuple[DefaultDict[str, int], Metrics]: + all_client_metrics: list[tuple[int, Metrics]], +) -> tuple[defaultdict[str, int], Metrics]: """ Function that aggregates client metrics and divides by the number of clients that contributed to metric. Args: - all_client_metrics (List[Tuple[int, Metrics]]): A list of tuples with the + all_client_metrics (list[tuple[int, Metrics]]): A list of tuples with the sample counts and metrics for each client. Returns: - Tuple[DefaultDict[str, int], Metrics]: Client counts per metric and the uniformly aggregated metrics. + tuple[defaultdict[str, int], Metrics]: Client counts per metric and the uniformly aggregated metrics. """ aggregated_metrics: Metrics = {} - total_client_count_by_metric: DefaultDict[str, int] = defaultdict(int) + total_client_count_by_metric: defaultdict[str, int] = defaultdict(int) # Run through all of the metrics for _, client_metrics in all_client_metrics: for metric_name, metric_value in client_metrics.items(): @@ -40,17 +39,17 @@ def uniform_metric_aggregation( def metric_aggregation( - all_client_metrics: List[Tuple[int, Metrics]], -) -> Tuple[int, Metrics]: + all_client_metrics: list[tuple[int, Metrics]], +) -> tuple[int, Metrics]: """ Function that computes a weighted aggregation of metrics normalized by the total number of samples. Args: - all_client_metrics (List[Tuple[int, Metrics]]): A list of tuples with the + all_client_metrics (list[tuple[int, Metrics]]): A list of tuples with the sample counts and metrics for each client. Returns: - Tuple[int, Metrics]: The total number of examples along with aggregated metrics. + tuple[int, Metrics]: The total number of examples along with aggregated metrics. """ aggregated_metrics: Metrics = {} total_examples = 0 @@ -93,13 +92,13 @@ def normalize_metrics(total_examples: int, aggregated_metrics: Metrics) -> Metri def uniform_normalize_metrics( - total_client_count_by_metric: DefaultDict[str, int], aggregated_metrics: Metrics + total_client_count_by_metric: defaultdict[str, int], aggregated_metrics: Metrics ) -> Metrics: """ Function that normalizes metrics based on how many clients contributed to the metric. Args: - total_client_count_by_metric (DefaultDict[str, int]): The count of clients that contributed to each metric. + total_client_count_by_metric (defaultdict[str, int]): The count of clients that contributed to each metric. aggregated_metrics (Metrics): Metrics that have been aggregated across clients. Returns: @@ -114,14 +113,14 @@ def uniform_normalize_metrics( def fit_metrics_aggregation_fn( - all_client_metrics: List[Tuple[int, Metrics]], + all_client_metrics: list[tuple[int, Metrics]], ) -> Metrics: """ Function for fit that computes a weighted aggregation of the client metrics and normalizes by the total number of samples. Args: - all_client_metrics (List[Tuple[int, Metrics]]): A list of tuples with the + all_client_metrics (list[tuple[int, Metrics]]): A list of tuples with the sample counts and metrics for each client. Returns: @@ -134,14 +133,14 @@ def fit_metrics_aggregation_fn( def evaluate_metrics_aggregation_fn( - all_client_metrics: List[Tuple[int, Metrics]], + all_client_metrics: list[tuple[int, Metrics]], ) -> Metrics: """ Function for evaluate that computes a weighted aggregation of the client metrics and normalizes by the total number of samples. Args: - all_client_metrics (List[Tuple[int, Metrics]]): A list of tuples with the + all_client_metrics (list[tuple[int, Metrics]]): A list of tuples with the sample counts and metrics for each client. Returns: @@ -154,14 +153,14 @@ def evaluate_metrics_aggregation_fn( def uniform_evaluate_metrics_aggregation_fn( - all_client_metrics: List[Tuple[int, Metrics]], + all_client_metrics: list[tuple[int, Metrics]], ) -> Metrics: """ Function for evaluate that computes aggregation of the client metrics and normalizes by the number of clients that contributed to the metric. Args: - all_client_metrics (List[Tuple[int, Metrics]]): A list of tuples with the + all_client_metrics (list[tuple[int, Metrics]]): A list of tuples with the sample counts and metrics for each client. Returns: diff --git a/fl4health/utils/metrics.py b/fl4health/utils/metrics.py index 74af911b6..4988c5da8 100644 --- a/fl4health/utils/metrics.py +++ b/fl4health/utils/metrics.py @@ -1,11 +1,11 @@ import copy from abc import ABC, abstractmethod +from collections.abc import Sequence from enum import Enum -from typing import Dict, List, Sequence, Tuple import numpy as np import torch -from flwr.common.typing import Metrics, Optional, Scalar +from flwr.common.typing import Metrics, Scalar from sklearn import metrics as sklearn_metrics from torchmetrics import Metric as TMetric @@ -47,12 +47,12 @@ def update(self, input: torch.Tensor, target: torch.Tensor) -> None: raise NotImplementedError @abstractmethod - def compute(self, name: Optional[str]) -> Metrics: + def compute(self, name: str | None) -> Metrics: """ Compute metric on accumulated input and output over updates. Args: - name (Optional[str]): Optional name used in conjunction with class attribute name + name (str | None): Optional name used in conjunction with class attribute name to define key in metrics dictionary. Raises: @@ -97,12 +97,12 @@ def update(self, input: torch.Tensor, target: torch.Tensor) -> None: """ self.metric.update(input, target.long()) - def compute(self, name: Optional[str]) -> Metrics: + def compute(self, name: str | None) -> Metrics: """ Compute value of underlying TorchMetric. Args: - name (Optional[str]): Optional name used in conjunction with class attribute name + name (str | None): Optional name used in conjunction with class attribute name to define key in metrics dictionary. Returns: @@ -127,8 +127,8 @@ def __init__(self, name: str) -> None: name (str): Name of the metric. """ super().__init__(name) - self.accumulated_inputs: List[torch.Tensor] = [] - self.accumulated_targets: List[torch.Tensor] = [] + self.accumulated_inputs: list[torch.Tensor] = [] + self.accumulated_targets: list[torch.Tensor] = [] def update(self, input: torch.Tensor, target: torch.Tensor) -> None: """ @@ -142,12 +142,12 @@ def update(self, input: torch.Tensor, target: torch.Tensor) -> None: self.accumulated_inputs.append(input) self.accumulated_targets.append(target) - def compute(self, name: Optional[str] = None) -> Metrics: + def compute(self, name: str | None = None) -> Metrics: """ Compute metric on accumulated input and output over updates. Args: - name (Optional[str]): Optional name used in conjunction with class attribute name + name (str | None): Optional name used in conjunction with class attribute name to define key in metrics dictionary. Raises: @@ -188,8 +188,8 @@ class TransformsMetric(Metric): def __init__( self, metric: Metric, - pred_transforms: Optional[Sequence[TorchTransformFunction]] = None, - target_transforms: Optional[Sequence[TorchTransformFunction]] = None, + pred_transforms: Sequence[TorchTransformFunction] | None = None, + target_transforms: Sequence[TorchTransformFunction] | None = None, ) -> None: """ A thin wrapper class to allow transforms to be applied to preds and @@ -197,11 +197,11 @@ def __init__( Args: metric (Metric): A FL4Health compatible metric - pred_transforms (Optional[Sequence[TorchTransformFunction]], optional): A + pred_transforms (Sequence[TorchTransformFunction] | None, optional): A list of transform functions to apply to the model predictions before computing the metrics. Each callable must accept and return a torch. Tensor. Use partial to set other arguments. - target_transforms (Optional[Sequence[TorchTransformFunction]], optional): A + target_transforms (Sequence[TorchTransformFunction] | None, optional): A list of transform functions to apply to the targets before computing the metrics. Each callable must accept and return a torch.Tensor. Use partial to set other arguments. @@ -220,7 +220,7 @@ def update(self, pred: torch.Tensor, target: torch.Tensor) -> None: self.metric.update(pred, target) - def compute(self, name: Optional[str]) -> Metrics: + def compute(self, name: str | None) -> Metrics: return self.metric.compute(name) def clear(self) -> None: @@ -232,8 +232,8 @@ def __init__( self, name: str = "BinarySoftDiceCoefficient", epsilon: float = 1.0e-7, - spatial_dimensions: Tuple[int, ...] = (2, 3, 4), - logits_threshold: Optional[float] = 0.5, + spatial_dimensions: tuple[int, ...] = (2, 3, 4), + logits_threshold: float | None = 0.5, ): """ Binary DICE Coefficient Metric with configurable spatial dimensions and logits threshold. @@ -241,7 +241,7 @@ def __init__( Args: name (str): Name of the metric. epsilon (float): Small float to add to denominator of DICE calculation to avoid divide by 0. - spatial_dimensions (Tuple[int, ...]): The spatial dimensions of the image within the prediction tensors. + spatial_dimensions (tuple[int, ...]): The spatial dimensions of the image within the prediction tensors. The default assumes that the images are 3D and have shape: batch_size, channel, spatial, spatial, spatial. logits_threshold: This is a threshold value where values above are classified as 1 @@ -333,7 +333,7 @@ class F1(SimpleMetric): def __init__( self, name: str = "F1 score", - average: Optional[str] = "weighted", + average: str | None = "weighted", ): """ Computes the F1 score using the sklearn f1_score function. As such, the values of average correspond to @@ -341,7 +341,7 @@ def __init__( Args: name (str, optional): Name of the metric. Defaults to "F1 score". - average (Optional[str], optional): Whether to perform averaging of the F1 scores and how. The values of + average (str | None, optional): Whether to perform averaging of the F1 scores and how. The values of this string corresponds to those of the sklearn f1_score function. See: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html Defaults to "weighted". @@ -369,7 +369,7 @@ def __init__(self, metrics: Sequence[Metric], metric_manager_name: str) -> None: """ self.original_metrics = metrics self.metric_manager_name = metric_manager_name - self.metrics_per_prediction_type: Dict[str, Sequence[Metric]] = {} + self.metrics_per_prediction_type: dict[str, Sequence[Metric]] = {} def update(self, preds: TorchPredType, target: TorchTargetType) -> None: """ @@ -425,7 +425,7 @@ def clear(self) -> None: self.metrics_per_prediction_type = {} def check_target_prediction_keys_equal( - self, preds: Dict[str, torch.Tensor], target: Dict[str, torch.Tensor] + self, preds: dict[str, torch.Tensor], target: dict[str, torch.Tensor] ) -> None: assert target.keys() == preds.keys(), ( "Received a dict with multiple targets, but the keys of the " diff --git a/fl4health/utils/nnunet_utils.py b/fl4health/utils/nnunet_utils.py index f07c16fbd..e3169e979 100644 --- a/fl4health/utils/nnunet_utils.py +++ b/fl4health/utils/nnunet_utils.py @@ -3,11 +3,12 @@ import signal import sys import warnings +from collections.abc import Callable, Sequence from enum import Enum from importlib import reload from logging import DEBUG, INFO, WARN, Logger from math import ceil -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, no_type_check +from typing import Any, no_type_check import numpy as np import torch @@ -64,7 +65,7 @@ def use_default_signal_handlers(fn: Callable) -> Callable: flwr 1.9.0 overrides the default signal handlers with handlers that raise an error on any interruption or termination. Since nnunet spawns child processes which inherit these handlers, when those subprocesses are terminated (which is expected - behaviour), the flwr signal handlers raise an error (which we don't want). + behavior), the flwr signal handlers raise an error (which we don't want). Flwr is expected to fix this in the next release. See the following issue: https://github.com/adap/flower/issues/3837 @@ -129,8 +130,8 @@ def set_nnunet_env(verbose: bool = False, **kwargs: str) -> None: # The two convert deepsupervision methods are necessary because fl4health requires # predictions, targets and inputs to be single torch.Tensors or Dicts of torch.Tensors def convert_deep_supervision_list_to_dict( - tensor_list: Union[List[torch.Tensor], Tuple[torch.Tensor]], num_spatial_dims: int -) -> Dict[str, torch.Tensor]: + tensor_list: list[torch.Tensor] | tuple[torch.Tensor], num_spatial_dims: int +) -> dict[str, torch.Tensor]: """ Converts a list of torch.Tensors to a dictionary. Names the keys for each tensor based on the spatial resolution of the tensor and its @@ -139,12 +140,12 @@ def convert_deep_supervision_list_to_dict( spatial dimensions of the tensors are last. Args: - tensor_list (List[torch.Tensor]): A list of tensors, usually either + tensor_list (list[torch.Tensor]): A list of tensors, usually either nnunet model outputs or targets, to be converted into a dictionary num_spatial_dims (int): The number of spatial dimensions. Assumes the spatial dimensions are last Returns: - Dict[str, torch.Tensor]: A dictionary containing the tensors as + dict[str, torch.Tensor]: A dictionary containing the tensors as values where the keys are 'i-XxYxZ' where i was the tensor's index in the list and X,Y,Z are the spatial dimensions of the tensor """ @@ -159,19 +160,19 @@ def convert_deep_supervision_list_to_dict( return tensors -def convert_deep_supervision_dict_to_list(tensor_dict: Dict[str, torch.Tensor]) -> List[torch.Tensor]: +def convert_deep_supervision_dict_to_list(tensor_dict: dict[str, torch.Tensor]) -> list[torch.Tensor]: """ Converts a dictionary of tensors back into a list so that it can be used by nnunet deep supervision loss functions Args: - tensor_dict (Dict[str, torch.Tensor]): Dictionary containing + tensor_dict (dict[str, torch.Tensor]): Dictionary containing torch.Tensors. The key values must start with 'X-' where X is an integer representing the index at which the tensor should be placed in the output list Returns: - List[torch.Tensor]: A list of torch.Tensors + list[torch.Tensor]: A list of torch.Tensors """ sorted_list = sorted(tensor_dict.items(), key=lambda x: int(x[0].split("-")[0])) return [tensor for key, tensor in sorted_list] @@ -252,15 +253,15 @@ def get_dataset_n_voxels(source_plans: dict, n_cases: int) -> float: return approx_n_voxels -def prepare_loss_arg(tensor: torch.Tensor | Dict[str, torch.Tensor]) -> Union[torch.Tensor, List[torch.Tensor]]: +def prepare_loss_arg(tensor: torch.Tensor | dict[str, torch.Tensor]) -> torch.Tensor | list[torch.Tensor]: """ Converts pred and target tensors into the proper data type to be passed to the nnunet loss functions. Args: - tensor (torch.Tensor | Dict[str, torch.Tensor]): The input tensor + tensor (torch.Tensor | dict[str, torch.Tensor]): The input tensor Returns: - torch.Tensor | List[torch.Tensor]: The tensor ready to be passed to the loss + torch.Tensor | list[torch.Tensor]: The tensor ready to be passed to the loss function. A single tensor if not using deep supervision and a list of tensors if deep supervision is on. """ @@ -277,26 +278,22 @@ def prepare_loss_arg(tensor: torch.Tensor | Dict[str, torch.Tensor]) -> Union[to class nnUNetDataLoaderWrapper(DataLoader): def __init__( self, - nnunet_augmenter: Union[SingleThreadedAugmenter, NonDetMultiThreadedAugmenter, MultiThreadedAugmenter], - nnunet_config: Union[NnunetConfig, str], + nnunet_augmenter: SingleThreadedAugmenter | NonDetMultiThreadedAugmenter | MultiThreadedAugmenter, + nnunet_config: NnunetConfig | str, infinite: bool = False, ) -> None: """ - Wraps nnunet dataloader classes using the pytorch dataloader to make - them pytorch compatible. Also handles some unique stuff specific to - nnunet such as deep supervision and infinite dataloaders. The nnunet - dataloaders should only be used for training and validation, not final testing. + Wraps nnunet dataloader classes using the pytorch dataloader to make them pytorch compatible. Also handles + some unique stuff specific to nnunet such as deep supervision and infinite dataloaders. The nnunet dataloaders + should only be used for training and validation, not final testing. Args: - nnunet_dataloader (Union[SingleThreadedAugmenter, - NonDetMultiThreadedAugmenter]): The dataloader used by nnunet - nnunet_config (NnUNetConfig): The nnunet config. Enum type helps - ensure that nnunet config is valid - infinite (bool, optional): Whether or not to treat the dataset - as infinite. The dataloaders sample data with replacement - either way. The only difference is that if set to False, a - StopIteration is generated after num_samples/batch_size steps. - Defaults to False. + nnunet_dataloader (SingleThreadedAugmenter | NonDetMultiThreadedAugmenter | MultiThreadedAugmenter): The + dataloader used by nnunet + nnunet_config (NnUNetConfig): The nnunet config. Enum type helps ensure that nnunet config is valid + infinite (bool, optional): Whether or not to treat the dataset as infinite. The dataloaders sample data + with replacement either way. The only difference is that if set to False, a StopIteration is + generated after num_samples/batch_size steps. Defaults to False. """ # The augmenter is a wrapper on the nnunet dataloader self.nnunet_augmenter = nnunet_augmenter @@ -317,7 +314,7 @@ def __init__( self.current_step = 0 self.infinite = infinite - def __next__(self) -> Tuple[torch.Tensor, Union[torch.Tensor, Dict[str, torch.Tensor]]]: + def __next__(self) -> tuple[torch.Tensor, torch.Tensor | dict[str, torch.Tensor]]: if not self.infinite and self.current_step == self.__len__(): self.reset() raise StopIteration # Raise stop iteration after epoch has completed @@ -328,7 +325,7 @@ def __next__(self) -> Tuple[torch.Tensor, Union[torch.Tensor, Dict[str, torch.Te # segmentations at various spatial scales/resolutions # nnUNet has a wrapper for loss functions to enable deep supervision inputs: torch.Tensor = batch["data"] - targets: Union[torch.Tensor, List[torch.Tensor]] = batch["target"] + targets: torch.Tensor | list[torch.Tensor] = batch["target"] if isinstance(targets, list): target_dict = convert_deep_supervision_list_to_dict(targets, self.num_spatial_dims) return inputs, target_dict @@ -390,7 +387,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: class StreamToLogger(io.StringIO): - def __init__(self, logger: Logger, level: Union[LogLevel, int]) -> None: + def __init__(self, logger: Logger, level: LogLevel | int) -> None: """ File-like stream object that redirects writes to a logger. Useful for redirecting stdout to a logger. @@ -429,8 +426,8 @@ def __init__( optimizer (Optimizer): The optimizer to apply LR scheduler to. initial_lr (float): The initial learning rate of the optimizer. max_steps (int): The maximum total number of steps across all FL rounds. - exponent (float): Controls how quickly LR descreases over time. Higher values - lead to more rapdid descent. Defaults to 0.9. + exponent (float): Controls how quickly LR decreases over time. Higher values + lead to more rapid descent. Defaults to 0.9. steps_per_lr (int): The number of steps per LR before decaying. (ie 10 means the LR will be constant for 10 steps prior to being decreased to the subsequent value). Defaults to 250 as that is the default for nnunet (decay LR once an epoch and epoch is 250 steps). diff --git a/fl4health/utils/parameter_extraction.py b/fl4health/utils/parameter_extraction.py index e6fd4bb09..7062e3953 100644 --- a/fl4health/utils/parameter_extraction.py +++ b/fl4health/utils/parameter_extraction.py @@ -1,4 +1,4 @@ -from typing import Iterable +from collections.abc import Iterable import torch import torch.nn as nn diff --git a/fl4health/utils/partitioners.py b/fl4health/utils/partitioners.py index b9b2492a2..a6212f15b 100644 --- a/fl4health/utils/partitioners.py +++ b/fl4health/utils/partitioners.py @@ -1,6 +1,6 @@ import math from logging import INFO, WARN -from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import Generic, TypeVar import numpy as np import torch @@ -9,17 +9,17 @@ from fl4health.utils.dataset import DictionaryDataset, TensorDataset, select_by_indices T = TypeVar("T") -D = TypeVar("D", bound=Union[TensorDataset, DictionaryDataset]) +D = TypeVar("D", bound=TensorDataset | DictionaryDataset) class DirichletLabelBasedAllocation(Generic[T]): def __init__( self, number_of_partitions: int, - unique_labels: List[T], - min_label_examples: Optional[int] = None, - beta: Optional[float] = None, - prior_distribution: Optional[Dict[T, np.ndarray]] = None, + unique_labels: list[T], + min_label_examples: int | None = None, + beta: float | None = None, + prior_distribution: dict[T, np.ndarray] | None = None, ) -> None: """ The class supports partitioning of a dataset into a set of datasets (of the same type) via Dirichlet @@ -47,17 +47,17 @@ def __init__( Args: number_of_partitions (int): Number of new datasets that we want to break the current dataset into - unique_labels (List[T]): This is the set of labels through which we'll iterate to perform allocation - min_label_examples (Optional[int], optional): This is an optional input if you want to ensure a minimum + unique_labels (list[T]): This is the set of labels through which we'll iterate to perform allocation + min_label_examples (int | None, optional): This is an optional input if you want to ensure a minimum number of labels is present on each partition. If prior distribution is provided, this is ignored. NOTE: This does not guarantee feasibility. That is, if you have a very small beta and request a large minimum number here, you are unlikely to satisfy this request. In partitioning, if the minimum isn't satisfied, we resample from the Dirichlet distribution. This is repeated some limited number of times. Otherwise the partitioner "gives up". Defaults to None. - beta (Optional[float]): This controls the heterogeneity of the partition allocations. The smaller the beta, + beta (float | None): This controls the heterogeneity of the partition allocations. The smaller the beta, the more skewed the label assignments will be to different clients. It is mutually exclusive with given prior distribution. - prior_distribution (Optional[Dict[T, np.ndarray]], optional): This is an optional input if you want to + prior_distribution (dict[T, np.ndarray] | None, optional): This is an optional input if you want to provide a prior distribution for the Dirichlet distribution. This is useful if you want to make sure that the partitioning of test data is similar to the partitioning of the training data. Defaults to None. It is mutually exclusive with the beta parameter and min_label_examples. @@ -84,7 +84,7 @@ def __init__( def partition_label_indices( self, label: T, label_indices: torch.Tensor - ) -> Tuple[List[torch.Tensor], int, np.ndarray]: + ) -> tuple[list[torch.Tensor], int, np.ndarray]: """ Given a set of indices from the dataset corresponding to a particular label, the indices are allocated using a Dirichlet distribution, to the partitions. @@ -95,7 +95,7 @@ def partition_label_indices( that the tensor is 1D and it's len constitutes the number of total datapoints with the label. Returns: - List[torch.Tensor]: partitioned indices of datapoints with the corresponding label. + list[torch.Tensor]: partitioned indices of datapoints with the corresponding label. int: The minimum number of data points assigned to a partition. np.ndarray: The Dirichlet distribution used to partition the data points. """ @@ -149,8 +149,8 @@ def partition_label_indices( return partitioned_indices[:-1], min_samples, partition_allocations def partition_dataset( - self, original_dataset: D, max_retries: Optional[int] = 5 - ) -> Tuple[List[D], Dict[T, np.ndarray]]: + self, original_dataset: D, max_retries: int | None = 5 + ) -> tuple[list[D], dict[T, np.ndarray]]: """ Attempts partitioning of the original dataset up to max_retries times. Retries are potentially required if the user requests a minimum number of labels be assigned to each of the partitions. If the drawn Dirichlet @@ -159,7 +159,7 @@ def partition_dataset( Args: original_dataset (D): The dataset to be partitioned - max_retries (Optional[int], optional): Number of times to attempt to satisfy a user provided minimum + max_retries (int | None, optional): Number of times to attempt to satisfy a user provided minimum label-associated data points per partition. Set this value to None if you want to retry indefinitely. Defaults to 5. @@ -167,8 +167,8 @@ def partition_dataset( ValueError: Throws this error if the retries have been exhausted and the user provided minimum is not met. Returns: - Tuple[List[D], Dict[T, np.ndarray]]: List[D] is the partitioned datasets, length should correspond to - self.number_of_partitions. Dict[T, np.ndarray] is the Dirichlet distribution used to partition the data + tuple[list[D], dict[T, np.ndarray]]: list[D] is the partitioned datasets, length should correspond to + self.number_of_partitions. dict[T, np.ndarray] is the Dirichlet distribution used to partition the data points for each label. """ @@ -177,7 +177,7 @@ def partition_dataset( partitioned_indices = [torch.Tensor([]).int() for _ in range(self.number_of_partitions)] partition_attempts = 0 - partitioned_probabilities: Dict[T, np.ndarray] = {} + partitioned_probabilities: dict[T, np.ndarray] = {} for label in self.unique_labels: label_indices = torch.where(targets == label)[0].int() min_selected_labels = -1 diff --git a/fl4health/utils/privacy_utilities.py b/fl4health/utils/privacy_utilities.py index d481bdd12..c21f8e8e0 100644 --- a/fl4health/utils/privacy_utilities.py +++ b/fl4health/utils/privacy_utilities.py @@ -1,5 +1,5 @@ from logging import INFO, WARNING -from typing import Any, Tuple +from typing import Any import torch.nn as nn from flwr.common.logger import log @@ -8,7 +8,7 @@ from opacus.validators import ModuleValidator -def privacy_validate_and_fix_modules(model: nn.Module) -> Tuple[nn.Module, bool]: +def privacy_validate_and_fix_modules(model: nn.Module) -> tuple[nn.Module, bool]: """ This function runs Opacus model validation to ensure that the provided models layers are compatible with the privacy mechanisms in Opacus. The function attempts to use Opacus to replace any incompatible layers if possible. @@ -20,7 +20,7 @@ def privacy_validate_and_fix_modules(model: nn.Module) -> Tuple[nn.Module, bool] model (nn.Module): The model to be validated and potentially modified to be Opacus compliant. Returns: - Tuple[nn.Module, bool]: Returns a (possibly) modified pytorch model and a boolean indicating whether a + tuple[nn.Module, bool]: Returns a (possibly) modified pytorch model and a boolean indicating whether a reinitialization of any optimizers associated with the model will be required. Reinitialization of the optimizer parameters is required, for example, when the model layers are modified, yielding a mismatch in the optimizer parameters and the new model parameters. diff --git a/fl4health/utils/random.py b/fl4health/utils/random.py index a1f8237b5..b9afc24e3 100644 --- a/fl4health/utils/random.py +++ b/fl4health/utils/random.py @@ -1,7 +1,7 @@ import random import uuid from logging import INFO -from typing import Any, Dict, Optional, Tuple +from typing import Any import numpy as np import torch @@ -9,7 +9,7 @@ def set_all_random_seeds( - seed: Optional[int] = 42, use_deterministic_torch_algos: bool = False, disable_torch_benchmarking: bool = False + seed: int | None = 42, use_deterministic_torch_algos: bool = False, disable_torch_benchmarking: bool = False ) -> None: """ Set seeds for python random, numpy random, and pytorch random. It also offers the option to force pytorch to use @@ -24,7 +24,7 @@ def set_all_random_seeds( here: https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility Args: - seed (Optional[int], optional): The seed value to be used for random number generators. Default is 42. Seed + seed (int | None, optional): The seed value to be used for random number generators. Default is 42. Seed setting will no-op if the seed is explicitly set to None use_deterministic_torch_algos (bool, optional): Whether or not to set torch.use_deterministic_algorithms to True. Defaults to False. @@ -60,13 +60,13 @@ def unset_all_random_seeds() -> None: torch.use_deterministic_algorithms(False) -def save_random_state() -> Tuple[Tuple[Any, ...], Dict[str, Any], torch.Tensor]: +def save_random_state() -> tuple[tuple[Any, ...], dict[str, Any], torch.Tensor]: """ Save the state of the random number generators for Python, NumPy, and PyTorch. This will allow you to restore the state of the random number generators at a later time. Returns: - Tuple[Tuple[Any, ...], Dict[str, Any], torch.Tensor]: A tuple containing the state of the random number + tuple[tuple[Any, ...], dict[str, Any], torch.Tensor]: A tuple containing the state of the random number generators for Python, NumPy, and """ log(INFO, "Saving random state.") @@ -77,15 +77,15 @@ def save_random_state() -> Tuple[Tuple[Any, ...], Dict[str, Any], torch.Tensor]: def restore_random_state( - random_state: Tuple[Any, ...], numpy_state: Dict[str, Any], torch_state: torch.Tensor + random_state: tuple[Any, ...], numpy_state: dict[str, Any], torch_state: torch.Tensor ) -> None: """ Restore the state of the random number generators for Python, NumPy, and PyTorch. This will allow you to restore the state of the random number generators to a previously saved state. Args: - random_state (Tuple[Any, ...]): The state of the Python random number generator - numpy_state (Dict[str, Any]): The state of the NumPy random number generator + random_state (tuple[Any, ...]): The state of the Python random number generator + numpy_state (dict[str, Any]): The state of the NumPy random number generator torch_state (torch.Tensor): The state of the PyTorch random number generator """ log(INFO, "Restoring random state.") diff --git a/fl4health/utils/sampler.py b/fl4health/utils/sampler.py index caa53844d..c9f7ef0ad 100644 --- a/fl4health/utils/sampler.py +++ b/fl4health/utils/sampler.py @@ -1,7 +1,8 @@ import math from abc import ABC, abstractmethod +from collections.abc import Set from logging import INFO, WARN -from typing import Any, List, Optional, Set, TypeVar, Union +from typing import Any, TypeVar import numpy as np import torch @@ -10,17 +11,17 @@ from fl4health.utils.dataset import DictionaryDataset, TensorDataset, select_by_indices T = TypeVar("T") -D = TypeVar("D", bound=Union[TensorDataset, DictionaryDataset]) +D = TypeVar("D", bound=TensorDataset | DictionaryDataset) class LabelBasedSampler(ABC): - def __init__(self, unique_labels: List[Any]) -> None: + def __init__(self, unique_labels: list[Any]) -> None: """ This is an abstract class to be extended to create dataset samplers based on the class of samples. Args: - unique_labels (List[Any]): The full set of labels contained in the dataset. + unique_labels (list[Any]): The full set of labels contained in the dataset. """ self.unique_labels = unique_labels self.num_classes = len(self.unique_labels) @@ -31,7 +32,7 @@ def subsample(self, dataset: D) -> D: class MinorityLabelBasedSampler(LabelBasedSampler): - def __init__(self, unique_labels: List[T], downsampling_ratio: float, minority_labels: Set[T]) -> None: + def __init__(self, unique_labels: list[T], downsampling_ratio: float, minority_labels: Set[T]) -> None: """ This class is used to subsample a dataset so the classes are distributed in a non-IID way. In particular, the MinorityLabelBasedSampler explicitly downsamples classes based on the @@ -40,7 +41,7 @@ def __init__(self, unique_labels: List[T], downsampling_ratio: float, minority_l the resulting subsampled dataset. Args: - unique_labels (List[T]): The full set of labels contained in the dataset. + unique_labels (list[T]): The full set of labels contained in the dataset. downsampling_ratio (float): The percentage to which the specified "minority" labels are downsampled. For example, if a label L has 10 examples and the downsampling_ratio is 0.2, then 8 of the datapoints with label L are discarded. @@ -61,7 +62,7 @@ def subsample(self, dataset: D) -> D: D: New dataset with downsampled labels. """ assert dataset.targets is not None, "A label-based sampler requires targets but this dataset has no targets" - selected_indices_list: List[torch.Tensor] = [] + selected_indices_list: list[torch.Tensor] = [] for label in self.unique_labels: # Get indices of samples equal to the current label indices_of_label = (dataset.targets == label).nonzero() @@ -99,8 +100,8 @@ def _get_random_subsample(self, tensor_to_subsample: torch.Tensor, subsample_siz class DirichletLabelBasedSampler(LabelBasedSampler): def __init__( self, - unique_labels: List[Any], - hash_key: Optional[int] = None, + unique_labels: list[Any], + hash_key: int | None = None, sample_percentage: float = 0.5, beta: float = 100, ) -> None: @@ -119,12 +120,12 @@ class used to subsample a dataset so the classes of samples are distributed in a np.random.dirichlet([1000]*5): array([0.2066252 , 0.19644968, 0.20080513, 0.19992536, 0.19619462]) Args: - unique_labels (List[Any]): The full set of labels contained in the dataset. + unique_labels (list[Any]): The full set of labels contained in the dataset. sample_percentage (float, optional): The downsampling of the entire dataset to do. For example, if this value is 0.5 and the dataset is of size 100, we will end up with 50 total data points. Defaults to 0.5. beta (float, optional): This controls the heterogeneity of the label sampling. The smaller the beta, the more skewed the label assignments will be for the dataset. Defaults to 100. - hash_key (Optional[int], optional): Seed for the random number generators and samplers. Defaults to None. + hash_key (int | None, optional): Seed for the random number generators and samplers. Defaults to None. """ super().__init__(unique_labels) diff --git a/fl4health/utils/typing.py b/fl4health/utils/typing.py index a298bd86d..1b129ac9d 100644 --- a/fl4health/utils/typing.py +++ b/fl4health/utils/typing.py @@ -1,7 +1,6 @@ import logging from collections.abc import Callable from enum import Enum -from typing import List, Tuple, Union import torch import torch.nn as nn @@ -16,8 +15,8 @@ TorchTransformFunction = Callable[[torch.Tensor], torch.Tensor] LayerSelectionFunction = Callable[[nn.Module, nn.Module | None], tuple[NDArrays, list[str]]] -FitFailures = List[Union[Tuple[ClientProxy, FitRes], BaseException]] -EvaluateFailures = List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] +FitFailures = list[tuple[ClientProxy, FitRes] | BaseException] +EvaluateFailures = list[tuple[ClientProxy, EvaluateRes] | BaseException] class LogLevel(Enum): diff --git a/mypy_disallow_legacy_types.py b/mypy_disallow_legacy_types.py new file mode 100644 index 000000000..e8e366bfc --- /dev/null +++ b/mypy_disallow_legacy_types.py @@ -0,0 +1,67 @@ +import re +import sys +from collections.abc import Set + +# List of files that we want to skip with this check. Currently empty. +files_to_ignore: Set[str] = {"Empty"} +file_types_to_ignore: Set[str] = {".png", ".pkl", ".pt", ".md"} +# List of disallowed types to search for that should no longer be imported from the typing library. These types +# have been migrated to either collections.abc or into core python +disallowed_types = [ + "Union", + "Optional", + "List", + "Dict", + "Sequence", + "Set", + "Callable", + "Iterable", + "Hashable", + "Generator", + "Tuple", + "Mapping", + "Type", +] + + +type_or = "|".join(disallowed_types) +comma_separated_types = ", ".join(disallowed_types) +file_suffixes = "|".join(file_types_to_ignore) +file_type_regex = rf".*({file_suffixes})$" + + +def filter_files_to_ignore(file_paths: list[str]) -> list[str]: + file_paths = [file_path for file_path in file_paths if file_path not in files_to_ignore] + file_paths = [file_path for file_path in file_paths if not re.match(file_type_regex, file_path)] + return file_paths + + +def construct_same_line_import_regex() -> str: + return rf"from typing import ([^\n]*?, )*({type_or})(\n|, [^\n]*?\n)" + + +def construct_multi_line_import_regex() -> str: + return rf"from typing import \(\n(\s{{4}}.*,\n)*\s{{4}}({type_or}),\n(\s{{4}}.*,\n)*\)$" + + +same_line_import_re = construct_same_line_import_regex() +multi_line_import_re = construct_multi_line_import_regex() + + +def discover_legacy_imports(file_paths: list[str]) -> None: + file_paths = filter_files_to_ignore(file_paths=file_paths) + for file_path in file_paths: + with open(file_path, mode="r") as file_handle: + file_contents = file_handle.read() + same_line_match = re.search(same_line_import_re, file_contents, flags=re.MULTILINE) + multi_line_match = re.search(multi_line_import_re, file_contents, flags=re.MULTILINE) + if same_line_match or multi_line_match: + raise ValueError( + f"A legacy mypy type is being imported in file {file_path}. " + f"Disallowed imports from the typing library are: {comma_separated_types}" + ) + + +if __name__ == "__main__": + file_relative_paths = sys.argv[1:] + discover_legacy_imports(file_relative_paths) diff --git a/research/ag_news/client_data.py b/research/ag_news/client_data.py index 72c61e648..3753879a3 100644 --- a/research/ag_news/client_data.py +++ b/research/ag_news/client_data.py @@ -1,6 +1,5 @@ from functools import partial from pathlib import Path -from typing import Dict, List, Optional, Tuple import datasets import torch @@ -13,8 +12,8 @@ def collate_fn_with_padding( - tokenizer: PreTrainedTokenizer, batch: List[Tuple[Dict[str, List[torch.Tensor]], torch.Tensor]] -) -> Tuple[Dict[str, List[torch.Tensor]], torch.Tensor]: + tokenizer: PreTrainedTokenizer, batch: list[tuple[dict[str, list[torch.Tensor]], torch.Tensor]] +) -> tuple[dict[str, list[torch.Tensor]], torch.Tensor]: """ Pad the sequences within a batch to the same length. """ @@ -26,7 +25,7 @@ def collate_fn_with_padding( def create_text_classification_dataset( - dataset: datasets.Dataset, column_names: List[str], target_name: str + dataset: datasets.Dataset, column_names: list[str], target_name: str ) -> DictionaryDataset: data_dict = {} for column_name in column_names: @@ -36,8 +35,8 @@ def create_text_classification_dataset( def construct_dataloaders( - batch_size: int, sample_percentage: float, beta: float, data_path: Optional[Path] = None -) -> Tuple[DataLoader, DataLoader]: + batch_size: int, sample_percentage: float, beta: float, data_path: Path | None = None +) -> tuple[DataLoader, DataLoader]: assert 0 <= sample_percentage <= 1 and beta > 0 sampler = DirichletLabelBasedSampler(list(range(4)), sample_percentage=sample_percentage, beta=beta) diff --git a/research/ag_news/dynamic_layer_exchange/client.py b/research/ag_news/dynamic_layer_exchange/client.py index 4dc920aee..75bde7254 100644 --- a/research/ag_news/dynamic_layer_exchange/client.py +++ b/research/ag_news/dynamic_layer_exchange/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -38,10 +38,10 @@ def __init__( exchange_percentage: float, norm_threshold: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, store_initial_model: bool = True, ) -> None: super().__init__( @@ -65,7 +65,7 @@ def get_model(self, config: Config) -> nn.Module: model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=num_classes) return model.to(self.device) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sample_percentage = narrow_dict_type(config, "sample_percentage", float) beta = narrow_dict_type(config, "beta", float) @@ -106,7 +106,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: parameter_exchanger = DynamicLayerExchanger(layer_selection_function=layer_selection_function) return parameter_exchanger - def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + def predict(self, input: TorchInputType) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: # Here the predict method is overwritten in order # to rename the key to match what comes with the hugging face datasets. outputs, features = super().predict(input) diff --git a/research/ag_news/dynamic_layer_exchange/run_fold_experiment.slrm b/research/ag_news/dynamic_layer_exchange/run_fold_experiment.slrm index c384dac5b..c5c0b0683 100644 --- a/research/ag_news/dynamic_layer_exchange/run_fold_experiment.slrm +++ b/research/ag_news/dynamic_layer_exchange/run_fold_experiment.slrm @@ -89,7 +89,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/ag_news/dynamic_layer_exchange/server.py b/research/ag_news/dynamic_layer_exchange/server.py index adb66adc6..0c09014ec 100644 --- a/research/ag_news/dynamic_layer_exchange/server.py +++ b/research/ag_news/dynamic_layer_exchange/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -66,7 +66,7 @@ def fit_config( ) -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/ag_news/sparse_tensor_exchange/client.py b/research/ag_news/sparse_tensor_exchange/client.py index 2b1aba8d9..4dd08ff4a 100644 --- a/research/ag_news/sparse_tensor_exchange/client.py +++ b/research/ag_news/sparse_tensor_exchange/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -37,10 +37,10 @@ def __init__( learning_rate: float, sparsity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, store_initial_model: bool = True, ) -> None: super().__init__( @@ -62,7 +62,7 @@ def get_model(self, config: Config) -> nn.Module: model = BertForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=num_classes) return model.to(self.device) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) sample_percentage = narrow_dict_type(config, "sample_percentage", float) beta = narrow_dict_type(config, "beta", float) @@ -84,7 +84,7 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger: ) return parameter_exchanger - def predict(self, input: TorchInputType) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + def predict(self, input: TorchInputType) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: outputs, features = super().predict(input) preds = {} preds["prediction"] = outputs["logits"] diff --git a/research/ag_news/sparse_tensor_exchange/run_fold_experiment.slrm b/research/ag_news/sparse_tensor_exchange/run_fold_experiment.slrm index f1fd74f7a..c3498a758 100644 --- a/research/ag_news/sparse_tensor_exchange/run_fold_experiment.slrm +++ b/research/ag_news/sparse_tensor_exchange/run_fold_experiment.slrm @@ -89,7 +89,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/ag_news/sparse_tensor_exchange/server.py b/research/ag_news/sparse_tensor_exchange/server.py index 6372df342..b84ae2453 100644 --- a/research/ag_news/sparse_tensor_exchange/server.py +++ b/research/ag_news/sparse_tensor_exchange/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -54,7 +54,7 @@ def fit_config( ) -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/README.md b/research/cifar10/README.md index 88bcb4ab6..b2d11be93 100644 --- a/research/cifar10/README.md +++ b/research/cifar10/README.md @@ -5,12 +5,12 @@ The CIFAR-10 dataset consists of 60,000 32x32 color images across 10 classes, wi To do so we should run the following script: ```bash -python -m research.cifar10.preprocess --dataset_dir path_to_folder_for_dataset --save_dataset_dir path_to_save_partiotioned_dataset --seed seed --beta beta --num_clients num_clients +python -m research.cifar10.preprocess --dataset_dir path_to_folder_for_dataset --save_dataset_dir path_to_save_partitioned_dataset --seed seed --beta beta --num_clients num_clients ``` Where: -- `path_to_folder_for_datasett` is the path to the CIFAR-10 dataset. -- `path_to_save_partiotioned_dataset` is the path to save the partitioned dataset. +- `path_to_folder_for_dataset` is the path to the CIFAR-10 dataset. +- `path_to_save_partitioned_dataset` is the path to save the partitioned dataset. - `seed` is the seed to use for the random number generator to have reproducible splits. - `beta` is the heterogeneity level of the dataset. The lower the value, the more heterogeneity in the data distribution. - `num_clients` is the number of clients to partition the dataset into. @@ -76,4 +76,4 @@ Where: - `--eval_best_global_model` tells the evaluation script to search for the saved best global model on the server side. It looks for a model named `server_best_model.pkl` and evaluates it across all clients' individual datasets. - `--eval_last_global_model` tells the evaluation script to search for the saved last global model on the server side. It looks for a model named `server_last_model.pkl` and evaluates it across all clients' individual datasets. - `--eval_over_aggregated_test_data` tells the evaluation script to evaluate any model from the previous steps on the pooled test data. -- `--use_partitioned_data` tells the evaluation script to use preporcessed partitioned data for evaluation. If this flag is not set, the script will use the original CIFAR-10 dataset and partition a subset of data for each client based on a fixed seed. +- `--use_partitioned_data` tells the evaluation script to use preprocessed partitioned data for evaluation. If this flag is not set, the script will use the original CIFAR-10 dataset and partition a subset of data for each client based on a fixed seed. diff --git a/research/cifar10/adaptive_pfl/ditto/client.py b/research/cifar10/adaptive_pfl/ditto/client.py index adf14587d..fc8b0c5be 100644 --- a/research/cifar10/adaptive_pfl/ditto/client.py +++ b/research/cifar10/adaptive_pfl/ditto/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -35,10 +35,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -56,7 +56,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = get_preprocessed_data( self.data_path, self.client_number, batch_size, self.heterogeneity_level @@ -66,7 +66,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) return {"global": global_optimizer, "local": local_optimizer} diff --git a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm index fb3a6898d..ee7cf4ebb 100644 --- a/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/ditto/run_fold_experiment.slrm @@ -62,7 +62,7 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi -# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# This environment variable must be set in order to force torch to use deterministic algorithms. See documentation # in fl4health/utils/random.py for more information export CUBLAS_WORKSPACE_CONFIG=:4096:8 @@ -102,7 +102,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/adaptive_pfl/ditto/server.py b/research/cifar10/adaptive_pfl/ditto/server.py index de04aceb0..376f22a3f 100644 --- a/research/cifar10/adaptive_pfl/ditto/server.py +++ b/research/cifar10/adaptive_pfl/ditto/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_weight: bool) -> None: +def main(config: dict[str, Any], server_address: str, lam: float, adapt_loss_weight: bool) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/adaptive_pfl/fedprox/client.py b/research/cifar10/adaptive_pfl/fedprox/client.py index 3c8de5279..0d02b012f 100644 --- a/research/cifar10/adaptive_pfl/fedprox/client.py +++ b/research/cifar10/adaptive_pfl/fedprox/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -35,10 +35,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -56,7 +56,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = get_preprocessed_data( self.data_path, self.client_number, batch_size, self.heterogeneity_level diff --git a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm index 4e63da1c3..88c66fed7 100644 --- a/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fedprox/run_fold_experiment.slrm @@ -62,7 +62,7 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi -# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# This environment variable must be set in order to force torch to use deterministic algorithms. See documentation # in fl4health/utils/random.py for more information export CUBLAS_WORKSPACE_CONFIG=:4096:8 @@ -102,7 +102,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/adaptive_pfl/fedprox/server.py b/research/cifar10/adaptive_pfl/fedprox/server.py index e204aa575..58114a621 100644 --- a/research/cifar10/adaptive_pfl/fedprox/server.py +++ b/research/cifar10/adaptive_pfl/fedprox/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -37,7 +37,7 @@ def fit_config( def main( - config: Dict[str, Any], + config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/client.py b/research/cifar10/adaptive_pfl/fenda_ditto/client.py index 447e2b9b0..d4b29bbac 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/client.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -36,10 +36,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, freeze_global_feature_extractor: bool = False, ) -> None: super().__init__( @@ -59,7 +59,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = get_preprocessed_data( self.data_path, self.client_number, batch_size, self.heterogeneity_level @@ -69,7 +69,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) return {"global": global_optimizer, "local": local_optimizer} diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm index b141e7110..ed32a243e 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/fenda_ditto/run_fold_experiment.slrm @@ -64,7 +64,7 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi -# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# This environment variable must be set in order to force torch to use deterministic algorithms. See documentation # in fl4health/utils/random.py for more information export CUBLAS_WORKSPACE_CONFIG=:4096:8 @@ -105,7 +105,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/adaptive_pfl/fenda_ditto/server.py b/research/cifar10/adaptive_pfl/fenda_ditto/server.py index b15a03449..e2e99d62e 100644 --- a/research/cifar10/adaptive_pfl/fenda_ditto/server.py +++ b/research/cifar10/adaptive_pfl/fenda_ditto/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_weight: bool) -> None: +def main(config: dict[str, Any], server_address: str, lam: float, adapt_loss_weight: bool) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/adaptive_pfl/mrmtl/client.py b/research/cifar10/adaptive_pfl/mrmtl/client.py index 2e9135f9f..14a54f20a 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/client.py +++ b/research/cifar10/adaptive_pfl/mrmtl/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -35,10 +35,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -56,7 +56,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = get_preprocessed_data( self.data_path, self.client_number, batch_size, self.heterogeneity_level diff --git a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm index 9a6099b97..6161882f2 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm +++ b/research/cifar10/adaptive_pfl/mrmtl/run_fold_experiment.slrm @@ -62,7 +62,7 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi -# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# This environment variable must be set in order to force torch to use deterministic algorithms. See documentation # in fl4health/utils/random.py for more information export CUBLAS_WORKSPACE_CONFIG=:4096:8 @@ -102,7 +102,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/adaptive_pfl/mrmtl/server.py b/research/cifar10/adaptive_pfl/mrmtl/server.py index a7d7ebb7a..c1b192386 100644 --- a/research/cifar10/adaptive_pfl/mrmtl/server.py +++ b/research/cifar10/adaptive_pfl/mrmtl/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.logger import log @@ -32,7 +32,7 @@ def __init__( self, client_manager: ClientManager, fl_config: Config, - strategy: Optional[Strategy] = None, + strategy: Strategy | None = None, ) -> None: assert isinstance( strategy, FedAvgWithAdaptiveConstraint @@ -58,7 +58,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float, adapt_loss_weight: bool) -> None: +def main(config: dict[str, Any], server_address: str, lam: float, adapt_loss_weight: bool) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/ditto/client.py b/research/cifar10/ditto/client.py index 15328b301..75cfa3413 100644 --- a/research/cifar10/ditto/client.py +++ b/research/cifar10/ditto/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -37,10 +37,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, use_partitioned_data: bool = False, ) -> None: super().__init__( @@ -64,7 +64,7 @@ def setup_client(self, config: Config) -> None: assert 0 <= self.client_number < num_clients super().setup_client(config) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) if self.use_partitioned_data: # The partitioned data should be generated prior to running the clients via preprocess_data function @@ -92,7 +92,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: ) return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) if self.use_partitioned_data: # The partitioned data should be generated prior to running the clients via preprocess_data function @@ -117,7 +117,7 @@ def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized # Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9) diff --git a/research/cifar10/ditto/run_fold_experiment.slrm b/research/cifar10/ditto/run_fold_experiment.slrm index f025f409a..14a960849 100644 --- a/research/cifar10/ditto/run_fold_experiment.slrm +++ b/research/cifar10/ditto/run_fold_experiment.slrm @@ -95,7 +95,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/ditto/server.py b/research/cifar10/ditto/server.py index 92ee85dd1..e39e074d1 100644 --- a/research/cifar10/ditto/server.py +++ b/research/cifar10/ditto/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/ditto_deep_mmd/client.py b/research/cifar10/ditto_deep_mmd/client.py index c1a2d54ad..382f1e58e 100644 --- a/research/cifar10/ditto_deep_mmd/client.py +++ b/research/cifar10/ditto_deep_mmd/client.py @@ -1,9 +1,9 @@ import argparse import os from collections import OrderedDict +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -43,10 +43,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, deep_mmd_loss_weight: float = 10, deep_mmd_loss_depth: int = 1, use_partitioned_data: bool = False, @@ -75,7 +75,7 @@ def setup_client(self, config: Config) -> None: assert 0 <= self.client_number < num_clients super().setup_client(config) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) # The partitioned data should be generated prior to running the clients via preprocess_data function # in the research/cifar10/preprocess.py file @@ -103,7 +103,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: ) return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) # The partitioned data should be generated prior to running the clients via preprocess_data function # in the research/cifar10/preprocess.py file @@ -128,7 +128,7 @@ def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized # Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9) diff --git a/research/cifar10/ditto_deep_mmd/run_fold_experiment.slrm b/research/cifar10/ditto_deep_mmd/run_fold_experiment.slrm index 815c326f7..f2f801988 100644 --- a/research/cifar10/ditto_deep_mmd/run_fold_experiment.slrm +++ b/research/cifar10/ditto_deep_mmd/run_fold_experiment.slrm @@ -101,7 +101,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/ditto_deep_mmd/server.py b/research/cifar10/ditto_deep_mmd/server.py index 92ee85dd1..e39e074d1 100644 --- a/research/cifar10/ditto_deep_mmd/server.py +++ b/research/cifar10/ditto_deep_mmd/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/ditto_mkmmd/client.py b/research/cifar10/ditto_mkmmd/client.py index ed8d18821..f092c4209 100644 --- a/research/cifar10/ditto_mkmmd/client.py +++ b/research/cifar10/ditto_mkmmd/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -43,10 +43,10 @@ def __init__( feature_l2_norm_weight: float = 1, mkmmd_loss_depth: int = 1, beta_global_update_interval: int = 20, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, use_partitioned_data: bool = False, ) -> None: super().__init__( @@ -76,7 +76,7 @@ def setup_client(self, config: Config) -> None: assert 0 <= self.client_number < num_clients super().setup_client(config) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) # The partitioned data should be generated prior to running the clients via preprocess_data function # in the research/cifar10/preprocess.py file @@ -104,7 +104,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: ) return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) # The partitioned data should be generated prior to running the clients via preprocess_data function # in the research/cifar10/preprocess.py file @@ -129,7 +129,7 @@ def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Following the implementation in pFL-Bench : A Comprehensive Benchmark for Personalized # Federated Learning (https://arxiv.org/pdf/2405.17724) for cifar10 dataset we use SGD optimizer global_optimizer = torch.optim.SGD(self.global_model.parameters(), lr=self.learning_rate, momentum=0.9) diff --git a/research/cifar10/ditto_mkmmd/run_fold_experiment.slrm b/research/cifar10/ditto_mkmmd/run_fold_experiment.slrm index 9bc11c7b6..5ef1de7e5 100644 --- a/research/cifar10/ditto_mkmmd/run_fold_experiment.slrm +++ b/research/cifar10/ditto_mkmmd/run_fold_experiment.slrm @@ -107,7 +107,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/ditto_mkmmd/server.py b/research/cifar10/ditto_mkmmd/server.py index 92ee85dd1..e39e074d1 100644 --- a/research/cifar10/ditto_mkmmd/server.py +++ b/research/cifar10/ditto_mkmmd/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -33,7 +33,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/evaluate_on_test.py b/research/cifar10/evaluate_on_test.py index b5a224b40..c80631d98 100644 --- a/research/cifar10/evaluate_on_test.py +++ b/research/cifar10/evaluate_on_test.py @@ -2,7 +2,6 @@ import copy from logging import INFO from pathlib import Path -from typing import Dict import torch from flwr.common.logger import log @@ -46,7 +45,7 @@ def main( ) -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") all_run_folder_dir = get_all_run_folders(artifact_dir) - test_results: Dict[str, float] = {} + test_results: dict[str, float] = {} metrics = [Accuracy("cifar10_accuracy")] all_pre_best_local_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} diff --git a/research/cifar10/fed_dgga_pfl/ditto/client.py b/research/cifar10/fed_dgga_pfl/ditto/client.py index adf14587d..fc8b0c5be 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/ditto/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -35,10 +35,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -56,7 +56,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = get_preprocessed_data( self.data_path, self.client_number, batch_size, self.heterogeneity_level @@ -66,7 +66,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) return {"global": global_optimizer, "local": local_optimizer} diff --git a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm index 48b257c7b..e6b9f3fd8 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/ditto/run_fold_experiment.slrm @@ -61,7 +61,7 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi -# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# This environment variable must be set in order to force torch to use deterministic algorithms. See documentation # in fl4health/utils/random.py for more information export CUBLAS_WORKSPACE_CONFIG=:4096:8 @@ -101,7 +101,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/fed_dgga_pfl/ditto/server.py b/research/cifar10/fed_dgga_pfl/ditto/server.py index 4b21287e3..443b06f72 100644 --- a/research/cifar10/fed_dgga_pfl/ditto/server.py +++ b/research/cifar10/fed_dgga_pfl/ditto/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -38,7 +38,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float, step_size: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float, step_size: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/fed_dgga_pfl/fenda/client.py b/research/cifar10/fed_dgga_pfl/fenda/client.py index 1913bb819..727ea1c55 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -35,10 +35,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -56,7 +56,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = get_preprocessed_data( self.data_path, self.client_number, batch_size, self.heterogeneity_level diff --git a/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm index afd97e40b..9ea16c1a6 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/fenda/run_fold_experiment.slrm @@ -60,7 +60,7 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi -# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# This environment variable must be set in order to force torch to use deterministic algorithms. See documentation # in fl4health/utils/random.py for more information export CUBLAS_WORKSPACE_CONFIG=:4096:8 @@ -99,7 +99,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/fed_dgga_pfl/fenda/server.py b/research/cifar10/fed_dgga_pfl/fenda/server.py index aadce7b7d..774259744 100644 --- a/research/cifar10/fed_dgga_pfl/fenda/server.py +++ b/research/cifar10/fed_dgga_pfl/fenda/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -37,7 +37,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, step_size: float) -> None: +def main(config: dict[str, Any], server_address: str, step_size: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py index a8bad44b8..a289d342a 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -36,10 +36,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, freeze_global_feature_extractor: bool = False, ) -> None: super().__init__( @@ -59,7 +59,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = get_preprocessed_data( self.data_path, self.client_number, batch_size, self.heterogeneity_level @@ -69,7 +69,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: def get_criterion(self, config: Config) -> _Loss: return torch.nn.CrossEntropyLoss() - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) return {"global": global_optimizer, "local": local_optimizer} diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm index e9ce6066b..963bbe26c 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/run_fold_experiment.slrm @@ -64,7 +64,7 @@ if [[ "${SLURM_JOB_PARTITION}" == "t4v2" ]] || \ export NCCL_SOCKET_IFNAME=bond0 fi -# This environment variable must be set in order to force torch to use determinsitic algorithms. See documentation +# This environment variable must be set in order to force torch to use deterministic algorithms. See documentation # in fl4health/utils/random.py for more information export CUBLAS_WORKSPACE_CONFIG=:4096:8 @@ -105,7 +105,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py b/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py index 9d5abe498..654a89acd 100644 --- a/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py +++ b/research/cifar10/fed_dgga_pfl/fenda_ditto/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -38,7 +38,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, lam: float, step_size: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float, step_size: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/fedavg/client.py b/research/cifar10/fedavg/client.py index 3e8c18078..054b97384 100644 --- a/research/cifar10/fedavg/client.py +++ b/research/cifar10/fedavg/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -37,10 +37,10 @@ def __init__( learning_rate: float, heterogeneity_level: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, use_partitioned_data: bool = False, ) -> None: super().__init__( @@ -66,7 +66,7 @@ def setup_client(self, config: Config) -> None: assert 0 <= self.client_number < num_clients super().setup_client(config) - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) # The partitioned data should be generated prior to running the clients via preprocess_data function # in the research/cifar10/preprocess.py file @@ -94,7 +94,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: ) return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) # The partitioned data should be generated prior to running the clients via preprocess_data function # in the research/cifar10/preprocess.py file diff --git a/research/cifar10/fedavg/run_fold_experiment.slrm b/research/cifar10/fedavg/run_fold_experiment.slrm index bb9fcec5a..8d6d93229 100644 --- a/research/cifar10/fedavg/run_fold_experiment.slrm +++ b/research/cifar10/fedavg/run_fold_experiment.slrm @@ -92,7 +92,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/cifar10/fedavg/server.py b/research/cifar10/fedavg/server.py index 9eaee7252..cade1a0e2 100644 --- a/research/cifar10/fedavg/server.py +++ b/research/cifar10/fedavg/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -37,7 +37,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/cifar10/find_best_hp.py b/research/cifar10/find_best_hp.py index b44a63ad1..2737bb4fc 100644 --- a/research/cifar10/find_best_hp.py +++ b/research/cifar10/find_best_hp.py @@ -1,18 +1,17 @@ import argparse import os from logging import INFO -from typing import List, Optional import numpy as np from flwr.common.logger import log -def get_hp_folders(hp_sweep_dir: str) -> List[str]: +def get_hp_folders(hp_sweep_dir: str) -> list[str]: paths_in_hp_sweep_dir = [os.path.join(hp_sweep_dir, contents) for contents in os.listdir(hp_sweep_dir)] return [hp_folder for hp_folder in paths_in_hp_sweep_dir if os.path.isdir(hp_folder)] -def get_run_folders(hp_dir: str) -> List[str]: +def get_run_folders(hp_dir: str) -> list[str]: run_folder_names = [folder_name for folder_name in os.listdir(hp_dir) if "Run" in folder_name] return [os.path.join(hp_dir, run_folder_name) for run_folder_name in run_folder_names] @@ -27,7 +26,7 @@ def get_weighted_loss_from_server_log(run_folder_path: str) -> float: def main(hp_sweep_dir: str) -> None: hp_folders = get_hp_folders(hp_sweep_dir) - best_avg_loss: Optional[float] = None + best_avg_loss: float | None = None best_folder = "" for hp_folder in hp_folders: run_folders = get_run_folders(hp_folder) diff --git a/research/cifar10/personal_server.py b/research/cifar10/personal_server.py index a4d0334f7..ae3469497 100644 --- a/research/cifar10/personal_server.py +++ b/research/cifar10/personal_server.py @@ -1,5 +1,4 @@ from logging import INFO -from typing import Dict, Optional, Tuple from flwr.common.logger import log from flwr.common.typing import Config, Scalar @@ -24,20 +23,20 @@ def __init__( self, client_manager: ClientManager, fl_config: Config, - strategy: Optional[Strategy] = None, + strategy: Strategy | None = None, ) -> None: # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with # some globally shared weights. So we don't checkpoint a global model super().__init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, checkpoint_and_state_module=None ) - self.best_aggregated_loss: Optional[float] = None + self.best_aggregated_loss: float | None = None def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # loss_aggregated is the aggregated validation per step loss # aggregated over each client (weighted by num examples) eval_round_results = super().evaluate_round(server_round, timeout) diff --git a/research/cifar10/preprocess.py b/research/cifar10/preprocess.py index 58a478a1e..2690f6565 100644 --- a/research/cifar10/preprocess.py +++ b/research/cifar10/preprocess.py @@ -2,7 +2,6 @@ import os from logging import INFO from pathlib import Path -from typing import Dict, List, Tuple import numpy as np import torch @@ -18,7 +17,7 @@ def get_preprocessed_data( dataset_dir: Path, client_num: int, batch_size: int, beta: float -) -> Tuple[DataLoader, DataLoader, Dict[str, int]]: +) -> tuple[DataLoader, DataLoader, dict[str, int]]: transform = transforms.Compose( [ ToNumpy(), @@ -58,7 +57,7 @@ def get_preprocessed_data( def get_test_preprocessed_data( dataset_dir: Path, client_num: int, batch_size: int, beta: float -) -> Tuple[DataLoader, Dict[str, int]]: +) -> tuple[DataLoader, dict[str, int]]: transform = transforms.Compose( [ ToNumpy(), @@ -82,7 +81,7 @@ def get_test_preprocessed_data( def preprocess_data( dataset_dir: Path, num_clients: int, beta: float -) -> Tuple[List[TensorDataset], List[TensorDataset], List[TensorDataset]]: +) -> tuple[list[TensorDataset], list[TensorDataset], list[TensorDataset]]: # Get raw data data, targets = get_cifar10_data_and_target_tensors(dataset_dir, True) @@ -119,7 +118,7 @@ def preprocess_data( def save_preprocessed_data( - save_dataset_dir: Path, partitioned_datasets: List[TensorDataset], beta: float, mode: str + save_dataset_dir: Path, partitioned_datasets: list[TensorDataset], beta: float, mode: str ) -> None: save_dir_path = f"{save_dataset_dir}/beta_{beta}" os.makedirs(save_dir_path, exist_ok=True) diff --git a/research/cifar10/utils.py b/research/cifar10/utils.py index b2a823fe8..34a6ca976 100644 --- a/research/cifar10/utils.py +++ b/research/cifar10/utils.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Sequence, Tuple +from collections.abc import Sequence import numpy as np import torch @@ -9,7 +9,7 @@ from fl4health.utils.metrics import Metric, MetricManager -def get_all_run_folders(artifact_dir: str) -> List[str]: +def get_all_run_folders(artifact_dir: str) -> list[str]: run_folder_names = [folder_name for folder_name in os.listdir(artifact_dir) if "Run" in folder_name] return [os.path.join(artifact_dir, run_folder_name) for run_folder_name in run_folder_names] @@ -26,13 +26,13 @@ def load_last_global_model(run_folder_dir: str) -> nn.Module: return model -def get_metric_avg_std(metrics: List[float]) -> Tuple[float, float]: +def get_metric_avg_std(metrics: list[float]) -> tuple[float, float]: mean = float(np.mean(metrics)) std = float(np.std(metrics, ddof=1)) return mean, std -def write_measurement_results(eval_write_path: str, results: Dict[str, float]) -> None: +def write_measurement_results(eval_write_path: str, results: dict[str, float]) -> None: with open(eval_write_path, "w") as f: for key, metric_value in results.items(): f.write(f"{key}: {metric_value}\n") diff --git a/research/flamby/fed_heart_disease/apfl/client.py b/research/flamby/fed_heart_disease/apfl/client.py index f0a0ba002..d278aded5 100644 --- a/research/flamby/fed_heart_disease/apfl/client.py +++ b/research/flamby/fed_heart_disease/apfl/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -33,10 +33,10 @@ def __init__( learning_rate: float, alpha_learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -54,7 +54,7 @@ def __init__( self.alpha_learning_rate = alpha_learning_rate self.client_number = client_number - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( self.client_number, str(self.data_path) ) @@ -66,7 +66,7 @@ def get_model(self, config: Config) -> ApflModule: model: ApflModule = ApflModule(Baseline(), alpha_lr=self.alpha_learning_rate).to(self.device) return model - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=self.learning_rate) return {"local": local_optimizer, "global": global_optimizer} diff --git a/research/flamby/fed_heart_disease/apfl/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/apfl/run_fold_experiment.slrm index a03555a01..b35bec165 100644 --- a/research/flamby/fed_heart_disease/apfl/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/apfl/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/apfl/server.py b/research/flamby/fed_heart_disease/apfl/server.py index 367f382be..9717c5900 100644 --- a/research/flamby/fed_heart_disease/apfl/server.py +++ b/research/flamby/fed_heart_disease/apfl/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_heart_disease import Baseline @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_heart_disease/central/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/central/run_fold_experiment.slrm index 7fc17cfe7..63d9bc904 100644 --- a/research/flamby/fed_heart_disease/central/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/central/run_fold_experiment.slrm @@ -76,7 +76,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/ditto/client.py b/research/flamby/fed_heart_disease/ditto/client.py index 0663d23ef..80d72f1e5 100644 --- a/research/flamby/fed_heart_disease/ditto/client.py +++ b/research/flamby/fed_heart_disease/ditto/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -33,10 +33,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -54,7 +54,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( self.client_number, str(self.data_path) ) @@ -66,7 +66,7 @@ def get_model(self, config: Config) -> nn.Module: model: nn.Module = Baseline().to(self.device) return model - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Note that the global optimizer operates on self.global_model.parameters() and local optimizer operates on # self.model.parameters(). global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) diff --git a/research/flamby/fed_heart_disease/ditto/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/ditto/run_fold_experiment.slrm index f796bf611..2c015753b 100644 --- a/research/flamby/fed_heart_disease/ditto/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/ditto/run_fold_experiment.slrm @@ -92,7 +92,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/ditto/server.py b/research/flamby/fed_heart_disease/ditto/server.py index 6acc6686d..0716bbd7a 100644 --- a/research/flamby/fed_heart_disease/ditto/server.py +++ b/research/flamby/fed_heart_disease/ditto/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_heart_disease import Baseline @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_heart_disease/evaluate_on_holdout.py b/research/flamby/fed_heart_disease/evaluate_on_holdout.py index 32f944e1b..bd9119b88 100644 --- a/research/flamby/fed_heart_disease/evaluate_on_holdout.py +++ b/research/flamby/fed_heart_disease/evaluate_on_holdout.py @@ -1,6 +1,5 @@ import argparse from logging import INFO -from typing import Dict import torch from flamby.datasets.fed_heart_disease import BATCH_SIZE, NUM_CLIENTS, FedHeartDisease @@ -28,7 +27,7 @@ def main( ) -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") all_run_folder_dir = get_all_run_folders(artifact_dir) - test_results: Dict[str, float] = {} + test_results: dict[str, float] = {} metrics = [Accuracy("FedHeartDisease_accuracy")] all_local_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} diff --git a/research/flamby/fed_heart_disease/fedadam/client.py b/research/flamby/fed_heart_disease/fedadam/client.py index ea3563c3c..ab09ccfba 100644 --- a/research/flamby/fed_heart_disease/fedadam/client.py +++ b/research/flamby/fed_heart_disease/fedadam/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,10 +32,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -53,7 +53,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_heart_disease/fedadam/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/fedadam/run_fold_experiment.slrm index b801b7a4a..ce98be452 100644 --- a/research/flamby/fed_heart_disease/fedadam/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/fedadam/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/fedadam/server.py b/research/flamby/fed_heart_disease/fedadam/server.py index 017f61b58..cf97fe5ef 100644 --- a/research/flamby/fed_heart_disease/fedadam/server.py +++ b/research/flamby/fed_heart_disease/fedadam/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_heart_disease import Baseline @@ -21,7 +21,7 @@ def main( - config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float + config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float ) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( diff --git a/research/flamby/fed_heart_disease/fedavg/client.py b/research/flamby/fed_heart_disease/fedavg/client.py index dd4f0a261..aab7490b5 100644 --- a/research/flamby/fed_heart_disease/fedavg/client.py +++ b/research/flamby/fed_heart_disease/fedavg/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,10 +32,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -53,7 +53,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_heart_disease/fedavg/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/fedavg/run_fold_experiment.slrm index e2c45196d..ec34140b2 100644 --- a/research/flamby/fed_heart_disease/fedavg/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/fedavg/run_fold_experiment.slrm @@ -85,7 +85,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/fedavg/server.py b/research/flamby/fed_heart_disease/fedavg/server.py index 1859f0711..6eb494980 100644 --- a/research/flamby/fed_heart_disease/fedavg/server.py +++ b/research/flamby/fed_heart_disease/fedavg/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_heart_disease import Baseline @@ -20,7 +20,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_heart_disease/fedper/client.py b/research/flamby/fed_heart_disease/fedper/client.py index bbd9ea45e..c6fe5e113 100644 --- a/research/flamby/fed_heart_disease/fedper/client.py +++ b/research/flamby/fed_heart_disease/fedper/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -36,10 +36,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -58,7 +58,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_heart_disease/fedper/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/fedper/run_fold_experiment.slrm index d7c18fcb8..ffd8acaba 100644 --- a/research/flamby/fed_heart_disease/fedper/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/fedper/run_fold_experiment.slrm @@ -89,7 +89,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/fedper/server.py b/research/flamby/fed_heart_disease/fedper/server.py index 6ac981335..bf7701b7c 100644 --- a/research/flamby/fed_heart_disease/fedper/server.py +++ b/research/flamby/fed_heart_disease/fedper/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_heart_disease/fedprox/client.py b/research/flamby/fed_heart_disease/fedprox/client.py index ea8915003..e384153c5 100644 --- a/research/flamby/fed_heart_disease/fedprox/client.py +++ b/research/flamby/fed_heart_disease/fedprox/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,10 +32,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -53,7 +53,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_heart_disease/fedprox/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/fedprox/run_fold_experiment.slrm index 8b7172cbc..0ab0b6046 100644 --- a/research/flamby/fed_heart_disease/fedprox/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/fedprox/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/fedprox/server.py b/research/flamby/fed_heart_disease/fedprox/server.py index e1526c21c..2e5d55d13 100644 --- a/research/flamby/fed_heart_disease/fedprox/server.py +++ b/research/flamby/fed_heart_disease/fedprox/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_heart_disease import Baseline @@ -19,7 +19,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, mu: float, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_heart_disease/fenda/client.py b/research/flamby/fed_heart_disease/fenda/client.py index 7834a3bf2..3cadd47da 100644 --- a/research/flamby/fed_heart_disease/fenda/client.py +++ b/research/flamby/fed_heart_disease/fenda/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -34,10 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -55,7 +55,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_heart_disease/fenda/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/fenda/run_fold_experiment.slrm index 6b8cb9e45..b963d3170 100644 --- a/research/flamby/fed_heart_disease/fenda/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/fenda/run_fold_experiment.slrm @@ -89,7 +89,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/fenda/server.py b/research/flamby/fed_heart_disease/fenda/server.py index 6a3fb9611..d1631d2f1 100644 --- a/research/flamby/fed_heart_disease/fenda/server.py +++ b/research/flamby/fed_heart_disease/fenda/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_heart_disease/local/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/local/run_fold_experiment.slrm index a02539fb4..eb4e3f91d 100644 --- a/research/flamby/fed_heart_disease/local/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/local/run_fold_experiment.slrm @@ -79,7 +79,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/moon/client.py b/research/flamby/fed_heart_disease/moon/client.py index 16143e358..2ed3c719f 100644 --- a/research/flamby/fed_heart_disease/moon/client.py +++ b/research/flamby/fed_heart_disease/moon/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -34,10 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, contrastive_weight: float = 10, ) -> None: super().__init__( @@ -57,7 +57,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_heart_disease/moon/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/moon/run_fold_experiment.slrm index 02fc66ef4..d687cee8b 100644 --- a/research/flamby/fed_heart_disease/moon/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/moon/run_fold_experiment.slrm @@ -92,7 +92,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/moon/server.py b/research/flamby/fed_heart_disease/moon/server.py index 85bd9d2bb..4910aa852 100644 --- a/research/flamby/fed_heart_disease/moon/server.py +++ b/research/flamby/fed_heart_disease/moon/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -21,7 +21,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_heart_disease/perfcl/client.py b/research/flamby/fed_heart_disease/perfcl/client.py index fa4d6fd4c..9f0841ddd 100644 --- a/research/flamby/fed_heart_disease/perfcl/client.py +++ b/research/flamby/fed_heart_disease/perfcl/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -34,10 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, mu: float = 10.0, gamma: float = 10.0, ) -> None: @@ -59,7 +59,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_heart_disease/perfcl/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/perfcl/run_fold_experiment.slrm index 646b51482..c6e58012a 100644 --- a/research/flamby/fed_heart_disease/perfcl/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/perfcl/run_fold_experiment.slrm @@ -95,7 +95,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/perfcl/server.py b/research/flamby/fed_heart_disease/perfcl/server.py index c99441c81..43726c777 100644 --- a/research/flamby/fed_heart_disease/perfcl/server.py +++ b/research/flamby/fed_heart_disease/perfcl/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_heart_disease/scaffold/client.py b/research/flamby/fed_heart_disease/scaffold/client.py index 42bb46ffe..b65f41012 100644 --- a/research/flamby/fed_heart_disease/scaffold/client.py +++ b/research/flamby/fed_heart_disease/scaffold/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,10 +32,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -53,7 +53,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_heard_disease_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_heart_disease/scaffold/run_fold_experiment.slrm b/research/flamby/fed_heart_disease/scaffold/run_fold_experiment.slrm index daa99159e..e91ffb0cb 100644 --- a/research/flamby/fed_heart_disease/scaffold/run_fold_experiment.slrm +++ b/research/flamby/fed_heart_disease/scaffold/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_heart_disease/scaffold/server.py b/research/flamby/fed_heart_disease/scaffold/server.py index 6d9a0b546..1ca8db7d6 100644 --- a/research/flamby/fed_heart_disease/scaffold/server.py +++ b/research/flamby/fed_heart_disease/scaffold/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_heart_disease import Baseline @@ -19,7 +19,7 @@ def main( - config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float + config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float ) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( diff --git a/research/flamby/fed_isic2019/apfl/apfl_model.py b/research/flamby/fed_isic2019/apfl/apfl_model.py index cfc9ff3fe..2aefc0505 100644 --- a/research/flamby/fed_isic2019/apfl/apfl_model.py +++ b/research/flamby/fed_isic2019/apfl/apfl_model.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn from flamby.datasets.fed_isic2019 import Baseline @@ -21,7 +19,7 @@ class ApflEfficientNet(nn.Module): other approaches. """ - def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool = False): + def __init__(self, frozen_blocks: int | None = 13, turn_off_bn_tracking: bool = False): super().__init__() self.base_model = Baseline() # Freeze layers to reduce trainable parameters. diff --git a/research/flamby/fed_isic2019/apfl/client.py b/research/flamby/fed_isic2019/apfl/client.py index 224f7cb6e..cd100e7e3 100644 --- a/research/flamby/fed_isic2019/apfl/client.py +++ b/research/flamby/fed_isic2019/apfl/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -35,10 +35,10 @@ def __init__( learning_rate: float, alpha_learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -56,7 +56,7 @@ def __init__( self.alpha_learning_rate = alpha_learning_rate self.client_number = client_number - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) @@ -75,7 +75,7 @@ def get_model(self, config: Config) -> nn.Module: ).to(self.device) return model - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: local_optimizer: Optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) global_optimizer: Optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=self.learning_rate) return {"local": local_optimizer, "global": global_optimizer} diff --git a/research/flamby/fed_isic2019/apfl/run_fold_experiment.slrm b/research/flamby/fed_isic2019/apfl/run_fold_experiment.slrm index 50d242290..7a0dfa361 100644 --- a/research/flamby/fed_isic2019/apfl/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/apfl/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/apfl/server.py b/research/flamby/fed_isic2019/apfl/server.py index 00a8fef49..0160ebb94 100644 --- a/research/flamby/fed_isic2019/apfl/server.py +++ b/research/flamby/fed_isic2019/apfl/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/central/run_fold_experiment.slrm b/research/flamby/fed_isic2019/central/run_fold_experiment.slrm index 1e99cb622..6f7b1caed 100644 --- a/research/flamby/fed_isic2019/central/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/central/run_fold_experiment.slrm @@ -76,7 +76,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/ditto/client.py b/research/flamby/fed_isic2019/ditto/client.py index a22a57de8..f5a449aad 100644 --- a/research/flamby/fed_isic2019/ditto/client.py +++ b/research/flamby/fed_isic2019/ditto/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -33,10 +33,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -54,7 +54,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) @@ -66,7 +66,7 @@ def get_model(self, config: Config) -> nn.Module: model: nn.Module = Baseline().to(self.device) return model - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Note that the global optimizer operates on self.global_model.parameters() and local optimizer operates on # self.model.parameters(). global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) diff --git a/research/flamby/fed_isic2019/ditto/run_fold_experiment.slrm b/research/flamby/fed_isic2019/ditto/run_fold_experiment.slrm index 278e5e565..8b35c7432 100644 --- a/research/flamby/fed_isic2019/ditto/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/ditto/run_fold_experiment.slrm @@ -92,7 +92,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/ditto/server.py b/research/flamby/fed_isic2019/ditto/server.py index 1172b0255..15002b54a 100644 --- a/research/flamby/fed_isic2019/ditto/server.py +++ b/research/flamby/fed_isic2019/ditto/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_isic2019 import Baseline @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/ditto_deep_mmd/client.py b/research/flamby/fed_isic2019/ditto_deep_mmd/client.py index 8ceefd44e..21ee0e424 100644 --- a/research/flamby/fed_isic2019/ditto_deep_mmd/client.py +++ b/research/flamby/fed_isic2019/ditto_deep_mmd/client.py @@ -1,9 +1,9 @@ import argparse import os from collections import OrderedDict +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -41,10 +41,10 @@ def __init__( loss_meter_type: LossMeterType = LossMeterType.AVERAGE, deep_mmd_loss_weight: float = 10, deep_mmd_loss_depth: int = 1, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: feature_extraction_layers_with_size = OrderedDict( list(FED_ISIC2019_BASELINE_LAYERS.items())[-1 * deep_mmd_loss_depth :] @@ -67,7 +67,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) @@ -79,7 +79,7 @@ def get_model(self, config: Config) -> nn.Module: model: nn.Module = Baseline().to(self.device) return model - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Note that the global optimizer operates on self.global_model.parameters() and local optimizer operates on # self.model.parameters(). global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) diff --git a/research/flamby/fed_isic2019/ditto_deep_mmd/run_fold_experiment.slrm b/research/flamby/fed_isic2019/ditto_deep_mmd/run_fold_experiment.slrm index 1a622040f..5cf75b881 100644 --- a/research/flamby/fed_isic2019/ditto_deep_mmd/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/ditto_deep_mmd/run_fold_experiment.slrm @@ -98,7 +98,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/ditto_deep_mmd/server.py b/research/flamby/fed_isic2019/ditto_deep_mmd/server.py index 1172b0255..15002b54a 100644 --- a/research/flamby/fed_isic2019/ditto_deep_mmd/server.py +++ b/research/flamby/fed_isic2019/ditto_deep_mmd/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_isic2019 import Baseline @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/ditto_mkmmd/client.py b/research/flamby/fed_isic2019/ditto_mkmmd/client.py index 6add1e5da..f9273cf1a 100644 --- a/research/flamby/fed_isic2019/ditto_mkmmd/client.py +++ b/research/flamby/fed_isic2019/ditto_mkmmd/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -38,10 +38,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, mkmmd_loss_weight: float = 10, feature_l2_norm_weight: float = 1, mkmmd_loss_depth: int = 1, @@ -67,7 +67,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) @@ -79,7 +79,7 @@ def get_model(self, config: Config) -> nn.Module: model: nn.Module = Baseline().to(self.device) return model - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Note that the global optimizer operates on self.global_model.parameters() and local optimizer operates on # self.model.parameters(). global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) diff --git a/research/flamby/fed_isic2019/ditto_mkmmd/run_fold_experiment.slrm b/research/flamby/fed_isic2019/ditto_mkmmd/run_fold_experiment.slrm index 648567687..7872f627f 100644 --- a/research/flamby/fed_isic2019/ditto_mkmmd/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/ditto_mkmmd/run_fold_experiment.slrm @@ -101,7 +101,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/ditto_mkmmd/server.py b/research/flamby/fed_isic2019/ditto_mkmmd/server.py index 1172b0255..15002b54a 100644 --- a/research/flamby/fed_isic2019/ditto_mkmmd/server.py +++ b/research/flamby/fed_isic2019/ditto_mkmmd/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_isic2019 import Baseline @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/evaluate_on_holdout.py b/research/flamby/fed_isic2019/evaluate_on_holdout.py index 505cc861d..42290969f 100644 --- a/research/flamby/fed_isic2019/evaluate_on_holdout.py +++ b/research/flamby/fed_isic2019/evaluate_on_holdout.py @@ -1,6 +1,5 @@ import argparse from logging import INFO -from typing import Dict import torch from flamby.datasets.fed_isic2019 import BATCH_SIZE, NUM_CLIENTS, FedIsic2019 @@ -28,7 +27,7 @@ def main( ) -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") all_run_folder_dir = get_all_run_folders(artifact_dir) - test_results: Dict[str, float] = {} + test_results: dict[str, float] = {} metrics = [BalancedAccuracy("FedIsic2019_balanced_accuracy")] all_local_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} diff --git a/research/flamby/fed_isic2019/fedadam/client.py b/research/flamby/fed_isic2019/fedadam/client.py index 7ff8f181e..ecec3018a 100644 --- a/research/flamby/fed_isic2019/fedadam/client.py +++ b/research/flamby/fed_isic2019/fedadam/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -33,10 +33,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -54,7 +54,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_isic2019/fedadam/fedadam_model.py b/research/flamby/fed_isic2019/fedadam/fedadam_model.py index 60a290f64..5be6690ed 100644 --- a/research/flamby/fed_isic2019/fedadam/fedadam_model.py +++ b/research/flamby/fed_isic2019/fedadam/fedadam_model.py @@ -7,7 +7,7 @@ class FedAdamEfficientNet(nn.Module): """FedAdam implements server-side momentum in aggregating the updates from each client. For layers that carry state - that must remain non-negative, like BatchNormalization layers (present in EffcientNet), they may become negative + that must remain non-negative, like BatchNormalization layers (present in EfficientNet), they may become negative due to momentum carrying updates past the origin. For Batch Normalization this means that the variance state estimated during training and applied during evaluation may become negative. This blows up the model. In order to get around this issue, we modify all batch normalization layers in EfficientNet to not carry such state by diff --git a/research/flamby/fed_isic2019/fedadam/run_fold_experiment.slrm b/research/flamby/fed_isic2019/fedadam/run_fold_experiment.slrm index 7a5f63bae..432e604ab 100644 --- a/research/flamby/fed_isic2019/fedadam/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/fedadam/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/fedadam/server.py b/research/flamby/fed_isic2019/fedadam/server.py index 6ca3d0e30..c71f09b49 100644 --- a/research/flamby/fed_isic2019/fedadam/server.py +++ b/research/flamby/fed_isic2019/fedadam/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -21,7 +21,7 @@ def main( - config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float + config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float ) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( diff --git a/research/flamby/fed_isic2019/fedavg/client.py b/research/flamby/fed_isic2019/fedavg/client.py index 514fe20fb..a6dbf5bc5 100644 --- a/research/flamby/fed_isic2019/fedavg/client.py +++ b/research/flamby/fed_isic2019/fedavg/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,10 +32,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -53,7 +53,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_isic2019/fedavg/run_fold_experiment.slrm b/research/flamby/fed_isic2019/fedavg/run_fold_experiment.slrm index dbb0aea7a..9a030a1cd 100644 --- a/research/flamby/fed_isic2019/fedavg/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/fedavg/run_fold_experiment.slrm @@ -85,7 +85,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/fedavg/server.py b/research/flamby/fed_isic2019/fedavg/server.py index 67ce57138..346885141 100644 --- a/research/flamby/fed_isic2019/fedavg/server.py +++ b/research/flamby/fed_isic2019/fedavg/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_isic2019 import Baseline @@ -20,7 +20,7 @@ from research.flamby.utils import fit_config -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/fedper/client.py b/research/flamby/fed_isic2019/fedper/client.py index af6198451..316fe7e0b 100644 --- a/research/flamby/fed_isic2019/fedper/client.py +++ b/research/flamby/fed_isic2019/fedper/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -36,10 +36,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -58,7 +58,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_isic2019/fedper/fedper_model.py b/research/flamby/fed_isic2019/fedper/fedper_model.py index 2d4a1928f..6be7fc0f3 100644 --- a/research/flamby/fed_isic2019/fedper/fedper_model.py +++ b/research/flamby/fed_isic2019/fedper/fedper_model.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn from efficientnet_pytorch import EfficientNet @@ -49,7 +47,7 @@ class BaseEfficientNet(nn.Module): it is not used in the FedPer experiments. """ - def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool = False): + def __init__(self, frozen_blocks: int | None = 13, turn_off_bn_tracking: bool = False): super().__init__() # include_top ensures that we just use feature extraction in the forward pass self.base_model = from_pretrained("efficientnet-b0", include_top=False) @@ -60,7 +58,7 @@ def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool def freeze_layers(self, frozen_blocks: int) -> None: # We freeze the bottom layers of the network. We always freeze the _conv_stem module, the _bn0 module and then - # we iterate throught the blocks freezing the specified number up to 15 (all of them) + # we iterate through the blocks freezing the specified number up to 15 (all of them) # Freeze the first two layers self.base_model._modules["_conv_stem"].requires_grad_(False) @@ -77,7 +75,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FedIsic2019FedPerModel(SequentiallySplitExchangeBaseModel): - def __init__(self, frozen_blocks: Optional[int] = None, turn_off_bn_tracking: bool = False) -> None: + def __init__(self, frozen_blocks: int | None = None, turn_off_bn_tracking: bool = False) -> None: base_module = BaseEfficientNet(frozen_blocks, turn_off_bn_tracking=turn_off_bn_tracking) head_module = HeadClassifier(1280) super().__init__(base_module, head_module) diff --git a/research/flamby/fed_isic2019/fedper/run_fold_experiment.slrm b/research/flamby/fed_isic2019/fedper/run_fold_experiment.slrm index e68609537..ca72b108e 100644 --- a/research/flamby/fed_isic2019/fedper/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/fedper/run_fold_experiment.slrm @@ -89,7 +89,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/fedper/server.py b/research/flamby/fed_isic2019/fedper/server.py index 33085dd37..bccb46560 100644 --- a/research/flamby/fed_isic2019/fedper/server.py +++ b/research/flamby/fed_isic2019/fedper/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/fedprox/client.py b/research/flamby/fed_isic2019/fedprox/client.py index 4e50bea61..3f932d4c8 100644 --- a/research/flamby/fed_isic2019/fedprox/client.py +++ b/research/flamby/fed_isic2019/fedprox/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,10 +32,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -53,7 +53,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_isic2019/fedprox/run_fold_experiment.slrm b/research/flamby/fed_isic2019/fedprox/run_fold_experiment.slrm index 44b3baad5..af04fdfc0 100644 --- a/research/flamby/fed_isic2019/fedprox/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/fedprox/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/fedprox/server.py b/research/flamby/fed_isic2019/fedprox/server.py index ebffacfa9..93b5bc0b2 100644 --- a/research/flamby/fed_isic2019/fedprox/server.py +++ b/research/flamby/fed_isic2019/fedprox/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_isic2019 import Baseline @@ -19,7 +19,7 @@ from research.flamby.utils import fit_config -def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, mu: float, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/fenda/client.py b/research/flamby/fed_isic2019/fenda/client.py index c237b06b6..4380a60e5 100644 --- a/research/flamby/fed_isic2019/fenda/client.py +++ b/research/flamby/fed_isic2019/fenda/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -34,10 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -55,7 +55,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_isic2019/fenda/fenda_model.py b/research/flamby/fed_isic2019/fenda/fenda_model.py index b45f1720f..9f81de3d2 100644 --- a/research/flamby/fed_isic2019/fenda/fenda_model.py +++ b/research/flamby/fed_isic2019/fenda/fenda_model.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn import torch.nn.functional as F @@ -58,7 +56,7 @@ class LocalEfficientNet(nn.Module): other approaches. """ - def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool = False): + def __init__(self, frozen_blocks: int | None = 13, turn_off_bn_tracking: bool = False): super().__init__() # include_top ensures that we just use feature extraction in the forward pass self.base_model = from_pretrained("efficientnet-b0", include_top=False) @@ -69,7 +67,7 @@ def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool def freeze_layers(self, frozen_blocks: int) -> None: # We freeze the bottom layers of the network. We always freeze the _conv_stem module, the _bn0 module and then - # we iterate throught the blocks freezing the specified number up to 15 (all of them) + # we iterate through the blocks freezing the specified number up to 15 (all of them) # Freeze the first two layers self.base_model._modules["_conv_stem"].requires_grad_(False) @@ -97,7 +95,7 @@ class GlobalEfficientNet(nn.Module): other approaches. """ - def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool = False): + def __init__(self, frozen_blocks: int | None = 13, turn_off_bn_tracking: bool = False): super().__init__() # include_top ensures that we just use feature extraction in the forward pass self.base_model = from_pretrained("efficientnet-b0", include_top=False) @@ -108,7 +106,7 @@ def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool def freeze_layers(self, frozen_blocks: int) -> None: # We freeze the bottom layers of the network. We always freeze the _conv_stem module, the _bn0 module and then - # we iterate throught the blocks freezing the specified number up to 15 (all of them) + # we iterate through the blocks freezing the specified number up to 15 (all of them) # Freeze the first two layers self.base_model._modules["_conv_stem"].requires_grad_(False) @@ -124,7 +122,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FedIsic2019FendaModel(FendaModel): - def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool = False) -> None: + def __init__(self, frozen_blocks: int | None = 13, turn_off_bn_tracking: bool = False) -> None: local_module = LocalEfficientNet(frozen_blocks, turn_off_bn_tracking=turn_off_bn_tracking) global_module = GlobalEfficientNet(frozen_blocks, turn_off_bn_tracking=turn_off_bn_tracking) model_head = FendaClassifier(ParallelFeatureJoinMode.CONCATENATE, 1280) diff --git a/research/flamby/fed_isic2019/fenda/run_fold_experiment.slrm b/research/flamby/fed_isic2019/fenda/run_fold_experiment.slrm index e8a25f042..54798afb9 100644 --- a/research/flamby/fed_isic2019/fenda/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/fenda/run_fold_experiment.slrm @@ -89,7 +89,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/fenda/server.py b/research/flamby/fed_isic2019/fenda/server.py index 3369eadb2..cfb7dcd79 100644 --- a/research/flamby/fed_isic2019/fenda/server.py +++ b/research/flamby/fed_isic2019/fenda/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/local/run_fold_experiment.slrm b/research/flamby/fed_isic2019/local/run_fold_experiment.slrm index 4fee1423c..68133154b 100644 --- a/research/flamby/fed_isic2019/local/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/local/run_fold_experiment.slrm @@ -79,7 +79,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/moon/client.py b/research/flamby/fed_isic2019/moon/client.py index e12a759fa..656eeb9a3 100644 --- a/research/flamby/fed_isic2019/moon/client.py +++ b/research/flamby/fed_isic2019/moon/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -34,10 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, contrastive_weight: float = 10, ) -> None: super().__init__( @@ -57,7 +57,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_isic2019/moon/moon_model.py b/research/flamby/fed_isic2019/moon/moon_model.py index 537f64e98..85f9b3b5b 100644 --- a/research/flamby/fed_isic2019/moon/moon_model.py +++ b/research/flamby/fed_isic2019/moon/moon_model.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn from efficientnet_pytorch import EfficientNet @@ -48,7 +46,7 @@ class BaseEfficientNet(nn.Module): it is not used in the MOON experiments. """ - def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool = False): + def __init__(self, frozen_blocks: int | None = 13, turn_off_bn_tracking: bool = False): super().__init__() # include_top ensures that we just use feature extraction in the forward pass self.base_model = from_pretrained("efficientnet-b0", include_top=False) @@ -59,7 +57,7 @@ def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool def freeze_layers(self, frozen_blocks: int) -> None: # We freeze the bottom layers of the network. We always freeze the _conv_stem module, the _bn0 module and then - # we iterate throught the blocks freezing the specified number up to 15 (all of them) + # we iterate through the blocks freezing the specified number up to 15 (all of them) # Freeze the first two layers self.base_model._modules["_conv_stem"].requires_grad_(False) @@ -76,7 +74,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FedIsic2019MoonModel(MoonModel): - def __init__(self, frozen_blocks: Optional[int] = None, turn_off_bn_tracking: bool = False) -> None: + def __init__(self, frozen_blocks: int | None = None, turn_off_bn_tracking: bool = False) -> None: base_module = BaseEfficientNet(frozen_blocks, turn_off_bn_tracking=turn_off_bn_tracking) head_module = HeadClassifier(1280) super().__init__(base_module, head_module) diff --git a/research/flamby/fed_isic2019/moon/run_fold_experiment.slrm b/research/flamby/fed_isic2019/moon/run_fold_experiment.slrm index 5514b3083..a89b22cbf 100644 --- a/research/flamby/fed_isic2019/moon/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/moon/run_fold_experiment.slrm @@ -92,7 +92,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/moon/server.py b/research/flamby/fed_isic2019/moon/server.py index b501ef17d..5797ef6c4 100644 --- a/research/flamby/fed_isic2019/moon/server.py +++ b/research/flamby/fed_isic2019/moon/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -21,7 +21,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py b/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py index 1163cf9bb..2203e36a4 100644 --- a/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py +++ b/research/flamby/fed_isic2019/mr_mtl_mkmmd/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -42,10 +42,10 @@ def __init__( feature_l2_norm_weight: float = 1, mkmmd_loss_depth: int = 1, beta_global_update_interval: int = 20, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -67,7 +67,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_isic2019/mr_mtl_mkmmd/run_fold_experiment.slrm b/research/flamby/fed_isic2019/mr_mtl_mkmmd/run_fold_experiment.slrm index e072e977e..f7a69a835 100644 --- a/research/flamby/fed_isic2019/mr_mtl_mkmmd/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/mr_mtl_mkmmd/run_fold_experiment.slrm @@ -101,7 +101,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/mr_mtl_mkmmd/server.py b/research/flamby/fed_isic2019/mr_mtl_mkmmd/server.py index df01930fa..eb6a370a0 100644 --- a/research/flamby/fed_isic2019/mr_mtl_mkmmd/server.py +++ b/research/flamby/fed_isic2019/mr_mtl_mkmmd/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_isic2019 import Baseline @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/perfcl/client.py b/research/flamby/fed_isic2019/perfcl/client.py index 92fc99a82..086c39119 100644 --- a/research/flamby/fed_isic2019/perfcl/client.py +++ b/research/flamby/fed_isic2019/perfcl/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -34,10 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, mu: float = 10.0, gamma: float = 10.0, ) -> None: @@ -59,7 +59,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_isic2019/perfcl/perfcl_model.py b/research/flamby/fed_isic2019/perfcl/perfcl_model.py index f7af692cf..5f9a0f589 100644 --- a/research/flamby/fed_isic2019/perfcl/perfcl_model.py +++ b/research/flamby/fed_isic2019/perfcl/perfcl_model.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn import torch.nn.functional as F @@ -58,7 +56,7 @@ class LocalEfficientNet(nn.Module): other approaches. """ - def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool = False): + def __init__(self, frozen_blocks: int | None = 13, turn_off_bn_tracking: bool = False): super().__init__() # include_top ensures that we just use feature extraction in the forward pass self.base_model = from_pretrained("efficientnet-b0", include_top=False) @@ -69,7 +67,7 @@ def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool def freeze_layers(self, frozen_blocks: int) -> None: # We freeze the bottom layers of the network. We always freeze the _conv_stem module, the _bn0 module and then - # we iterate throught the blocks freezing the specified number up to 15 (all of them) + # we iterate through the blocks freezing the specified number up to 15 (all of them) # Freeze the first two layers self.base_model._modules["_conv_stem"].requires_grad_(False) @@ -97,7 +95,7 @@ class GlobalEfficientNet(nn.Module): other approaches. """ - def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool = False): + def __init__(self, frozen_blocks: int | None = 13, turn_off_bn_tracking: bool = False): super().__init__() # include_top ensures that we just use feature extraction in the forward pass self.base_model = from_pretrained("efficientnet-b0", include_top=False) @@ -108,7 +106,7 @@ def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool def freeze_layers(self, frozen_blocks: int) -> None: # We freeze the bottom layers of the network. We always freeze the _conv_stem module, the _bn0 module and then - # we iterate throught the blocks freezing the specified number up to 15 (all of them) + # we iterate through the blocks freezing the specified number up to 15 (all of them) # Freeze the first two layers self.base_model._modules["_conv_stem"].requires_grad_(False) @@ -124,7 +122,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FedIsic2019PerFclModel(PerFclModel): - def __init__(self, frozen_blocks: Optional[int] = 13, turn_off_bn_tracking: bool = False) -> None: + def __init__(self, frozen_blocks: int | None = 13, turn_off_bn_tracking: bool = False) -> None: local_module = LocalEfficientNet(frozen_blocks, turn_off_bn_tracking=turn_off_bn_tracking) global_module = GlobalEfficientNet(frozen_blocks, turn_off_bn_tracking=turn_off_bn_tracking) model_head = PerFclClassifier(ParallelFeatureJoinMode.CONCATENATE, 1280) diff --git a/research/flamby/fed_isic2019/perfcl/run_fold_experiment.slrm b/research/flamby/fed_isic2019/perfcl/run_fold_experiment.slrm index e27fb8e17..9152f16e7 100644 --- a/research/flamby/fed_isic2019/perfcl/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/perfcl/run_fold_experiment.slrm @@ -95,7 +95,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/perfcl/server.py b/research/flamby/fed_isic2019/perfcl/server.py index a9510688e..5b78a6be9 100644 --- a/research/flamby/fed_isic2019/perfcl/server.py +++ b/research/flamby/fed_isic2019/perfcl/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_isic2019/scaffold/client.py b/research/flamby/fed_isic2019/scaffold/client.py index 8b0d4e51d..3aa9030d3 100644 --- a/research/flamby/fed_isic2019/scaffold/client.py +++ b/research/flamby/fed_isic2019/scaffold/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,10 +32,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -53,7 +53,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fedisic_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_isic2019/scaffold/run_fold_experiment.slrm b/research/flamby/fed_isic2019/scaffold/run_fold_experiment.slrm index e698458cd..e3569773d 100644 --- a/research/flamby/fed_isic2019/scaffold/run_fold_experiment.slrm +++ b/research/flamby/fed_isic2019/scaffold/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_isic2019/scaffold/server.py b/research/flamby/fed_isic2019/scaffold/server.py index 705ccff9b..bd397de62 100644 --- a/research/flamby/fed_isic2019/scaffold/server.py +++ b/research/flamby/fed_isic2019/scaffold/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_isic2019 import Baseline @@ -19,7 +19,7 @@ def main( - config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float + config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float ) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( diff --git a/research/flamby/fed_ixi/apfl/client.py b/research/flamby/fed_ixi/apfl/client.py index 157b3714f..e33df89db 100644 --- a/research/flamby/fed_ixi/apfl/client.py +++ b/research/flamby/fed_ixi/apfl/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -35,10 +35,10 @@ def __init__( learning_rate: float, alpha_learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -56,7 +56,7 @@ def __init__( self.alpha_learning_rate = alpha_learning_rate self.client_number = client_number - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( self.client_number, str(self.data_path) ) @@ -68,7 +68,7 @@ def get_model(self, config: Config) -> nn.Module: model: ApflModule = ApflModule(ApflUNet(), alpha_lr=self.alpha_learning_rate).to(self.device) return model - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=self.learning_rate) global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=self.learning_rate) return {"local": local_optimizer, "global": global_optimizer} diff --git a/research/flamby/fed_ixi/apfl/run_fold_experiment.slrm b/research/flamby/fed_ixi/apfl/run_fold_experiment.slrm index 588e441ba..065f1b9a0 100644 --- a/research/flamby/fed_ixi/apfl/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/apfl/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/apfl/server.py b/research/flamby/fed_ixi/apfl/server.py index 0969db0bb..098f8b65a 100644 --- a/research/flamby/fed_ixi/apfl/server.py +++ b/research/flamby/fed_ixi/apfl/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_ixi/central/run_fold_experiment.slrm b/research/flamby/fed_ixi/central/run_fold_experiment.slrm index 5626afde0..c7d1a4709 100644 --- a/research/flamby/fed_ixi/central/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/central/run_fold_experiment.slrm @@ -76,7 +76,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/ditto/client.py b/research/flamby/fed_ixi/ditto/client.py index 1c4686367..0f244b703 100644 --- a/research/flamby/fed_ixi/ditto/client.py +++ b/research/flamby/fed_ixi/ditto/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -33,10 +33,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -54,7 +54,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( self.client_number, str(self.data_path) ) @@ -69,7 +69,7 @@ def get_model(self, config: Config) -> nn.Module: model: nn.Module = Baseline(out_channels_first_layer=12).to(self.device) return model - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Note that the global optimizer operates on self.global_model.parameters() and local optimizer operates on # self.model.parameters(). global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) diff --git a/research/flamby/fed_ixi/ditto/run_fold_experiment.slrm b/research/flamby/fed_ixi/ditto/run_fold_experiment.slrm index d8155aca6..0c5851a0c 100644 --- a/research/flamby/fed_ixi/ditto/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/ditto/run_fold_experiment.slrm @@ -92,7 +92,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/ditto/server.py b/research/flamby/fed_ixi/ditto/server.py index 65a8f11b0..0f56ac05e 100644 --- a/research/flamby/fed_ixi/ditto/server.py +++ b/research/flamby/fed_ixi/ditto/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_ixi import Baseline @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, lam: float) -> None: +def main(config: dict[str, Any], server_address: str, lam: float) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_ixi/evaluate_on_holdout.py b/research/flamby/fed_ixi/evaluate_on_holdout.py index fae17d9f2..21f2745f7 100644 --- a/research/flamby/fed_ixi/evaluate_on_holdout.py +++ b/research/flamby/fed_ixi/evaluate_on_holdout.py @@ -1,6 +1,5 @@ import argparse from logging import INFO -from typing import Dict import torch from flamby.datasets.fed_ixi import BATCH_SIZE, NUM_CLIENTS, FedIXITiny @@ -28,7 +27,7 @@ def main( ) -> None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") all_run_folder_dir = get_all_run_folders(artifact_dir) - test_results: Dict[str, float] = {} + test_results: dict[str, float] = {} metrics = [BinarySoftDiceCoefficient("FedIXI_dice")] all_local_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} diff --git a/research/flamby/fed_ixi/fedadam/client.py b/research/flamby/fed_ixi/fedadam/client.py index 130b7f269..dfadab838 100644 --- a/research/flamby/fed_ixi/fedadam/client.py +++ b/research/flamby/fed_ixi/fedadam/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -33,10 +33,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -54,7 +54,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_ixi/fedadam/run_fold_experiment.slrm b/research/flamby/fed_ixi/fedadam/run_fold_experiment.slrm index 98b39be90..9bb961f16 100644 --- a/research/flamby/fed_ixi/fedadam/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/fedadam/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/fedadam/server.py b/research/flamby/fed_ixi/fedadam/server.py index 8f7cf5105..4a29e5322 100644 --- a/research/flamby/fed_ixi/fedadam/server.py +++ b/research/flamby/fed_ixi/fedadam/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -21,7 +21,7 @@ def main( - config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float + config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float ) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( diff --git a/research/flamby/fed_ixi/fedavg/client.py b/research/flamby/fed_ixi/fedavg/client.py index c59f3bc7b..93e9cf275 100644 --- a/research/flamby/fed_ixi/fedavg/client.py +++ b/research/flamby/fed_ixi/fedavg/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,10 +32,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -53,7 +53,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_ixi/fedavg/run_fold_experiment.slrm b/research/flamby/fed_ixi/fedavg/run_fold_experiment.slrm index 5f8861bda..a7b7b9b61 100644 --- a/research/flamby/fed_ixi/fedavg/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/fedavg/run_fold_experiment.slrm @@ -85,7 +85,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/fedavg/server.py b/research/flamby/fed_ixi/fedavg/server.py index e8b6126ba..0bae7b1b4 100644 --- a/research/flamby/fed_ixi/fedavg/server.py +++ b/research/flamby/fed_ixi/fedavg/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_ixi import Baseline @@ -20,7 +20,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_ixi/fedper/client.py b/research/flamby/fed_ixi/fedper/client.py index b7b5fb035..b410f662a 100644 --- a/research/flamby/fed_ixi/fedper/client.py +++ b/research/flamby/fed_ixi/fedper/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -36,10 +36,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -58,7 +58,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_ixi/fedper/fedper_feature_extractor.py b/research/flamby/fed_ixi/fedper/fedper_feature_extractor.py index 819a0bff2..a4acf4d92 100644 --- a/research/flamby/fed_ixi/fedper/fedper_feature_extractor.py +++ b/research/flamby/fed_ixi/fedper/fedper_feature_extractor.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn from flamby.datasets.fed_ixi.model import Decoder, Encoder, EncodingBlock @@ -21,15 +19,15 @@ def __init__( dimensions: int = 3, num_encoding_blocks: int = 3, out_channels_first_layer: int = 8, - normalization: Optional[str] = "batch", + normalization: str | None = "batch", pooling_type: str = "max", upsampling_type: str = "linear", preactivation: bool = False, residual: bool = False, padding: int = 1, padding_mode: str = "zeros", - activation: Optional[str] = "PReLU", - initial_dilation: Optional[int] = None, + activation: str | None = "PReLU", + initial_dilation: int | None = None, dropout: float = 0, ): super().__init__() diff --git a/research/flamby/fed_ixi/fedper/run_fold_experiment.slrm b/research/flamby/fed_ixi/fedper/run_fold_experiment.slrm index 13ede0b66..05e2a0a21 100644 --- a/research/flamby/fed_ixi/fedper/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/fedper/run_fold_experiment.slrm @@ -89,7 +89,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/fedper/server.py b/research/flamby/fed_ixi/fedper/server.py index d22204cb9..5a2ef74f1 100644 --- a/research/flamby/fed_ixi/fedper/server.py +++ b/research/flamby/fed_ixi/fedper/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_ixi/fedprox/client.py b/research/flamby/fed_ixi/fedprox/client.py index bc5fbe1d0..817e9ebd1 100644 --- a/research/flamby/fed_ixi/fedprox/client.py +++ b/research/flamby/fed_ixi/fedprox/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,10 +32,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -53,7 +53,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_ixi/fedprox/run_fold_experiment.slrm b/research/flamby/fed_ixi/fedprox/run_fold_experiment.slrm index 119fb06e7..ab3f5ab20 100644 --- a/research/flamby/fed_ixi/fedprox/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/fedprox/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/fedprox/server.py b/research/flamby/fed_ixi/fedprox/server.py index 17c988d0e..3299b8d17 100644 --- a/research/flamby/fed_ixi/fedprox/server.py +++ b/research/flamby/fed_ixi/fedprox/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_ixi import Baseline @@ -19,7 +19,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, mu: float, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, mu: float, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_ixi/fenda/client.py b/research/flamby/fed_ixi/fenda/client.py index 02897334c..6640774db 100644 --- a/research/flamby/fed_ixi/fenda/client.py +++ b/research/flamby/fed_ixi/fenda/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -34,10 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -55,7 +55,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_ixi/fenda/fenda_feature_extractor.py b/research/flamby/fed_ixi/fenda/fenda_feature_extractor.py index 70bcf21cb..bfa0b56c6 100644 --- a/research/flamby/fed_ixi/fenda/fenda_feature_extractor.py +++ b/research/flamby/fed_ixi/fenda/fenda_feature_extractor.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn from flamby.datasets.fed_ixi.model import Decoder, Encoder, EncodingBlock @@ -21,15 +19,15 @@ def __init__( dimensions: int = 3, num_encoding_blocks: int = 3, out_channels_first_layer: int = 8, - normalization: Optional[str] = "batch", + normalization: str | None = "batch", pooling_type: str = "max", upsampling_type: str = "linear", preactivation: bool = False, residual: bool = False, padding: int = 1, padding_mode: str = "zeros", - activation: Optional[str] = "PReLU", - initial_dilation: Optional[int] = None, + activation: str | None = "PReLU", + initial_dilation: int | None = None, dropout: float = 0, ): super().__init__() diff --git a/research/flamby/fed_ixi/fenda/run_fold_experiment.slrm b/research/flamby/fed_ixi/fenda/run_fold_experiment.slrm index 1997fbeda..a88ac6d21 100644 --- a/research/flamby/fed_ixi/fenda/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/fenda/run_fold_experiment.slrm @@ -89,7 +89,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/fenda/server.py b/research/flamby/fed_ixi/fenda/server.py index 1db939191..9510d1ba5 100644 --- a/research/flamby/fed_ixi/fenda/server.py +++ b/research/flamby/fed_ixi/fenda/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_ixi/local/run_fold_experiment.slrm b/research/flamby/fed_ixi/local/run_fold_experiment.slrm index be0a0ea8a..5392d1a77 100644 --- a/research/flamby/fed_ixi/local/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/local/run_fold_experiment.slrm @@ -79,7 +79,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/moon/client.py b/research/flamby/fed_ixi/moon/client.py index e4bfc68df..0a325d243 100644 --- a/research/flamby/fed_ixi/moon/client.py +++ b/research/flamby/fed_ixi/moon/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -34,10 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, contrastive_weight: float = 10, ) -> None: super().__init__( @@ -57,7 +57,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_ixi/moon/moon_feature_extractor.py b/research/flamby/fed_ixi/moon/moon_feature_extractor.py index e86bc6607..70b361a94 100644 --- a/research/flamby/fed_ixi/moon/moon_feature_extractor.py +++ b/research/flamby/fed_ixi/moon/moon_feature_extractor.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn from flamby.datasets.fed_ixi.model import Decoder, Encoder, EncodingBlock @@ -21,15 +19,15 @@ def __init__( dimensions: int = 3, num_encoding_blocks: int = 3, out_channels_first_layer: int = 8, - normalization: Optional[str] = "batch", + normalization: str | None = "batch", pooling_type: str = "max", upsampling_type: str = "linear", preactivation: bool = False, residual: bool = False, padding: int = 1, padding_mode: str = "zeros", - activation: Optional[str] = "PReLU", - initial_dilation: Optional[int] = None, + activation: str | None = "PReLU", + initial_dilation: int | None = None, dropout: float = 0, ): super().__init__() diff --git a/research/flamby/fed_ixi/moon/run_fold_experiment.slrm b/research/flamby/fed_ixi/moon/run_fold_experiment.slrm index 1466eb86b..ffa85daa5 100644 --- a/research/flamby/fed_ixi/moon/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/moon/run_fold_experiment.slrm @@ -92,7 +92,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/moon/server.py b/research/flamby/fed_ixi/moon/server.py index e2062b55b..56a8a24f7 100644 --- a/research/flamby/fed_ixi/moon/server.py +++ b/research/flamby/fed_ixi/moon/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -21,7 +21,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_ixi/perfcl/client.py b/research/flamby/fed_ixi/perfcl/client.py index 7e66bf3e3..0a52c087e 100644 --- a/research/flamby/fed_ixi/perfcl/client.py +++ b/research/flamby/fed_ixi/perfcl/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -34,10 +34,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, mu: float = 10.0, gamma: float = 10.0, ) -> None: @@ -59,7 +59,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_ixi/perfcl/perfcl_feature_extractor.py b/research/flamby/fed_ixi/perfcl/perfcl_feature_extractor.py index 9c41bf172..32768f6dd 100644 --- a/research/flamby/fed_ixi/perfcl/perfcl_feature_extractor.py +++ b/research/flamby/fed_ixi/perfcl/perfcl_feature_extractor.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch import torch.nn as nn from flamby.datasets.fed_ixi.model import Decoder, Encoder, EncodingBlock @@ -21,15 +19,15 @@ def __init__( dimensions: int = 3, num_encoding_blocks: int = 3, out_channels_first_layer: int = 8, - normalization: Optional[str] = "batch", + normalization: str | None = "batch", pooling_type: str = "max", upsampling_type: str = "linear", preactivation: bool = False, residual: bool = False, padding: int = 1, padding_mode: str = "zeros", - activation: Optional[str] = "PReLU", - initial_dilation: Optional[int] = None, + activation: str | None = "PReLU", + initial_dilation: int | None = None, dropout: float = 0, ): super().__init__() diff --git a/research/flamby/fed_ixi/perfcl/run_fold_experiment.slrm b/research/flamby/fed_ixi/perfcl/run_fold_experiment.slrm index df7a6396f..c41fdb48b 100644 --- a/research/flamby/fed_ixi/perfcl/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/perfcl/run_fold_experiment.slrm @@ -95,7 +95,7 @@ do SEED="${SEEDS[i]}" # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/perfcl/server.py b/research/flamby/fed_ixi/perfcl/server.py index 7226a14f2..6f3f09ecb 100644 --- a/research/flamby/fed_ixi/perfcl/server.py +++ b/research/flamby/fed_ixi/perfcl/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -17,7 +17,7 @@ from research.flamby.utils import fit_config, summarize_model_info -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/flamby/fed_ixi/scaffold/client.py b/research/flamby/fed_ixi/scaffold/client.py index f1e26d877..a8838d5b5 100644 --- a/research/flamby/fed_ixi/scaffold/client.py +++ b/research/flamby/fed_ixi/scaffold/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -32,10 +32,10 @@ def __init__( client_number: int, learning_rate: float, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, ) -> None: super().__init__( data_path=data_path, @@ -53,7 +53,7 @@ def __init__( assert 0 <= client_number < NUM_CLIENTS log(INFO, f"Client Name: {self.client_name}, Client Number: {self.client_number}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_dataset, validation_dataset = construct_fed_ixi_train_val_datasets( self.client_number, str(self.data_path) ) diff --git a/research/flamby/fed_ixi/scaffold/run_fold_experiment.slrm b/research/flamby/fed_ixi/scaffold/run_fold_experiment.slrm index 12ce4e40e..53e96182e 100644 --- a/research/flamby/fed_ixi/scaffold/run_fold_experiment.slrm +++ b/research/flamby/fed_ixi/scaffold/run_fold_experiment.slrm @@ -88,7 +88,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/flamby/fed_ixi/scaffold/server.py b/research/flamby/fed_ixi/scaffold/server.py index ea4adf6d7..41026315e 100644 --- a/research/flamby/fed_ixi/scaffold/server.py +++ b/research/flamby/fed_ixi/scaffold/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl from flamby.datasets.fed_ixi import Baseline @@ -19,7 +19,7 @@ def main( - config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float + config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float ) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( diff --git a/research/flamby/find_best_hp.py b/research/flamby/find_best_hp.py index 64e6a1d3d..e5177afd6 100644 --- a/research/flamby/find_best_hp.py +++ b/research/flamby/find_best_hp.py @@ -1,18 +1,17 @@ import argparse import os from logging import INFO -from typing import List, Optional import numpy as np from flwr.common.logger import log -def get_hp_folders(hp_sweep_dir: str) -> List[str]: +def get_hp_folders(hp_sweep_dir: str) -> list[str]: paths_in_hp_sweep_dir = [os.path.join(hp_sweep_dir, contents) for contents in os.listdir(hp_sweep_dir)] return [hp_folder for hp_folder in paths_in_hp_sweep_dir if os.path.isdir(hp_folder)] -def get_run_folders(hp_dir: str) -> List[str]: +def get_run_folders(hp_dir: str) -> list[str]: run_folder_names = [folder_name for folder_name in os.listdir(hp_dir) if "Run" in folder_name] return [os.path.join(hp_dir, run_folder_name) for run_folder_name in run_folder_names] @@ -39,7 +38,7 @@ def get_weighted_loss_from_server_log( def main(hp_sweep_dir: str, experiment_name: str, is_partial_efficient_net: bool) -> None: hp_folders = get_hp_folders(hp_sweep_dir) - best_avg_loss: Optional[float] = None + best_avg_loss: float | None = None best_folder = "" for hp_folder in hp_folders: run_folders = get_run_folders(hp_folder) diff --git a/research/flamby/flamby_data_utils.py b/research/flamby/flamby_data_utils.py index 8fbd861f0..f241fc799 100644 --- a/research/flamby/flamby_data_utils.py +++ b/research/flamby/flamby_data_utils.py @@ -1,5 +1,3 @@ -from typing import Tuple - from flamby.datasets.fed_heart_disease import FedHeartDisease from flamby.datasets.fed_isic2019 import FedIsic2019 from flamby.datasets.fed_ixi import FedIXITiny @@ -8,7 +6,7 @@ def construct_fedisic_train_val_datasets( client_number: int, dataset_dir: str, pooled: bool = False -) -> Tuple[FedIsic2019, FedIsic2019]: +) -> tuple[FedIsic2019, FedIsic2019]: # If pooled is True then client number is ignored full_train_dataset = FedIsic2019(center=client_number, train=True, pooled=pooled, data_path=dataset_dir) # Something weird is happening with the typing of the split sequence in random split. Punting with a mypy @@ -21,7 +19,7 @@ def construct_fed_heard_disease_train_val_datasets( client_number: int, dataset_dir: str, pooled: bool = False, -) -> Tuple[FedHeartDisease, FedHeartDisease]: +) -> tuple[FedHeartDisease, FedHeartDisease]: # If pooled is True then client number is ignored full_train_dataset = FedHeartDisease(center=client_number, train=True, pooled=pooled, data_path=dataset_dir) # Something weird is happening with the typing of the split sequence in random split. Punting with a mypy @@ -32,7 +30,7 @@ def construct_fed_heard_disease_train_val_datasets( def construct_fed_ixi_train_val_datasets( client_number: int, dataset_dir: str, pooled: bool = False -) -> Tuple[FedIXITiny, FedIXITiny]: +) -> tuple[FedIXITiny, FedIXITiny]: # If pooled is True then client number is ignored full_train_dataset = FedIXITiny( center=client_number, diff --git a/research/flamby/flamby_servers/full_exchange_server.py b/research/flamby/flamby_servers/full_exchange_server.py index 7e5097498..c04058fc7 100644 --- a/research/flamby/flamby_servers/full_exchange_server.py +++ b/research/flamby/flamby_servers/full_exchange_server.py @@ -1,5 +1,3 @@ -from typing import Optional - from flwr.common.typing import Config from flwr.server.client_manager import ClientManager from flwr.server.strategy import Strategy @@ -14,7 +12,7 @@ def __init__( self, client_manager: ClientManager, fl_config: Config, - strategy: Optional[Strategy] = None, + strategy: Strategy | None = None, checkpoint_and_state_module: BaseServerCheckpointAndStateModule | None = None, ) -> None: super().__init__( diff --git a/research/flamby/flamby_servers/personal_server.py b/research/flamby/flamby_servers/personal_server.py index 77f561e00..051397cc9 100644 --- a/research/flamby/flamby_servers/personal_server.py +++ b/research/flamby/flamby_servers/personal_server.py @@ -1,5 +1,4 @@ from logging import INFO -from typing import Dict, Optional, Tuple from flwr.common.logger import log from flwr.common.typing import Config, Scalar @@ -25,20 +24,20 @@ def __init__( self, client_manager: ClientManager, fl_config: Config, - strategy: Optional[Strategy] = None, + strategy: Strategy | None = None, ) -> None: # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with # some globally shared weights. So we don't checkpoint a global model super().__init__( client_manager=client_manager, fl_config=fl_config, strategy=strategy, checkpoint_and_state_module=None ) - self.best_aggregated_loss: Optional[float] = None + self.best_aggregated_loss: float | None = None def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # loss_aggregated is the aggregated validation per step loss # aggregated over each client (weighted by num examples) eval_round_results = super().evaluate_round(server_round, timeout) diff --git a/research/flamby/single_node_trainer.py b/research/flamby/single_node_trainer.py index e49e9108b..afee32197 100644 --- a/research/flamby/single_node_trainer.py +++ b/research/flamby/single_node_trainer.py @@ -1,6 +1,5 @@ import os from logging import INFO -from typing import Dict, Tuple import torch import torch.nn as nn @@ -33,14 +32,14 @@ def __init__( self.train_loader: DataLoader self.val_loader: DataLoader - def _maybe_checkpoint(self, loss: float, metrics: Dict[str, Scalar]) -> None: + def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar]) -> None: if self.checkpointer: self.checkpointer.maybe_checkpoint(self.model, loss, metrics) def _handle_reporting( self, loss: float, - metrics_dict: Dict[str, Scalar], + metrics_dict: dict[str, Scalar], is_validation: bool = False, ) -> None: metric_string = "\t".join([f"{key}: {str(val)}" for key, val in metrics_dict.items()]) @@ -50,7 +49,7 @@ def _handle_reporting( f"Centralized {metric_prefix} Loss: {loss} \n" f"Centralized {metric_prefix} Metrics: {metric_string}", ) - def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + def train_step(self, input: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: # forward pass on the model preds = self.model(input) loss = self.criterion(preds, target) diff --git a/research/flamby/utils.py b/research/flamby/utils.py index 42febbef3..cb5e67fc6 100644 --- a/research/flamby/utils.py +++ b/research/flamby/utils.py @@ -1,7 +1,7 @@ import os import warnings +from collections.abc import Sequence from logging import INFO -from typing import Dict, List, Sequence, Tuple import numpy as np import torch @@ -29,7 +29,7 @@ def fit_config( } -def get_initial_model_info_with_control_variates(client_model: nn.Module) -> Tuple[Parameters, Parameters]: +def get_initial_model_info_with_control_variates(client_model: nn.Module) -> tuple[Parameters, Parameters]: # Initializing the model parameters on the server side. model_weights = [val.cpu().numpy() for _, val in client_model.state_dict().items()] # Initializing the control variates to zero, as suggested in the original scaffold paper @@ -37,12 +37,12 @@ def get_initial_model_info_with_control_variates(client_model: nn.Module) -> Tup return ndarrays_to_parameters(model_weights), ndarrays_to_parameters(control_variates) -def get_all_run_folders(artifact_dir: str) -> List[str]: +def get_all_run_folders(artifact_dir: str) -> list[str]: run_folder_names = [folder_name for folder_name in os.listdir(artifact_dir) if "Run" in folder_name] return [os.path.join(artifact_dir, run_folder_name) for run_folder_name in run_folder_names] -def write_measurement_results(eval_write_path: str, results: Dict[str, float]) -> None: +def write_measurement_results(eval_write_path: str, results: dict[str, float]) -> None: with open(eval_write_path, "w") as f: for key, metric_value in results.items(): f.write(f"{key}: {metric_value}\n") @@ -60,7 +60,7 @@ def load_global_model(run_folder_dir: str) -> nn.Module: return model -def get_metric_avg_std(metrics: List[float]) -> Tuple[float, float]: +def get_metric_avg_std(metrics: list[float]) -> tuple[float, float]: mean = float(np.mean(metrics)) std = float(np.std(metrics, ddof=1)) return mean, std diff --git a/research/flamby/visualization_scripts/average_performance.py b/research/flamby/visualization_scripts/average_performance.py index d5a2cdf32..de6f91035 100644 --- a/research/flamby/visualization_scripts/average_performance.py +++ b/research/flamby/visualization_scripts/average_performance.py @@ -1,12 +1,11 @@ import argparse import os -from typing import Dict, List from research.flamby.visualization_scripts.average_performance_configs import fed_isic_file_names_to_info -def process_results_dict(results_lines: List[str]) -> Dict[str, float]: - results_dict: Dict[str, float] = {} +def process_results_dict(results_lines: list[str]) -> dict[str, float]: + results_dict: dict[str, float] = {} for results_line in results_lines: split_line = results_line.split(":") results_dict[split_line[0]] = float(split_line[1]) @@ -14,7 +13,7 @@ def process_results_dict(results_lines: List[str]) -> Dict[str, float]: def process_results_to_matlab_string( - chart_method_names: List[str], chart_means: List[float], chart_std_devs: List[float] + chart_method_names: list[str], chart_means: list[float], chart_std_devs: list[float] ) -> None: out_string = "method = {'" chart_method_names_joined = "', '".join(chart_method_names) diff --git a/research/flamby/visualization_scripts/average_performance_configs.py b/research/flamby/visualization_scripts/average_performance_configs.py index 12d2ac11d..616971abc 100644 --- a/research/flamby/visualization_scripts/average_performance_configs.py +++ b/research/flamby/visualization_scripts/average_performance_configs.py @@ -1,8 +1,6 @@ -from typing import List, Tuple - # File name mapped to tuples of name appearing on the graph, keys for the mean, and keys for the std dev # NOTE: that only some methods with both server and local models have multiple mean and std dev keys -fed_isic_file_names_to_info: List[Tuple[str, str, Tuple[List[str], List[str]]]] = [ +fed_isic_file_names_to_info: list[tuple[str, str, tuple[list[str], list[str]]]] = [ ( "central_eval_performance.txt", "Central", @@ -111,7 +109,7 @@ # File name mapped to tuples of name appearing on the graph, keys for the mean, and keys for the std dev # NOTE: that only some methods with both server and local models have multiple mean and std dev keys -fed_heart_disease_file_names_to_info: List[Tuple[str, str, Tuple[List[str], List[str]]]] = [ +fed_heart_disease_file_names_to_info: list[tuple[str, str, tuple[list[str], list[str]]]] = [ ( "central_eval_performance_small_model.txt", "Central_S", @@ -276,7 +274,7 @@ # File name mapped to tuples of name appearing on the graph, keys for the mean, and keys for the std dev # NOTE: that only some methods with both server and local models have multiple mean and std dev keys -fed_ixi_file_names_to_info: List[Tuple[str, str, Tuple[List[str], List[str]]]] = [ +fed_ixi_file_names_to_info: list[tuple[str, str, tuple[list[str], list[str]]]] = [ ( "central_eval_performance.txt", "Central", diff --git a/research/flamby/visualization_scripts/model_generalization.py b/research/flamby/visualization_scripts/model_generalization.py index 9cc900640..2e44f38e2 100644 --- a/research/flamby/visualization_scripts/model_generalization.py +++ b/research/flamby/visualization_scripts/model_generalization.py @@ -1,12 +1,11 @@ import argparse import os -from typing import Dict, List from research.flamby.visualization_scripts.model_generalization_configs import fed_isic_file_names_to_info -def process_results_dict(results_lines: List[str]) -> Dict[str, float]: - results_dict: Dict[str, float] = {} +def process_results_dict(results_lines: list[str]) -> dict[str, float]: + results_dict: dict[str, float] = {} for results_line in results_lines: split_line = results_line.split(":") results_dict[split_line[0]] = float(split_line[1]) @@ -14,7 +13,7 @@ def process_results_dict(results_lines: List[str]) -> Dict[str, float]: def process_results_to_matlab_string( - chart_method_names: List[str], chart_means: List[List[float]], chart_variable_names: List[str] + chart_method_names: list[str], chart_means: list[list[float]], chart_variable_names: list[str] ) -> None: out_string = "" c_string = "C = [" diff --git a/research/flamby/visualization_scripts/model_generalization_configs.py b/research/flamby/visualization_scripts/model_generalization_configs.py index c7e2027ed..a01647386 100644 --- a/research/flamby/visualization_scripts/model_generalization_configs.py +++ b/research/flamby/visualization_scripts/model_generalization_configs.py @@ -1,7 +1,5 @@ -from typing import List, Tuple - # File name mapped to tuples of name appearing on the graph, variable name for array, keys for the mean -fed_isic_file_names_to_info: List[Tuple[str, str, str, List[str]]] = [ +fed_isic_file_names_to_info: list[tuple[str, str, str, list[str]]] = [ ( "client_0_eval_performance.txt", "Local 0", @@ -226,7 +224,7 @@ ] # File name mapped to tuples of name appearing on the graph, variable name for array -fed_heart_disease_file_names_to_info: List[Tuple[str, str, str, List[str]]] = [ +fed_heart_disease_file_names_to_info: list[tuple[str, str, str, list[str]]] = [ ( "client_0_eval_performance_small_model.txt", "Local 0_S", @@ -538,7 +536,7 @@ ] # File name mapped to tuples of name appearing on the graph, variable name for array, keys for the mean -fed_ixi_file_names_to_info: List[Tuple[str, str, str, List[str]]] = [ +fed_ixi_file_names_to_info: list[tuple[str, str, str, list[str]]] = [ ( "client_0_eval_performance.txt", "Local 0", diff --git a/research/gemini/apfl/client.py b/research/gemini/apfl/client.py index b46fe9fca..fb91f02c8 100644 --- a/research/gemini/apfl/client.py +++ b/research/gemini/apfl/client.py @@ -2,7 +2,6 @@ import os from logging import INFO from pathlib import Path -from typing import Dict, List, Tuple import flwr as fl import torch @@ -27,15 +26,15 @@ GlobalPreds = torch.Tensor PersonalPreds = torch.Tensor -ApflTrainStepOutputs = Tuple[LocalLoss, GlobalLoss, PersonalLoss, LocalPreds, GlobalPreds, PersonalPreds] +ApflTrainStepOutputs = tuple[LocalLoss, GlobalLoss, PersonalLoss, LocalPreds, GlobalPreds, PersonalPreds] class GeminiApflClient(ApflClient): def __init__( self, data_path: Path, - metrics: List[Metric], - hospitals_id: List[str], + metrics: list[Metric], + hospitals_id: list[str], device: torch.device, learning_task: str, learning_rate: float, @@ -85,7 +84,7 @@ def setup_client(self, config: Config) -> None: super().setup_client(config) - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -105,7 +104,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict metric_values, ) - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -135,7 +134,7 @@ def train_step(self, input: torch.Tensor, target: torch.Tensor) -> ApflTrainStep global_loss.backward() self.global_optimizer.step() - # Make sure gradients are zero prior to foward passes of global and local model + # Make sure gradients are zero prior to forward passes of global and local model # to generate personalized predictions # NOTE: We zero the global optimizer grads because they are used (after the backward calculation below) # to update the scalar alpha (see update_alpha() where .grad is called.) @@ -159,7 +158,7 @@ def train_step(self, input: torch.Tensor, target: torch.Tensor) -> ApflTrainStep def train_by_epochs( self, epochs: int, global_meter: Meter, local_meter: Meter, personal_meter: Meter - ) -> Dict[str, Scalar]: + ) -> dict[str, Scalar]: self.model.train() for epoch in range(epochs): loss_dict = {"personal": 0.0, "local": 0.0, "global": 0.0} @@ -192,14 +191,14 @@ def train_by_epochs( global_metrics = global_meter.compute() local_metrics = local_meter.compute() personal_metrics = personal_meter.compute() - metrics: Dict[str, Scalar] = {**global_metrics, **local_metrics, **personal_metrics} + metrics: dict[str, Scalar] = {**global_metrics, **local_metrics, **personal_metrics} log(INFO, f"Performed {epochs} Epochs of Local training") return metrics def validate( self, global_meter: Meter, local_meter: Meter, personal_meter: Meter - ) -> Tuple[float, Dict[str, Scalar]]: + ) -> tuple[float, dict[str, Scalar]]: self.model.eval() loss_dict = {"global": 0.0, "personal": 0.0, "local": 0.0} global_meter.clear() @@ -230,7 +229,7 @@ def validate( global_metrics = global_meter.compute() local_metrics = local_meter.compute() personal_metrics = personal_meter.compute() - metrics: Dict[str, Scalar] = {**global_metrics, **local_metrics, **personal_metrics} + metrics: dict[str, Scalar] = {**global_metrics, **local_metrics, **personal_metrics} self._maybe_checkpoint(loss_dict["personal"]) return loss_dict["personal"], metrics diff --git a/research/gemini/apfl/run_fold_experiment.slrm b/research/gemini/apfl/run_fold_experiment.slrm index 0b8073c36..a9535d51f 100644 --- a/research/gemini/apfl/run_fold_experiment.slrm +++ b/research/gemini/apfl/run_fold_experiment.slrm @@ -38,7 +38,7 @@ else fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -61,7 +61,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/apfl/server.py b/research/gemini/apfl/server.py index b7023eefd..8a3a0b39e 100644 --- a/research/gemini/apfl/server.py +++ b/research/gemini/apfl/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import flwr as fl import torch.nn as nn @@ -23,18 +23,18 @@ class GeminiAPFLServer(FlServer): def __init__( self, client_manager: ClientManager, - strategy: Optional[Strategy] = None, + strategy: Strategy | None = None, ) -> None: # APFL doesn't train a "server" model. Rather, each client trains a client specific model with some globally # shared weights. So we don't checkpoint a global model super().__init__(client_manager, strategy, checkpointer=None) - self.best_aggregated_loss: Optional[float] = None + self.best_aggregated_loss: float | None = None def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # loss_aggregated is the aggregated validation per step loss # aggregated over each client (weighted by num examples) eval_round_results = super().evaluate_round(server_round, timeout) @@ -63,14 +63,14 @@ def evaluate_round( return loss_aggregated, metrics_aggregated, (results, failures) -def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def fit_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients fit function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) return normalize_metrics(total_examples, aggregated_metrics) -def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def evaluate_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients evaluate function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) @@ -96,7 +96,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment # mappings = get_mappings(Path(ENCOUNTERS_FILE)) diff --git a/research/gemini/central/run_central.sh b/research/gemini/central/run_central.sh index 5edc5a492..a26c11849 100644 --- a/research/gemini/central/run_central.sh +++ b/research/gemini/central/run_central.sh @@ -45,7 +45,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${EXPERIMENT_DIRECTORY}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then @@ -65,7 +65,7 @@ do # RUN_NAME="Run1" # RUN_DIR="${EXPERIMENT_DIRECTORY}${RUN_NAME}/" -# echo "Starting Run and logging artifcats at ${RUN_DIR}" +# echo "Starting Run and logging artifacts at ${RUN_DIR}" # mkdir "${RUN_DIR}" # OUTPUT_FILE="${RUN_DIR}output.out" diff --git a/research/gemini/central/run_test.sh b/research/gemini/central/run_test.sh index e08c715e7..9a54163e5 100644 --- a/research/gemini/central/run_test.sh +++ b/research/gemini/central/run_test.sh @@ -7,7 +7,7 @@ #SBATCH --mem=5120MB #SBATCH --partition=gpu #SBATCH --qos=hipri -#SBATCH --job-name=centeral-testing +#SBATCH --job-name=central-testing #SBATCH --output=%j_%x.out #SBATCH --error=%j_%x.err #SBATCH --mail-user=your_email@vectorinstitute.ai diff --git a/research/gemini/central/test.py b/research/gemini/central/test.py index c53cb4bd6..108d93f2e 100644 --- a/research/gemini/central/test.py +++ b/research/gemini/central/test.py @@ -2,7 +2,6 @@ import os from logging import INFO from pathlib import Path -from typing import Dict, List import torch import torch.nn as nn @@ -20,15 +19,15 @@ def load_centralized_model(run_folder_dir: str) -> nn.Module: return model -def write_measurement_results(eval_write_path: str, metric, results: Dict[str, float]) -> None: +def write_measurement_results(eval_write_path: str, metric, results: dict[str, float]) -> None: metric_write_path = os.path.join(eval_write_path, f"{metric.name}_metric.txt") with open(metric_write_path, "w") as f: - for key, metric_vaue in results.items(): - f.write(f"{key}: {metric_vaue}\n") + for key, metric_value in results.items(): + f.write(f"{key}: {metric_value}\n") def main( - data_path: Path, artifact_dir: str, eval_write_path: str, n_clients: int, hospitals: List[str], learning_task: str + data_path: Path, artifact_dir: str, eval_write_path: str, n_clients: int, hospitals: list[str], learning_task: str ) -> None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") all_run_folder_dir = get_all_run_folders(artifact_dir) @@ -45,7 +44,7 @@ def main( pooled_test_loader = load_test_delirium(data_path, 64) for metric in metrics: - test_results: Dict[str, float] = {} + test_results: dict[str, float] = {} all_clients_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} client_test_metrics = {client_id: [] for client_id in range(n_clients)} diff --git a/research/gemini/central/train.py b/research/gemini/central/train.py index 2cca3f1df..8755183f1 100644 --- a/research/gemini/central/train.py +++ b/research/gemini/central/train.py @@ -2,7 +2,6 @@ import os from logging import INFO from pathlib import Path -from typing import List import torch import torch.nn as nn @@ -18,7 +17,7 @@ def main( data_path: Path, - metrics: List[Metric], + metrics: list[Metric], device: torch.device, learning_task: str, batch_size: int, diff --git a/research/gemini/delirium_models/fedper_model.py b/research/gemini/delirium_models/fedper_model.py index 1a8b8b55c..0fde4de61 100644 --- a/research/gemini/delirium_models/fedper_model.py +++ b/research/gemini/delirium_models/fedper_model.py @@ -4,7 +4,7 @@ from fl4health.model_bases.fedper_base import FedPerModel -class FedPerGloalFeatureExtractor(nn.Module): +class FedPerGlobalFeatureExtractor(nn.Module): def __init__(self, input_dim: int) -> None: super().__init__() self.fc1 = nn.Linear(input_dim, 256 * 4) @@ -28,7 +28,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class FedPerGloalFeatureExtractor_het(nn.Module): +class FedPerGlobalFeatureExtractor_het(nn.Module): def __init__(self, input_dim: int) -> None: super().__init__() self.fc1 = nn.Linear(input_dim, 256 * 2) @@ -65,6 +65,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DeliriumFedPerModel(FedPerModel): def __init__(self, input_dim: int, output_dim: int) -> None: - base_module = FedPerGloalFeatureExtractor_het(input_dim) + base_module = FedPerGlobalFeatureExtractor_het(input_dim) head_module = FedPerLocalPredictionHead(output_dim) super().__init__(base_module, head_module) diff --git a/research/gemini/ditto/client.py b/research/gemini/ditto/client.py index 75b81f3f1..d24c31864 100644 --- a/research/gemini/ditto/client.py +++ b/research/gemini/ditto/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Dict, List, Optional, Sequence, Tuple import flwr as fl import torch @@ -34,14 +34,14 @@ def __init__( data_path: Path, metrics: Sequence[Metric], device: torch.device, - hospital_id: List[str], + hospital_id: list[str], learning_rate: float, learning_task: str, lam: float, checkpoint_stub: str, run_name: str = "", loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[TorchModuleCheckpointer] = None, + checkpointer: TorchModuleCheckpointer | None = None, ) -> None: # Checkpointing: create a string of the names of the hospitals self.hospitals = hospital_id @@ -66,7 +66,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name} Client hospitals {self.hospitals}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = self.narrow_dict_type(config, "batch_size", int) if self.learning_task == "mortality": ( @@ -85,7 +85,7 @@ def get_model(self, config: Config) -> nn.Module: model: nn.Module = delirium_model(input_dim=8093, output_dim=1).to(self.device) return model - def get_optimizer(self, config: Config) -> Dict[str, Optimizer]: + def get_optimizer(self, config: Config) -> dict[str, Optimizer]: # Note that the global optimizer operates on self.global_model.parameters() and global_optimizer = torch.optim.AdamW(self.global_model.parameters(), lr=self.learning_rate) local_optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate) diff --git a/research/gemini/ditto/run_fold_experiment.slrm b/research/gemini/ditto/run_fold_experiment.slrm index 44942ca6e..968d5f77a 100644 --- a/research/gemini/ditto/run_fold_experiment.slrm +++ b/research/gemini/ditto/run_fold_experiment.slrm @@ -37,7 +37,7 @@ else fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -62,7 +62,7 @@ do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" ((SEED++)) - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/ditto/server.py b/research/gemini/ditto/server.py index 53d5fd09f..ecf6841fa 100644 --- a/research/gemini/ditto/server.py +++ b/research/gemini/ditto/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl import torch.nn as nn @@ -40,7 +40,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/gemini/evaluation/delirium_readme.md b/research/gemini/evaluation/delirium_readme.md index 24869d0ac..1a8dcd066 100644 --- a/research/gemini/evaluation/delirium_readme.md +++ b/research/gemini/evaluation/delirium_readme.md @@ -125,7 +125,7 @@ python -m evaluation.evaluate_on_holdout --artifact_dir "FedOpt/delirium_runs/hp ``` -## SCAFFODL +## SCAFFOLD ### hyper-parameters ``` python -m evaluation.find_best_hp --hp_sweep_dir "Scaffold/delirium_runs/hp_sweep_results" diff --git a/research/gemini/evaluation/evaluate_on_holdout.py b/research/gemini/evaluation/evaluate_on_holdout.py index 0deac6ace..a3e8f72c4 100644 --- a/research/gemini/evaluation/evaluate_on_holdout.py +++ b/research/gemini/evaluation/evaluate_on_holdout.py @@ -1,7 +1,6 @@ import argparse from logging import INFO from pathlib import Path -from typing import Dict, List import torch from data.data import load_test_delirium, load_test_mortality @@ -23,7 +22,7 @@ def main( dataset_dir: Path, eval_write_path: str, n_clients: int, - hospitals: List[str], + hospitals: list[str], learning_task: str, eval_global_model: bool, is_apfl: bool, @@ -41,7 +40,7 @@ def main( # metrics = [] for metric in metrics: - test_results: Dict[str, float] = {} + test_results: dict[str, float] = {} all_local_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} all_server_test_metrics = {run_folder_dir: 0.0 for run_folder_dir in all_run_folder_dir} diff --git a/research/gemini/evaluation/find_best_hp.py b/research/gemini/evaluation/find_best_hp.py index f5e9d512b..8e5b0bc0c 100644 --- a/research/gemini/evaluation/find_best_hp.py +++ b/research/gemini/evaluation/find_best_hp.py @@ -1,18 +1,17 @@ import argparse import os from logging import INFO -from typing import List, Optional import numpy as np from flwr.common.logger import log -def get_hp_folders(hp_sweep_dir: str) -> List[str]: +def get_hp_folders(hp_sweep_dir: str) -> list[str]: paths_in_hp_sweep_dir = [os.path.join(hp_sweep_dir, contents) for contents in os.listdir(hp_sweep_dir)] return [hp_folder for hp_folder in paths_in_hp_sweep_dir if os.path.isdir(hp_folder)] -def get_run_folders(hp_dir: str) -> List[str]: +def get_run_folders(hp_dir: str) -> list[str]: run_folder_names = [folder_name for folder_name in os.listdir(hp_dir) if "Run" in folder_name] return [os.path.join(hp_dir, run_folder_name) for run_folder_name in run_folder_names] @@ -49,7 +48,7 @@ def create_avg_loss_clients(run_folder_path: str, client_id: int) -> float: def main(hp_sweep_dir: str, client_id: int) -> None: hp_folders = get_hp_folders(hp_sweep_dir) - best_avg_loss: Optional[float] = None + best_avg_loss: float | None = None best_folder = "" for hp_folder in hp_folders: log(INFO, f"Now analyzing: {hp_folder}") diff --git a/research/gemini/evaluation/utils.py b/research/gemini/evaluation/utils.py index e6dc18fd5..ab7ac2279 100644 --- a/research/gemini/evaluation/utils.py +++ b/research/gemini/evaluation/utils.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Dict, List, Sequence, Tuple +from collections.abc import Sequence import numpy as np import torch @@ -12,16 +12,16 @@ warnings.filterwarnings("ignore", category=UserWarning) -def get_all_run_folders(artifact_dir: str) -> List[str]: +def get_all_run_folders(artifact_dir: str) -> list[str]: run_folder_names = [folder_name for folder_name in os.listdir(artifact_dir) if "Run" in folder_name] return [os.path.join(artifact_dir, run_folder_name) for run_folder_name in run_folder_names] -def write_measurement_results(eval_write_path: str, metric, results: Dict[str, float]) -> None: +def write_measurement_results(eval_write_path: str, metric, results: dict[str, float]) -> None: metric_write_path = os.path.join(eval_write_path, f"{metric.name}_metric.txt") with open(metric_write_path, "w") as f: - for key, metric_vaue in results.items(): - f.write(f"{key}: {metric_vaue}\n") + for key, metric_value in results.items(): + f.write(f"{key}: {metric_value}\n") def load_local_model(run_folder_dir: str, hospital_names: str) -> nn.Module: @@ -38,7 +38,7 @@ def load_global_model(run_folder_dir: str) -> nn.Module: return model -def get_metric_avg_std(metrics: List[float]) -> Tuple[float, float]: +def get_metric_avg_std(metrics: list[float]) -> tuple[float, float]: mean = float(np.mean(metrics)) std = float(np.std(metrics, ddof=1)) return mean, std diff --git a/research/gemini/fedavg/client.py b/research/gemini/fedavg/client.py index ab4ba1154..f83ed24b6 100644 --- a/research/gemini/fedavg/client.py +++ b/research/gemini/fedavg/client.py @@ -2,7 +2,6 @@ import os from logging import INFO from pathlib import Path -from typing import Dict, List, Tuple import flwr as fl import torch @@ -24,8 +23,8 @@ class GeminiFedAvgClient(NumpyFlClient): def __init__( self, data_path: Path, - metrics: List[Metric], - hospitals_id: List[str], + metrics: list[Metric], + hospitals_id: list[str], device: torch.device, learning_task: str, learning_rate: float, @@ -69,7 +68,7 @@ def setup_client(self, config: Config) -> None: super().setup_client(config) - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -88,7 +87,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict metric_values, ) - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -109,7 +108,7 @@ def train_by_epochs( current_server_round: int, epochs: int, meter: Meter, - ) -> Dict[str, Scalar]: + ) -> dict[str, Scalar]: self.model.train() for local_epoch in range(epochs): meter.clear() @@ -128,7 +127,7 @@ def train_by_epochs( # Return final training metrics return metrics - def validate(self, current_server_round: int, meter: Meter) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, current_server_round: int, meter: Meter) -> tuple[float, dict[str, Scalar]]: self.model.eval() val_loss_sum = 0 with torch.no_grad(): diff --git a/research/gemini/fedavg/run_fold_experiment.slrm b/research/gemini/fedavg/run_fold_experiment.slrm index ae1b5080f..e06add283 100644 --- a/research/gemini/fedavg/run_fold_experiment.slrm +++ b/research/gemini/fedavg/run_fold_experiment.slrm @@ -37,7 +37,7 @@ fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -59,7 +59,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/fedavg/server.py b/research/gemini/fedavg/server.py index aa83e83ea..ce19ff33b 100644 --- a/research/gemini/fedavg/server.py +++ b/research/gemini/fedavg/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import flwr as fl import torch.nn as nn @@ -15,7 +15,7 @@ from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger -from fl4health.reporting.fl_wanb import ServerWandBReporter +from fl4health.reporting.fl_wandb import ServerWandBReporter from fl4health.servers.server import FlServer from fl4health.utils.config import load_config from research.gemini.delirium_models.NN import NN as delirium_model @@ -28,9 +28,9 @@ def __init__( self, client_manager: ClientManager, client_model: nn.Module, - strategy: Optional[Strategy] = None, - checkpointer: Optional[BestMetricTorchCheckpointer] = None, - wandb_reporter: Optional[ServerWandBReporter] = None, + strategy: Strategy | None = None, + checkpointer: BestMetricTorchCheckpointer | None = None, + wandb_reporter: ServerWandBReporter | None = None, ) -> None: self.client_model = client_model # To help with model rehydration @@ -49,8 +49,8 @@ def _maybe_checkpoint(self, checkpoint_metric: float) -> None: def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # loss_aggregated is the aggregated validation per step loss # aggregated over each client (weighted by num examples) eval_round_results = super().evaluate_round(server_round, timeout) @@ -62,14 +62,14 @@ def evaluate_round( return loss_aggregated, metrics_aggregated, (results, failures) -def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def fit_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients fit function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) return normalize_metrics(total_examples, aggregated_metrics) -def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def evaluate_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients evaluate function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) @@ -100,7 +100,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( diff --git a/research/gemini/fedopt/client.py b/research/gemini/fedopt/client.py index 8043fb457..7bfb87873 100644 --- a/research/gemini/fedopt/client.py +++ b/research/gemini/fedopt/client.py @@ -2,7 +2,6 @@ import os from logging import INFO from pathlib import Path -from typing import Dict, List, Tuple import flwr as fl import torch @@ -24,8 +23,8 @@ class GeminiFedOptClient(NumpyFlClient): def __init__( self, data_path: Path, - metrics: List[Metric], - hospitals_id: List[str], + metrics: list[Metric], + hospitals_id: list[str], device: torch.device, learning_task: str, learning_rate: float, @@ -70,7 +69,7 @@ def setup_client(self, config: Config) -> None: super().setup_client(config) - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -89,7 +88,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict metric_values, ) - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -110,7 +109,7 @@ def train_by_epochs( current_server_round: int, epochs: int, meter: Meter, - ) -> Dict[str, Scalar]: + ) -> dict[str, Scalar]: self.model.train() for local_epoch in range(epochs): meter.clear() @@ -131,7 +130,7 @@ def train_by_epochs( # Return final training metrics return metrics - def validate(self, current_server_round: int, meter: Meter) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, current_server_round: int, meter: Meter) -> tuple[float, dict[str, Scalar]]: self.model.eval() val_loss_sum = 0 with torch.no_grad(): diff --git a/research/gemini/fedopt/run_fold_experiment.slrm b/research/gemini/fedopt/run_fold_experiment.slrm index 3ef40fdd3..5f2dc2161 100644 --- a/research/gemini/fedopt/run_fold_experiment.slrm +++ b/research/gemini/fedopt/run_fold_experiment.slrm @@ -38,7 +38,7 @@ fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -59,7 +59,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/fedopt/server.py b/research/gemini/fedopt/server.py index 06dd6cbe7..e4c79ef5e 100644 --- a/research/gemini/fedopt/server.py +++ b/research/gemini/fedopt/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import flwr as fl import torch.nn as nn @@ -29,8 +29,8 @@ def __init__( self, client_manager: ClientManager, client_model: nn.Module, - strategy: Optional[Strategy] = None, - checkpointer: Optional[BestMetricTorchCheckpointer] = None, + strategy: Strategy | None = None, + checkpointer: BestMetricTorchCheckpointer | None = None, ) -> None: self.client_model = client_model # To help with model rehydration @@ -49,8 +49,8 @@ def _maybe_checkpoint(self, checkpoint_metric: float) -> None: def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # loss_aggregated is the aggregated validation per step loss # aggregated over each client (weighted by num examples) eval_round_results = super().evaluate_round(server_round, timeout) @@ -62,14 +62,14 @@ def evaluate_round( return loss_aggregated, metrics_aggregated, (results, failures) -def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def fit_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients fit function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) return normalize_metrics(total_examples, aggregated_metrics) -def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def evaluate_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients evaluate function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) @@ -98,7 +98,7 @@ def fit_config( def main( - config: Dict[str, Any], + config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, diff --git a/research/gemini/fedper/client.py b/research/gemini/fedper/client.py index f3877ca81..ebed9aa85 100644 --- a/research/gemini/fedper/client.py +++ b/research/gemini/fedper/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import List, Optional, Sequence, Tuple import flwr as fl import torch @@ -36,13 +36,13 @@ def __init__( data_path: Path, metrics: Sequence[Metric], device: torch.device, - hospital_id: List[str], + hospital_id: list[str], learning_rate: float, learning_task: str, checkpoint_stub: str, run_name: str = "", loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[TorchModuleCheckpointer] = None, + checkpointer: TorchModuleCheckpointer | None = None, ) -> None: # Checkpointing: create a string of the names of the hospitals self.hospitals = hospital_id @@ -66,7 +66,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name} Client hospitals {self.hospitals}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = self.narrow_dict_type(config, "batch_size", int) if self.learning_task == "mortality": ( diff --git a/research/gemini/fedper/run_fold_experiment.slrm b/research/gemini/fedper/run_fold_experiment.slrm index e6e99c6c3..c1595e905 100644 --- a/research/gemini/fedper/run_fold_experiment.slrm +++ b/research/gemini/fedper/run_fold_experiment.slrm @@ -36,7 +36,7 @@ else fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -61,7 +61,7 @@ do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" ((SEED++)) - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/fedper/server.py b/research/gemini/fedper/server.py index 85a816366..b9e79bb0c 100644 --- a/research/gemini/fedper/server.py +++ b/research/gemini/fedper/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl import torch.nn as nn @@ -40,7 +40,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/gemini/fedprox/client.py b/research/gemini/fedprox/client.py index 98decf7eb..eb31ce7c8 100644 --- a/research/gemini/fedprox/client.py +++ b/research/gemini/fedprox/client.py @@ -2,7 +2,6 @@ import os from logging import INFO from pathlib import Path -from typing import Dict, List, Tuple import flwr as fl import torch @@ -24,8 +23,8 @@ class GeminiFedProxClient(FedProxClient): def __init__( self, data_path: Path, - metrics: List[Metric], - hospitals_id: List[str], + metrics: list[Metric], + hospitals_id: list[str], device: torch.device, learning_task: str, learning_rate: float, @@ -71,7 +70,7 @@ def setup_client(self, config: Config) -> None: super().setup_client(config) - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -88,7 +87,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict metric_values, ) - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) diff --git a/research/gemini/fedprox/run_fold_experiment.slrm b/research/gemini/fedprox/run_fold_experiment.slrm index 5c64de52a..5ac845dbf 100644 --- a/research/gemini/fedprox/run_fold_experiment.slrm +++ b/research/gemini/fedprox/run_fold_experiment.slrm @@ -37,7 +37,7 @@ else fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -60,7 +60,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/fedprox/server.py b/research/gemini/fedprox/server.py index aa4babffa..99f1b3618 100644 --- a/research/gemini/fedprox/server.py +++ b/research/gemini/fedprox/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import flwr as fl import torch.nn as nn @@ -15,7 +15,7 @@ from fl4health.checkpointing.checkpointer import BestMetricTorchCheckpointer from fl4health.parameter_exchange.full_exchanger import FullParameterExchanger -from fl4health.reporting.fl_wanb import ServerWandBReporter +from fl4health.reporting.fl_wandb import ServerWandBReporter from fl4health.servers.server import FlServer from fl4health.utils.config import load_config from research.gemini.delirium_models.NN import NN as delirium_model @@ -28,9 +28,9 @@ def __init__( self, client_manager: ClientManager, client_model: nn.Module, - strategy: Optional[Strategy] = None, - checkpointer: Optional[BestMetricTorchCheckpointer] = None, - wandb_reporter: Optional[ServerWandBReporter] = None, + strategy: Strategy | None = None, + checkpointer: BestMetricTorchCheckpointer | None = None, + wandb_reporter: ServerWandBReporter | None = None, ) -> None: self.client_model = client_model # To help with model rehydration @@ -49,8 +49,8 @@ def _maybe_checkpoint(self, checkpoint_metric: float) -> None: def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # loss_aggregated is the aggregated validation per step loss # aggregated over each client (weighted by num examples) eval_round_results = super().evaluate_round(server_round, timeout) @@ -62,14 +62,14 @@ def evaluate_round( return loss_aggregated, metrics_aggregated, (results, failures) -def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def fit_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients fit function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) return normalize_metrics(total_examples, aggregated_metrics) -def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def evaluate_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients evaluate function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) @@ -101,7 +101,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/gemini/fenda/client.py b/research/gemini/fenda/client.py index a5911bfe5..f67cdfb8e 100644 --- a/research/gemini/fenda/client.py +++ b/research/gemini/fenda/client.py @@ -2,7 +2,6 @@ import os from logging import INFO from pathlib import Path -from typing import Dict, List, Tuple import flwr as fl import torch @@ -30,8 +29,8 @@ class GeminiFendaClient(NumpyFlClient): def __init__( self, data_path: Path, - metrics: List[Metric], - hospitals_id: List[str], + metrics: list[Metric], + hospitals_id: list[str], device: torch.device, learning_task: str, learning_rate: float, @@ -79,7 +78,7 @@ def setup_client(self, config: Config) -> None: super().setup_client(config) - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -98,7 +97,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict metric_values, ) - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -119,7 +118,7 @@ def train_by_epochs( current_server_round: int, epochs: int, meter: Meter, - ) -> Dict[str, Scalar]: + ) -> dict[str, Scalar]: self.model.train() for local_epoch in range(epochs): meter.clear() @@ -140,7 +139,7 @@ def train_by_epochs( # Return final training metrics return metrics - def validate(self, current_server_round: int, meter: Meter) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, current_server_round: int, meter: Meter) -> tuple[float, dict[str, Scalar]]: self.model.eval() val_loss_sum = 0 with torch.no_grad(): diff --git a/research/gemini/fenda/run_fold_experiment.slrm b/research/gemini/fenda/run_fold_experiment.slrm index f320e1cb6..893cb1893 100644 --- a/research/gemini/fenda/run_fold_experiment.slrm +++ b/research/gemini/fenda/run_fold_experiment.slrm @@ -37,7 +37,7 @@ fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -62,7 +62,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/fenda/server.py b/research/gemini/fenda/server.py index f2ead4b20..2d30e0bc2 100644 --- a/research/gemini/fenda/server.py +++ b/research/gemini/fenda/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import flwr as fl import torch.nn as nn @@ -29,18 +29,18 @@ class GeminiFendaServer(FlServer): def __init__( self, client_manager: ClientManager, - strategy: Optional[Strategy] = None, + strategy: Strategy | None = None, ) -> None: # FENDA doesn't train a "server" model. Rather, each client trains a client specific model with some globally # shared weights. So we don't checkpoint a global model super().__init__(client_manager, strategy, checkpointer=None) - self.best_aggregated_loss: Optional[float] = None + self.best_aggregated_loss: float | None = None def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # loss_aggregated is the aggregated validation per step loss # aggregated over each client (weighted by num examples) eval_round_results = super().evaluate_round(server_round, timeout) @@ -69,14 +69,14 @@ def evaluate_round( return loss_aggregated, metrics_aggregated, (results, failures) -def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def fit_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients fit function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) return normalize_metrics(total_examples, aggregated_metrics) -def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def evaluate_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients evaluate function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) @@ -105,7 +105,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str) -> None: +def main(config: dict[str, Any], server_address: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment # mappings = get_mappings(Path(ENCOUNTERS_FILE)) diff --git a/research/gemini/local/run_fold_experiment.slrm b/research/gemini/local/run_fold_experiment.slrm index e4c0a2aae..7308e42c6 100644 --- a/research/gemini/local/run_fold_experiment.slrm +++ b/research/gemini/local/run_fold_experiment.slrm @@ -35,7 +35,7 @@ else HOSPITALS=( "107" ) fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -58,7 +58,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/local/train.py b/research/gemini/local/train.py index 1b0ff09f0..eabab3106 100644 --- a/research/gemini/local/train.py +++ b/research/gemini/local/train.py @@ -2,7 +2,6 @@ import os from logging import INFO from pathlib import Path -from typing import List import torch import torch.nn as nn @@ -18,9 +17,9 @@ def main( data_path: Path, - metrics: List[Metric], + metrics: list[Metric], device: torch.device, - hospitals_id: List[str], + hospitals_id: list[str], learning_task: str, batch_size: int, num_epochs: int, diff --git a/research/gemini/moon/client.py b/research/gemini/moon/client.py index 576f8ddf7..b30d7d2e9 100644 --- a/research/gemini/moon/client.py +++ b/research/gemini/moon/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import List, Optional, Sequence, Tuple import flwr as fl import torch @@ -34,14 +34,14 @@ def __init__( data_path: Path, metrics: Sequence[Metric], device: torch.device, - hospital_id: List[str], + hospital_id: list[str], learning_rate: float, learning_task: str, checkpoint_stub: str, run_name: str = "", loss_meter_type: LossMeterType = LossMeterType.AVERAGE, contrastive_weight: float = 10, - checkpointer: Optional[TorchModuleCheckpointer] = None, + checkpointer: TorchModuleCheckpointer | None = None, ) -> None: # Checkpointing: create a string of the names of the hospitals self.hospitals = hospital_id @@ -66,7 +66,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name} Client hospitals {self.hospitals}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = self.narrow_dict_type(config, "batch_size", int) if self.learning_task == "mortality": ( diff --git a/research/gemini/moon/run_fold_experiment.slrm b/research/gemini/moon/run_fold_experiment.slrm index 4d3110a63..763c6336f 100644 --- a/research/gemini/moon/run_fold_experiment.slrm +++ b/research/gemini/moon/run_fold_experiment.slrm @@ -37,7 +37,7 @@ else fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -62,7 +62,7 @@ do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" ((SEED++)) - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/moon/server.py b/research/gemini/moon/server.py index 8133da667..b9bb46f74 100644 --- a/research/gemini/moon/server.py +++ b/research/gemini/moon/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl import torch.nn as nn @@ -42,7 +42,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/gemini/mortality_models/fedper_model.py b/research/gemini/mortality_models/fedper_model.py index 94ca1e1e9..184311523 100644 --- a/research/gemini/mortality_models/fedper_model.py +++ b/research/gemini/mortality_models/fedper_model.py @@ -4,7 +4,7 @@ from fl4health.model_bases.fedper_base import FedPerModel -class FedPerGloalFeatureExtractor(nn.Module): +class FedPerGlobalFeatureExtractor(nn.Module): def __init__(self, input_dim: int) -> None: super().__init__() self.fc1 = nn.Linear(input_dim, 256 * 2) @@ -33,6 +33,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class GeminiFedPerModel(FedPerModel): def __init__(self, input_dim: int, output_dim: int) -> None: - base_module = FedPerGloalFeatureExtractor(input_dim) + base_module = FedPerGlobalFeatureExtractor(input_dim) head_module = FedPerLocalPredictionHead(output_dim) super().__init__(base_module, head_module) diff --git a/research/gemini/perfcl/client.py b/research/gemini/perfcl/client.py index 841baf4a8..6e5edb747 100644 --- a/research/gemini/perfcl/client.py +++ b/research/gemini/perfcl/client.py @@ -1,8 +1,8 @@ import argparse import os +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import List, Optional, Sequence, Tuple import flwr as fl import torch @@ -34,14 +34,14 @@ def __init__( data_path: Path, metrics: Sequence[Metric], device: torch.device, - hospital_id: List[str], + hospital_id: list[str], learning_rate: float, learning_task: str, checkpoint_stub: str, run_name: str = "", loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpointer: Optional[TorchModuleCheckpointer] = None, - extra_loss_weights: Tuple[float, float] = (10, 10), + checkpointer: TorchModuleCheckpointer | None = None, + extra_loss_weights: tuple[float, float] = (10, 10), ) -> None: # Checkpointing: create a string of the names of the hospitals self.hospitals = hospital_id @@ -66,7 +66,7 @@ def __init__( log(INFO, f"Client Name: {self.client_name} Client hospitals {self.hospitals}") - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = self.narrow_dict_type(config, "batch_size", int) if self.learning_task == "mortality": ( diff --git a/research/gemini/perfcl/run_fold_experiment.slrm b/research/gemini/perfcl/run_fold_experiment.slrm index feaf4c2ec..87d1930e4 100644 --- a/research/gemini/perfcl/run_fold_experiment.slrm +++ b/research/gemini/perfcl/run_fold_experiment.slrm @@ -38,7 +38,7 @@ else fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -63,7 +63,7 @@ do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" ((SEED++)) - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/perfcl/server.py b/research/gemini/perfcl/server.py index 92746713a..1ac3e60f4 100644 --- a/research/gemini/perfcl/server.py +++ b/research/gemini/perfcl/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from logging import INFO -from typing import Any, Dict +from typing import Any import flwr as fl import torch.nn as nn @@ -40,7 +40,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: +def main(config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/gemini/scaffold/client.py b/research/gemini/scaffold/client.py index d3c2b23c1..d47c8fdcb 100644 --- a/research/gemini/scaffold/client.py +++ b/research/gemini/scaffold/client.py @@ -2,7 +2,6 @@ import os from logging import INFO from pathlib import Path -from typing import Dict, List, Tuple import flwr as fl import torch @@ -24,8 +23,8 @@ class GeminiScaffoldclient(ScaffoldClient): def __init__( self, data_path: Path, - metrics: List[Metric], - hospitals_id: List[str], + metrics: list[Metric], + hospitals_id: list[str], device: torch.device, learning_task: str, learning_rate: float, @@ -74,7 +73,7 @@ def setup_client(self, config: Config) -> None: super().setup_client(config) - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -92,7 +91,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict metric_values, ) - def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]: + def evaluate(self, parameters: NDArrays, config: Config) -> tuple[float, int, dict[str, Scalar]]: if not self.initialized: self.setup_client(config) @@ -163,7 +162,7 @@ def train_by_rounds( self, local_steps: int, meter: Meter, - ) -> Dict[str, Scalar]: + ) -> dict[str, Scalar]: self.model.train() running_loss = 0.0 meter.clear() @@ -201,7 +200,7 @@ def train_by_rounds( return metrics - def validate(self, meter: Meter) -> Tuple[float, Dict[str, Scalar]]: + def validate(self, meter: Meter) -> tuple[float, dict[str, Scalar]]: self.model.eval() running_loss = 0.0 meter.clear() diff --git a/research/gemini/scaffold/run_fold_experiment.slrm b/research/gemini/scaffold/run_fold_experiment.slrm index 877b35620..6849df185 100644 --- a/research/gemini/scaffold/run_fold_experiment.slrm +++ b/research/gemini/scaffold/run_fold_experiment.slrm @@ -37,7 +37,7 @@ fi -# Create the artficat directory +# Create the artifact directory mkdir "${ARTIFACT_DIR}" RUN_NAMES=( "Run1" "Run2" "Run3" "Run4" "Run5" ) @@ -59,7 +59,7 @@ for RUN_NAME in "${RUN_NAMES[@]}"; do # create the run directory RUN_DIR="${ARTIFACT_DIR}${RUN_NAME}/" - echo "Starting Run and logging artifcats at ${RUN_DIR}" + echo "Starting Run and logging artifacts at ${RUN_DIR}" if [ -d "${RUN_DIR}" ] then # Directory already exists, we check if the done.out file exists diff --git a/research/gemini/scaffold/server.py b/research/gemini/scaffold/server.py index d3ddd4f1a..81e24995b 100644 --- a/research/gemini/scaffold/server.py +++ b/research/gemini/scaffold/server.py @@ -2,7 +2,7 @@ import os from functools import partial from logging import INFO -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import flwr as fl import numpy as np @@ -33,8 +33,8 @@ def __init__( self, client_manager: ClientManager, client_model: nn.Module, - strategy: Optional[Strategy] = None, - checkpointer: Optional[BestMetricTorchCheckpointer] = None, + strategy: Strategy | None = None, + checkpointer: BestMetricTorchCheckpointer | None = None, ) -> None: self.client_model = client_model # To help with model rehydration @@ -56,8 +56,8 @@ def _maybe_checkpoint(self, checkpoint_metric: float) -> None: def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # loss_aggregated is the aggregated validation per step loss # aggregated over each client (weighted by num examples) eval_round_results = super().evaluate_round(server_round, timeout) @@ -69,14 +69,14 @@ def evaluate_round( return loss_aggregated, metrics_aggregated, (results, failures) -def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def fit_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients fit function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) return normalize_metrics(total_examples, aggregated_metrics) -def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def evaluate_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients evaluate function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) @@ -106,7 +106,7 @@ def fit_config( def main( - config: Dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float + config: dict[str, Any], server_address: str, checkpoint_stub: str, run_name: str, server_learning_rate: float ) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment diff --git a/research/gemini/servers/full_exchange_server.py b/research/gemini/servers/full_exchange_server.py index 47ab4d2a5..bf4ff2c06 100644 --- a/research/gemini/servers/full_exchange_server.py +++ b/research/gemini/servers/full_exchange_server.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch.nn as nn from flwr.server.client_manager import ClientManager from flwr.server.strategy import Strategy @@ -14,8 +12,8 @@ def __init__( self, client_manager: ClientManager, model: nn.Module, - strategy: Optional[Strategy] = None, - checkpointer: Optional[TorchModuleCheckpointer] = None, + strategy: Strategy | None = None, + checkpointer: TorchModuleCheckpointer | None = None, ) -> None: # To help with model rehydration parameter_exchanger = FullParameterExchanger() diff --git a/research/gemini/servers/personal_server.py b/research/gemini/servers/personal_server.py index fdfc1fbf3..54822ee1c 100644 --- a/research/gemini/servers/personal_server.py +++ b/research/gemini/servers/personal_server.py @@ -1,5 +1,4 @@ from logging import INFO -from typing import Dict, Optional, Tuple from flwr.common.logger import log from flwr.common.typing import Scalar @@ -23,18 +22,18 @@ class PersonalServer(FlServer): def __init__( self, client_manager: ClientManager, - strategy: Optional[Strategy] = None, + strategy: Strategy | None = None, ) -> None: # Personal approaches don't train a "server" model. Rather, each client trains a client specific model with # some globally shared weights. So we don't checkpoint a global model super().__init__(client_manager, strategy, checkpointer=None) - self.best_aggregated_loss: Optional[float] = None + self.best_aggregated_loss: float | None = None def evaluate_round( self, server_round: int, - timeout: Optional[float], - ) -> Optional[Tuple[Optional[float], Dict[str, Scalar], EvaluateResultsAndFailures]]: + timeout: float | None, + ) -> tuple[float | None, dict[str, Scalar], EvaluateResultsAndFailures] | None: # loss_aggregated is the aggregated validation per step loss # aggregated over each client (weighted by num examples) eval_round_results = super().evaluate_round(server_round, timeout) diff --git a/research/gemini/simple_metric_aggregation.py b/research/gemini/simple_metric_aggregation.py index c6386f3f2..d825a660c 100644 --- a/research/gemini/simple_metric_aggregation.py +++ b/research/gemini/simple_metric_aggregation.py @@ -1,12 +1,11 @@ from collections import defaultdict -from typing import DefaultDict, List, Tuple from flwr.common.typing import Metrics -def uniform_metric_aggregation(all_client_metrics: List[Tuple[int, Metrics]]) -> Tuple[DefaultDict[str, int], Metrics]: +def uniform_metric_aggregation(all_client_metrics: list[tuple[int, Metrics]]) -> tuple[defaultdict[str, int], Metrics]: aggregated_metrics: Metrics = {} - total_client_count_by_metric: DefaultDict[str, int] = defaultdict(int) + total_client_count_by_metric: defaultdict[str, int] = defaultdict(int) # Run through all of the metrics for _, client_metrics in all_client_metrics: for metric_name, metric_value in client_metrics.items(): @@ -27,7 +26,7 @@ def uniform_metric_aggregation(all_client_metrics: List[Tuple[int, Metrics]]) -> return total_client_count_by_metric, aggregated_metrics -def metric_aggregation(all_client_metrics: List[Tuple[int, Metrics]]) -> Tuple[int, Metrics]: +def metric_aggregation(all_client_metrics: list[tuple[int, Metrics]]) -> tuple[int, Metrics]: aggregated_metrics: Metrics = {} total_examples = 0 # Run through all of the metrics @@ -59,7 +58,7 @@ def normalize_metrics(total_examples: int, aggregated_metrics: Metrics) -> Metri def uniform_normalize_metrics( - total_client_count_by_metric: DefaultDict[str, int], aggregated_metrics: Metrics + total_client_count_by_metric: defaultdict[str, int], aggregated_metrics: Metrics ) -> Metrics: # Normalize all metric values by the total count of clients that contributed to the metric. normalized_metrics: Metrics = {} @@ -69,21 +68,21 @@ def uniform_normalize_metrics( return normalized_metrics -def fit_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def fit_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients fit function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) return normalize_metrics(total_examples, aggregated_metrics) -def evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def evaluate_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients evaluate function # NOTE: The first value of the tuple is number of examples for FedAvg total_examples, aggregated_metrics = metric_aggregation(all_client_metrics) return normalize_metrics(total_examples, aggregated_metrics) -def uniform_evaluate_metrics_aggregation_fn(all_client_metrics: List[Tuple[int, Metrics]]) -> Metrics: +def uniform_evaluate_metrics_aggregation_fn(all_client_metrics: list[tuple[int, Metrics]]) -> Metrics: # This function is run by the server to aggregate metrics returned by each clients evaluate function # NOTE: The first value of the tuple is number of examples for FedAvg, but it is not used here. total_client_count_by_metric, aggregated_metrics = uniform_metric_aggregation(all_client_metrics) diff --git a/research/picai/central/README.md b/research/picai/central/README.md index c390ef37d..75974c93b 100644 --- a/research/picai/central/README.md +++ b/research/picai/central/README.md @@ -2,7 +2,7 @@ The following instructions outline training and validating a simple U-Net model on the Preprocessed Dataset described in the [PICAI Documentation](/research/picai/README.md) using a centralized setup. Running the centralized example out of the box is as simple as executing the command below. -An example of the usage is below. Note that the script needs to be run from the top level of the FL4Health repository. Moreover, a python environment with the required libraries must already exist. See the main PICAI documentation Cluster [PICAI Documentation](/research/picai/README.md) for instructions on creating and activating the environment required to exectute the following code. The commands below should be run from the top level directory: +An example of the usage is below. Note that the script needs to be run from the top level of the FL4Health repository. Moreover, a python environment with the required libraries must already exist. See the main PICAI documentation Cluster [PICAI Documentation](/research/picai/README.md) for instructions on creating and activating the environment required to execute the following code. The commands below should be run from the top level directory: ```bash python research/picai/central/train.py --overviews_dir /path/to/overviews_dir --fold --run_name diff --git a/research/picai/data/README.md b/research/picai/data/README.md index aa4329598..e3a2b79ac 100644 --- a/research/picai/data/README.md +++ b/research/picai/data/README.md @@ -72,7 +72,7 @@ The Preprocessed Dataset can be generated by running the `prepare_annotations.py `prepare_annotations.py` is a simple script that copies the human and ai-derived annotations into a specified folder. An example invocation is as follows: ``` -python research/picai/preprare_annoations.py --human_annotations_dir /path/to/human/annotations --ai_annotations_dir /path/to/ai/annotations --annotations_write_dir /path/to/write/dir +python research/picai/prepare_annotations.py --human_annotations_dir /path/to/human/annotations --ai_annotations_dir /path/to/ai/annotations --annotations_write_dir /path/to/write/dir ``` `prepare_data.py` is the main preprocessing script that takes in a number of arguments related to the location of the raw dataset and details about the preprocessing and produces a preprocessed dataset with an associated dataset overview file. Here is an example invocation: diff --git a/research/picai/data/data_utils.py b/research/picai/data/data_utils.py index 99d7bd15b..6cb9dbc49 100644 --- a/research/picai/data/data_utils.py +++ b/research/picai/data/data_utils.py @@ -1,9 +1,9 @@ import json import os import random +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -91,7 +91,7 @@ def __call__(self, data: torch.Tensor) -> torch.Tensor: def get_img_transform() -> Compose: """ Basic transformation pipeline for images that includes ensuring type and shape of data, - performing z score normalization, random roation, intensity scaling and adjusting contrast. + performing z score normalization, random rotation, intensity scaling and adjusting contrast. Returns: Compose: Image transformation pipeline. @@ -119,18 +119,18 @@ def get_seg_transform() -> Compose: return Compose(transforms) -def z_score_norm(image: torch.Tensor, quantile: Optional[float] = None) -> torch.Tensor: +def z_score_norm(image: torch.Tensor, quantile: float | None = None) -> torch.Tensor: """ Function that performs instance wise Z-score normalization (mean=0; stdev=1), where intensities below or above the given percentile are discarded. Args: image (torch.Tensor): N-dimensional image to be normalized and optionally clipped. - quantile (Optional[float]): Quantile used to set threshold to clip activations. + quantile (float | None): Quantile used to set threshold to clip activations. If None, no clipping occurs. If a quantile is specified, must be 0 =< 0.5 Returns: - torch.Tensor: Z-Score Normalized vesrion of input that is clipped if a quantile is specified. + torch.Tensor: Z-Score Normalized version of input that is clipped if a quantile is specified. """ image = image.float() @@ -157,7 +157,7 @@ def get_img_and_seg_paths( include_t2w: bool = True, include_adc: bool = True, include_hbv: bool = True, -) -> Tuple[List[List[str]], List[str], torch.Tensor]: +) -> tuple[list[list[str]], list[str], torch.Tensor]: """ Gets the image paths, segmentation paths and label proportions for the specified fold. Exclude t2w, adc or hbv scan if specified. @@ -173,9 +173,9 @@ def get_img_and_seg_paths( include_hbv (bool): Whether or not to include hbv Sequence as part of the input data. Returns: - Tuple[Sequence[Sequence[str]], Sequence[str], torch.Tensor]: The first element of the returned tuple + tuple[Sequence[Sequence[str]], Sequence[str], torch.Tensor]: The first element of the returned tuple is a list of list of strings where the outer list represents a list of file paths corresponding - to the diffferent MRI Sequences for a given patient exam. The second element is a list of strings + to the different MRI Sequences for a given patient exam. The second element is a list of strings representing the associated segmentation labels. The final element of the returned tuple is a torch tensor that give the class proportions. """ @@ -217,25 +217,25 @@ def get_img_and_seg_paths( def split_img_and_seg_paths( - img_paths: List[List[str]], seg_paths: List[str], splits: int, seed: int = 0 -) -> Tuple[List[List[List[str]]], List[List[str]]]: + img_paths: list[list[str]], seg_paths: list[str], splits: int, seed: int = 0 +) -> tuple[list[list[list[str]]], list[list[str]]]: """ Split image and segmentation paths into a number of mutually exclusive sets. img_paths (Sequence[Sequence[str]]: List of list of strings where the outer list represents - a list of file paths corresponding to the diffferent MRI Sequences for a given patient exam. + a list of file paths corresponding to the different MRI Sequences for a given patient exam. seg_paths (Sequence[str]): List of strings representing the segmentation labels associated with images. splits (int): The number of splits to partition the dataset. Returns: - Tuple[Sequence[Sequence[str]], Sequence[str]]: The image and segmentation paths for + tuple[Sequence[Sequence[str]], Sequence[str]]: The image and segmentation paths for images and segmentation labels. """ assert len(img_paths) == len(seg_paths) random.seed(seed) client_assignments = [random.choice([i for i in range(splits)]) for _ in range(len(img_paths))] - client_img_paths: List[List[List[str]]] = [[] for _ in range(splits)] - client_seg_paths: List[List[str]] = [[] for _ in range(splits)] + client_img_paths: list[list[list[str]]] = [[] for _ in range(splits)] + client_seg_paths: list[list[str]] = [[] for _ in range(splits)] for i, assignment in enumerate(client_assignments): client_img_paths[assignment].append(img_paths[i]) client_seg_paths[assignment].append(seg_paths[i]) @@ -244,7 +244,7 @@ def split_img_and_seg_paths( def get_dataloader( - img_paths: Union[Sequence[Sequence[str]], Sequence[str]], + img_paths: Sequence[Sequence[str]] | Sequence[str], seg_paths: Sequence[str], batch_size: int, img_transform: Compose, @@ -256,10 +256,10 @@ def get_dataloader( Initializes and returns MONAI Dataloader. Args: img_paths (Sequence[Sequence[str]]: List of list of strings where the outer list represents a - list of file paths corresponding to the diffferent MRI Sequences for a given patient exam. + list of file paths corresponding to the different MRI Sequences for a given patient exam. seg_paths (Sequence[str]): List of strings representing the segmentation labels associated with images. batch_size (str): The number of samples per batch yielded by the DataLoader. - img_transorm (Compose): The series of transformations applied to input images during dataloading. + img_transform (Compose): The series of transformations applied to input images during dataloading. seg_transform (Compose): The series of transformations applied to the segmentation labels during dataloading. shuffle (bool): Whether or not to shuffle the dataset. num_workers (int): The number of workers used by the DataLoader. @@ -268,7 +268,7 @@ def get_dataloader( DataLoader: MONAI dataloader. """ # Ignoring type of image_files because Sequence[Sequence[str]] is valid input - # list of files interpreted as multi-parametric sequnce. Supported by image loader: + # list of files interpreted as multi-parametric sequence. Supported by image loader: # https://docs.monai.io/en/stable/transforms.html#loadimage used by ImageDataset: # https://docs.monai.io/en/latest/data.html#monai.data.ImageDataset ds = ImageDataset( diff --git a/research/picai/data/prepare_annotations.py b/research/picai/data/prepare_annotations.py index a72989fc2..1f76d5ec2 100644 --- a/research/picai/data/prepare_annotations.py +++ b/research/picai/data/prepare_annotations.py @@ -6,7 +6,7 @@ def prepare_annotations(human_annotations_dir: str, ai_annotations_dir: str, annotations_write_dir: str) -> None: """ - Copy seperate annotation sources (ie human and ai-derived annotations) to a central location. + Copy separate annotation sources (ie human and ai-derived annotations) to a central location. Args: human_annotations_dir (str): The path to the folder containing human annotations. diff --git a/research/picai/data/prepare_data.py b/research/picai/data/prepare_data.py index b5b6634bc..46a214cfd 100644 --- a/research/picai/data/prepare_data.py +++ b/research/picai/data/prepare_data.py @@ -1,8 +1,8 @@ import argparse import json import os +from collections.abc import Sequence from pathlib import Path -from typing import Dict, List, Optional, Sequence, Tuple, Union import numpy as np import SimpleITK as sitk @@ -28,13 +28,13 @@ ] -def get_labels(paths_for_each_sample: Sequence[Tuple[Sequence[Path], Path]]) -> Sequence[float]: +def get_labels(paths_for_each_sample: Sequence[tuple[Sequence[Path], Path]]) -> Sequence[float]: """ Get the label of each sample. The label is negative if no foreground objects exist, as per annotation, else positive. Args: - paths_for_each_sample (Sequence[Tuple[Sequence[Path], Path]]): A sequence in which each member + paths_for_each_sample (Sequence[tuple[Sequence[Path], Path]]): A sequence in which each member is tuple where the first entry is a list of scan paths and the second in the annotation path. Returns: @@ -52,33 +52,33 @@ def get_labels(paths_for_each_sample: Sequence[Tuple[Sequence[Path], Path]]) -> def filter_split_on_subject_id( - scan_annotation_label_list: Sequence[Tuple[Sequence[str], str, float]], - split: Dict[str, Sequence[str]], + scan_annotation_label_list: Sequence[tuple[Sequence[str], str, float]], + split: dict[str, Sequence[str]], train: bool, -) -> Dict[str, Union[Sequence[float], Sequence[str]]]: +) -> dict[str, Sequence[float] | Sequence[str]]: """ Filters the scan_annotation_label_list to only include samples with a subject_id apart of split. Returns Dict with image paths, label paths and case labels Args: - scan_annotation_label_list (Sequence[Tuple[List[str], str, float]]): A sequence where each member + scan_annotation_label_list (Sequence[tuple[list[str], str, float]]): A sequence where each member is a tuple where the first entry is a list of scan paths, the second entry is the annotation path and the third entry is the label of the sample. - split (Dict[str, Sequence[str]]): A Dict of sequences of subject_ids included in the current split. + split (dict[str, Sequence[str]]): A Dict of sequences of subject_ids included in the current split. Dict contains two keys: train and val. train (bool): Whether to use the train or the test split. Returns: - Dict[str, Union[Sequence[float], Sequence[str]]]: A Dict containing image_paths, label_paths + dict[str, Sequence[float] | Sequence[str]]: A Dict containing image_paths, label_paths and case_label for each sample part of the split. """ train_or_val_string = "train" if train else "val" - filtered_scan_annotation_label_list: Sequence[Tuple[Sequence[str], str, float]] = [ + filtered_scan_annotation_label_list: Sequence[tuple[Sequence[str], str, float]] = [ (scan_paths, annotation_path, label) for (scan_paths, annotation_path, label) in scan_annotation_label_list if any([subject_id in annotation_path for subject_id in split[train_or_val_string]]) ] - labeled_data: Dict[str, Union[Sequence[float], Sequence[str]]] = {} + labeled_data: dict[str, Sequence[float] | Sequence[str]] = {} labeled_data["image_paths"], labeled_data["label_paths"], labeled_data["case_label"] = zip( *filtered_scan_annotation_label_list ) @@ -86,9 +86,9 @@ def filter_split_on_subject_id( def generate_dataset_json( - paths_for_each_sample: Sequence[Tuple[Sequence[Path], Path]], + paths_for_each_sample: Sequence[tuple[Sequence[Path], Path]], write_dir: Path, - splits_path: Optional[Path] = None, + splits_path: Path | None = None, ) -> None: """ Generates JSON file(s) that include the image_paths, label_paths and case_labels. @@ -96,10 +96,10 @@ def generate_dataset_json( If no splits_path is supplied, a single JSON file will be created with all of the samples. Args: - paths_for_each_sample (Sequence[Tuple[Sequence[Path], Path]]): A sequence in which each member + paths_for_each_sample (Sequence[tuple[Sequence[Path], Path]]): A sequence in which each member is tuple where the first entry is a list of scan paths and the second in the annotation path. write_dir (Path): The directory to write the dataset file(s). - splits_path (Optional[Path]): The path to the desired spits. JSON file with key for each split. + splits_path (Path | None): The path to the desired spits. JSON file with key for each split. Each key contains nested keys train and val. Inside the nested keys are lists of subject_id strings to be included in the split. """ @@ -114,7 +114,7 @@ def generate_dataset_json( if splits_path is None: # If splits_path is None, create a singe dataset overview - labeled_data: Dict[str, Union[Sequence[str], Sequence[float]]] = {} + labeled_data: dict[str, Sequence[str] | Sequence[float]] = {} labeled_data["image_paths"], labeled_data["label_paths"], labeled_data["case_label"] = zip( *scan_annotation_label_list ) @@ -143,19 +143,19 @@ def generate_dataset_json( json.dump(val_labeled_data, f) -def preprare_data( +def prepare_data( scans_read_dir: Path, annotation_read_dir: Path, scans_write_dir: Path, annotation_write_dir: Path, overview_write_dir: Path, - size: Optional[Tuple[int, int, int]] = None, - physical_size: Optional[Tuple[float, float, float]] = None, - spacing: Optional[Tuple[float, float, float]] = None, + size: tuple[int, int, int] | None = None, + physical_size: tuple[float, float, float] | None = None, + spacing: tuple[float, float, float] | None = None, scan_extension: str = "mha", annotation_extension: str = ".nii.gz", num_threads: int = 4, - splits_path: Optional[Path] = None, + splits_path: Path | None = None, ) -> None: """ Runs preprocessing on data with specified settings. @@ -170,17 +170,17 @@ def preprare_data( All annotations are written into same directory. overviews_write_dir (Path): The path where the dataset json files are located. For each split 1-5, there is a train and validation file with scan paths, label paths and case labels. - size (Optional[Tuple[int, int, int]]): Desired dimensions of preprocessed scans in voxels. + size (tuple[int, int, int] | None): Desired dimensions of preprocessed scans in voxels. Triplet of the form: Depth x Height x Width. - physical_size (Optional[Tuple[float, float, float]]): Desired dimensions of preprocessed scans in mm. + physical_size (tuple[float, float, float] | None): Desired dimensions of preprocessed scans in mm. Simply the product of the number of voxels by the spacing along a particular dimension: Triplet of the form: Depth x Height x Width. - spacing (Optional[Tuple[float, float, float]]): Desired spacing of preprocessed scans in in mm/voxel. + spacing (tuple[float, float, float] | None): Desired spacing of preprocessed scans in in mm/voxel. Triplet of the form: Depth x Height x Width. scan_extension (str): The expected extension of scan file paths. annotation_extension (str): The expected extension of annotation file paths. num_threads (str): The number of threads to use during preprocessing. - splits_path (Optional[Path]): The path to the file containing the splits. + splits_path (Path | None): The path to the file containing the splits. """ settings = PreprocessingSettings( scans_write_dir, @@ -191,7 +191,7 @@ def preprare_data( ) valid_annotation_filenames = [f for f in os.listdir(annotation_read_dir) if f.endswith(annotation_extension)] - samples: List[Case] = [] + samples: list[Case] = [] for annotation_filename in valid_annotation_filenames: # Annotation filename is subject id (ie patient_id study_id) # We use it to get the corresponding scan paths @@ -292,7 +292,7 @@ def main() -> None: raise ValueError("Argument spacing must have length 3") spacing = (float(args.spacing[0]), float(args.spacing[1]), float(args.spacing[2])) if args.spacing else None - preprare_data( + prepare_data( args.scans_read_dir, args.annotation_read_dir, args.scans_write_dir, diff --git a/research/picai/data/preprocessing.py b/research/picai/data/preprocessing.py index b7fcad410..bea1a2552 100644 --- a/research/picai/data/preprocessing.py +++ b/research/picai/data/preprocessing.py @@ -2,11 +2,11 @@ import os from abc import ABC, abstractmethod +from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from functools import partial, reduce from pathlib import Path -from typing import List, Optional, Sequence, Tuple import numpy as np import SimpleITK as sitk @@ -20,9 +20,9 @@ def __init__( self, scans_write_dir: Path, annotation_write_dir: Path, - size: Optional[Tuple[int, int, int]], - physical_size: Optional[Tuple[float, float, float]], - spacing: Optional[Tuple[float, float, float]], + size: tuple[int, int, int] | None, + physical_size: tuple[float, float, float] | None, + spacing: tuple[float, float, float] | None, ) -> None: """ Dataclass encapsulating parameters of preprocessing. @@ -30,13 +30,13 @@ def __init__( Args: scans_write_dir (Path): The directory to write the preprocessed scans. annotation_write_dir (Path): The directory to write the preprocessed annotation. - size (Optional[Tuple[int, int, int]]): Tuple of 3 int representing size of scan in voxels. + size (tuple[int, int, int] | None): Tuple of 3 int representing size of scan in voxels. In the format of Depth x Height x Width. If None, preprocessed scans and annotations retain their original size. - physical_size (Optional[Tuple[float, float, float]]): Tuple of 3 float representing actual scan size in mm. + physical_size (tuple[float, float, float] | None): Tuple of 3 float representing actual scan size in mm. In the format of Depth x Height x Width. If None and size and spacing are not None, physical_size will be inferred. - spacing (Optional[Tuple[float, float, float]]): Tuple of 3 float representing spacing between voxels + spacing (tuple[float, float, float] | None): Tuple of 3 float representing spacing between voxels of scan in mm/voxel. In the format of Depth x Height x Width. If None, preprocessed scans and annotations retain their original spacing. """ @@ -79,7 +79,7 @@ def __init__( self.annotations_path = annotations_path self.settings = settings - self.scans: List[sitk.Image] + self.scans: list[sitk.Image] self.annotation: sitk.Image @abstractmethod @@ -94,13 +94,13 @@ def read(self) -> None: raise NotImplementedError @abstractmethod - def write(self) -> Tuple[Sequence[Path], Path]: + def write(self) -> tuple[Sequence[Path], Path]: """ Abstract method to be implemented by children that writes the preprocessed scans and annotation to their destination and returns the file paths. Returns: - Tuple[Sequence[Path], Path]: A tuple in which the first entry is a sequence of file paths + tuple[Sequence[Path], Path]: A tuple in which the first entry is a sequence of file paths for the scans and the second entry is the file path to the corresponding annotation. Raises: @@ -141,7 +141,7 @@ def read(self) -> None: self.scans = [sitk.ReadImage(path) for path in self.scan_paths] self.annotation = sitk.ReadImage(self.annotations_path) - def write(self) -> Tuple[Sequence[Path], Path]: + def write(self) -> tuple[Sequence[Path], Path]: """ Writes preprocessed scans and annotations from PICAI dataset to disk and returns the scan file paths and annotation file path in a tuple. @@ -154,7 +154,7 @@ def write(self) -> Tuple[Sequence[Path], Path]: below. Returns: - Tuple[Sequence[Path], Path]: A tuple in which the first entry is a sequence of file paths + tuple[Sequence[Path], Path]: A tuple in which the first entry is a sequence of file paths for the scans and the second entry is the file path to the corresponding annotation. """ modality_suffix_map = {"t2w": "0000", "adc": "0001", "hbv": "0002"} @@ -335,7 +335,7 @@ def __call__(self, case: Case) -> Case: return case -def apply_transform(case: Case, transforms: Sequence[PreprocessingTransform]) -> Tuple[Sequence[Path], Path]: +def apply_transform(case: Case, transforms: Sequence[PreprocessingTransform]) -> tuple[Sequence[Path], Path]: """ Reads in scans and annotation, applies sequence of transformations, and writes resulting case to disk. Returns tuple with scan paths and corresponding annotation path. @@ -345,7 +345,7 @@ def apply_transform(case: Case, transforms: Sequence[PreprocessingTransform]) -> transforms (Sequence[PreprocessingTransform]): The sequence of transformation to be applied. Returns: - Tuple[Sequence[Path], Path]: A tuple in which the first entry is a sequence of file paths + tuple[Sequence[Path], Path]: A tuple in which the first entry is a sequence of file paths for the scans and the second entry is the file path to the corresponding annotation. Raises: @@ -361,18 +361,18 @@ def apply_transform(case: Case, transforms: Sequence[PreprocessingTransform]) -> def preprocess( - cases: List[Case], transforms: Sequence[PreprocessingTransform], num_threads: int = 4 -) -> Sequence[Tuple[Sequence[Path], Path]]: + cases: list[Case], transforms: Sequence[PreprocessingTransform], num_threads: int = 4 +) -> Sequence[tuple[Sequence[Path], Path]]: """ Preprocesses a list of cases according to the specified transformations. Args: - cases (List[Case]): A list of cases to be preprocessed. + cases (list[Case]): A list of cases to be preprocessed. transforms (Sequence[PreprocessingTransform]): The sequence of transformation to be applied. nums_threads (int): The number of threads to use for preprocessing. Returns: - Sequence[Tuple[Sequence[Path], Path]]: A sequence of tuples in which the first entry is a sequence of + Sequence[tuple[Sequence[Path], Path]]: A sequence of tuples in which the first entry is a sequence of file paths for the scans and the second entry is the file path to the corresponding annotation. Raises: diff --git a/research/picai/data/preprocessing_transforms.py b/research/picai/data/preprocessing_transforms.py index 9d0eda883..de75ca24f 100644 --- a/research/picai/data/preprocessing_transforms.py +++ b/research/picai/data/preprocessing_transforms.py @@ -1,15 +1,13 @@ -from typing import Optional, Tuple, Union - import numpy as np import SimpleITK as sitk def resample_img( image: sitk.Image, - spacing: Tuple[float, float, float], - size: Optional[Tuple[int, int, int]] = None, + spacing: tuple[float, float, float], + size: tuple[int, int, int] | None = None, is_label: bool = False, - pad_value: Optional[Union[float, int]] = 0.0, + pad_value: float | int | None = 0.0, ) -> sitk.Image: """ Resample images to target resolution spacing. @@ -18,12 +16,12 @@ def resample_img( Args: image (sitk.Image): Image to be resized. - spacing (Tuple[float, float, float]): Target spacing between voxels in mm. + spacing (tuple[float, float, float]): Target spacing between voxels in mm. Expected to be in Depth x Height x Width format. - size (Tuple[int, int, int]): Target size in voxels. + size (tuple[int, int, int]): Target size in voxels. Expected to be in Depth x Height x Width format. is_label (bool): Whether or not this is an annotation. - pad_value (Optional[Union[float, int]]): Amount of padding to use. + pad_value (float | int | None): Amount of padding to use. Returns: sitk.Image: The resampled image. @@ -73,9 +71,9 @@ def resample_img( def input_verification_crop_or_pad( image: sitk.Image, - size: Tuple[int, int, int] = (20, 256, 256), - physical_size: Optional[Tuple[float, float, float]] = None, -) -> Tuple[Tuple[int, int, int], Tuple[int, int, int]]: + size: tuple[int, int, int] = (20, 256, 256), + physical_size: tuple[float, float, float] | None = None, +) -> tuple[tuple[int, int, int], tuple[int, int, int]]: """ Calculate target size for cropping and/or padding input image. @@ -83,13 +81,13 @@ def input_verification_crop_or_pad( Args: image (sitk.Image): Image to be resized. - size (Tuple[int, int, int]): Target size in voxels. + size (tuple[int, int, int]): Target size in voxels. Expected to be in Depth x Height x Width format. - physical_size (Tuple[float, float, float]): Target size in mm. (Number of Voxels x Spacing) + physical_size (tuple[float, float, float]): Target size in mm. (Number of Voxels x Spacing) Expected to be in Depth x Height x Width format. Returns: - Tuple[Tuple[int, int, int], Tuple[int, int, int]]: + tuple[tuple[int, int, int], tuple[int, int, int]]: Shape of original image (in convention of SimpleITK (x, y, z) or numpy (z, y, x)) and Size of target image (in convention of SimpleITK (x, y, z) or numpy (z, y, x)) """ @@ -135,8 +133,8 @@ def input_verification_crop_or_pad( def crop_or_pad( image: sitk.Image, - size: Tuple[int, int, int], - physical_size: Optional[Tuple[float, float, float]] = None, + size: tuple[int, int, int], + physical_size: tuple[float, float, float] | None = None, crop_only: bool = False, ) -> sitk.Image: """ @@ -146,9 +144,9 @@ def crop_or_pad( Args: image (sitk.Image): Image to be resized. - size (Tuple[int, int, int]): Target size in voxels. + size (tuple[int, int, int]): Target size in voxels. Expected to be in Depth x Height x Width format. - physical_size (Tuple[float, float, float]): Target size in mm. (Number of Voxels x Spacing) + physical_size (tuple[float, float, float]): Target size in mm. (Number of Voxels x Spacing) Expected to be in Depth x Height x Width format. Returns: @@ -158,7 +156,7 @@ def crop_or_pad( shape, size = input_verification_crop_or_pad(image, size, physical_size) # Since the subarrays are being set in the below for loop - # We have to ensure that they are seperate lists + # We have to ensure that they are separate lists # and not the same reference (ie [[0, 0]] * rank) padding = [[0, 0] for _ in range(len(size))] slicer = [slice(None) for _ in range(len(size))] diff --git a/research/picai/fedavg/README.md b/research/picai/fedavg/README.md index 32bc4a210..57eee86dc 100644 --- a/research/picai/fedavg/README.md +++ b/research/picai/fedavg/README.md @@ -1,6 +1,6 @@ # Running FedAvg Example -The following instructions outline training and validating a simple U-Net model on the Preprocessed PICAI Dataset described in the [PICAI Documentation](/research/picai/preprocessing/README.md) in a federated manner across two clients using FedAvg. The dataset is partitioned randomly in a uniform manner based on the number of clients. The provided script spins up server and clients on the same machine which is demonstrated below. See the main [PICAI Documentation](/research/picai/README.md) for instructions on creating and activating the environment required to exectute the following code. The following commands can must executed from the root directory of the reporsitory. First, spin up the server as follows: +The following instructions outline training and validating a simple U-Net model on the Preprocessed PICAI Dataset described in the [PICAI Documentation](/research/picai/preprocessing/README.md) in a federated manner across two clients using FedAvg. The dataset is partitioned randomly in a uniform manner based on the number of clients. The provided script spins up server and clients on the same machine which is demonstrated below. See the main [PICAI Documentation](/research/picai/README.md) for instructions on creating and activating the environment required to execute the following code. The following commands can must executed from the root directory of the repository. First, spin up the server as follows: ```bash python -m research.picai.fedavg.server --config-path path/to/config.yaml --artifact_dir path/to/artifact_dir --n_client diff --git a/research/picai/fedavg/client.py b/research/picai/fedavg/client.py index bcd227b11..0515a690a 100644 --- a/research/picai/fedavg/client.py +++ b/research/picai/fedavg/client.py @@ -1,7 +1,7 @@ import argparse +from collections.abc import Sequence from logging import INFO from pathlib import Path -from typing import Optional, Sequence, Tuple import flwr as fl import torch @@ -39,12 +39,12 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, overviews_dir: Path = Path("./"), - data_partition: Optional[int] = None, + data_partition: int | None = None, ) -> None: super().__init__( data_path=data_path, @@ -61,7 +61,7 @@ def __init__( self.overviews_dir = overviews_dir self.class_proportions: torch.Tensor - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_img_paths, train_seg_paths, class_proportions = get_img_and_seg_paths( self.overviews_dir, int(config["fold_id"]), True ) diff --git a/research/picai/fedavg/server.py b/research/picai/fedavg/server.py index e54ceecb5..6c86836c1 100644 --- a/research/picai/fedavg/server.py +++ b/research/picai/fedavg/server.py @@ -2,7 +2,7 @@ from functools import partial from logging import INFO from pathlib import Path -from typing import Any, Dict +from typing import Any import flwr as fl from flwr.common.logger import log @@ -38,7 +38,7 @@ def fit_config( } -def main(config: Dict[str, Any], server_address: str, n_clients: int, artifact_dir: str) -> None: +def main(config: dict[str, Any], server_address: str, n_clients: int, artifact_dir: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/picai/fl_nnunet/README.md b/research/picai/fl_nnunet/README.md index 1bcdb3133..fe09138b6 100644 --- a/research/picai/fl_nnunet/README.md +++ b/research/picai/fl_nnunet/README.md @@ -27,7 +27,7 @@ Then start a single or multiple clients in different sessions using the followin python -m research.picai.fl_nnunet.start_client --dataset-id 012 ``` -The federated training will commmence once n_clients have been instantiated. +The federated training will commence once n_clients have been instantiated. ## Running on Vector Cluster A slurm script has been made available to launch the experiments on the Vector Cluster. This script will automatically handle relaunching the job if it times out. The script `run_fl_single_node.slrm` first spins up a server and subsequently the clients to perform an FL experiment on the same machine. The commands below should be run from the top level directory: diff --git a/research/picai/fl_nnunet/nnunet_utils.py b/research/picai/fl_nnunet/nnunet_utils.py index c731b098b..efadc9b4e 100644 --- a/research/picai/fl_nnunet/nnunet_utils.py +++ b/research/picai/fl_nnunet/nnunet_utils.py @@ -4,7 +4,7 @@ from collections.abc import Callable from enum import Enum from logging import WARNING, Logger -from typing import Any, Dict, List, Tuple, Union +from typing import Any import numpy as np import torch @@ -91,8 +91,8 @@ def new_fn(*args: Any, **kwargs: Any) -> Any: # The two convert deepsupervision methods are necessary because fl4health requires # predictions, targets and inputs to be single torch.Tensors or Dicts of torch.Tensors def convert_deepsupervision_list_to_dict( - tensor_list: Union[List[torch.Tensor], Tuple[torch.Tensor]], num_spatial_dims: int -) -> Dict[str, torch.Tensor]: + tensor_list: list[torch.Tensor] | tuple[torch.Tensor], num_spatial_dims: int +) -> dict[str, torch.Tensor]: """ Converts a list of torch.Tensors to a dictionary. Names the keys for each tensor based on the spatial resolution of the tensor and its @@ -101,12 +101,12 @@ def convert_deepsupervision_list_to_dict( spatial dimensions of the tensors are last. Args: - tensor_list (List[torch.Tensor]): A list of tensors, usually either + tensor_list (list[torch.Tensor]): A list of tensors, usually either nnunet model outputs or targets, to be converted into a dictionary num_spatial_dims (int): The number of spatial dimensions. Assumes the spatial dimensions are last Returns: - Dict[str, torch.Tensor]: A dictionary containing the tensors as + dict[str, torch.Tensor]: A dictionary containing the tensors as values where the keys are 'i-XxYxZ' where i was the tensor's index in the list and X,Y,Z are the spatial dimensions of the tensor """ @@ -121,19 +121,19 @@ def convert_deepsupervision_list_to_dict( return tensors -def convert_deepsupervision_dict_to_list(tensor_dict: Dict[str, torch.Tensor]) -> List[torch.Tensor]: +def convert_deepsupervision_dict_to_list(tensor_dict: dict[str, torch.Tensor]) -> list[torch.Tensor]: """ Converts a dictionary of tensors back into a list so that it can be used by nnunet deep supervision loss functions Args: - tensor_dict (Dict[str, torch.Tensor]): Dictionary containing + tensor_dict (dict[str, torch.Tensor]): Dictionary containing torch.Tensors. The key values must start with 'X-' where X is an integer representing the index at which the tensor should be placed in the output list Returns: - List[torch.Tensor]: A list of torch.Tensors + list[torch.Tensor]: A list of torch.Tensors """ sorted_list = sorted(tensor_dict.items(), key=lambda x: int(x[0].split("-")[0])) return [tensor for key, tensor in sorted_list] @@ -142,26 +142,22 @@ def convert_deepsupervision_dict_to_list(tensor_dict: Dict[str, torch.Tensor]) - class nnUNetDataLoaderWrapper(DataLoader): def __init__( self, - nnunet_augmenter: Union[SingleThreadedAugmenter, NonDetMultiThreadedAugmenter, MultiThreadedAugmenter], - nnunet_config: Union[NnunetConfig, str], + nnunet_augmenter: SingleThreadedAugmenter | NonDetMultiThreadedAugmenter | MultiThreadedAugmenter, + nnunet_config: NnunetConfig | str, infinite: bool = False, ) -> None: """ - Wraps nnunet dataloader classes using the pytorch dataloader to make - them pytorch compatible. Also handles some unique stuff specific to - nnunet such as deep supervision and infinite dataloaders. The nnunet - dataloaders should only be used for training and validation, not final testing. + Wraps nnunet dataloader classes using the pytorch dataloader to make them pytorch compatible. Also handles + some unique stuff specific to nnunet such as deep supervision and infinite dataloaders. The nnunet dataloaders + should only be used for training and validation, not final testing. Args: - nnunet_dataloader (Union[SingleThreadedAugmenter, - NonDetMultiThreadedAugmenter]): The dataloader used by nnunet - nnunet_config (NnUNetConfig): The nnunet config. Enum type helps - ensure that nnunet config is valid - infinite (bool, optional): Whether or not to treat the dataset - as infinite. The dataloaders sample data with replacement - either way. The only difference is that if set to False, a - StopIteration is generated after num_samples/batch_size steps. - Defaults to False. + nnunet_dataloader (SingleThreadedAugmenter | NonDetMultiThreadedAugmenter | MultiThreadedAugmenter): The + dataloader used by nnunet + nnunet_config (NnUNetConfig): The nnunet config. Enum type helps ensure that nnunet config is valid + infinite (bool, optional): Whether or not to treat the dataset as infinite. The dataloaders sample data + with replacement either way. The only difference is that if set to False, a StopIteration is + generated after num_samples/batch_size steps. Defaults to False. """ # The augmenter is a wrapper on the nnunet dataloader self.nnunet_augmenter = nnunet_augmenter @@ -182,7 +178,7 @@ def __init__( self.current_step = 0 self.infinite = infinite - def __next__(self) -> Tuple[torch.Tensor, Union[torch.Tensor, Dict[str, torch.Tensor]]]: + def __next__(self) -> tuple[torch.Tensor, torch.Tensor | dict[str, torch.Tensor]]: if not self.infinite and self.current_step == self.__len__(): self.reset() raise StopIteration # Raise stop iteration after epoch has completed @@ -192,7 +188,7 @@ def __next__(self) -> Tuple[torch.Tensor, Union[torch.Tensor, Dict[str, torch.Te # Note: When deep supervision is on, target is a list of segmentations at various scales # nnUNet has a wrapper for loss functions to enable deep supervision inputs: torch.Tensor = batch["data"] - targets: Union[torch.Tensor, List[torch.Tensor]] = batch["target"] + targets: torch.Tensor | list[torch.Tensor] = batch["target"] if isinstance(targets, list): target_dict = convert_deepsupervision_list_to_dict(targets, self.num_spatial_dims) return inputs, target_dict @@ -254,7 +250,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: class StreamToLogger(io.StringIO): - def __init__(self, logger: Logger, level: Union[LogLevel, int]) -> None: + def __init__(self, logger: Logger, level: LogLevel | int) -> None: """ File-like stream object that redirects writes to a logger. Useful for redirecting stdout to a logger. diff --git a/research/picai/fl_nnunet/start_client.py b/research/picai/fl_nnunet/start_client.py index 8aa656b33..0482eb54f 100644 --- a/research/picai/fl_nnunet/start_client.py +++ b/research/picai/fl_nnunet/start_client.py @@ -4,7 +4,6 @@ from functools import partial from logging import INFO from pathlib import Path -from typing import Optional, Union from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer from fl4health.checkpointing.client_module import ClientCheckpointAndStateModule @@ -28,15 +27,15 @@ def main( dataset_id: int, - data_identifier: Optional[str], - plans_identifier: Optional[str], + data_identifier: str | None, + plans_identifier: str | None, always_preprocess: bool, server_address: str, - fold: Union[str, int], + fold: str | int, verbose: bool, compile: bool, - intermediate_client_state_dir: Optional[str] = None, - client_name: Optional[str] = None, + intermediate_client_state_dir: str | None = None, + client_name: str | None = None, ) -> None: # Log device and server address @@ -181,7 +180,7 @@ def main( update_console_handler(level=args.logLevel) # Convert fold to an integer if it is not 'all' - fold: Union[int, str] = "all" if args.fold == "all" else int(args.fold) + fold: int | str = "all" if args.fold == "all" else int(args.fold) main( dataset_id=args.dataset_id, diff --git a/research/picai/fl_nnunet/start_server.py b/research/picai/fl_nnunet/start_server.py index 730d56f7c..fb0437f83 100644 --- a/research/picai/fl_nnunet/start_server.py +++ b/research/picai/fl_nnunet/start_server.py @@ -4,7 +4,6 @@ import warnings from functools import partial from pathlib import Path -from typing import Optional from fl4health.checkpointing.checkpointer import PerRoundStateCheckpointer from fl4health.checkpointing.server_module import NnUnetServerCheckpointAndStateModule @@ -35,9 +34,9 @@ def get_config( n_server_rounds: int, batch_size: int, n_clients: int, - nnunet_plans: Optional[str] = None, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + nnunet_plans: str | None = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: # Create config config: Config = { @@ -60,8 +59,8 @@ def get_config( def main( config: dict, server_address: str, - intermediate_server_state_dir: Optional[str] = None, - server_name: Optional[str] = None, + intermediate_server_state_dir: str | None = None, + server_name: str | None = None, ) -> None: # Partial function with everything set except current server round fit_config_fn = partial( diff --git a/research/picai/losses.py b/research/picai/losses.py index e7d73716c..9c0787966 100644 --- a/research/picai/losses.py +++ b/research/picai/losses.py @@ -10,7 +10,7 @@ def __init__(self, alpha: float = 1.0, gamma: float = 1.0, reduction: str = "sum where the scaling factor decays to zero as the confidence in the correct class increases. Args: - alpha (float): The weight assocaited with the the positive class. Usually set inversely proportional + alpha (float): The weight associated with the the positive class. Usually set inversely proportional to the amount of samples in a given class. gamma (float): The exponent to raise the residual between the predicted probability and and ground truth. Higher values of gamma lead to emphasizing the contribution of harder examples in the loss. diff --git a/research/picai/model_utils.py b/research/picai/model_utils.py index 37996a779..dddd4da66 100644 --- a/research/picai/model_utils.py +++ b/research/picai/model_utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Tuple +from collections.abc import Sequence import torch import torch.nn as nn @@ -6,18 +6,18 @@ def get_model( - device: Optional[torch.device] = None, + device: torch.device | None = None, model_type: str = "unet", spatial_dims: int = 3, in_channels: int = 3, out_channels: int = 2, channels: Sequence[int] = [32, 64, 128, 256, 512, 1024], - strides: Sequence[Tuple[int, ...]] = [(2, 2, 2), (1, 2, 2), (1, 2, 2), (1, 2, 2), (2, 2, 2)], + strides: Sequence[tuple[int, ...]] = [(2, 2, 2), (1, 2, 2), (1, 2, 2), (1, 2, 2), (2, 2, 2)], ) -> nn.Module: """Select neural network architecture for given run""" if model_type == "unet": - # ignore typing for strides argument because Sequence[Tuple[int, ...]] is valid input type + # ignore typing for strides argument because Sequence[tuple[int, ...]] is valid input type # https://docs.monai.io/en/stable/networks.html#unet model = UNet( spatial_dims=spatial_dims, diff --git a/research/picai/monai_scripts/auto3dseg.py b/research/picai/monai_scripts/auto3dseg.py index 3c1383f23..2d9d8f5e0 100644 --- a/research/picai/monai_scripts/auto3dseg.py +++ b/research/picai/monai_scripts/auto3dseg.py @@ -6,7 +6,6 @@ import argparse import os from os.path import join -from typing import Dict, Optional import numpy as np import yaml @@ -14,7 +13,7 @@ from monai.apps.auto3dseg.auto_runner import AutoRunner -def gen_dataset_list(data_dir: str, output_path: Optional[str] = None, ext: str = ".nii.gz") -> str: +def gen_dataset_list(data_dir: str, output_path: str | None = None, ext: str = ".nii.gz") -> str: """ Generates a MONAI dataset list for an nnUNet structured dataset @@ -24,7 +23,7 @@ def gen_dataset_list(data_dir: str, output_path: Optional[str] = None, ext: str Args: data_dir (str): Path to the nnUNet_raw dataset. - output_path (Optional[str]): Where and what to save the file as. Must be a json. + output_path (str | None): Where and what to save the file as. Must be a json. Default is to save as datalist.json in the data_dir Returns: @@ -41,11 +40,11 @@ def gen_dataset_list(data_dir: str, output_path: Optional[str] = None, ext: str # Initialize datalist # The values to the testing and training keys should be a list of dictionaries # where each dictionary contains information about a single case - datalist: Dict[str, list] = {"testing": [], "training": []} + datalist: dict[str, list] = {"testing": [], "training": []} # nnUNet datasets store images as unique-case-identifier_xxxx.ext - # xxxx is a 4 digit integer representing the channel/modaility. - # ext is the file extenstion + # xxxx is a 4 digit integer representing the channel/modality. + # ext is the file extention # Labels are stored as unique-case-identifier.ext as they do not have multiple channels if os.path.exists(test_dir): # nnUNet Datasets do not always have test sets diff --git a/research/picai/monai_scripts/readme.md b/research/picai/monai_scripts/readme.md index cf804b42d..7ec548e5e 100644 --- a/research/picai/monai_scripts/readme.md +++ b/research/picai/monai_scripts/readme.md @@ -6,7 +6,7 @@ This directory contains work to integrate the Monai AutoSeg3d pipeline. This file runs the monai autoseg3d pipeline on an nnunet structured dataset. -Autoseg3d is designed to work with the very common [MSD](http://medicaldecathlon.com/) dataset format. The nnunet derives it's dataset format from the MSD guidelines but alters it slightly. One of the main changes is that different modalities/channels are stored as different files. Although nnunet provides a script to convert an MSD dataset into an nnunet dataset, we'd rather not have multiple local copies of the same raw dataset. It's easier to get monai's autoseg3d to work with nnunet datasets than to get nnunet to work with MSD datasets, therefore we choose the nnunet dataset structure as our standard that will work with everything. In the standard monai workflow, a path to the image file is set as the value to the 'image' keys in the datalist json. To make autoseg3d run on nnunet datasets, one must simply replace this value with a list of filepaths, where each filepath points to one of the image channels/modalities for a particular image. Importantly, the order of the channels in this list must be consistent for all images as monai just concatenates seperate images at the beginning of the pipeline. +Autoseg3d is designed to work with the very common [MSD](http://medicaldecathlon.com/) dataset format. The nnunet derives it's dataset format from the MSD guidelines but alters it slightly. One of the main changes is that different modalities/channels are stored as different files. Although nnunet provides a script to convert an MSD dataset into an nnunet dataset, we'd rather not have multiple local copies of the same raw dataset. It's easier to get monai's autoseg3d to work with nnunet datasets than to get nnunet to work with MSD datasets, therefore we choose the nnunet dataset structure as our standard that will work with everything. In the standard monai workflow, a path to the image file is set as the value to the 'image' keys in the datalist json. To make autoseg3d run on nnunet datasets, one must simply replace this value with a list of filepaths, where each filepath points to one of the image channels/modalities for a particular image. Importantly, the order of the channels in this list must be consistent for all images as monai just concatenates separate images at the beginning of the pipeline. Use the ```--help``` flag for a list of arguments that can be passed to the autoseg3d.py script. To run the default autoseg3d pipeline on a nnunet structured dataset run the following command diff --git a/research/picai/nnunet_scripts/README.md b/research/picai/nnunet_scripts/README.md index 232cae84f..b93c048a3 100644 --- a/research/picai/nnunet_scripts/README.md +++ b/research/picai/nnunet_scripts/README.md @@ -88,12 +88,12 @@ The PICAI competition from which the picai datasets originates used the followin $$PICAI\ Score=\frac{AUROC+AP}{2} $$ -Where AUROC is the Area Under the Reciever Operating Characteristic curve and AP is the Average Precision. +Where AUROC is the Area Under the Receiver Operating Characteristic curve and AP is the Average Precision. The eval.py script computes all of these metrics plus a few more under the hood such as: - Precision Recall (PR) Curve -- Reciever Operating Characteristic (ROC) curve -- Free-Response Reciever Operating Characteristic (FROC) curve +- Receiver Operating Characteristic (ROC) curve +- Free-Response Receiver Operating Characteristic (FROC) curve For more information on the evaluation metrics see the [picai_eval](https://github.com/DIAGNijmegen/picai_eval) repo diff --git a/research/picai/nnunet_scripts/eval.py b/research/picai/nnunet_scripts/eval.py index 26c7f3e11..fd14c989c 100644 --- a/research/picai/nnunet_scripts/eval.py +++ b/research/picai/nnunet_scripts/eval.py @@ -4,10 +4,11 @@ import os import time import warnings +from collections.abc import Callable, Hashable from logging import INFO from os.path import exists, join from pathlib import Path -from typing import Any, Callable, Dict, Hashable, List, Optional, Tuple, Union +from typing import Any import numpy as np import SimpleITK as sitk @@ -25,14 +26,14 @@ # logger.setLevel(multiprocessing.SUBDEBUG) -def read_image(path: Union[Path, str], npz_key: Optional[str] = None) -> NDArray: +def read_image(path: Path | str, npz_key: str | None = None) -> NDArray: """Taken from picai eval. Had to change one line so that they wouldn't throw away additional channels. They were assuming binary segmentation. Also made it work for any npz file Args: - path (Union[Path, str]): Path to the image file - npz_key (Optional[str]): If the file type is .npz, then a key must be + path (Path | str): Path to the image file + npz_key (str | None): If the file type is .npz, then a key must be provided to access the numpy array from the NpzFile object """ if isinstance(path, Path): @@ -54,8 +55,8 @@ def read_image(path: Union[Path, str], npz_key: Optional[str] = None) -> NDArray def scan_folder_for_cases( - folder: Union[str, Path], postfixes: Optional[List[str]] = None, extensions: Optional[List[str]] = None -) -> List[str]: + folder: str | Path, postfixes: list[str] | None = None, extensions: list[str] | None = None +) -> list[str]: if postfixes is None: postfixes = [""] if extensions is None: @@ -79,12 +80,12 @@ def scan_folder_for_cases( def get_case_files( - folder: Union[str, Path], - case_identifiers: List[str], - postfixes: Optional[List[str]] = None, - extensions: Optional[List[str]] = None, + folder: str | Path, + case_identifiers: list[str], + postfixes: list[str] | None = None, + extensions: list[str] | None = None, basename: bool = False, -) -> List[str]: +) -> list[str]: if postfixes is None: postfixes = [""] if extensions is None: @@ -108,31 +109,31 @@ def get_case_files( def generate_detection_map( - probability_map: Union[NDArray, str, Path], - save_path: Union[str, Path], - npz_key: Optional[str] = "probabilities", - transforms: Optional[List[Callable[[NDArray], NDArray]]] = None, + probability_map: NDArray | str | Path, + save_path: str | Path, + npz_key: str | None = "probabilities", + transforms: list[Callable[[NDArray], NDArray]] | None = None, **kwargs: Any, ) -> None: """ Generates a detection map from a probability map by doing lesion - extraction. Supports multiclass probability maps by extracting a seperate + extraction. Supports multiclass probability maps by extracting a separate lesion detection map for each class/channel. Args: - probability_map (Union[NDArray, str, Path]): One hot encoded + probability_map (NDArray | str | Path]): One hot encoded probability map for a single image. Should be shape (num_classes, ...). num_classes includes the background class. If the probability maps are .npz files then a npz_key should be provided - save_path (Union[str, Path]): Path to save the detection map. Will be + save_path (str | Path): Path to save the detection map. Will be saved as a numpy compressed .npz file under key 'detection_map'. Detection map will have shape (num_classes - 1, ...) since a detection map is not computed for the background class. - npz_key (Optional[str]): If probability_map is a path to a .npz + npz_key (str | None): If probability_map is a path to a .npz file then a key must be provided to access the numpy array from the NpzFile object. Defaults to 'probabilities' - transforms (Optional[List[Callable[[NDArray], NDArray]]]): A list of + transforms (list[Callable[[NDArray], NDArray]] | None): A list of transform functions to apply to the probability map before passing it to the lesion extraction method. The functions will be applied in the order they are given. Can be used, for example, to one hot @@ -159,16 +160,16 @@ def generate_detection_map( def generate_detection_maps( - input_folder: Union[str, Path], - output_folder: Union[str, Path], - transforms: Optional[List[Callable[[NDArray], NDArray]]] = None, - npz_key: Optional[str] = "probabilities", - num_threads: Optional[int] = None, - postfixes: Optional[List[str]] = None, - extensions: Optional[List[str]] = None, + input_folder: str | Path, + output_folder: str | Path, + transforms: list[Callable[[NDArray], NDArray]] | None = None, + npz_key: str | None = "probabilities", + num_threads: int | None = None, + postfixes: list[str] | None = None, + extensions: list[str] | None = None, verbose: bool = True, **kwargs: Any, -) -> List[str]: +) -> list[str]: """ Extracts lesions from predicted probability maps and saves the predicted lesions as detection maps in the output folder. @@ -178,7 +179,7 @@ def generate_detection_maps( Args: input_folder (_type_): Path to the folder containing the predicted - probability maps. Each probability map must be saved as a seperate + probability maps. Each probability map must be saved as a separate file where the files basename will be used ato derive the case identifier. The probability maps must be one hot encoded and have shape (num_classes, ...) where num_classes includes the background @@ -191,7 +192,7 @@ def generate_detection_maps( generated for the background class. Ie. (num_lesion_classes == num_classes - 1). Note that this method will overwrite existing files if they already exist - transforms (Optional[List[Callable[[NDArray], NDArray]]]): A list of + transforms (list[Callable[[NDArray], NDArray]] | None): A list of transform functions to apply to the probability map before passing it to the lesion extraction method. The functions will be applied in the order they are given. Can be used, for example, to one hot @@ -207,12 +208,12 @@ def generate_detection_maps( num_threads (int, optional): The maximum number of threads to allow when extracting the detection maps. If left as None, the number of threads is automatically determined. - postfixes (Optional[List[str]]): File postfixes (endings after the + postfixes (list[str] | None): File postfixes (endings after the unique identifier but before the extension). Detection maps will only be generated for files with one or more of the specified postfixes. Postfixes are omitted from the returned case identifiers. Defaults to [""]. - extensions (List[str], optional): File extensions to allow. Detection + extensions (list[str], optional): File extensions to allow. Detection maps will only be generated for files with on of the specified file extensions. File extensions are omitted from the returned case identifiers. Defaults to [".npz", ".npy", ".nii.gz", ".nii", @@ -223,7 +224,7 @@ def generate_detection_maps( function from the report_guided_annotation API. Returns: - List[str]: A list of unique case identifiers. The case identifiers are + list[str]: A list of unique case identifiers. The case identifiers are the file basenames of the chosen input probability map files stripped of the the specified postfixes and their extension. """ @@ -264,8 +265,8 @@ def one_hot_ndarray(input: NDArray, num_classes: int) -> NDArray: def evaluate_case_multichannel( - detection_map: Union[NDArray, str, Path], ground_truth: Union[NDArray, str, Path], **kwargs: Any -) -> Tuple[List[Tuple[int, float, float]], float, float, str]: + detection_map: NDArray | str | Path, ground_truth: NDArray | str | Path, **kwargs: Any +) -> tuple[list[tuple[int, float, float]], float, float, str]: if isinstance(detection_map, (str, Path)): detection_map = read_image(detection_map, npz_key="detection_map") if isinstance(ground_truth, (str, Path)): @@ -308,11 +309,11 @@ def evaluate_case_multichannel( def get_picai_metrics( - detection_map_folder: Union[str, Path], - ground_truth_annotations_folder: Union[str, Path], - num_threads: Optional[int] = None, - sample_weights: Optional[List[float]] = None, - case_identifiers: Optional[List[str]] = None, + detection_map_folder: str | Path, + ground_truth_annotations_folder: str | Path, + num_threads: int | None = None, + sample_weights: list[float] | None = None, + case_identifiers: list[str] | None = None, verbose: bool = True, **kwargs: Any, ) -> PicaiEvalMetrics: @@ -322,13 +323,13 @@ def get_picai_metrics( allow multiclass evaluation Args: - detection_maps_folder (Union[str, Path]): Path to the folder + detection_maps_folder (str | Path): Path to the folder containing the detection maps ground_truth_annotations_folder (NDArray): The ground truth annotations. Must have shape (num_samples, num_classes or num_lesion_classes, ...). If num_classes is provided, the function will attempt to remove the background class from index 0 for you - case_identifiers (Optional[Iterable[str]], optional): A list of case + case_identifiers (list[str] | None, optional): A list of case identifiers. If not provided the subjects will be identified by their index Defaults to None. verbose (bool): Whether or not to print a log statement summarizing @@ -365,13 +366,13 @@ def get_picai_metrics( sample_weights = [1] * len(case_ids) # Initialize variables to hold results - case_targets: Dict[Hashable, int] = {} - case_weights: Dict[Hashable, float] = {} - case_preds: Dict[Hashable, float] = {} - lesion_results: Dict[Hashable, List[Tuple[int, float, float]]] = {} - lesion_weights: Dict[Hashable, List[float]] = {} + case_targets: dict[Hashable, int] = {} + case_weights: dict[Hashable, float] = {} + case_preds: dict[Hashable, float] = {} + lesion_results: dict[Hashable, list[tuple[int, float, float]]] = {} + lesion_weights: dict[Hashable, list[float]] = {} - # Evaluation must be calculated seperately for each class + # Evaluation must be calculated separately for each class with concurrent.futures.ThreadPoolExecutor(num_threads) as pool: futures = { pool.submit(evaluate_case_multichannel, detection_map=det_map, ground_truth=gt, idx=case, **kwargs): case diff --git a/research/picai/nnunet_scripts/nnunet_launch.slrm b/research/picai/nnunet_scripts/nnunet_launch.slrm index a0b95d693..ad47383f6 100644 --- a/research/picai/nnunet_scripts/nnunet_launch.slrm +++ b/research/picai/nnunet_scripts/nnunet_launch.slrm @@ -16,7 +16,7 @@ UNET_CONFIG=$2 FOLD=$3 VENV_PATH=$4 PLANS_IDENTIFIER=${5:-'plans'} # Default value of plans string -PRETAINED_WEIGHTS=${6:-''} # Default value of empty string +PRETRAINED_WEIGHTS=${6:-''} # Default value of empty string # Set environment paths that nnUNet expects export nnUNet_raw="/ssd003/projects/aieng/public/PICAI/nnUNet/nnUNet_raw" diff --git a/research/picai/nnunet_scripts/nnunet_launch_fold_experiment.slrm b/research/picai/nnunet_scripts/nnunet_launch_fold_experiment.slrm index f0a8df2a1..66f7b643a 100644 --- a/research/picai/nnunet_scripts/nnunet_launch_fold_experiment.slrm +++ b/research/picai/nnunet_scripts/nnunet_launch_fold_experiment.slrm @@ -13,7 +13,7 @@ DATASET_NAME=$1 UNET_CONFIG=$2 VENV_PATH=$3 PLANS_IDENTIFIER=${4:-'plans'} -PRETAINED_WEIGHTS=${5:-''} # Default value of empty string +PRETRAINED_WEIGHTS=${5:-''} # Default value of empty string SOURCE_DATASET_NAME=${6:-''} # Default value of empty string SOURCE_PLANS_IDENTIFIER=${7:-'plans'} # Default value of plans string diff --git a/research/picai/nnunet_scripts/old_nnunet_inference/eval_old.py b/research/picai/nnunet_scripts/old_nnunet_inference/eval_old.py index 100bf824a..253f00550 100644 --- a/research/picai/nnunet_scripts/old_nnunet_inference/eval_old.py +++ b/research/picai/nnunet_scripts/old_nnunet_inference/eval_old.py @@ -3,8 +3,9 @@ import os import warnings from collections import defaultdict +from collections.abc import Iterable from os.path import join -from typing import Any, Dict, Iterable, List, Optional +from typing import Any import numpy as np from numpy.typing import NDArray @@ -20,21 +21,21 @@ def load_images_from_folder( folder: str, - case_identifiers: List[str], - postfixes: Optional[List[str]] = None, - extensions: List[str] = [".nii.gz", ".nii", ".mha", ".mhd", ".npz", ".npy"], + case_identifiers: list[str], + postfixes: list[str] | None = None, + extensions: list[str] = [".nii.gz", ".nii", ".mha", ".mhd", ".npz", ".npy"], ) -> NDArray: """ Loads images from a folder given a list of case identifiers Args: folder (str): The folder containing the images - case_identifiers (List[str]): A list of case identifiers for each + case_identifiers (list[str]): A list of case identifiers for each file. Typically just the filenames without the extension - postfixes (Optional[List[str]], optional): A list of strings to append + postfixes (list[str] | None, optional): A list of strings to append to the case identifiers when looking for files. For example '_labels'. Defaults to None. - extensions (List[str], optional): A list of possible image extensions. + extensions (list[str], optional): A list of possible image extensions. Defaults to [".nii.gz", ".nii", ".mha", ".mhd", ".npz", ".npy"]. Returns: @@ -86,7 +87,7 @@ def get_detection_maps(probability_maps: NDArray) -> NDArray: def get_picai_metrics( detection_maps: NDArray, ground_truth_annotations: NDArray, - case_identifiers: Optional[Iterable[str]] = None, + case_identifiers: Iterable[str] | None = None, **kwargs: Any ) -> PicaiEvalMetrics: """ @@ -102,7 +103,7 @@ class should not be included in the detection maps have shape (num_samples, num_classes or num_lesion_classes, ...). If num_classes is provided, the function will attempt to remove the background class from index 0 for you - case_identifiers (Optional[Iterable[str]], optional): A list of case + case_identifiers (Iterable[str] | None, optional): A list of case identifiers. If not provided the subjects will be identified by their index Defaults to None. **kwargs: Keyword arguments for the picai_eval.evaluate function @@ -129,9 +130,9 @@ class should not be included in the detection maps detection_maps.shape == ground_truth_annotations.shape ), "Got unexpected shapes for detection maps and ground truth annotations" - # Evaluation must be calculated seperately for each class + # Evaluation must be calculated separately for each class num_classes = detection_maps.shape[1] - metrics: List[PicaiEvalMetrics] = [] + metrics: list[PicaiEvalMetrics] = [] for cls in range(num_classes): metrics.append( evaluate( @@ -144,10 +145,10 @@ class should not be included in the detection maps subject_list = metrics[0].subject_list assert isinstance(subject_list, list), "Got unexpected subject list from picai eval metrics object" - lesion_results: Dict[Any, list] = defaultdict(list) - lesion_weights: Dict[Any, list] = defaultdict(list) - case_targets: Dict[Any, int] = defaultdict(int) - case_preds: Dict[Any, float] = defaultdict(float) + lesion_results: dict[Any, list] = defaultdict(list) + lesion_weights: dict[Any, list] = defaultdict(list) + case_targets: dict[Any, int] = defaultdict(int) + case_preds: dict[Any, float] = defaultdict(float) for s in subject_list: # Ignoring mypy errors here for now because i don't know how to get around them [lesion_results[s].extend(m.lesion_results[s]) for m in metrics] # type: ignore diff --git a/research/picai/nnunet_scripts/old_nnunet_inference/predict_and_eval_old.py b/research/picai/nnunet_scripts/old_nnunet_inference/predict_and_eval_old.py index 73add9f44..6ef40656a 100644 --- a/research/picai/nnunet_scripts/old_nnunet_inference/predict_and_eval_old.py +++ b/research/picai/nnunet_scripts/old_nnunet_inference/predict_and_eval_old.py @@ -4,7 +4,6 @@ import warnings from logging import INFO from os.path import exists, join -from typing import Optional import numpy as np from flwr.common.logger import log @@ -21,7 +20,7 @@ def pred_and_eval( config_path: str, inputs_folder: str, labels_folder: str, - output_folder: Optional[str] = None, + output_folder: str | None = None, save_probability_maps: bool = False, save_detection_maps: bool = False, save_annotations: bool = False, @@ -48,13 +47,13 @@ def pred_and_eval( inputs_folder (str): Path to the folder containing the raw input data that has not been processed by nnunet yet. File names must follow the nnunet convention where each channel modality is stored as a - seperate file. File names should be case-identifier_0000 where + separate file. File names should be case-identifier_0000 where 0000 is a 4 digit integer representing the channel/modality of the image. All cases must have the same N channels numbered from 0 to N. labels_folder (str): Path to the folder containing the ground truth annotation maps. File names must match the case identifiers of the input images - output_folder (Optional[str], optional): Path to the output folder. By + output_folder (str | None, optional): Path to the output folder. By default the only output is a 'metrics.json' file containing the evaluation results. If left as none then nothing is saved. Defaults to None. @@ -194,7 +193,7 @@ def main() -> None: type=str, help="""Path to the folder containing the raw input data that has not been processed by nnunet yet. File names must follow the nnunet - convention where each channel modality is stored as a seperate + convention where each channel modality is stored as a separate file. File names should be case-identifier_0000 where 0000 is a 4 digit integer representing the channel/modality of the image. All cases must have the same N channels numbered from 0 to N.""", diff --git a/research/picai/nnunet_scripts/old_nnunet_inference/predict_old.py b/research/picai/nnunet_scripts/old_nnunet_inference/predict_old.py index d99702cfa..0038c4a57 100644 --- a/research/picai/nnunet_scripts/old_nnunet_inference/predict_old.py +++ b/research/picai/nnunet_scripts/old_nnunet_inference/predict_old.py @@ -5,9 +5,10 @@ import os import time import warnings +from collections.abc import Generator from logging import INFO from os.path import basename, isdir, join -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Any import numpy as np import torch @@ -17,7 +18,7 @@ with warnings.catch_warnings(): # We get a bunch of scipy deprecation warnings from these packages - # Curiosly this only happens if flwr is imported first + # Curiously this only happens if flwr is imported first # Raised issue https://github.com/MIC-DKFZ/nnUNet/issues/2370 warnings.filterwarnings("ignore", category=DeprecationWarning) import nnunetv2 @@ -34,7 +35,7 @@ class MyNnUNetPredictor(nnUNetPredictor): def predict_from_data_iterator( self, data_iterator: Generator, return_probabilities: bool = False, num_processes: int = default_num_processes - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Override of the predict from data iterator class so that we can have it return the model outputs along with their output filenames and data @@ -124,13 +125,13 @@ def predict_from_data_iterator( return return_dict -def get_predictor(ckpt_list: List[str], nnunet_config: str, dataset_json: dict, plans: dict) -> nnUNetPredictor: +def get_predictor(ckpt_list: list[str], nnunet_config: str, dataset_json: dict, plans: dict) -> nnUNetPredictor: """ Returns an initialized nnUNetPredictor for a set of nnunet models with the same config and architecture Args: - ckpt_list (List[str]): A list containing the paths to the checkpoint + ckpt_list (list[str]): A list containing the paths to the checkpoint files for the nnunet models nnunet_config (str): The nnunet config of the the models specific in ckpt_list. @@ -148,12 +149,12 @@ def get_predictor(ckpt_list: List[str], nnunet_config: str, dataset_json: dict, """ # Helper function to make code cleaner - def check_for_ckpt_info(model: dict) -> Tuple[str, bool]: + def check_for_ckpt_info(model: dict) -> tuple[str, bool]: """ Checks model dict for trainer name and inference_allowed_mirroring_axes Returns: - Tuple[Optional[str], bool]: Tuple with elements trainer_name and + tuple[str | None, bool]: Tuple with elements trainer_name and inference_allowed_mirroring_axes. Defaults to ('nnUNetTrainer, False) """ @@ -166,7 +167,7 @@ def check_for_ckpt_info(model: dict) -> Tuple[str, bool]: return trainer_name, inference_allowed_mirror_axes - # Create unintialized predictor instance + # Create uninitialized predictor instance predictor = MyNnUNetPredictor(verbose=False, verbose_preprocessing=False, allow_tqdm=False) # Get parameters for each model and maybe some predictor init parameters @@ -225,10 +226,10 @@ def check_for_ckpt_info(model: dict) -> Tuple[str, bool]: def predict( config_path: str, input_folder: str, - probs_folder: Optional[str] = None, - annotations_folder: Optional[str] = None, + probs_folder: str | None = None, + annotations_folder: str | None = None, verbose: bool = True, -) -> Tuple[NDArray, NDArray, List[str]]: +) -> tuple[NDArray, NDArray, list[str]]: """ Uses multiprocessing to quickly do model inference for a single model, a group of models with the same nnunet config or an ensemble of different @@ -248,16 +249,16 @@ def predict( create a new json yourself with the 'label' and 'file_ending' keys and their corresponding values as specified by nnunet input_folder (str): Path to the folder containing the raw input data - that has notbeen processed by nnunet yet. File names must follow the + that has not been processed by nnunet yet. File names must follow the nnunet convention where each channel modality is stored as a - seperate file.File names should be case-identifier_0000 where 0000 + separate file.File names should be case-identifier_0000 where 0000 is a 4 digit integer representing the channel/modality of the image. All cases must have the same number of channels N numbered from 0 to N. - preds_folder (Optional[str]): [OPTIONAL] Path to the output folder to + preds_folder (str | None): [OPTIONAL] Path to the output folder to save the model predicted probabilities. If not provided the probabilities are not saved - annotations_folder (Optional[str]): [OPTIONAL] Path to the output + annotations_folder (str | None): [OPTIONAL] Path to the output folder to save the model predicted annotations. If not provided the annotations are not saved Returns: @@ -266,7 +267,7 @@ def predict( NDArray[int]: a numpy array with a single predicted annotation map for each input image. Unlike the predicted probabilities these are NOT one hot encoded. Shape: (num_samples, spatial_dims...) - List[str]: A list containing the unique case identifier for + list[str]: A list containing the unique case identifier for each prediction """ t_start = time.time() @@ -355,7 +356,7 @@ def predict( log(INFO, f"\tNum Classes: {shape[1]}") log(INFO, f"\tSpatial Dimensions {shape[2:]}") - # Save predicted probabilites if output_folder was provided + # Save predicted probabilities if output_folder was provided if probs_folder is not None: t = time.time() for pred, case in zip(final_probs, case_identifiers): @@ -428,7 +429,7 @@ def main() -> None: type=str, help="""Path to the folder containing the raw input data that has not been processed by nnunet yet. File names must follow the nnunet - convention where each channel modality is stored as a seperate + convention where each channel modality is stored as a separate file. File names should be case-identifier_0000 where 0000 is a 4 digit integer representing the channel/modality of the image. All cases must have the same N channels numbered from 0 to N.""", diff --git a/research/picai/nnunet_scripts/predict.py b/research/picai/nnunet_scripts/predict.py index ada061a90..f859747f7 100644 --- a/research/picai/nnunet_scripts/predict.py +++ b/research/picai/nnunet_scripts/predict.py @@ -7,7 +7,6 @@ import warnings from logging import INFO from os.path import basename, exists, isdir, join -from typing import List, Tuple import numpy as np import torch @@ -18,7 +17,7 @@ with warnings.catch_warnings(): # We get a bunch of scipy deprecation warnings from these packages - # Curiosly this only happens if flwr is imported first + # Curiously this only happens if flwr is imported first # Raised issue https://github.com/MIC-DKFZ/nnUNet/issues/2370 warnings.filterwarnings("ignore", category=DeprecationWarning) import nnunetv2 @@ -35,13 +34,13 @@ def yaml_join(loader: yaml.Loader, node: yaml.SequenceNode) -> str: return os.path.join(*seq) -def get_predictor(ckpt_list: List[str], nnunet_config: str, dataset_json: dict, plans: dict) -> nnUNetPredictor: +def get_predictor(ckpt_list: list[str], nnunet_config: str, dataset_json: dict, plans: dict) -> nnUNetPredictor: """ Returns an initialized nnUNetPredictor for a set of nnunet models with the same config and architecture Args: - ckpt_list (List[str]): A list containing the paths to the checkpoint + ckpt_list (list[str]): A list containing the paths to the checkpoint files for the nnunet models nnunet_config (str): The nnunet config of the the models specific in ckpt_list. @@ -57,12 +56,12 @@ def get_predictor(ckpt_list: List[str], nnunet_config: str, dataset_json: dict, """ # Helper function to make code cleaner - def check_for_ckpt_info(model: dict) -> Tuple[str, bool]: + def check_for_ckpt_info(model: dict) -> tuple[str, bool]: """ Checks model dict for trainer name and inference_allowed_mirroring_axes Returns: - Tuple[Optional[str], bool]: Tuple with elements trainer_name and + tuple[str | None, bool]: Tuple with elements trainer_name and inference_allowed_mirroring_axes. Defaults to ('nnUNetTrainer, False) """ @@ -75,7 +74,7 @@ def check_for_ckpt_info(model: dict) -> Tuple[str, bool]: return trainer_name, inference_allowed_mirror_axes - # Create unintialized predictor instance + # Create uninitialized predictor instance predictor = nnUNetPredictor(verbose=False, verbose_preprocessing=False, allow_tqdm=False) # Get parameters for each model and maybe some predictor init parameters @@ -163,14 +162,14 @@ def predict( base_path: &base_path /home/user/data dataset_json: !join [*base_path, 'PICAI', 'dataset.json'] input_folder (str): Path to the folder containing the raw input data - that has notbeen processed by nnunet yet. File names must follow the + that has not been processed by nnunet yet. File names must follow the nnunet convention where each channel modality is stored as a - seperate file.File names should be case-identifier_0000 where 0000 + separate file.File names should be case-identifier_0000 where 0000 is a 4 digit integer representing the channel/modality of the image. All cases must have the same number of channels N numbered from 0 to N. output_folder (str): Path to save the predicted probabilities and - predicted annotations. Each will be stored in a seperate + predicted annotations. Each will be stored in a separate subdirectory. Probabilities will be stored as .npz files. The NPZ file object will have the key 'probabilities'. The predicted annotations will be saved as the original input image @@ -180,7 +179,7 @@ def predict( annotations_folder_name (str): What to name the folder within the output folder that the predicted annotations will be stored in """ - # Note: I should split output folder into two seperate paths for model outputs + # Note: I should split output folder into two separate paths for model outputs t_start = time.time() # Add !join constructor to yaml so that config files can be more readable @@ -342,7 +341,7 @@ def main() -> None: type=str, help="""Path to the folder containing the raw input data that has not been processed by nnunet yet. File names must follow the nnunet - convention where each channel modality is stored as a seperate + convention where each channel modality is stored as a separate file. File names should be case-identifier_0000 where 0000 is a 4 digit integer representing the channel/modality of the image. All cases must have the same N channels numbered from 0 to N.""", @@ -352,7 +351,7 @@ def main() -> None: required=True, type=str, help="""[OPTIONAL] Path to save the predicted probabilities and - predicted annotations. Each will be stored in a seperate + predicted annotations. Each will be stored in a separate subdirectory. Probabilities will be stored as .npz files. The NPZ file object will have the key 'probabilities'. The predicted annotations will be saved as the original input image diff --git a/research/picai/nnunet_scripts/predict_and_eval.py b/research/picai/nnunet_scripts/predict_and_eval.py index 14e74c7d6..847c10f0f 100644 --- a/research/picai/nnunet_scripts/predict_and_eval.py +++ b/research/picai/nnunet_scripts/predict_and_eval.py @@ -38,7 +38,7 @@ def pred_and_eval( input_folder (str): Path to the folder containing the raw input data that has not been processed by nnunet yet. File names must follow the nnunet convention where each channel modality is stored as a - seperate file. File names should be case-identifier_0000 where + separate file. File names should be case-identifier_0000 where 0000 is a 4 digit integer representing the channel/modality. All cases must have the same N channels numbered from 0 to N. label_folder (str): Path to the folder containing the ground truth @@ -67,7 +67,7 @@ def pred_and_eval( npz_key="probabilities", num_threads=None, # Let threadpool determine optimal num threads postfixes=[""], - extensions=[".npz"], # Probablity maps saved as npz files + extensions=[".npz"], # Probability maps saved as npz files verbose=True, ) log(INFO, "") @@ -127,7 +127,7 @@ def main() -> None: type=str, help="""Path to the folder containing the raw input data that has not been processed by nnunet yet. File names must follow the nnunet - convention where each channel modality is stored as a seperate + convention where each channel modality is stored as a separate file. File names should be case-identifier_0000 where 0000 is a 4 digit integer representing the channel/modality of the image. All cases must have the same N channels numbered from 0 to N.""", diff --git a/research/picai/reporting/client.py b/research/picai/reporting/client.py index f1fefda87..80b9fa74b 100644 --- a/research/picai/reporting/client.py +++ b/research/picai/reporting/client.py @@ -1,6 +1,5 @@ import argparse from pathlib import Path -from typing import Optional, Tuple import flwr as fl import torch @@ -19,12 +18,12 @@ class CifarClient(BasicClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) test_loader, _ = load_cifar10_test_data(self.data_path, batch_size) return test_loader diff --git a/research/picai/reporting/server.py b/research/picai/reporting/server.py index 9c9d698a4..97842ec58 100644 --- a/research/picai/reporting/server.py +++ b/research/picai/reporting/server.py @@ -1,6 +1,6 @@ import argparse from functools import partial -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -22,8 +22,8 @@ def fit_config( batch_size: int, current_server_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -32,7 +32,7 @@ def fit_config( } -def main(config: Dict[str, Any]) -> None: +def main(config: dict[str, Any]) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/research/picai/single_node_trainer.py b/research/picai/single_node_trainer.py index 085d2cd5d..7cafbeec4 100644 --- a/research/picai/single_node_trainer.py +++ b/research/picai/single_node_trainer.py @@ -1,7 +1,6 @@ import os from logging import INFO from pathlib import Path -from typing import Dict, Tuple import torch import torch.nn as nn @@ -55,14 +54,14 @@ def __init__( ckpt = self.per_epoch_checkpointer.load_checkpoint(self.state_checkpoint_path) self.model, self.optimizer, self.epoch = ckpt["model"], ckpt["optimizer"], ckpt["epoch"] - def _maybe_checkpoint(self, loss: float, metrics: Dict[str, Scalar]) -> None: + def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar]) -> None: if self.checkpointer: self.checkpointer.maybe_checkpoint(self.model, loss, metrics) def _handle_reporting( self, loss: float, - metrics_dict: Dict[str, Scalar], + metrics_dict: dict[str, Scalar], is_validation: bool = False, ) -> None: metric_string = "\t".join([f"{key}: {str(val)}" for key, val in metrics_dict.items()]) @@ -72,7 +71,7 @@ def _handle_reporting( f"Centralized {metric_prefix} Loss: {loss} \n" f"Centralized {metric_prefix} Metrics: {metric_string}", ) - def train_step(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + def train_step(self, input: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: self.model.train() # forward pass on the model preds = self.model(input) diff --git a/research/picai/utils.py b/research/picai/utils.py index 4b30af6cf..b1c462224 100644 --- a/research/picai/utils.py +++ b/research/picai/utils.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, List +from typing import Any class MultiAttributeEnum(Enum): @@ -51,7 +51,7 @@ def get_attribute_keys(self, attributes): Cat = ['Felis Catus', True, 0.25] Args: - attributes (Union[Dict[str, Any], List]): A list or dictionary of + attributes (dict[str, Any] | List): A list or dictionary of attribute values for the enum member. If a list is given then self.get_attribute_keys must be defined so that the class knows what to name the attributes. @@ -64,9 +64,9 @@ def get_attribute_keys(self, attributes): for key, value in zip(attribute_keys, attributes): setattr(self, key, value) - # Creat attributes that will be assigned for each member seperately - self.attribute_keys: List[str] - self.attribute_values: List[Any] + # Creat attributes that will be assigned for each member separately + self.attribute_keys: list[str] + self.attribute_values: list[Any] def __new__(cls, attributes: Any) -> Enum: # type: ignore """ @@ -97,17 +97,17 @@ def __new__(cls, attributes: Any) -> Enum: # type: ignore return obj # Return member - def keys(self) -> List[str]: + def keys(self) -> list[str]: """ Returns: - List[str]: a list containing the names of the attributes for this member + list[str]: a list containing the names of the attributes for this member """ return self.attribute_keys # These are set in __new__ for each member - def values(self) -> List[Any]: + def values(self) -> list[Any]: """ Returns: - List[Any]: A list of the attribute values for this member + list[Any]: A list of the attribute values for this member """ return self.attribute_values @@ -148,7 +148,7 @@ def (self, attributes): elif isinstance(attributes, list): return attributes[0] - def get_attribute_keys(self, attributes: List) -> List[str]: + def get_attribute_keys(self, attributes: list) -> list[str]: raise NotImplementedError( "Received a list of attributes but the self.get_attribute_keys class method was not implemented" ) diff --git a/tests/checkpointing/test_client_module.py b/tests/checkpointing/test_client_module.py index 4cf1a7439..0f831615a 100644 --- a/tests/checkpointing/test_client_module.py +++ b/tests/checkpointing/test_client_module.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List import pytest import torch @@ -105,7 +104,7 @@ def test_client_checkpointer_module(tmp_path: Path) -> None: def test_client_checkpointer_module_with_sequence_of_checkpointers(tmp_path: Path) -> None: checkpoint_dir = tmp_path.joinpath("resources") checkpoint_dir.mkdir() - pre_aggregation_checkpointer: List[TorchModuleCheckpointer] = [ + pre_aggregation_checkpointer: list[TorchModuleCheckpointer] = [ BestLossTorchModuleCheckpointer(str(checkpoint_dir), "pre_agg_best.pkl"), LatestTorchModuleCheckpointer(str(checkpoint_dir), "pre_agg_latest.pkl"), ] diff --git a/tests/checkpointing/test_function_checkpointer.py b/tests/checkpointing/test_function_checkpointer.py index cf3cdb28f..0cb004778 100644 --- a/tests/checkpointing/test_function_checkpointer.py +++ b/tests/checkpointing/test_function_checkpointer.py @@ -1,11 +1,9 @@ -from typing import Dict - from flwr.common.typing import Scalar from fl4health.checkpointing.checkpointer import FunctionTorchModuleCheckpointer -def score_function(_: float, metrics: Dict[str, Scalar]) -> float: +def score_function(_: float, metrics: dict[str, Scalar]) -> float: accuracy = metrics["accuracy"] precision = metrics["precision"] assert isinstance(accuracy, float) @@ -17,8 +15,8 @@ def score_function(_: float, metrics: Dict[str, Scalar]) -> float: def test_function_checkpointer() -> None: function_checkpointer = FunctionTorchModuleCheckpointer("", "", score_function, maximize=True) loss_1, loss_2 = 1.0, 0.9 - metrics_1: Dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.67, "f1": 0.76} - metrics_2: Dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.9, "f1": 0.6} + metrics_1: dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.67, "f1": 0.76} + metrics_2: dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.9, "f1": 0.6} function_checkpointer.best_score = 0.85 # Should be false since the best score seen is set to 0.85 above diff --git a/tests/checkpointing/test_opacus_checkpointers.py b/tests/checkpointing/test_opacus_checkpointers.py index ae40f92ca..c457342f6 100644 --- a/tests/checkpointing/test_opacus_checkpointers.py +++ b/tests/checkpointing/test_opacus_checkpointers.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Dict import pytest import torch @@ -72,7 +71,7 @@ def test_save_and_load_latest_checkpoint(tmp_path: Path) -> None: assert torch.equal(model_1.linear.weight, target_model.linear.weight) -def score_function(loss: float, metrics: Dict[str, Scalar]) -> float: +def score_function(loss: float, metrics: dict[str, Scalar]) -> float: accuracy = metrics["accuracy"] precision = metrics["precision"] assert isinstance(accuracy, float) @@ -95,8 +94,8 @@ def test_save_and_load_function_checkpoint(tmp_path: Path) -> None: opacus_checkpointer = OpacusCheckpointer(str(checkpoint_dir), checkpoint_name, score_function, maximize=True) loss_1, loss_2 = 1.0, 0.9 - metrics_1: Dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.67, "f1": 0.76} - metrics_2: Dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.90, "f1": 0.60} + metrics_1: dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.67, "f1": 0.76} + metrics_2: dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.90, "f1": 0.60} opacus_checkpointer.best_score = 0.85 # model_1 should not be checkpointed because the model score is lower than the best score set above diff --git a/tests/checkpointing/test_save_load.py b/tests/checkpointing/test_save_load.py index d31043abb..324deaaf6 100644 --- a/tests/checkpointing/test_save_load.py +++ b/tests/checkpointing/test_save_load.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Dict import torch from flwr.common.typing import Scalar @@ -49,7 +48,7 @@ def test_save_and_load_latest_checkpoint(tmp_path: Path) -> None: assert torch.equal(model_1.linear.weight, loaded_model.linear.weight) -def score_function(loss: float, metrics: Dict[str, Scalar]) -> float: +def score_function(loss: float, metrics: dict[str, Scalar]) -> float: accuracy = metrics["accuracy"] precision = metrics["precision"] assert isinstance(accuracy, float) @@ -70,8 +69,8 @@ def test_save_and_load_function_checkpoint(tmp_path: Path) -> None: str(checkpoint_dir), checkpoint_name, score_function, maximize=True ) loss_1, loss_2 = 1.0, 0.9 - metrics_1: Dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.67, "f1": 0.76} - metrics_2: Dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.90, "f1": 0.60} + metrics_1: dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.67, "f1": 0.76} + metrics_2: dict[str, Scalar] = {"accuracy": 0.87, "precision": 0.90, "f1": 0.60} function_checkpointer.best_score = 0.85 # model_1 should not be checkpointed because the model score is lower than the best score set above diff --git a/tests/clients/test_basic_client.py b/tests/clients/test_basic_client.py index 337d72f01..f8f89f4f6 100644 --- a/tests/clients/test_basic_client.py +++ b/tests/clients/test_basic_client.py @@ -1,7 +1,6 @@ import datetime from collections.abc import Sequence from pathlib import Path -from typing import Dict, Optional from unittest.mock import MagicMock import freezegun @@ -51,7 +50,7 @@ def test_json_reporter_shutdown() -> None: def test_metrics_reporter_fit() -> None: test_current_server_round = 2 test_loss_dict = {"test_loss": 123.123} - test_metrics: Dict[str, Scalar] = {"test_metric": 1234} + test_metrics: dict[str, Scalar] = {"test_metric": 1234} reporter = JsonReporter() fl_client = MockBasicClient(loss_dict=test_loss_dict, metrics=test_metrics, reporters=[reporter]) @@ -77,8 +76,8 @@ def test_metrics_reporter_fit() -> None: def test_metrics_reporter_evaluate() -> None: test_current_server_round = 2 test_loss = 123.123 - test_metrics: Dict[str, Scalar] = {"test_metric": 1234} - test_metrics_testing: Dict[str, Scalar] = {"testing_metric": 1234} + test_metrics: dict[str, Scalar] = {"test_metric": 1234} + test_metrics_testing: dict[str, Scalar] = {"testing_metric": 1234} test_metrics_final = { "test_metric": 1234, "testing_metric": 1234, @@ -154,10 +153,10 @@ def test_num_val_samples_correct() -> None: class MockBasicClient(BasicClient): def __init__( self, - loss_dict: Optional[Dict[str, float]] = None, - metrics: Optional[Dict[str, Scalar]] = None, - test_set_metrics: Optional[Dict[str, Scalar]] = None, - loss: Optional[float] = 0, + loss_dict: dict[str, float] | None = None, + metrics: dict[str, Scalar] | None = None, + test_set_metrics: dict[str, Scalar] | None = None, + loss: float | None = 0, reporters: Sequence[BaseReporter] | None = None, ): super().__init__(Path(""), [], torch.device(0), reporters=reporters) diff --git a/tests/clients/test_evaluate_client.py b/tests/clients/test_evaluate_client.py index a2968c20c..768648758 100644 --- a/tests/clients/test_evaluate_client.py +++ b/tests/clients/test_evaluate_client.py @@ -2,7 +2,6 @@ import math from collections.abc import Sequence from pathlib import Path -from typing import Dict, Optional, Union from unittest.mock import MagicMock import pytest @@ -19,11 +18,11 @@ def test_evaluate_merge_metrics(caplog: pytest.LogCaptureFixture) -> None: - global_metrics: Dict[str, Scalar] = { + global_metrics: dict[str, Scalar] = { "global_metric_1": 0.22, "local_metric_2": 0.11, } - local_metrics: Dict[str, Scalar] = {"local_metric_1": 0.1, "local_metric_2": 0.99} + local_metrics: dict[str, Scalar] = {"local_metric_1": 0.1, "local_metric_2": 0.99} merged_metrics = EvaluateClient.merge_metrics(global_metrics, local_metrics) # Test merge is good, local metrics are folded in last, so they take precedence when overlap exists assert merged_metrics == { @@ -111,7 +110,7 @@ def test_metrics_reporter_setup_client() -> None: @freeze_time("2012-12-12 12:12:12") def test_metrics_reporter_evaluate() -> None: test_loss = 123.123 - test_metrics: Dict[str, Union[bool, bytes, float, int, str]] = {"test_metric": 1234} + test_metrics: dict[str, bool | bytes | float | int | str] = {"test_metric": 1234} reporter = JsonReporter() evaluate_client = MockEvaluateClient(loss=test_loss, metrics=test_metrics, reporters=[reporter]) @@ -135,8 +134,8 @@ def test_metrics_reporter_evaluate() -> None: class MockEvaluateClient(EvaluateClient): def __init__( self, - loss: Optional[float] = None, - metrics: Optional[Dict[str, Scalar]] = None, + loss: float | None = None, + metrics: dict[str, Scalar] | None = None, reporters: Sequence[BaseReporter] | None = None, ): super().__init__(Path(""), [], torch.device(0), reporters=reporters) diff --git a/tests/clients/test_fedrep_client.py b/tests/clients/test_fedrep_client.py index a80e78705..d7792dce8 100644 --- a/tests/clients/test_fedrep_client.py +++ b/tests/clients/test_fedrep_client.py @@ -1,5 +1,3 @@ -from typing import Dict - import mock import pytest import torch @@ -148,14 +146,14 @@ def test_dictionary_modification_and_config_extraction(get_client: FedRepClient) torch.seed() # resetting the seed at the end, just to be safe -def get_optimizer_patch_1(self: FedRepClient, config: Config) -> Dict[str, Optimizer]: +def get_optimizer_patch_1(self: FedRepClient, config: Config) -> dict[str, Optimizer]: assert isinstance(self.model, FedRepModel) head_optimizer = torch.optim.AdamW(self.model.head_module.parameters(), lr=0.01) rep_optimizer = torch.optim.AdamW(self.model.base_module.parameters(), lr=0.01) return {"head": head_optimizer, "representation": rep_optimizer} -def get_optimizer_patch_2(self: FedRepClient, config: Config) -> Dict[str, Optimizer]: +def get_optimizer_patch_2(self: FedRepClient, config: Config) -> dict[str, Optimizer]: assert isinstance(self.model, FedRepModel) head_optimizer = torch.optim.AdamW(self.model.head_module.parameters(), lr=0.01) rep_optimizer = torch.optim.AdamW(self.model.base_module.parameters(), lr=0.01) diff --git a/tests/clients/test_instance_level.py b/tests/clients/test_instance_level.py index 29d7681ea..871413804 100644 --- a/tests/clients/test_instance_level.py +++ b/tests/clients/test_instance_level.py @@ -1,6 +1,5 @@ import copy from pathlib import Path -from typing import Tuple import pytest import torch @@ -28,7 +27,7 @@ def __init__(self, data_size: int = 100) -> None: def __len__(self) -> int: return len(self.data) - def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: return self.data[index], self.targets[index] diff --git a/tests/clients/test_setup_client_basic.py b/tests/clients/test_setup_client_basic.py index 913c08fda..ff0da03ac 100644 --- a/tests/clients/test_setup_client_basic.py +++ b/tests/clients/test_setup_client_basic.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Tuple import torch import torch.nn as nn @@ -15,7 +14,7 @@ class ClientForTest(BasicClient): - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: train_loader = DataLoader(TensorDataset(torch.ones((4, 4)), torch.ones((4)))) val_loader = DataLoader(TensorDataset(torch.ones((4, 4)), torch.ones((4)))) return train_loader, val_loader diff --git a/tests/models/test_ensemble_base.py b/tests/models/test_ensemble_base.py index 064603726..a3d41667a 100644 --- a/tests/models/test_ensemble_base.py +++ b/tests/models/test_ensemble_base.py @@ -1,5 +1,3 @@ -from typing import Dict - import mock import torch import torch.nn as nn @@ -9,7 +7,7 @@ def test_forward_average_mode() -> None: - models: Dict[str, nn.Module] = {"model_0": SmallCnn(), "model_1": SmallCnn()} + models: dict[str, nn.Module] = {"model_0": SmallCnn(), "model_1": SmallCnn()} ensemble_model = EnsembleModel(models, EnsembleAggregationMode.AVERAGE) data = torch.rand((64, 1, 28, 28)) ensemble_predictions = ensemble_model(data) @@ -23,7 +21,7 @@ def test_forward_average_mode() -> None: def test_forward_vote_mode() -> None: - models: Dict[str, nn.Module] = {"model_0": SmallCnn(), "model_1": SmallCnn()} + models: dict[str, nn.Module] = {"model_0": SmallCnn(), "model_1": SmallCnn()} ensemble_model = EnsembleModel(models, EnsembleAggregationMode.VOTE) data = torch.rand((64, 1, 28, 28)) ensemble_predictions = ensemble_model(data) diff --git a/tests/parameter_exchange/test_packing_exchanger.py b/tests/parameter_exchange/test_packing_exchanger.py index fe8b035e3..a2dbccfa5 100644 --- a/tests/parameter_exchange/test_packing_exchanger.py +++ b/tests/parameter_exchange/test_packing_exchanger.py @@ -1,5 +1,4 @@ import copy -from typing import List import numpy as np import pytest @@ -28,13 +27,13 @@ @pytest.fixture -def get_ndarrays(layer_sizes: List[List[int]]) -> NDArrays: +def get_ndarrays(layer_sizes: list[list[int]]) -> NDArrays: ndarrays = [np.ones(tuple(size)) for size in layer_sizes] return ndarrays @pytest.fixture -def get_sparse_tensors(num_tensors: int) -> List[Tensor]: +def get_sparse_tensors(num_tensors: int) -> list[Tensor]: tensors = [] for _ in range(num_tensors): x = torch.tensor([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4], [5, 0, 0, 0]]) @@ -142,7 +141,7 @@ def test_parameter_packer_with_layer_names(get_ndarrays: NDArrays) -> None: # n @pytest.mark.parametrize("num_tensors", [6]) -def test_sparse_coo_parameter_packer(get_sparse_tensors: List[Tensor]) -> None: +def test_sparse_coo_parameter_packer(get_sparse_tensors: list[Tensor]) -> None: model_tensors = get_sparse_tensors tensor_names = ["tensor1", "tensor2", "tensor3", "tensor4", "tensor5", "tensor6"] parameter_nonzero_values = [] diff --git a/tests/preprocessing/test_ae_dim_reduction.py b/tests/preprocessing/test_ae_dim_reduction.py index ed0f1e4c9..62490923f 100644 --- a/tests/preprocessing/test_ae_dim_reduction.py +++ b/tests/preprocessing/test_ae_dim_reduction.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Tuple import torch from torch.utils.data import DataLoader @@ -23,7 +22,7 @@ def __init__(self, data_size: int = 50, sample_vector_size: int = 10, condition_ def __len__(self) -> int: return len(self.data) - def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: return self.data[index], self.targets[index] @@ -36,7 +35,7 @@ def __init__(self, data_size: int = 50, sample_vector_size: int = 10) -> None: def __len__(self) -> int: return len(self.data) - def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]: return self.data[index], self.targets[index] diff --git a/tests/preprocessing/test_ae_loss.py b/tests/preprocessing/test_ae_loss.py index b9e29abe2..42bc226bb 100644 --- a/tests/preprocessing/test_ae_loss.py +++ b/tests/preprocessing/test_ae_loss.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch from sklearn.metrics import mean_squared_error @@ -7,7 +5,7 @@ def kl_divergence_normal( - q_params: Tuple[torch.Tensor, torch.Tensor], p_params: Tuple[torch.Tensor, torch.Tensor] + q_params: tuple[torch.Tensor, torch.Tensor], p_params: tuple[torch.Tensor, torch.Tensor] ) -> torch.Tensor: mu_q, logvar_q = q_params mu_p, logvar_p = p_params diff --git a/tests/servers/test_base_server.py b/tests/servers/test_base_server.py index 8264c55bb..6d571f2b7 100644 --- a/tests/servers/test_base_server.py +++ b/tests/servers/test_base_server.py @@ -1,6 +1,5 @@ import datetime from pathlib import Path -from typing import List, Tuple, Union from unittest.mock import Mock, patch import pytest @@ -199,7 +198,7 @@ def test_unpack_metrics() -> None: }, ) - results: List[Tuple[ClientProxy, EvaluateRes]] = [(client_proxy, eval_res)] + results: list[tuple[ClientProxy, EvaluateRes]] = [(client_proxy, eval_res)] val_results, test_results = fl_server._unpack_metrics(results) @@ -244,11 +243,11 @@ def test_handle_result_aggregation() -> None: }, ) - results: List[Tuple[ClientProxy, EvaluateRes]] = [ + results: list[tuple[ClientProxy, EvaluateRes]] = [ (client_proxy1, eval_res1), (client_proxy2, eval_res2), ] - failures: List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]] = [] + failures: list[tuple[ClientProxy, EvaluateRes] | BaseException] = [] server_round = 1 _, val_metrics_aggregated = fl_server._handle_result_aggregation(server_round, results, failures) diff --git a/tests/servers/test_polling.py b/tests/servers/test_polling.py index 58fdd26ef..6ea999915 100644 --- a/tests/servers/test_polling.py +++ b/tests/servers/test_polling.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - from flwr.common.typing import Config, GetPropertiesIns from flwr.server.client_proxy import ClientProxy @@ -23,7 +21,7 @@ def test_poll_clients() -> None: clients = [CustomClientProxy(cid, count) for cid, count in zip(client_ids, sample_counts)] config: Config = {"test": 0} ins = GetPropertiesIns(config=config) - clients_instructions: List[Tuple[ClientProxy, GetPropertiesIns]] = [(client, ins) for client in clients] + clients_instructions: list[tuple[ClientProxy, GetPropertiesIns]] = [(client, ins) for client in clients] results, _ = poll_clients(client_instructions=clients_instructions, max_workers=None, timeout=None) diff --git a/tests/smoke_tests/load_from_checkpoint_example/client.py b/tests/smoke_tests/load_from_checkpoint_example/client.py index a21a48de0..c759cd8a8 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/client.py +++ b/tests/smoke_tests/load_from_checkpoint_example/client.py @@ -1,6 +1,6 @@ import argparse +from collections.abc import Sequence from pathlib import Path -from typing import Dict, Optional, Sequence, Tuple import flwr as fl import torch @@ -30,10 +30,10 @@ def __init__( metrics: Sequence[Metric], device: torch.device, loss_meter_type: LossMeterType = LossMeterType.AVERAGE, - checkpoint_and_state_module: Optional[ClientCheckpointAndStateModule] = None, + checkpoint_and_state_module: ClientCheckpointAndStateModule | None = None, reporters: Sequence[BaseReporter] | None = None, progress_bar: bool = False, - client_name: Optional[str] = None, + client_name: str | None = None, seed: int = 42, ) -> None: super().__init__( @@ -48,12 +48,12 @@ def __init__( ) self.seed = seed - def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]: + def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]: batch_size = narrow_dict_type(config, "batch_size", int) train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size) return train_loader, val_loader - def get_test_data_loader(self, config: Config) -> Optional[DataLoader]: + def get_test_data_loader(self, config: Config) -> DataLoader | None: batch_size = narrow_dict_type(config, "batch_size", int) test_loader, _ = load_cifar10_test_data(self.data_path, batch_size) return test_loader @@ -67,7 +67,7 @@ def get_optimizer(self, config: Config) -> Optimizer: def get_model(self, config: Config) -> nn.Module: return Net().to(self.device) - def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]: + def fit(self, parameters: NDArrays, config: Config) -> tuple[NDArrays, int, dict[str, Scalar]]: set_all_random_seeds(self.seed) return super().fit(parameters, config) diff --git a/tests/smoke_tests/load_from_checkpoint_example/server.py b/tests/smoke_tests/load_from_checkpoint_example/server.py index 7b8787751..7057cf55e 100644 --- a/tests/smoke_tests/load_from_checkpoint_example/server.py +++ b/tests/smoke_tests/load_from_checkpoint_example/server.py @@ -1,7 +1,7 @@ import argparse from functools import partial from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any import flwr as fl from flwr.common.typing import Config @@ -28,8 +28,8 @@ def fit_config( batch_size: int, current_server_round: int, - local_epochs: Optional[int] = None, - local_steps: Optional[int] = None, + local_epochs: int | None = None, + local_steps: int | None = None, ) -> Config: return { **make_dict_with_epochs_or_steps(local_epochs, local_steps), @@ -38,7 +38,7 @@ def fit_config( } -def main(config: Dict[str, Any], intermediate_server_state_dir: str, server_name: str) -> None: +def main(config: dict[str, Any], intermediate_server_state_dir: str, server_name: str) -> None: # This function will be used to produce a config that is sent to each client to initialize their own environment fit_config_fn = partial( fit_config, diff --git a/tests/smoke_tests/run_smoke_test.py b/tests/smoke_tests/run_smoke_test.py index 142011521..9a2f1810f 100644 --- a/tests/smoke_tests/run_smoke_test.py +++ b/tests/smoke_tests/run_smoke_test.py @@ -4,7 +4,7 @@ import logging from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any import yaml from flwr.common.typing import Config @@ -26,14 +26,14 @@ async def run_smoke_test( client_python_path: str, config_path: str, dataset_path: str, - checkpoint_path: Optional[str] = None, - assert_evaluation_logs: Optional[bool] = False, + checkpoint_path: str | None = None, + assert_evaluation_logs: bool | None = False, # The param below exists to work around an issue with some clients # not printing the "Current FL Round" log message reliably - skip_assert_client_fl_rounds: Optional[bool] = False, - seed: Optional[int] = None, - server_metrics: Optional[Dict[str, Any]] = None, - client_metrics: Optional[Dict[str, Any]] = None, + skip_assert_client_fl_rounds: bool | None = False, + seed: int | None = None, + server_metrics: dict[str, Any] | None = None, + client_metrics: dict[str, Any] | None = None, ) -> None: """Runs a smoke test for a given server, client, and dataset configuration. @@ -111,18 +111,18 @@ async def run_smoke_test( `batch_size`: the size of the batch, to be used by the dataset preloader dataset_path (str): the path of the dataset. Depending on which dataset is being used, it will ty to preload it to avoid problems when running on different runtimes. - checkpoint_path (Optional[str]): Optional, default None. If set, it will send that path as a checkpoint model + checkpoint_path (str | None): Optional, default None. If set, it will send that path as a checkpoint model to the client. - assert_evaluation_logs (Optional[bool]): Optional, default `False`. Set this to `True` if testing an + assert_evaluation_logs (bool | None): Optional, default `False`. Set this to `True` if testing an evaluation model, which produces different log outputs. - skip_assert_client_fl_rounds (Optional[str]): Optional, default `False`. If set to `True`, will skip the + skip_assert_client_fl_rounds (str | None): Optional, default `False`. If set to `True`, will skip the assertion of the "Current FL Round" message on the clients' logs. This is necessary because some clients (namely client_level_dp, client_level_dp_weighted, instance_level_dp) do not reliably print that message. - seed (Optional[int]): The random seed to be passed in to both the client and the server. - server_metrics (Optional[Dict[str, Any]]): A dictionary of metrics to be checked against the metrics file + seed (int | None): The random seed to be passed in to both the client and the server. + server_metrics (dict[str, Any] | None): A dictionary of metrics to be checked against the metrics file saved by the server. Should be in the same format as fl4health.reporting.metrics.MetricsReporter. Default is None. - client_metrics (Optional[Dict[str, Any]]): A dictionary of metrics to be checked against the metrics file + client_metrics (dict[str, Any] | None): A dictionary of metrics to be checked against the metrics file saved by the clients. Should be in the same format as fl4health.reporting.metrics.MetricsReporter. Default is None. """ @@ -296,9 +296,9 @@ async def run_fault_tolerance_smoke_test( config_path: str, partial_config_path: str, dataset_path: str, - server_metrics: Dict[str, Any], - client_metrics: Dict[str, Any], - seed: Optional[int] = None, + server_metrics: dict[str, Any], + client_metrics: dict[str, Any], + seed: int | None = None, intermediate_checkpoint_dir: str = "./", server_name: str = "server", ) -> None: @@ -323,10 +323,10 @@ async def run_fault_tolerance_smoke_test( dataset_path (str): the path of the dataset. Depending on which dataset is being used, it will ty to preload it to avoid problems when running on different runtimes. intermediate_checkpoint_dir (str): Path to store intermediate checkpoints for server and client. - seed (Optional[int]): The random seed to be passed in to both the client and the server. - server_metrics (Dict[str, Any]): A dictionary of metrics to be checked against the metrics file + seed (int | None): The random seed to be passed in to both the client and the server. + server_metrics (dict[str, Any]): A dictionary of metrics to be checked against the metrics file saved by the server. Should be in the same format as fl4health.reporting.metrics.MetricsReporter. - client_metrics (Dict[str, Any]): A dictionary of metrics to be checked against the metrics file + client_metrics (dict[str, Any]): A dictionary of metrics to be checked against the metrics file saved by the clients. Should be in the same format as fl4health.reporting.metrics.MetricsReporter. """ clear_metrics_folder() @@ -454,7 +454,7 @@ async def run_fault_tolerance_smoke_test( logger.info("All checks passed. Test finished.") -def _preload_dataset(dataset_path: str, config: Config, seed: Optional[int] = None) -> None: +def _preload_dataset(dataset_path: str, config: Config, seed: int | None = None) -> None: if "mnist" in dataset_path: logger.info("Preloading MNIST dataset...") @@ -538,8 +538,8 @@ class MetricType(Enum): DEFAULT_TOLERANCE = 0.0005 -def _assert_metrics(metric_type: MetricType, metrics_to_assert: Optional[Dict[str, Any]] = None) -> List[str]: - errors: List[str] = [] +def _assert_metrics(metric_type: MetricType, metrics_to_assert: dict[str, Any] | None = None) -> list[str]: + errors: list[str] = [] if metrics_to_assert is None: return errors @@ -563,10 +563,10 @@ def _assert_metrics(metric_type: MetricType, metrics_to_assert: Optional[Dict[st return errors -def _assert_metrics_dict(metrics_to_assert: Dict[str, Any], metrics_saved: Dict[str, Any]) -> List[str]: +def _assert_metrics_dict(metrics_to_assert: dict[str, Any], metrics_saved: dict[str, Any]) -> list[str]: errors = [] - def _assert(value: Any, saved_value: Any) -> Optional[str]: + def _assert(value: Any, saved_value: Any) -> str | None: # helper function to avoid code repetition tolerance = DEFAULT_TOLERANCE if isinstance(value, dict): @@ -620,7 +620,7 @@ def clear_metrics_folder() -> None: f.unlink() -def load_metrics_from_file(file_path: str) -> Dict[str, Any]: +def load_metrics_from_file(file_path: str) -> dict[str, Any]: with open(file_path, "r") as f: return json.load(f) diff --git a/tests/strategies/test_basic_fedavg.py b/tests/strategies/test_basic_fedavg.py index 6719cdf13..08618e948 100644 --- a/tests/strategies/test_basic_fedavg.py +++ b/tests/strategies/test_basic_fedavg.py @@ -1,5 +1,4 @@ import random -from typing import List, Tuple import numpy as np from flwr.common import ( @@ -33,7 +32,7 @@ def construct_fit_res(parameters: NDArrays, metric: float, num_examples: int) -> client1_res = construct_fit_res([np.ones((3, 3)), np.full((4, 4), 2)], 0.2, 50) client2_res = construct_fit_res([np.full((3, 3), 3), np.full((4, 4), 3)], 0.3, 100) client3_res = construct_fit_res([np.full((3, 3), 4), np.full((4, 4), 4)], 0.4, 200) -clients_res: List[Tuple[ClientProxy, FitRes]] = [ +clients_res: list[tuple[ClientProxy, FitRes]] = [ (CustomClientProxy("c0"), client0_res), (CustomClientProxy("c1"), client1_res), (CustomClientProxy("c2"), client2_res), @@ -41,7 +40,7 @@ def construct_fit_res(parameters: NDArrays, metric: float, num_examples: int) -> ] -def metrics_aggregation(to_aggregate: List[Tuple[int, Metrics]]) -> Metrics: +def metrics_aggregation(to_aggregate: list[tuple[int, Metrics]]) -> Metrics: # Select last set of metrics (dummy for test) return to_aggregate[-1][1] @@ -104,7 +103,7 @@ def construct_evaluate_res(loss: float, metric: float, num_examples: int) -> Eva client1_eval_res = construct_evaluate_res(1.0, 0.2, 50) client2_eval_res = construct_evaluate_res(3.0, 0.3, 100) client3_eval_res = construct_evaluate_res(4.0, 0.4, 200) -clients_eval_res: List[Tuple[ClientProxy, EvaluateRes]] = [ +clients_eval_res: list[tuple[ClientProxy, EvaluateRes]] = [ (CustomClientProxy("c0"), client0_eval_res), (CustomClientProxy("c1"), client1_eval_res), (CustomClientProxy("c2"), client2_eval_res), diff --git a/tests/strategies/test_client_dp_fedavgm.py b/tests/strategies/test_client_dp_fedavgm.py index fe0544631..0cc342b31 100644 --- a/tests/strategies/test_client_dp_fedavgm.py +++ b/tests/strategies/test_client_dp_fedavgm.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - import numpy as np import pytest from flwr.common import Code, FitRes, NDArrays, Parameters, Status, ndarrays_to_parameters @@ -31,7 +29,7 @@ def construct_fit_res(parameters: NDArrays, metric: float, num_examples: int) -> client1_res = construct_fit_res([np.ones((3, 3)), np.full((4, 4), 2)], 0.2, 50) client2_res = construct_fit_res([np.full((3, 3), 3), np.full((4, 4), 3)], 0.3, 100) client3_res = construct_fit_res([np.full((3, 3), 4), np.full((4, 4), 4)], 0.4, 200) -clients_res: List[Tuple[ClientProxy, FitRes]] = [ +clients_res: list[tuple[ClientProxy, FitRes]] = [ (CustomClientProxy("c0"), client0_res), (CustomClientProxy("c1"), client1_res), (CustomClientProxy("c2"), client2_res), @@ -80,7 +78,7 @@ def test_unpacking_weights_and_clipping_bits() -> None: n_layers = 4 n_clients = 3 n_client_datapoints = 10 - fit_res_results: List[FitRes] = [ + fit_res_results: list[FitRes] = [ construct_fit_res( [np.random.rand(2, 3) for _ in range(n_layers)] + [np.random.binomial(1, 0.5, 1).astype(float)], 0.1, @@ -88,7 +86,7 @@ def test_unpacking_weights_and_clipping_bits() -> None: ) for _ in range(n_clients) ] - results: List[Tuple[ClientProxy, FitRes]] = list( + results: list[tuple[ClientProxy, FitRes]] = list( zip( [CustomClientProxy("c0"), CustomClientProxy("c1"), CustomClientProxy("c2"), CustomClientProxy("c3")], fit_res_results, diff --git a/tests/strategies/test_fedavg_sparse_coo_tensor.py b/tests/strategies/test_fedavg_sparse_coo_tensor.py index 185638d0d..dacb2d3e8 100644 --- a/tests/strategies/test_fedavg_sparse_coo_tensor.py +++ b/tests/strategies/test_fedavg_sparse_coo_tensor.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - import numpy as np import torch from flwr.common import NDArray, NDArrays @@ -13,7 +11,7 @@ total_train_size = sum(client_train_sizes) -def create_coo_tensor_diagonal(n: int, all_ones: bool) -> Tuple[NDArray, NDArray, NDArray]: +def create_coo_tensor_diagonal(n: int, all_ones: bool) -> tuple[NDArray, NDArray, NDArray]: if all_ones: parameters = np.array([1 for _ in range(1, n + 1)]) else: @@ -25,8 +23,8 @@ def create_coo_tensor_diagonal(n: int, all_ones: bool) -> Tuple[NDArray, NDArray def create_client_parameters( - num_tensors: int, sizes: List[int], all_ones_lst: List[bool] -) -> Tuple[NDArrays, NDArrays, NDArrays]: + num_tensors: int, sizes: list[int], all_ones_lst: list[bool] +) -> tuple[NDArrays, NDArrays, NDArrays]: assert len(sizes) == num_tensors and len(sizes) == len(all_ones_lst) client_parameters = [] client_indices = [] diff --git a/tests/strategies/test_feddg_ga.py b/tests/strategies/test_feddg_ga.py index a562a2707..0166ce4ff 100644 --- a/tests/strategies/test_feddg_ga.py +++ b/tests/strategies/test_feddg_ga.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import Dict, List, Tuple from unittest.mock import Mock import numpy as np @@ -17,14 +16,14 @@ def test_configure_fit_and_evaluate_success() -> None: fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) test_n_server_rounds = 3 - def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": test_n_server_rounds, "evaluate_after_fit": True, "pack_losses_with_val_metrics": True, } - def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + def on_evaluate_config_fn(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": test_n_server_rounds, "pack_losses_with_val_metrics": True, @@ -53,7 +52,7 @@ def test_configure_fit_fail() -> None: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with bad client manager type - def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": True, @@ -65,7 +64,7 @@ def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), simple_client_manager) # Fail with no n_server_rounds - def on_fit_config_fn_1(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_1(server_round: int) -> dict[str, Scalar]: return { "foo": 123, "evaluate_after_fit": True, @@ -79,7 +78,7 @@ def on_fit_config_fn_1(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with n_server_rounds not being an integer - def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_2(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 1.1, "evaluate_after_fit": True, @@ -93,7 +92,7 @@ def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with evaluate_after_fit not being set - def on_fit_config_fn_3(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_3(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "pack_losses_with_val_metrics": True, @@ -104,7 +103,7 @@ def on_fit_config_fn_3(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with evaluate_after_fit not being True - def on_fit_config_fn_4(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_4(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": False, @@ -116,7 +115,7 @@ def on_fit_config_fn_4(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with pack_losses_with_val_metrics not being there - def on_fit_config_fn_5(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_5(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": True, @@ -127,7 +126,7 @@ def on_fit_config_fn_5(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with pack_losses_with_val_metrics not being True - def on_fit_config_fn_6(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_6(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": True, @@ -149,7 +148,7 @@ def test_configure_evaluate_fail() -> None: strategy.configure_evaluate(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with bad client manager type - def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + def on_evaluate_config_fn(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "pack_losses_with_val_metrics": True, @@ -160,7 +159,7 @@ def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: strategy.configure_evaluate(1, Parameters([], ""), simple_client_manager) # Fail with no pack_losses_with_val_metrics - def on_evaluate_config_fn_1(server_round: int) -> Dict[str, Scalar]: + def on_evaluate_config_fn_1(server_round: int) -> dict[str, Scalar]: return { "foo": 123, } @@ -170,7 +169,7 @@ def on_evaluate_config_fn_1(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with pack_losses_with_val_metrics not being True - def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_2(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 1.1, "pack_losses_with_val_metrics": False, @@ -345,12 +344,12 @@ def _apply_mocks_to_client_manager(client_manager: ClientManager) -> ClientManag return client_manager -def _make_test_data() -> Tuple[List[Tuple[ClientProxy, FitRes]], List[Tuple[ClientProxy, EvaluateRes]]]: +def _make_test_data() -> tuple[list[tuple[ClientProxy, FitRes]], list[tuple[ClientProxy, EvaluateRes]]]: test_val_loss_key = FairnessMetricType.LOSS.value - test_fit_metrics_1: Dict[str, Scalar] = {test_val_loss_key: 1.0} - test_fit_metrics_2: Dict[str, Scalar] = {test_val_loss_key: 2.0} - test_eval_metrics_1: Dict[str, Scalar] = {"metric-1": 1.0, test_val_loss_key: 1.2} - test_eval_metrics_2: Dict[str, Scalar] = {"metric-2": 2.0, test_val_loss_key: 2.2} + test_fit_metrics_1: dict[str, Scalar] = {test_val_loss_key: 1.0} + test_fit_metrics_2: dict[str, Scalar] = {test_val_loss_key: 2.0} + test_eval_metrics_1: dict[str, Scalar] = {"metric-1": 1.0, test_val_loss_key: 1.2} + test_eval_metrics_2: dict[str, Scalar] = {"metric-2": 2.0, test_val_loss_key: 2.2} test_parameters_1 = ndarrays_to_parameters([np.array([1.0, 1.1])]) test_parameters_2 = ndarrays_to_parameters([np.array([2.0, 2.1])]) test_fit_results = [ diff --git a/tests/strategies/test_feddg_ga_with_adapt_constraint.py b/tests/strategies/test_feddg_ga_with_adapt_constraint.py index 6797a4c49..cad1da36a 100644 --- a/tests/strategies/test_feddg_ga_with_adapt_constraint.py +++ b/tests/strategies/test_feddg_ga_with_adapt_constraint.py @@ -1,5 +1,4 @@ from copy import deepcopy -from typing import Dict, List, Tuple from unittest.mock import Mock import numpy as np @@ -20,14 +19,14 @@ def test_configure_fit_and_evaluate_success() -> None: fixed_sampling_client_manager = _apply_mocks_to_client_manager(FixedSamplingClientManager()) test_n_server_rounds = 3 - def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": test_n_server_rounds, "evaluate_after_fit": True, "pack_losses_with_val_metrics": True, } - def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + def on_evaluate_config_fn(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": test_n_server_rounds, "pack_losses_with_val_metrics": True, @@ -60,7 +59,7 @@ def test_configure_fit_fail() -> None: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with bad client manager type - def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": True, @@ -72,7 +71,7 @@ def on_fit_config_fn(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), simple_client_manager) # Fail with no n_server_rounds - def on_fit_config_fn_1(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_1(server_round: int) -> dict[str, Scalar]: return { "foo": 123, "evaluate_after_fit": True, @@ -86,7 +85,7 @@ def on_fit_config_fn_1(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with n_server_rounds not being an integer - def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_2(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 1.1, "evaluate_after_fit": True, @@ -100,7 +99,7 @@ def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with evaluate_after_fit not being set - def on_fit_config_fn_3(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_3(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "pack_losses_with_val_metrics": True, @@ -111,7 +110,7 @@ def on_fit_config_fn_3(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with evaluate_after_fit not being True - def on_fit_config_fn_4(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_4(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": False, @@ -123,7 +122,7 @@ def on_fit_config_fn_4(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with pack_losses_with_val_metrics not being there - def on_fit_config_fn_5(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_5(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": True, @@ -134,7 +133,7 @@ def on_fit_config_fn_5(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with pack_losses_with_val_metrics not being True - def on_fit_config_fn_6(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_6(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "evaluate_after_fit": True, @@ -156,7 +155,7 @@ def test_configure_evaluate_fail() -> None: strategy.configure_evaluate(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with bad client manager type - def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: + def on_evaluate_config_fn(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 2, "pack_losses_with_val_metrics": True, @@ -169,7 +168,7 @@ def on_evaluate_config_fn(server_round: int) -> Dict[str, Scalar]: strategy.configure_evaluate(1, Parameters([], ""), simple_client_manager) # Fail with no pack_losses_with_val_metrics - def on_evaluate_config_fn_1(server_round: int) -> Dict[str, Scalar]: + def on_evaluate_config_fn_1(server_round: int) -> dict[str, Scalar]: return { "foo": 123, } @@ -181,7 +180,7 @@ def on_evaluate_config_fn_1(server_round: int) -> Dict[str, Scalar]: strategy.configure_fit(1, Parameters([], ""), fixed_sampling_client_manager) # Fails with pack_losses_with_val_metrics not being True - def on_fit_config_fn_2(server_round: int) -> Dict[str, Scalar]: + def on_fit_config_fn_2(server_round: int) -> dict[str, Scalar]: return { "n_server_rounds": 1.1, "pack_losses_with_val_metrics": False, @@ -391,12 +390,12 @@ def _apply_mocks_to_client_manager(client_manager: ClientManager) -> ClientManag return client_manager -def _make_test_data() -> Tuple[List[Tuple[ClientProxy, FitRes]], List[Tuple[ClientProxy, EvaluateRes]]]: +def _make_test_data() -> tuple[list[tuple[ClientProxy, FitRes]], list[tuple[ClientProxy, EvaluateRes]]]: test_val_loss_key = FairnessMetricType.LOSS.value - test_fit_metrics_1: Dict[str, Scalar] = {test_val_loss_key: 1.0} - test_fit_metrics_2: Dict[str, Scalar] = {test_val_loss_key: 2.0} - test_eval_metrics_1: Dict[str, Scalar] = {"metric-1": 1.0, test_val_loss_key: 1.2} - test_eval_metrics_2: Dict[str, Scalar] = {"metric-2": 2.0, test_val_loss_key: 2.2} + test_fit_metrics_1: dict[str, Scalar] = {test_val_loss_key: 1.0} + test_fit_metrics_2: dict[str, Scalar] = {test_val_loss_key: 2.0} + test_eval_metrics_1: dict[str, Scalar] = {"metric-1": 1.0, test_val_loss_key: 1.2} + test_eval_metrics_2: dict[str, Scalar] = {"metric-2": 2.0, test_val_loss_key: 2.2} test_parameters_1 = ndarrays_to_parameters([np.array([1.0, 1.1]), np.array(1.5)]) test_parameters_2 = ndarrays_to_parameters([np.array([2.0, 2.1]), np.array(2.5)]) test_fit_results = [ diff --git a/tests/strategies/test_flash.py b/tests/strategies/test_flash.py index 36c4b0ec9..c16291d4b 100644 --- a/tests/strategies/test_flash.py +++ b/tests/strategies/test_flash.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - import numpy as np from flwr.common import Code, FitRes, Metrics, NDArrays, Status, ndarrays_to_parameters from flwr.server.client_proxy import ClientProxy @@ -21,7 +19,7 @@ def construct_fit_res(parameters: NDArrays, metric: float, num_examples: int) -> client1_res = construct_fit_res([np.ones((3, 3)), np.full((4, 4), 2)], 0.2, 50) client2_res = construct_fit_res([np.full((3, 3), 3), np.full((4, 4), 3)], 0.3, 100) client3_res = construct_fit_res([np.full((3, 3), 4), np.full((4, 4), 4)], 0.4, 200) -clients_res_1: List[Tuple[ClientProxy, FitRes]] = [ +clients_res_1: list[tuple[ClientProxy, FitRes]] = [ (CustomClientProxy("c0"), client0_res), (CustomClientProxy("c1"), client1_res), (CustomClientProxy("c2"), client2_res), @@ -32,7 +30,7 @@ def construct_fit_res(parameters: NDArrays, metric: float, num_examples: int) -> client1_res = construct_fit_res([np.full((3, 3), 2.5), np.full((4, 4), 2.5)], 0.25, 60) client2_res = construct_fit_res([np.full((3, 3), 3.5), np.full((4, 4), 3.5)], 0.35, 110) client3_res = construct_fit_res([np.full((3, 3), 4.5), np.full((4, 4), 4.5)], 0.45, 210) -clients_res_2: List[Tuple[ClientProxy, FitRes]] = [ +clients_res_2: list[tuple[ClientProxy, FitRes]] = [ (CustomClientProxy("c0"), client0_res), (CustomClientProxy("c1"), client1_res), (CustomClientProxy("c2"), client2_res), @@ -40,7 +38,7 @@ def construct_fit_res(parameters: NDArrays, metric: float, num_examples: int) -> ] -def metrics_aggregation(to_aggregate: List[Tuple[int, Metrics]]) -> Metrics: +def metrics_aggregation(to_aggregate: list[tuple[int, Metrics]]) -> Metrics: # Select last set of metrics (dummy for test) return to_aggregate[-1][1] diff --git a/tests/strategies/test_model_merge_strategy.py b/tests/strategies/test_model_merge_strategy.py index f5a946362..479c840a6 100644 --- a/tests/strategies/test_model_merge_strategy.py +++ b/tests/strategies/test_model_merge_strategy.py @@ -1,5 +1,4 @@ import random -from typing import List, Tuple import numpy as np from flwr.common import ( @@ -33,7 +32,7 @@ def construct_fit_res(parameters: NDArrays, metric: float, num_examples: int) -> client1_res = construct_fit_res([np.ones((3, 3)), np.full((4, 4), 2)], 0.2, 50) client2_res = construct_fit_res([np.full((3, 3), 3), np.full((4, 4), 3)], 0.3, 100) client3_res = construct_fit_res([np.full((3, 3), 4), np.full((4, 4), 4)], 0.4, 200) -clients_res: List[Tuple[ClientProxy, FitRes]] = [ +clients_res: list[tuple[ClientProxy, FitRes]] = [ (CustomClientProxy("c0"), client0_res), (CustomClientProxy("c1"), client1_res), (CustomClientProxy("c2"), client2_res), @@ -41,7 +40,7 @@ def construct_fit_res(parameters: NDArrays, metric: float, num_examples: int) -> ] -def metrics_aggregation(to_aggregate: List[Tuple[int, Metrics]]) -> Metrics: +def metrics_aggregation(to_aggregate: list[tuple[int, Metrics]]) -> Metrics: # Select last set of metrics (dummy for test) return to_aggregate[-1][1] @@ -102,7 +101,7 @@ def construct_evaluate_res(loss: float, metric: float, num_examples: int) -> Eva client1_eval_res = construct_evaluate_res(1.0, 0.2, 50) client2_eval_res = construct_evaluate_res(3.0, 0.3, 100) client3_eval_res = construct_evaluate_res(4.0, 0.4, 200) -clients_eval_res: List[Tuple[ClientProxy, EvaluateRes]] = [ +clients_eval_res: list[tuple[ClientProxy, EvaluateRes]] = [ (CustomClientProxy("c0"), client0_eval_res), (CustomClientProxy("c1"), client1_eval_res), (CustomClientProxy("c2"), client2_eval_res), diff --git a/tests/test_picai/test_preprocess_transforms.py b/tests/test_picai/test_preprocess_transforms.py index ada9e3a5f..0a62664b7 100644 --- a/tests/test_picai/test_preprocess_transforms.py +++ b/tests/test_picai/test_preprocess_transforms.py @@ -1,6 +1,6 @@ +from collections.abc import Sequence from functools import partial from pathlib import Path -from typing import Sequence, Tuple import numpy as np import SimpleITK as sitk @@ -29,8 +29,8 @@ def __init__( scan_paths: Sequence[Path], annotation_path: Path, settings: PreprocessingSettings, - original_scan_sizes: Sequence[Tuple[int, int, int]] = [(256, 256, 20), (512, 512, 40), (128, 128, 10)], - original_annotation_size: Tuple[int, int, int] = (384, 384, 30), + original_scan_sizes: Sequence[tuple[int, int, int]] = [(256, 256, 20), (512, 512, 40), (128, 128, 10)], + original_annotation_size: tuple[int, int, int] = (384, 384, 30), ) -> None: super().__init__(scan_paths, annotation_path, settings) self.original_scan_sizes = original_scan_sizes @@ -41,7 +41,7 @@ def read(self) -> None: self.scans = [sitk.GetImageFromArray(scan) for scan in np_scans] self.annotation = sitk.GetImageFromArray(np.random.randint(0, 5, self.original_annotation_size)) - def write(self) -> Tuple[Sequence[Path], Path]: + def write(self) -> tuple[Sequence[Path], Path]: return ([Path("") for _ in range(3)], Path("")) diff --git a/tests/test_utils/assert_metrics_dict.py b/tests/test_utils/assert_metrics_dict.py index e1ec881d8..58fe3ebdb 100644 --- a/tests/test_utils/assert_metrics_dict.py +++ b/tests/test_utils/assert_metrics_dict.py @@ -1,11 +1,11 @@ -from typing import Any, Optional +from typing import Any from pytest import approx DEFAULT_TOLERANCE = 0.0005 -def _assert(value: Any, saved_value: Any, metric_key: str, tolerance: float = DEFAULT_TOLERANCE) -> Optional[str]: +def _assert(value: Any, saved_value: Any, metric_key: str, tolerance: float = DEFAULT_TOLERANCE) -> str | None: # helper function to avoid code repetition if isinstance(value, dict): # if the value is a dictionary, extract the target value and the custom tolerance diff --git a/tests/test_utils/custom_client_proxy.py b/tests/test_utils/custom_client_proxy.py index 14775d963..6c8dfaa86 100644 --- a/tests/test_utils/custom_client_proxy.py +++ b/tests/test_utils/custom_client_proxy.py @@ -1,5 +1,3 @@ -from typing import Optional - from flwr.common.typing import ( Code, DisconnectRes, @@ -27,8 +25,8 @@ def __init__(self, cid: str, num_samples: int = 1): def get_properties( self, ins: GetPropertiesIns, - timeout: Optional[float], - group_id: Optional[int], + timeout: float | None, + group_id: int | None, ) -> GetPropertiesRes: status: Status = Status(code=Code["OK"], message="Test") res = GetPropertiesRes(status=status, properties=self.properties) @@ -37,31 +35,31 @@ def get_properties( def get_parameters( self, ins: GetParametersIns, - timeout: Optional[float], - group_id: Optional[int], + timeout: float | None, + group_id: int | None, ) -> GetParametersRes: raise NotImplementedError def fit( self, ins: FitIns, - timeout: Optional[float], - group_id: Optional[int], + timeout: float | None, + group_id: int | None, ) -> FitRes: raise NotImplementedError def evaluate( self, ins: EvaluateIns, - timeout: Optional[float], - group_id: Optional[int], + timeout: float | None, + group_id: int | None, ) -> EvaluateRes: raise NotImplementedError def reconnect( self, ins: ReconnectIns, - timeout: Optional[float], - group_id: Optional[int], + timeout: float | None, + group_id: int | None, ) -> DisconnectRes: raise NotImplementedError diff --git a/tests/test_utils/models_for_test.py b/tests/test_utils/models_for_test.py index b7e1ad147..c889226ab 100644 --- a/tests/test_utils/models_for_test.py +++ b/tests/test_utils/models_for_test.py @@ -1,5 +1,3 @@ -from typing import List, Optional, Tuple, Union - import torch import torch.nn as nn import torch.nn.functional as F @@ -189,15 +187,15 @@ def __init__( dimensions: int = 3, num_encoding_blocks: int = 3, out_channels_first_layer: int = 8, - normalization: Optional[str] = "batch", + normalization: str | None = "batch", pooling_type: str = "max", upsampling_type: str = "linear", preactivation: bool = False, residual: bool = False, padding: int = 1, padding_mode: str = "zeros", - activation: Optional[str] = "PReLU", - initial_dilation: Optional[int] = None, + activation: str | None = "PReLU", + initial_dilation: int | None = None, dropout: float = 0, monte_carlo_dropout: float = 0, ): @@ -304,13 +302,13 @@ def __init__( dimensions: int, in_channels: int, out_channels: int, - normalization: Optional[str] = None, + normalization: str | None = None, kernel_size: int = 3, - activation: Optional[str] = "ReLU", - preactivation: Optional[bool] = False, + activation: str | None = "ReLU", + preactivation: bool | None = False, padding: int = 0, padding_mode: str = "zeros", - dilation: Optional[int] = None, + dilation: int | None = None, dropout: float = 0, ): super().__init__() @@ -373,7 +371,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.block(x) @staticmethod - def add_if_not_none(module_list: nn.ModuleList, module: Optional[nn.Module]) -> None: + def add_if_not_none(module_list: nn.ModuleList, module: nn.Module | None) -> None: if module is not None: module_list.append(module) @@ -391,13 +389,13 @@ def __init__( dimensions: int, upsampling_type: str, num_decoding_blocks: int, - normalization: Optional[str], + normalization: str | None, preactivation: bool = False, residual: bool = False, padding: int = 0, padding_mode: str = "zeros", - activation: Optional[str] = "ReLU", - initial_dilation: Optional[int] = None, + activation: str | None = "ReLU", + initial_dilation: int | None = None, dropout: float = 0, ): super().__init__() @@ -423,7 +421,7 @@ def __init__( if self.dilation is not None: self.dilation //= 2 - def forward(self, skip_connections: List[torch.Tensor], x: torch.Tensor) -> torch.Tensor: + def forward(self, skip_connections: list[torch.Tensor], x: torch.Tensor) -> torch.Tensor: zipped = zip(reversed(skip_connections), self.decoding_blocks) for skip_connection, decoding_block in zipped: x = decoding_block(skip_connection, x) @@ -436,13 +434,13 @@ def __init__( in_channels_skip_connection: int, dimensions: int, upsampling_type: str, - normalization: Optional[str], + normalization: str | None, preactivation: bool = True, residual: bool = False, padding: int = 0, padding_mode: str = "zeros", - activation: Optional[str] = "ReLU", - dilation: Optional[int] = None, + activation: str | None = "ReLU", + dilation: int | None = None, dropout: float = 0, ): super().__init__() @@ -554,13 +552,13 @@ def __init__( dimensions: int, pooling_type: str, num_encoding_blocks: int, - normalization: Optional[str], + normalization: str | None, preactivation: bool = False, residual: bool = False, padding: int = 0, padding_mode: str = "zeros", - activation: Optional[str] = "ReLU", - initial_dilation: Optional[int] = None, + activation: str | None = "ReLU", + initial_dilation: int | None = None, dropout: float = 0, ): super().__init__() @@ -595,8 +593,8 @@ def __init__( if self.dilation is not None: self.dilation *= 2 - def forward(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]: - skip_connections: List[torch.Tensor] = [] + def forward(self, x: torch.Tensor) -> tuple[list[torch.Tensor], torch.Tensor]: + skip_connections: list[torch.Tensor] = [] for encoding_block in self.encoding_blocks: x, skip_connection = encoding_block(x) skip_connections.append(skip_connection) @@ -613,21 +611,21 @@ def __init__( in_channels: int, out_channels_first: int, dimensions: int, - normalization: Optional[str], - pooling_type: Optional[str], - preactivation: Optional[bool] = False, + normalization: str | None, + pooling_type: str | None, + preactivation: bool | None = False, is_first_block: bool = False, residual: bool = False, padding: int = 0, padding_mode: str = "zeros", - activation: Optional[str] = "ReLU", - dilation: Optional[int] = None, + activation: str | None = "ReLU", + dilation: int | None = None, dropout: float = 0, ): super().__init__() - self.preactivation: Optional[bool] = preactivation - self.normalization: Optional[str] = normalization + self.preactivation: bool | None = preactivation + self.normalization: str | None = normalization self.residual = residual @@ -681,7 +679,7 @@ def __init__( if pooling_type is not None: self.downsample = get_downsampling_layer(dimensions, pooling_type) - def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if self.residual: connection = self.conv_residual(x) x = self.conv1(x) @@ -710,7 +708,7 @@ def get_downsampling_layer(dimensions: int, pooling_type: str, kernel_size: int # Autoencoder: encoder and decoder units class VariationalEncoder(nn.Module): - def __init__(self, embedding_size: int = 2, condition_vector_size: Optional[int] = None) -> None: + def __init__(self, embedding_size: int = 2, condition_vector_size: int | None = None) -> None: super().__init__() if condition_vector_size is not None: self.fc_mu = nn.Linear(100 + condition_vector_size, embedding_size) @@ -719,28 +717,28 @@ def __init__(self, embedding_size: int = 2, condition_vector_size: Optional[int] self.fc_mu = nn.Linear(100, embedding_size) self.fc_logvar = nn.Linear(100, embedding_size) - def forward(self, x: torch.Tensor, condition: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor, condition: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor]: if condition is not None: return self.fc_mu(torch.cat((x, condition), dim=-1)), self.fc_logvar(torch.cat((x, condition), dim=-1)) return self.fc_mu(x), self.fc_logvar(x) class VariationalDecoder(nn.Module): - def __init__(self, embedding_size: int = 2, condition_vector_size: Optional[int] = None) -> None: + def __init__(self, embedding_size: int = 2, condition_vector_size: int | None = None) -> None: super().__init__() if condition_vector_size is not None: self.linear = nn.Linear(embedding_size + condition_vector_size, 100) else: self.linear = nn.Linear(embedding_size, 100) - def forward(self, x: torch.Tensor, condition: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward(self, x: torch.Tensor, condition: torch.Tensor | None = None) -> torch.Tensor: if condition is not None: return self.linear(torch.cat((x, condition), dim=-1)) return self.linear(x) class ConstantConvNet(nn.Module): - def __init__(self, constants: List[float]) -> None: + def __init__(self, constants: list[float]) -> None: assert len(constants) == 4 super().__init__() self.conv1 = nn.Conv2d(1, 6, 5, bias=False) diff --git a/tests/utils/dataset_converter_test.py b/tests/utils/dataset_converter_test.py index ccef3e220..8dc7d7a3e 100644 --- a/tests/utils/dataset_converter_test.py +++ b/tests/utils/dataset_converter_test.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch from torch.utils.data import DataLoader @@ -17,7 +15,7 @@ def test_dataset_converter() -> None: dummy_dataset = get_dummy_dataset() # Create a dummy converter function for testing - def dummy_converter(data: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def dummy_converter(data: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return data, target # Test DatasetConverter diff --git a/tests/utils/functions_test.py b/tests/utils/functions_test.py index 1f8b024b9..1ada2863c 100644 --- a/tests/utils/functions_test.py +++ b/tests/utils/functions_test.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - import numpy as np import pytest import torch @@ -79,7 +77,7 @@ def test_decode_and_pseudo_sort_results() -> None: client0_res = construct_fit_res([np.ones((3, 3)), np.ones((4, 4))], 0.1, 100) client1_res = construct_fit_res([np.ones((3, 3)), np.full((4, 4), 2.0)], 0.2, 75) client2_res = construct_fit_res([np.full((3, 3), 3.0), np.full((4, 4), 3.0)], 0.3, 50) - clients_res: List[Tuple[ClientProxy, FitRes]] = [ + clients_res: list[tuple[ClientProxy, FitRes]] = [ (CustomClientProxy("c0"), client0_res), (CustomClientProxy("c1"), client1_res), (CustomClientProxy("c2"), client2_res), diff --git a/tests/utils/metric_aggregation_test.py b/tests/utils/metric_aggregation_test.py index b10564821..149270663 100644 --- a/tests/utils/metric_aggregation_test.py +++ b/tests/utils/metric_aggregation_test.py @@ -1,5 +1,3 @@ -from typing import List, Tuple - from flwr.common.typing import Metrics from fl4health.utils.metric_aggregation import ( @@ -15,18 +13,18 @@ def test_metric_aggregation() -> None: n_clients = 10 int_metric_counts = [10 for _ in range(n_clients)] - int_metric_vals: List[Metrics] = [] + int_metric_vals: list[Metrics] = [] for i in range(n_clients): metric: Metrics = {"score": 1} if i < 5 else {"score": 2} int_metric_vals.append(metric) - int_metrics: List[Tuple[int, Metrics]] = [(count, val) for count, val in zip(int_metric_counts, int_metric_vals)] + int_metrics: list[tuple[int, Metrics]] = [(count, val) for count, val in zip(int_metric_counts, int_metric_vals)] int_total_examples, int_aggregated_metrics = metric_aggregation(int_metrics) assert int_total_examples == sum(int_metric_counts) gt_int_metrics: Metrics = {"score": 150.0} assert int_aggregated_metrics == gt_int_metrics float_metric_counts = [20 for _ in range(n_clients)] - float_metric_vals: List[Metrics] = [] + float_metric_vals: list[Metrics] = [] for i in range(n_clients): float_metric: Metrics = {"score": float(i)} float_metric_vals.append(float_metric) @@ -49,7 +47,7 @@ def test_normalize_metrics() -> None: def test_fit_metrics_aggregation_fn() -> None: n_clients = 10 metric_counts = [10 for _ in range(n_clients)] - metric_vals: List[Metrics] = [] + metric_vals: list[Metrics] = [] for i in range(n_clients): metric: Metrics = {"score": float(i)} metric_vals.append(metric) @@ -63,7 +61,7 @@ def test_fit_metrics_aggregation_fn() -> None: def test_evaluate_metrics_aggregation_fn() -> None: n_clients = 5 metric_counts = [20 for _ in range(n_clients)] - metric_vals: List[Metrics] = [] + metric_vals: list[Metrics] = [] for i in range(n_clients): metric: Metrics = {"score": float(i)} metric_vals.append(metric) @@ -76,7 +74,7 @@ def test_evaluate_metrics_aggregation_fn() -> None: def test_uniform_metric_aggregation() -> None: client_sample_counts = [100, 200, 100, 200, 100] vals = [5.0, 10.0, 20.0, 10.0, 10.0] - client_metric_vals: List[Tuple[int, Metrics]] = [] + client_metric_vals: list[tuple[int, Metrics]] = [] for count, val in zip(client_sample_counts, vals): client_metrics: Metrics = {"score": val} client_metric_vals.append((count, client_metrics)) @@ -89,7 +87,7 @@ def test_uniform_metric_aggregation() -> None: def test_uniform_evaluate_metrics_aggregation_fn() -> None: client_sample_counts = [100, 200, 100, 200, 100] vals = [5.0, 10.0, 20.0, 10.0, 10.0] - client_metric_vals: List[Tuple[int, Metrics]] = [] + client_metric_vals: list[tuple[int, Metrics]] = [] for count, val in zip(client_sample_counts, vals): client_metrics: Metrics = {"score": val} client_metric_vals.append((count, client_metrics)) diff --git a/tests/utils/sampler_test.py b/tests/utils/sampler_test.py index ae972310d..eb531c2be 100644 --- a/tests/utils/sampler_test.py +++ b/tests/utils/sampler_test.py @@ -1,5 +1,3 @@ -from typing import Tuple - import numpy as np import pytest import torch @@ -11,7 +9,7 @@ from fl4health.utils.sampler import DirichletLabelBasedSampler, MinorityLabelBasedSampler -def construct_synthetic_dataset() -> Tuple[SyntheticDataset, SyntheticDataset]: +def construct_synthetic_dataset() -> tuple[SyntheticDataset, SyntheticDataset]: # set seed for creation torch.manual_seed(42) random_inputs = torch.rand((20000, 3, 3))