Learned Optimizers That Scale and Generalize

Abstract

Learning to learn has emerged as an important direction for achieving artificial intelligence. Two of the primary barriers to its adoption are an inability to scale to larger problems and a limited ability to generalize to new tasks. We introduce a learned gradient descent optimizer that generalizes well to new tasks, and which has significantly reduced memory and computation overhead. We achieve this by introducing a novel hierarchical RNN architecture, with minimal per-parameter overhead, augmented with additional architectural features that mirror the known structure of optimization tasks. We also develop a meta-training ensemble of small, diverse, optimization tasks capturing common properties of loss landscapes. The optimizer learns to outperform RMSProp/ADAM on problems in this corpus. More importantly, it performs comparably or better when applied to small convolutional neural networks, despite seeing no neural networks in its meta-training set. Finally, it generalizes to train Inception V3 and ResNet V2 architectures on the ImageNet dataset for thousands of steps, optimization problems that are of a vastly different scale than those it was trained on.

Cite

Text

Wichrowska et al. "Learned Optimizers That Scale and Generalize." International Conference on Machine Learning, 2017.

Markdown

[Wichrowska et al. "Learned Optimizers That Scale and Generalize." International Conference on Machine Learning, 2017.](https://mlanthology.org/icml/2017/wichrowska2017icml-learned/)

BibTeX

@inproceedings{wichrowska2017icml-learned,
  title     = {{Learned Optimizers That Scale and Generalize}},
  author    = {Wichrowska, Olga and Maheswaranathan, Niru and Hoffman, Matthew W. and Colmenarejo, Sergio Gómez and Denil, Misha and Freitas, Nando and Sohl-Dickstein, Jascha},
  booktitle = {International Conference on Machine Learning},
  year      = {2017},
  pages     = {3751-3760},
  volume    = {70},
  url       = {https://mlanthology.org/icml/2017/wichrowska2017icml-learned/}
}