Transformer Learns Optimal Variable Selection in Group-Sparse Classification

Abstract

Transformers have demonstrated remarkable success across various applications. However, the success of transformers have not been understood in theory. In this work, we give a case study of how transformers can be trained to learn a classic statistical model with "group sparsity", where the input variables form multiple groups, and the label only depends on the variables from one of the groups. We theoretically demonstrate that, a one-layer transformer trained by gradient descent can correctly leverage the attention mechanism to select variables, disregarding irrelevant ones and focusing on those beneficial for classification. We also demonstrate that a well-pretrained one-layer transformer can be adapted to new downstream tasks to achieve good prediction accuracy with a limited number of samples. Our study sheds light on how transformers effectively learn structured data.

Cite

Text

Zhang et al. "Transformer Learns Optimal Variable Selection in Group-Sparse Classification." International Conference on Learning Representations, 2025.

Markdown

[Zhang et al. "Transformer Learns Optimal Variable Selection in Group-Sparse Classification." International Conference on Learning Representations, 2025.](https://mlanthology.org/iclr/2025/zhang2025iclr-transformer/)

BibTeX

@inproceedings{zhang2025iclr-transformer,
  title     = {{Transformer Learns Optimal Variable Selection in Group-Sparse Classification}},
  author    = {Zhang, Chenyang and Meng, Xuran and Cao, Yuan},
  booktitle = {International Conference on Learning Representations},
  year      = {2025},
  url       = {https://mlanthology.org/iclr/2025/zhang2025iclr-transformer/}
}