Weight-Sharing Regularization

Abstract

Weight-sharing is ubiquitous in deep learning. Motivated by this, we propose a “weight-sharing regularization” penalty on the weights $w \in \mathbb{R}^d$ of a neural network, defined as $\mathcal{R}(w) = \frac{1}{d - 1}\sum_{i > j}^d |w_i - w_j|$. We study the proximal mapping of $\mathcal{R}$ and provide an intuitive interpretation of it in terms of a physical system of interacting particles. We also parallelize existing algorithms for $\mathrm{prox}_{\mathcal{R}}$ (to run on GPU) and find that one of them is fast in practice but slow ($O(d)$) for worst-case inputs. Using the physical interpretation, we design a novel parallel algorithm which runs in $O(\log^3 d)$ when sufficient processors are available, thus guaranteeing fast training. Our experiments reveal that weight-sharing regularization enables fully connected networks to learn convolution-like filters even when pixels have been shuffled while convolutional neural networks fail in this setting. Our code is available on \href{https://github.com/motahareh-sohrabi/weight-sharing-regularization}github.

Cite

Text

Shakerinava et al. "Weight-Sharing Regularization." Artificial Intelligence and Statistics, 2024.

Markdown

[Shakerinava et al. "Weight-Sharing Regularization." Artificial Intelligence and Statistics, 2024.](https://mlanthology.org/aistats/2024/shakerinava2024aistats-weightsharing/)

BibTeX

@inproceedings{shakerinava2024aistats-weightsharing,
  title     = {{Weight-Sharing Regularization}},
  author    = {Shakerinava, Mehran and MS Sohrabi, Motahareh and Ravanbakhsh, Siamak and Lacoste-Julien, Simon},
  booktitle = {Artificial Intelligence and Statistics},
  year      = {2024},
  pages     = {4204-4212},
  volume    = {238},
  url       = {https://mlanthology.org/aistats/2024/shakerinava2024aistats-weightsharing/}
}