Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mypy Typing Migration and Some typo fixes #299

Merged
merged 17 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ repos:
- id: nbqa-flake8
- id: nbqa-mypy

- repo: local
hooks:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool :)

- id: mypy legacy type check
name: mypy legacy type check
entry: python mypy_disallow_legacy_types.py
language: python
pass_filenames: true

ci:
autofix_commit_msg: |
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
Expand Down
11 changes: 11 additions & 0 deletions CONTRIBUTING.MD
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ The settings for `mypy` are in the `mypy.ini`, settings for `flake8` are contain

All of these checks and formatters are invoked by pre-commit hooks. These hooks are run remotely on GitHub. In order to ensure that your code conforms to these standards, and, therefore, passes the remote checks, you can install the pre-commit hooks to be run locally. This is done by running (with your environment active)

**Note**: We use the modern mypy types introduced in Python 3.10 and above. See some of the [documentation here](https://mypy.readthedocs.io/en/stable/builtin_types.html)

For example, this means that we're using `list[str], tuple[int, int], tuple[int, ...], dict[str, int], type[C]` as built-in types and `Iterable[int], Sequence[bool], Mapping[str, int], Callable[[...], ...]` from collections.abc (as now recommended by mypy).

We are also moving to the new Optional and Union specification style:
```python
Optional[typing_stuff] -> typing_stuff | None
Union[typing1, typing2] -> typing1 | typing2
Optional[Union[typing1, typing2]] -> typing1 | typing2 | None
```

```bash
pre-commit install
```
Expand Down
2 changes: 1 addition & 1 deletion examples/ae_examples/cvae_dim_example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from the FL4Health directory. The following arguments must be present in the spe
* `n_server_rounds`: The number of rounds to run FL
* `checkpoint_path`: path to save the best server model
* `latent_dim`: size of the latent vector in the CVAE or VAE model
* `cvae_model_path`: path to the saved CVAE model for dimesionality reduction
* `cvae_model_path`: path to the saved CVAE model for dimensionality reduction

**NOTE**: Instead of using a global CVAE for all the clients, you can pass personalized CVAE models to each client, but make sure that these models are previously trained in an FL setting, and are not very different, otherwise, that can lead the dimensionality reduction to map the data samples into different latent spaces which might increase the heterogeneity.

Expand Down
4 changes: 2 additions & 2 deletions examples/ae_examples/cvae_dim_example/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from collections.abc import Sequence
from pathlib import Path
from typing import Sequence, Tuple

import flwr as fl
import torch
Expand All @@ -26,7 +26,7 @@ def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.dev
super().__init__(data_path, metrics, device)
self.condition = condition

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
cvae_model_path = Path(narrow_dict_type(config, "cvae_model_path", str))
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
Expand Down
4 changes: 2 additions & 2 deletions examples/ae_examples/cvae_dim_example/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from functools import partial
from typing import Any, Dict
from typing import Any

import flwr as fl
from flwr.common.typing import Config
Expand Down Expand Up @@ -33,7 +33,7 @@ def fit_config(
}


def main(config: Dict[str, Any]) -> None:
def main(config: dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from collections.abc import Sequence
from pathlib import Path
from typing import Sequence, Tuple

import flwr as fl
import torch
Expand All @@ -25,7 +25,7 @@

def binary_class_condition_data_converter(
data: torch.Tensor, target: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Create a condition for each data sample.
# Condition is the binary representation of the target.
binary_representation = bin(int(target))[2:] # Convert to binary and remove the '0b' prefix
Expand Down Expand Up @@ -56,7 +56,7 @@ def setup_client(self, config: Config) -> None:
assert isinstance(self.model, ConditionalVae)
self.model.unpack_input_condition = self.autoencoder_converter.get_unpacking_function()

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# To make sure pixels stay in the range [0.0, 1.0].
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -23,7 +21,7 @@ def __init__(
self.fc_mu = nn.Linear(64, latent_dim)
self.fc_logvar = nn.Linear(64, latent_dim)

def forward(self, input: torch.Tensor, condition: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x = self.conv(input)
# Flatten the tensor
x = x.view(x.size(0), -1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from functools import partial
from typing import Any, Dict
from typing import Any

import flwr as fl
from flwr.common.typing import Config
Expand Down Expand Up @@ -32,7 +32,7 @@ def fit_config(
}


def main(config: Dict[str, Any]) -> None:
def main(config: dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down
4 changes: 2 additions & 2 deletions examples/ae_examples/cvae_examples/mlp_cvae_example/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from collections.abc import Sequence
from pathlib import Path
from typing import Sequence, Tuple

import flwr as fl
import torch
Expand Down Expand Up @@ -44,7 +44,7 @@ def setup_client(self, config: Config) -> None:
assert isinstance(self.model, ConditionalVae)
self.model.unpack_input_condition = self.autoencoder_converter.get_unpacking_function()

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# ToTensor transform is used to make sure pixels stay in the range [0.0, 1.0].
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -19,7 +17,7 @@ def __init__(
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)

def forward(self, input: torch.Tensor, condition: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor, condition: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
input = torch.cat((input, condition), dim=-1)
x = F.relu(self.fc1(input))
x = F.relu(self.fc2(x))
Expand Down
4 changes: 2 additions & 2 deletions examples/ae_examples/cvae_examples/mlp_cvae_example/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from functools import partial
from typing import Any, Dict
from typing import Any

import flwr as fl
from flwr.common.typing import Config
Expand Down Expand Up @@ -32,7 +32,7 @@ def fit_config(
}


def main(config: Dict[str, Any]) -> None:
def main(config: dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down
3 changes: 1 addition & 2 deletions examples/ae_examples/fedprox_vae_example/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
from pathlib import Path
from typing import Tuple

import flwr as fl
import torch
Expand All @@ -22,7 +21,7 @@


class VaeFedProxClient(FedProxClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75, beta=100)
# Flattening the input images to use an MLP-based variational autoencoder.
Expand Down
4 changes: 1 addition & 3 deletions examples/ae_examples/fedprox_vae_example/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -18,7 +16,7 @@ def __init__(
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)

def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
x = F.relu(self.fc1(input))
x = F.relu(self.fc2(x))
return self.fc_mu(x), self.fc_logvar(x)
Expand Down
4 changes: 2 additions & 2 deletions examples/ae_examples/fedprox_vae_example/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from functools import partial
from typing import Any, Dict
from typing import Any

import flwr as fl
from flwr.common.typing import Config
Expand Down Expand Up @@ -31,7 +31,7 @@ def fit_config(
}


def main(config: Dict[str, Any]) -> None:
def main(config: dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down
5 changes: 2 additions & 3 deletions examples/apfl_example/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
from pathlib import Path
from typing import Dict, Tuple

import flwr as fl
import torch
Expand All @@ -22,7 +21,7 @@


class MnistApflClient(ApflClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
sampler = DirichletLabelBasedSampler(list(range(10)), sample_percentage=0.75)
train_loader, val_loader, _ = load_mnist_data(self.data_path, batch_size, sampler)
Expand All @@ -31,7 +30,7 @@ def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_model(self, config: Config) -> nn.Module:
return ApflModule(MnistNetWithBnAndFrozen()).to(self.device)

def get_optimizer(self, config: Config) -> Dict[str, Optimizer]:
def get_optimizer(self, config: Config) -> dict[str, Optimizer]:
local_optimizer = torch.optim.AdamW(self.model.local_model.parameters(), lr=0.01)
global_optimizer = torch.optim.AdamW(self.model.global_model.parameters(), lr=0.01)
return {"local": local_optimizer, "global": global_optimizer}
Expand Down
8 changes: 4 additions & 4 deletions examples/apfl_example/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from functools import partial
from typing import Any, Dict, Optional
from typing import Any

import flwr as fl
from flwr.common.typing import Config
Expand All @@ -22,8 +22,8 @@ def fit_config(
batch_size: int,
n_server_rounds: int,
current_round: int,
local_epochs: Optional[int] = None,
local_steps: Optional[int] = None,
local_epochs: int | None = None,
local_steps: int | None = None,
) -> Config:
return {
**make_dict_with_epochs_or_steps(local_epochs, local_steps),
Expand All @@ -33,7 +33,7 @@ def fit_config(
}


def main(config: Dict[str, Any]) -> None:
def main(config: dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down
5 changes: 2 additions & 3 deletions examples/basic_example/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import argparse
from pathlib import Path
from typing import Optional, Tuple

import flwr as fl
import torch
Expand All @@ -18,12 +17,12 @@


class CifarClient(BasicClient):
def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
def get_data_loaders(self, config: Config) -> tuple[DataLoader, DataLoader]:
batch_size = narrow_dict_type(config, "batch_size", int)
train_loader, val_loader, _ = load_cifar10_data(self.data_path, batch_size)
return train_loader, val_loader

def get_test_data_loader(self, config: Config) -> Optional[DataLoader]:
def get_test_data_loader(self, config: Config) -> DataLoader | None:
batch_size = narrow_dict_type(config, "batch_size", int)
test_loader, _ = load_cifar10_test_data(self.data_path, batch_size)
return test_loader
Expand Down
8 changes: 4 additions & 4 deletions examples/basic_example/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
from functools import partial
from typing import Any, Dict, Optional
from typing import Any

import flwr as fl
from flwr.common.typing import Config
Expand All @@ -21,8 +21,8 @@
def fit_config(
batch_size: int,
current_server_round: int,
local_epochs: Optional[int] = None,
local_steps: Optional[int] = None,
local_epochs: int | None = None,
local_steps: int | None = None,
) -> Config:
return {
**make_dict_with_epochs_or_steps(local_epochs, local_steps),
Expand All @@ -31,7 +31,7 @@ def fit_config(
}


def main(config: Dict[str, Any]) -> None:
def main(config: dict[str, Any]) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down
Loading
Loading