From 3a226839de015b97bcc2d58dbff321c4f8778021 Mon Sep 17 00:00:00 2001 From: Jeremy Fix Date: Sat, 4 Jan 2025 19:16:26 +0100 Subject: [PATCH] import of the example from torchcvnn --- mnist_conv/README.md | 30 +++++ mnist_conv/mnist.py | 188 ++++++++++++++++++++++++++++ mnist_conv/requirements.txt | 3 + mnist_conv/utils.py | 238 ++++++++++++++++++++++++++++++++++++ 4 files changed, 459 insertions(+) create mode 100644 mnist_conv/README.md create mode 100644 mnist_conv/mnist.py create mode 100644 mnist_conv/requirements.txt create mode 100644 mnist_conv/utils.py diff --git a/mnist_conv/README.md b/mnist_conv/README.md new file mode 100644 index 0000000..9b0e35b --- /dev/null +++ b/mnist_conv/README.md @@ -0,0 +1,30 @@ +# MNIST classification in the spectral domain + +This simple example demonstrates how to code and run a complex valued neural network for classification. + +The task does not necessarily make sense but provides complex valued inputs : we classifiy the MNIST digits from their spectral representation. + +```bash +python -m pip install -r requirements.txt +python mnist.py +``` + +An expected output is : + +```bash +Logging to ./logs/CMNIST_0 +>> Training +100%|██████| 844/844 [00:17<00:00, 48.61it/s] +>> Testing +[Step 0] Train : CE 0.20 Acc 0.94 | Valid : CE 0.08 Acc 0.97 | Test : CE 0.06 Acc 0.98[>> BETTER <<] + +>> Training +100%|██████| 844/844 [00:16<00:00, 51.69it/s] +>> Testing +[Step 1] Train : CE 0.06 Acc 0.98 | Valid : CE 0.06 Acc 0.98 | Test : CE 0.05 Acc 0.98[>> BETTER <<] + +>> Training +100%|██████| 844/844 [00:15<00:00, 53.47it/s] +>> Testing +[Step 2] Train : CE 0.04 Acc 0.99 | Valid : CE 0.04 Acc 0.99 | Test : CE 0.04 Acc 0.99[>> BETTER <<] +``` diff --git a/mnist_conv/mnist.py b/mnist_conv/mnist.py new file mode 100644 index 0000000..823b1c5 --- /dev/null +++ b/mnist_conv/mnist.py @@ -0,0 +1,188 @@ +# MIT License + +# Copyright (c) 2023 Jérémy Fix + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +""" +# Example using complex valued neural networks to classify MNIST from the Fourier Transform of the digits. + + + +Requires dependencies : + python3 -m pip install torchvision tqdm +""" + +# Standard imports +import random +import sys +from typing import List + +# External imports +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms.v2 as v2_transforms + +import torchcvnn.nn as c_nn + +# Local imports +import utils + + +def conv_block(in_c: int, out_c: int, cdtype: torch.dtype) -> List[nn.Module]: + """ + Builds a basic building block of + `Conv2d`-`Cardioid`-`Conv2d`-`Cardioid`-`AvgPool2d` + + Arguments: + in_c : the number of input channels + out_c : the number of output channels + cdtype : the dtype of complex values (expected to be torch.complex64 or torch.complex32) + """ + return [ + nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype), + c_nn.BatchNorm2d(out_c), + c_nn.Cardioid(), + nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dtype=cdtype), + c_nn.BatchNorm2d(out_c), + c_nn.Cardioid(), + c_nn.AvgPool2d(kernel_size=2, stride=2, padding=0), + ] + + +def train(): + """ + Train function + + Sample output : + ```.bash + (venv) me@host:~$ python mnist.py + Logging to ./logs/CMNIST_0 + >> Training + 100%|██████| 844/844 [00:17<00:00, 48.61it/s] + >> Testing + [Step 0] Train : CE 0.20 Acc 0.94 | Valid : CE 0.08 Acc 0.97 | Test : CE 0.06 Acc 0.98[>> BETTER <<] + + >> Training + 100%|██████| 844/844 [00:16<00:00, 51.69it/s] + >> Testing + [Step 1] Train : CE 0.06 Acc 0.98 | Valid : CE 0.06 Acc 0.98 | Test : CE 0.05 Acc 0.98[>> BETTER <<] + + >> Training + 100%|██████| 844/844 [00:15<00:00, 53.47it/s] + >> Testing + [Step 2] Train : CE 0.04 Acc 0.99 | Valid : CE 0.04 Acc 0.99 | Test : CE 0.04 Acc 0.99[>> BETTER <<] + + [...] + ``` + + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + valid_ratio = 0.1 + batch_size = 64 + epochs = 10 + cdtype = torch.complex64 + + # Dataloading + train_valid_dataset = torchvision.datasets.MNIST( + root="./data", + train=True, + download=True, + transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]), + ) + test_dataset = torchvision.datasets.MNIST( + root="./data", + train=False, + download=True, + transform=v2_transforms.Compose([v2_transforms.PILToTensor(), torch.fft.fft]), + ) + + all_indices = list(range(len(train_valid_dataset))) + random.shuffle(all_indices) + split_idx = int(valid_ratio * len(train_valid_dataset)) + valid_indices, train_indices = all_indices[:split_idx], all_indices[split_idx:] + + # Train dataloader + train_dataset = torch.utils.data.Subset(train_valid_dataset, train_indices) + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True + ) + + # Valid dataloader + valid_dataset = torch.utils.data.Subset(train_valid_dataset, valid_indices) + valid_loader = torch.utils.data.DataLoader( + valid_dataset, batch_size=batch_size, shuffle=False + ) + + # Test dataloader + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=batch_size, shuffle=False + ) + + # Model + conv_model = nn.Sequential( + *conv_block(1, 16, cdtype), + *conv_block(16, 16, cdtype), + *conv_block(16, 32, cdtype), + *conv_block(32, 32, cdtype), + nn.Flatten(), + ) + + with torch.no_grad(): + conv_model.eval() + dummy_input = torch.zeros((64, 1, 28, 28), dtype=cdtype, requires_grad=False) + out_conv = conv_model(dummy_input).view(64, -1) + lin_model = nn.Sequential( + nn.Linear(out_conv.shape[-1], 124, dtype=cdtype), + c_nn.Cardioid(), + nn.Linear(124, 10, dtype=cdtype), + c_nn.Mod(), + ) + model = nn.Sequential(conv_model, lin_model) + model.to(device) + + # Loss, optimizer, callbacks + f_loss = nn.CrossEntropyLoss() + optim = torch.optim.Adam(model.parameters(), lr=3e-4) + logpath = utils.generate_unique_logpath("./logs", "CMNIST") + print(f"Logging to {logpath}") + checkpoint = utils.ModelCheckpoint(model, logpath, 4, min_is_best=True) + + # Training loop + for e in range(epochs): + print(">> Training") + train_loss, train_acc = utils.train_epoch( + model, train_loader, f_loss, optim, device + ) + + print(">> Testing") + valid_loss, valid_acc = utils.test_epoch(model, valid_loader, f_loss, device) + test_loss, test_acc = utils.test_epoch(model, test_loader, f_loss, device) + updated = checkpoint.update(valid_loss) + better_str = "[>> BETTER <<]" if updated else "" + + print( + f"[Step {e}] Train : CE {train_loss:5.2f} Acc {train_acc:5.2f} | Valid : CE {valid_loss:5.2f} Acc {valid_acc:5.2f} | Test : CE {test_loss:5.2f} Acc {test_acc:5.2f}" + + better_str + ) + + +if __name__ == "__main__": + train() diff --git a/mnist_conv/requirements.txt b/mnist_conv/requirements.txt new file mode 100644 index 0000000..1ebd8dc --- /dev/null +++ b/mnist_conv/requirements.txt @@ -0,0 +1,3 @@ +torchcvnn +torchvision +tqdm diff --git a/mnist_conv/utils.py b/mnist_conv/utils.py new file mode 100644 index 0000000..49a3888 --- /dev/null +++ b/mnist_conv/utils.py @@ -0,0 +1,238 @@ +# coding: utf-8 +# MIT License + +# Copyright (c) 2023 Jeremy Fix + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# Standard imports +import os +from typing import Tuple + +# External imports +import torch +import torch.nn as nn +import tqdm + +# import torch.onnx + + +def train_epoch( + model: nn.Module, + loader: torch.utils.data.DataLoader, + f_loss: nn.Module, + optim: torch.optim.Optimizer, + device: torch.device, +) -> Tuple[float, float]: + """ + Run the training loop for nsteps minibatches of the dataloader + + Arguments: + model: the model to train + loader: an iterable dataloader + f_loss (nn.Module): the loss + optim : an optimizing algorithm + device: the device on which to run the code + + Returns: + The averaged training loss + The averaged training accuracy + """ + model.train() + + loss_avg = 0 + acc_avg = 0 + num_samples = 0 + for inputs, outputs in tqdm.tqdm(loader): + inputs = inputs.to(device) + outputs = outputs.to(device) + + # Forward propagate through the model + pred_outputs = model(inputs) + + # Forward propagate through the loss + loss = f_loss(pred_outputs, outputs) + + # Backward pass and update + optim.zero_grad() + loss.backward() + optim.step() + + num_samples += inputs.shape[0] + + # Denormalize the loss that is supposed to be averaged over the + # minibatch + loss_avg += inputs.shape[0] * loss.item() + pred_cls = pred_outputs.argmax(dim=-1) + acc_avg += (pred_cls == outputs).sum().item() + + return loss_avg / num_samples, acc_avg / num_samples + + +def test_epoch( + model: nn.Module, + loader: torch.utils.data.DataLoader, + f_loss: nn.Module, + device: torch.device, +) -> Tuple[float, float]: + """ + Run the test loop for n_test_batches minibatches of the dataloader + + Arguments: + model: the model to evaluate + loader: an iterable dataloader + f_loss: the loss + device: the device on which to run the code + + Returns: + The averaged test loss + The averaged test accuracy + + """ + model.eval() + + loss_avg = 0 + acc_avg = 0 + num_samples = 0 + with torch.no_grad(): + for inputs, outputs in loader: + inputs = inputs.to(device) + outputs = outputs.to(device) + + # Forward propagate through the model + pred_outputs = model(inputs) + + # Forward propagate through the loss + loss = f_loss(pred_outputs, outputs) + + loss_avg += inputs.shape[0] * loss.item() + pred_cls = pred_outputs.argmax(dim=-1) + acc_avg += (pred_cls == outputs).sum().item() + num_samples += inputs.shape[0] + + return loss_avg / num_samples, acc_avg / num_samples + + +class ModelCheckpoint(object): + def __init__( + self, + model: torch.nn.Module, + savepath: str, + num_input_dims: int, + min_is_best: bool = True, + ) -> None: + """ + Early stopping callback + + Arguments: + model: the model to save + savepath: the location where to save the model's parameters + num_input_dims: the number of dimensions for the input tensor (required for onnx export) + min_is_best: whether the min metric or the max metric as the best + """ + self.model = model + self.savepath = savepath + self.num_input_dims = num_input_dims + self.best_score = None + if min_is_best: + self.is_better = self.lower_is_better + else: + self.is_better = self.higher_is_better + + def lower_is_better(self, score: float) -> bool: + """ + Test if the provided score is lower than the best score found so far + + Arguments: + score: the score to test + + Returns: + res : is the provided score lower than the best score so far ? + """ + return self.best_score is None or score < self.best_score + + def higher_is_better(self, score: float) -> bool: + """ + Test if the provided score is higher than the best score found so far + + Arguments: + score: the score to test + + Returns: + res : is the provided score higher than the best score so far ? + """ + return self.best_score is None or score > self.best_score + + def update(self, score: float) -> bool: + """ + If the provided score is better than the best score registered so far, + saves the model's parameters on disk as a pytorch tensor + + Arguments: + score: the new score to consider + + Returns: + res: whether or not the provided score is better than the best score + registered so far + """ + if self.is_better(score): + self.model.eval() + + torch.save( + self.model.state_dict(), os.path.join(self.savepath, "best_model.pt") + ) + + # torch.onnx.export( + # self.model, + # dummy_input, + # os.path.join(self.savepath, "best_model.onnx"), + # verbose=False, + # input_names=["input"], + # output_names=["output"], + # dynamic_axes={ + # "input": {0: "batch"}, + # "output": {0: "batch"}, + # }, + # ) + + self.best_score = score + return True + return False + + +def generate_unique_logpath(logdir: str, raw_run_name: str) -> str: + """ + Generate a unique directory name and create it if necessary + + Arguments: + logdir: the prefix directory + raw_run_name: the base name + + Returns: + log_path: a non-existent path like logdir/raw_run_name_xxxx + where xxxx is an int + """ + i = 0 + while True: + run_name = raw_run_name + "_" + str(i) + log_path = os.path.join(logdir, run_name) + if not os.path.isdir(log_path): + os.makedirs(log_path) + return log_path + i = i + 1