PyTorch demo code for paper "Learning Neural Networks with Adaptive Regularization"
Learning Neural Networks with Adaptive Regularization
Han Zhao *, Yao-Hung Hubert Tsai *, Ruslan Salakhutdinov, and Geoffrey J. Gordon
Thirty-third Conference on Neural Information Processing Systems (NeurIPS), 2019. (*equal contribution)
If you use this code for your research and find it helpful, please cite our paper:
@inproceedings{zhao2019adaptive,
title={Learning Neural Networks with Adaptive Regularization},
author={Zhao, Han and Tsai, Yao-Hung Hubert and Salakhutdinov, Ruslan and Gordon, Geoffrey J},
booktitle={Advances in Neural Information Processing Systems},
year={2019}
}
In this paper we propose a method, AdaReg, to perform an adaptive and data-dependent regularization method when training neural networks with small-scale datasets. The follow figure shows a schematic illustration of AdaReg during training:
To summarize, for a fully connected layer (usually the last layer) , AdaReg maintains two additional covariance matrices:
During the training phase, AdaReg updates both and in a coordinate descent way. As usual, the update of could use any off-the-shelf optimizers provided by PyTorch. To update two covariance matrices, we derive closed form algorithm to achieve the optimal solution given any fixed . Essentially, the algorithm contains two steps:
- Compute a SVD of a matrix of size or
- Truncate all the singular values into the range . That is, for all the singular values smaller than , set them to be . Similarly, for all the singular values greater than , set them to be .
The choice of hyperparameter should satisfy and . In practice and in our experiments we fix them to be and . Pseudo-code of the algorithm is shown in the following figure.
Really simple! Here we give a minimum code snippet (in PyTorch) to illustrate the main idea. For the full implementation, please see function BayesNet(args)
in src/model.py
for more details.
First, for the weight matrix , we need to define two covariance matrices:
# Define two covariance matrices (in sqrt):
self.sqrt_covt = nn.Parameter(torch.eye(self.num_tasks), requires_grad=False)
self.sqrt_covf = nn.Parameter(torch.eye(self.num_feats), requires_grad=False)
Since we will use own analytic algorithm to optimize them, we set the requires_grad
to be False
. Next, implement a 4 line thresholding function:
def _thresholding(self, sv, lower, upper):
"""
Two-way thresholding of singular values.
:param sv: A list of singular values.
:param lower: Lower bound for soft-thresholding.
:param upper: Upper bound for soft-thresholding.
:return: Thresholded singular values.
"""
uidx = sv > upper
lidx = sv < lower
sv[uidx] = upper
sv[lidx] = lower
return sv
The overall algorithm for updating both covariance matrices can then be implemented as:
def update_covs(self, lower, upper):
"""
Update both the covariance matrix over row and over column, using the closed form solutions.
:param lower: Lower bound of the truncation.
:param upper: Upper bound of the truncation.
"""
covt = torch.mm(self.sqrt_covt, self.sqrt_covt.t())
covf = torch.mm(self.sqrt_covf, self.sqrt_covf.t())
ctask = torch.mm(torch.mm(self.W, covf), self.W.t())
cfeat = torch.mm(torch.mm(self.W.t(), covt), self.W)
# Compute SVD.
ct, st, _ = torch.svd(ctask.data)
cf, sf, _ = torch.svd(cfeat.data)
st = self.num_feats / st
sf = self.num_tasks / sf
# Truncation of both singular values.
st = self._thresholding(st, lower, upper)
st = torch.sqrt(st)
sf = self._thresholding(sf, lower, upper)
sf = torch.sqrt(sf)
# Recompute the value.
self.sqrt_covt.data = torch.mm(torch.mm(ct, torch.diag(st)), ct.t())
self.sqrt_covf.data = torch.mm(torch.mm(cf, torch.diag(sf)), cf.t())
Finally, we need to use the optimized covariance matrices to regularize the learning of our weight matrix (our goal!):
def regularizer(self):
"""
Compute the weight regularizer w.r.t. the weight matrix W.
"""
r = torch.mm(torch.mm(self.sqrt_covt, self.W), self.sqrt_covf)
return torch.sum(r * r)
Add this regularizer back to our favorite objective function (cross-entropy, mean-squared-error, etc) and backpropagate to update , done!
python demo.py --dataset CIFAR10
python demo.py --dataset MNIST
python demo.py --dataset MNIST --trainPartial --trainSize 1000 --batch_size 128
Please email to [email protected] or [email protected] should you have any questions, comments or suggestions.