Mitigating Simplicity Bias in Deep Learning for Improved OOD Generalization and Robustness

Abstract

Neural networks (NNs) are known to exhibit simplicity bias where they tend to prefer learning 'simple' features over more 'complex' ones, even when the latter may be more informative. Simplicity bias can lead to the model making biased predictions which have poor out-of-distribution (OOD) generalization and subgroup robustness. To address this, we propose a hypothesis about spurious features that directly connects to simplicity bias: we hypothesize that spurious features on many datasets are simple features that are still predictive of the label. We empirically validate this hypothesis, and subsequently develop a framework which leverages this hypothesis to learn more robust models. In our proposed framework, we first train a simple model, and then regularize the conditional mutual information with respect to it to obtain the final model. We theoretically study the effect of this regularization and show that it provably reduces reliance on spurious features in certain settings. We also empirically demonstrate the effectiveness of this framework in various problem settings and real-world applications, showing that it effectively addresses simplicity bias and leads to more features being used, enhances OOD generalization, and improves subgroup robustness and fairness.

Cite

Text

Vasudeva et al. "Mitigating Simplicity Bias in Deep Learning for Improved OOD Generalization and Robustness." Transactions on Machine Learning Research, 2024.

Markdown

[Vasudeva et al. "Mitigating Simplicity Bias in Deep Learning for Improved OOD Generalization and Robustness." Transactions on Machine Learning Research, 2024.](https://mlanthology.org/tmlr/2024/vasudeva2024tmlr-mitigating/)

BibTeX

@article{vasudeva2024tmlr-mitigating,
  title     = {{Mitigating Simplicity Bias in Deep Learning for Improved OOD Generalization and Robustness}},
  author    = {Vasudeva, Bhavya and Shahabi, Kameron and Sharan, Vatsal},
  journal   = {Transactions on Machine Learning Research},
  year      = {2024},
  url       = {https://mlanthology.org/tmlr/2024/vasudeva2024tmlr-mitigating/}
}