Combining Axes Preconditioners Through Kronecker Approximation for Deep Learning

Abstract

Adaptive regularization based optimization methods such as full-matrix Adagrad which use gradient second-moment information hold significant potential for fast convergence in deep neural network (DNN) training, but are memory intensive and computationally demanding for large neural nets. We develop a technique called Combining AxeS PReconditioners (CASPR), which optimizes matrix-shaped DNN parameters by finding different preconditioners for each mode/axis of the parameter and combining them using a Kronecker-sum based approximation. We show tighter convergence guarantees in stochastic optimization compared to a Kronecker product based preconditioner, Shampoo, which arises as a special case of CASPR. Furthermore, our experiments demonstrates that CASPR approximates the gradient second-moment matrix in full-matrix Adagrad more accurately, and shows significant improvement in training and generalization performance compared to existing practical adaptive regularization based methods such as Shampoo and Adam in a variety of tasks including graph neural network on OGBG-molpcba, Transformer on a universal dependencies dataset and auto-regressive large language modeling on C4 dataset.

Cite

Text

Duvvuri et al. "Combining Axes Preconditioners Through Kronecker Approximation for Deep Learning." International Conference on Learning Representations, 2024.

Markdown

[Duvvuri et al. "Combining Axes Preconditioners Through Kronecker Approximation for Deep Learning." International Conference on Learning Representations, 2024.](https://mlanthology.org/iclr/2024/duvvuri2024iclr-combining/)

BibTeX

@inproceedings{duvvuri2024iclr-combining,
  title     = {{Combining Axes Preconditioners Through Kronecker Approximation for Deep Learning}},
  author    = {Duvvuri, Sai Surya and Devvrit, Fnu and Anil, Rohan and Hsieh, Cho-Jui and Dhillon, Inderjit S},
  booktitle = {International Conference on Learning Representations},
  year      = {2024},
  url       = {https://mlanthology.org/iclr/2024/duvvuri2024iclr-combining/}
}