SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Abstract

Tabular data underpins numerous high-impact applications of machine learning from fraud detection to genomics and healthcare. Classical approaches to solving tabular problems, such as gradient boosting and random forests, are widely used by practitioners. However, recent deep learning methods have achieved a degree of performance competitive with popular techniques. We devise a hybrid deep learning approach to solving tabular data problems. Our method, SAINT, performs attention over both rows and columns, and it includes an enhanced embedding method. We also study a new contrastive self-supervised pre-training method for use when labels are scarce. SAINT consistently improves performance over previous deep learning methods, and it even performs competitively with gradient boosting methods, including XGBoost, CatBoost, and LightGBM, on average over $30$ benchmark datasets in regression, binary classification, and multi-class classification tasks.

Cite

Text

Somepalli et al. "SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training." NeurIPS 2022 Workshops: TRL, 2022.

Markdown

[Somepalli et al. "SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training." NeurIPS 2022 Workshops: TRL, 2022.](https://mlanthology.org/neuripsw/2022/somepalli2022neuripsw-saint/)

BibTeX

@inproceedings{somepalli2022neuripsw-saint,
  title     = {{SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training}},
  author    = {Somepalli, Gowthami and Schwarzschild, Avi and Goldblum, Micah and Bruss, C. Bayan and Goldstein, Tom},
  booktitle = {NeurIPS 2022 Workshops: TRL},
  year      = {2022},
  url       = {https://mlanthology.org/neuripsw/2022/somepalli2022neuripsw-saint/}
}