Coordinate Descent on the Orthogonal Group for Recurrent Neural Network Training

Abstract

We address the poor scalability of learning algorithms for orthogonal recurrent neural networks via the use of stochastic coordinate descent on the orthogonal group, leading to a cost per iteration that increases linearly with the number of recurrent states. This contrasts with the cubic dependency of typical feasible algorithms such as stochastic Riemannian gradient descent, which prohibits the use of big network architectures. Coordinate descent rotates successively two columns of the recurrent matrix. When the coordinate (i.e., indices of rotated columns) is selected uniformly at random at each iteration, we prove convergence of the algorithm under standard assumptions on the loss function, stepsize and minibatch noise. In addition, we numerically show that the Riemannian gradient has an approximately sparse structure. Leveraging this observation, we propose a variant of our proposed algorithm that relies on the Gauss-Southwell coordinate selection rule. Experiments on a benchmark recurrent neural network training problem show that the proposed approach is a very promising step towards the training of orthogonal recurrent neural networks with big architectures.

Cite

Text

Massart and Abrol. "Coordinate Descent on the Orthogonal Group for Recurrent Neural Network Training." AAAI Conference on Artificial Intelligence, 2022. doi:10.1609/AAAI.V36I7.20742

Markdown

[Massart and Abrol. "Coordinate Descent on the Orthogonal Group for Recurrent Neural Network Training." AAAI Conference on Artificial Intelligence, 2022.](https://mlanthology.org/aaai/2022/massart2022aaai-coordinate/) doi:10.1609/AAAI.V36I7.20742

BibTeX

@inproceedings{massart2022aaai-coordinate,
  title     = {{Coordinate Descent on the Orthogonal Group for Recurrent Neural Network Training}},
  author    = {Massart, Estelle M. and Abrol, Vinayak},
  booktitle = {AAAI Conference on Artificial Intelligence},
  year      = {2022},
  pages     = {7744-7751},
  doi       = {10.1609/AAAI.V36I7.20742},
  url       = {https://mlanthology.org/aaai/2022/massart2022aaai-coordinate/}
}