Deep Generalized Prediction Set Classifier and Its Theoretical Guarantees

Abstract

A standard classification rule returns a single-valued prediction for any observation without a confidence guarantee, which may result in severe consequences in many critical applications when the uncertainty is high. In contrast, set-valued classification is a new paradigm to handle the uncertainty in classification by reporting a set of plausible labels to observations in highly ambiguous regions. In this article, we propose the Deep Generalized Prediction Set (DeepGPS) method, a network-based set-valued classifier induced by acceptance region learning. DeepGPS is capable of identifying ambiguous observations and detecting out-of-distribution (OOD) observations. It is the first set-valued classification of this kind with a theoretical guarantee and scalable to large datasets. Our nontrivial proof shows that the risk of DeepGPS, defined as the expected size of the prediction set, attains the optimality within a neural network hypothesis class while simultaneously achieving the user-prescribed class-specific accuracy. Additionally, by using a weighted loss, DeepGPS returns tighter acceptance regions, leading to informative predictions and improved OOD detection performance. Empirically, our method outperforms the baselines on several benchmark datasets.

Cite

Text

Wang and Qiao. "Deep Generalized Prediction Set Classifier and Its Theoretical Guarantees." Transactions on Machine Learning Research, 2024.

Markdown

[Wang and Qiao. "Deep Generalized Prediction Set Classifier and Its Theoretical Guarantees." Transactions on Machine Learning Research, 2024.](https://mlanthology.org/tmlr/2024/wang2024tmlr-deep-a/)

BibTeX

@article{wang2024tmlr-deep-a,
  title     = {{Deep Generalized Prediction Set Classifier and Its Theoretical Guarantees}},
  author    = {Wang, Zhou and Qiao, Xingye},
  journal   = {Transactions on Machine Learning Research},
  year      = {2024},
  url       = {https://mlanthology.org/tmlr/2024/wang2024tmlr-deep-a/}
}