Simon Roburin*,
Yann de Mont-Marin*,
Andrei Bursuc,
Renaud Marlet,
Patrick Pérez,
Mathieu Aubry
*equal contribution
Batch Normalization (BN) is a prominent deep learning technique. In spite of its apparent simplicity, its implications over optimization are yet to be fully understood. In this paper, we study the optimization of neural networks with BN layers from a geometric perspective. We leverage the radial invariance of groups of parameters, such as neurons for multi-layer perceptrons or filters for convolutional neural networks, and translate several popular optimization schemes on the L2 unit hypersphere. This formulation and the associated geometric interpretation sheds new light on the training dynamics and the relation between different optimization schemes. In particular, we use it to derive the effective learning rate of Adam and stochastic gradient descent (SGD) with momentum, and we show that in the presence of BN layers, performing SGD alone is actually equivalent to a variant of Adam constrained to the unit hypersphere. Our analysis also leads us to introduce new variants of Adam. We empirically show, over a variety of datasets and architectures, that they improve accuracy in classification tasks.
This repository implements the new optimizer proposed in the paper AdamS and AdamSRT and gives a train.py
script to reproduce the benchmark results presented in the paper.
To use the package properly you need python3 and it is recommanded to use CUDA10 for acceleration. The Installation is as follow:
- Clone the repo:
$ git clone https://github.com/ymontmarin/adamsrt
- Install this repository and the dependencies using pip:
$ pip install -e adamsrt
With this, you can edit the code on the fly and import function and classes of the package in other project as well.
- Remove Optional. To uninstall this package, run:
$ pip uninstall adamsrt
To import the package you just need:
import adamsrt
The package contains pytorch Optimizer
for the new optimizers (AdamS, AdamSRT, SGDMRT) proposed in the paper as well as classes to load classic models and dataset and other optimizer (pytorch implementation of AdamG and paper variant of SGD SGD-MRT):
adamsrt.
AdamSRT, AdamS
models.
resnet18, resnet20, vgg16
dataloaders.
get_dataloader_cifar10, get_dataloader_cifar100, get_dataloader_SVHN, get_dataloader_imagenet
optimizers.
AdamSRT, AdamS, SGDMRT, AdamG
The paper introduces a geometrical framework which allows to identify behaviour of Adam which are not easily translated in the context of optimization on manifold. We introduce the rescaling and transport (RT) of the momentum and standardization (S) of the step division in Adam to neutralize these effects. We are able to apply these transformations to Adam with only a few lines of code to propose AdamS and AdamSRT.
These optimizers (AdamSRT, AdamS) are built to give a specific treatment on layers followed by BN (or other normalization layer).
To use it with pytorch, you need to use paramgroups of pytorch (see the doc).
It allows you to specify the parameters followed by a normalization and activate the special treatment option channel_wise=True
for these parameters.
The typical use for 2D convolutional networks where a convolutional layer is followed by a BN layer looks like:
from adamsrt import AdamSRT
par_groups = [{'params': model.conv_params(), 'channel_wise'=True},
{'params': model.other_params()}]
optimizer = AdamSRT(par_groups, lr=0.001, betas=(0.9, 0.99), weight_decay=1e-4)
optimizer.zero_grad()
loss_fn(model(input), target).backward()
optimizer.step()
Typical implementation of methods to filter the param for standard network as proposed by torchvision models are:
class CustomModel(MyModel):
...
def conv_params(self):
conv_params = []
for name, param in self.named_parameters():
if any(key in name for key in {'conv', 'downsample.0'}):
conv_params.append(param)
return conv_params
Advanced details on the use of the optimizers can be found in adamsrt/README.md
.
AdamSRT and Adam have been benchmark over a range of classification datasets and architecture. They are compared to the classical counterpart Adam and its variant (AdamW, AdamG) as well as state of the art SGD-M. For each architecture, each dataset we grid searched every hyperparameters and only selected the best to produce the following table.
Method | CIFAR10 ResNet20 |
CIFAR10 ResNet18 |
CIFAR10 VGG16 |
CIFAR100 ResNet18 |
CIFAR100 VGG16 |
SVHN ResNet18 |
SVHN VGG16 |
ImageNet ResNet18 |
---|---|---|---|---|---|---|---|---|
SGD-M | 92.39 (0.12) | 95.10 (0.04) | 93.56 (0.05) | 77.08 (0.18) | 73.77 (0.10) | 95.96 (0.15) | 95.95 (0.09) | 69.86(0.06) |
Adam | 90.98 (0.06) | 93.77 (0.20) | 92.83 (0.17) | 71.30 (0.36) | 68.43 (0.16) | 95.32 (0.23) | 95.57 (0.20) | 68.52 (0.10) |
AdamW | 90.19 (0.24) | 93.61 (0.12) | 92.53 (0.25) | 67.39 (0.27) | 71.37 (0.22) | 95.38 (0.15) | 95.60 (0.08) | - |
AdamG | 91.64 (0.17) | 94.67 (0.12) | 93.41 (0.17) | 73.76 (0.34) | 70.17 (0.20) | 95.73 (0.05) | 95.70 (0.25) | - |
AdamS (ours) | 91.15 (0.11) | 93.95 (0.23) | 92.92 (0.11) | 74.44 (0.22) | 68.73 (0.27) | 95.75 (0.09) | 95.66 (0.09) | 68.82 (0.22) |
AdamSRT (ours) | 91.81 (0.20) | 94.92 (0.05) | 93.75 (0.06) | 75.28 (0.35) | 71.45 (0.13) | 95.84 (0.07) | 95.82 (0.05) | 68.93(0.19) |
Where AdamSRT outperform classic Adam and existing variations (AdamW and AdamG) and breaching the gap in performance with SGD-M even outperforming it on the task CIFAR10 with VGG16.
To reproduce the results of the paper, you can try all the proposed methods AdamS and AdamSRT. SGD-MRT can also be tested but leads to less systematic improvements.
You can use the benchmark methods Adam, AdamG, AdamW, SGD.
As in the paper, training can be done on public dataset CIFAR10, CIFAR100, SVHN and the architecture ResNet18, VGG16 and ResNet20 (only for CIFAR10).
For the training, the best parameters found in a previous grid search (cf Appendix E. Table 4) are used for the chosen setting. They are referenced in the file best_hyper_parameters.py
.
The training is done over 400 epochs with a step-wise scheduling with 3 jumps.
To launch the training you just need to call training.py
with the proper options
cd adamsrt
python training.py --optimizer=adamsrt --model=resnet18 --dataloader=cifar100
Logs will give you the train and valid accuracy during the training as well as the test accuracy at the end of the training.
Options are :
dataloader in ['cifar10', 'cifar100', 'svhn', 'imagenet']
model in ['resnet18', 'resnet20', 'vgg16']
optimizer in ['adams', 'adamsrt', 'adamw', 'adam', 'adamg', 'sgd', 'sgdmrt']
resnet20
is only available for cifar10
and imagenet
is only available with resnet20
.
To use imagenet
you need to provide the path to imagenet data folder by editing the file adamsrt/config.py
.
-
For this project, we strongly relied on the torch framework.
-
All the experiments present in the paper and in the tabular were run on the valeo.ai cluster gpu infrastructure
-
We thank Gabriel de Marmiesse for his precious help to conduct successfully all of our experiments
@article{roburin2020spherical,
author = {Roburin, Simon and de Mont-Marin, Yann and Bursuc, Andrei and Marlet, Renaud and Perez, Patrick and Aubry, Mathieu},
title = {Spherical Perspective on Learning with Batch Norm},
journal = {arXiv preprint arXiv:2020.xxxxx},
year = {2020}
}