Why Adam Outperforms Gradient Descent on Language Models: A Heavy-Tailed Class Imbalance Problem

Abstract

We show that the heavy-tailed class imbalance found in language modeling tasks leads to difficul- ties in optimization dynamics. When training with gradient descent, the loss associated with low frequency classes decreases slower than the loss associated with high frequency classes. Under the heavy-tailed class imbalance found in language modeling tasks, most samples are from classes of low relative frequency, leading to overall slow decreasing on the average loss. Sign-based optimizers such as Adam and sign descent do not suffer from this problem, and lead to decrease on all classes. We give evidence of this behavior on training for a 2-layer transformer on language data, a linear model on synthetic data whose only property is a heavy-tailed class distribution, and a convolutional network on a modified MNIST dataset made to exhibit heavy-tailed class imbalance.

Cite

Text

Yadav et al. "Why Adam Outperforms Gradient Descent on Language Models: A Heavy-Tailed Class Imbalance Problem." NeurIPS 2023 Workshops: OPT, 2023.

Markdown

[Yadav et al. "Why Adam Outperforms Gradient Descent on Language Models: A Heavy-Tailed Class Imbalance Problem." NeurIPS 2023 Workshops: OPT, 2023.](https://mlanthology.org/neuripsw/2023/yadav2023neuripsw-adam/)

BibTeX

@inproceedings{yadav2023neuripsw-adam,
  title     = {{Why Adam Outperforms Gradient Descent on Language Models: A Heavy-Tailed Class Imbalance Problem}},
  author    = {Yadav, Robin and Kunstner, Frederik and Schmidt, Mark and Bietti, Alberto},
  booktitle = {NeurIPS 2023 Workshops: OPT},
  year      = {2023},
  url       = {https://mlanthology.org/neuripsw/2023/yadav2023neuripsw-adam/}
}