Multi-Layer Transformers Gradient Can Be Approximated in Almost Linear Time

Abstract

The quadratic computational complexity in the self-attention mechanism of popular transformer architectures poses significant challenges for training and inference, particularly in terms of efficiency and memory requirements. Towards addressing these challenges, this paper introduces a novel fast computation method for gradient calculation in multi-layer transformer models. Our approach enables the computation of gradients for the entire multi-layer transformer model in almost linear time $n^{1+o(1)}$, where $n$ is the input sequence length. This breakthrough significantly reduces the computational bottleneck associated with the traditional quadratic time complexity. Our theory holds for any loss function and maintains a bounded approximation error across the entire model. Furthermore, our analysis can hold when the multi-layer transformer model contains many practical sub-modules, such as residual connection, casual mask, and multi-head attention. By improving the efficiency of gradient computation in large language models, we hope that our work will facilitate the more effective training and deployment of long-context language models based on our theoretical results.

Cite

Text

Liang et al. "Multi-Layer Transformers Gradient Can Be Approximated in Almost Linear Time." NeurIPS 2024 Workshops: OPT, 2024.

Markdown

[Liang et al. "Multi-Layer Transformers Gradient Can Be Approximated in Almost Linear Time." NeurIPS 2024 Workshops: OPT, 2024.](https://mlanthology.org/neuripsw/2024/liang2024neuripsw-multilayer/)

BibTeX

@inproceedings{liang2024neuripsw-multilayer,
  title     = {{Multi-Layer Transformers Gradient Can Be Approximated in Almost Linear Time}},
  author    = {Liang, Yingyu and Sha, Zhizhou and Shi, Zhenmei and Song, Zhao and Zhou, Yufa},
  booktitle = {NeurIPS 2024 Workshops: OPT},
  year      = {2024},
  url       = {https://mlanthology.org/neuripsw/2024/liang2024neuripsw-multilayer/}
}