Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets

Abstract

Ensembles of machine learning models yield improved system performance as well as robust and interpretable uncertainty estimates; however, their inference costs can be prohibitively high. Ensemble Distribution Distillation (EnD$^2$) is an approach that allows a single model to efficiently capture both the predictive performance and uncertainty estimates of an ensemble. For classification, this is achieved by training a Dirichlet distribution over the ensemble members' output distributions via the maximum likelihood criterion. Although theoretically principled, this work shows that the criterion exhibits poor convergence when applied to large-scale tasks where the number of classes is very high. Specifically, we show that for the Dirichlet log-likelihood criterion classes with low probability induce larger gradients than high-probability classes. Hence during training the model focuses on the distribution of the ensemble tail-class probabilities rather than the probability of the correct and closely related classes. We propose a new training objective which minimizes the reverse KL-divergence to a \emph{Proxy-Dirichlet} target derived from the ensemble. This loss resolves the gradient issues of EnD$^2$, as we demonstrate both theoretically and empirically on the ImageNet, LibriSpeech, and WMT17 En-De datasets containing 1000, 5000, and 40,000 classes, respectively.

Cite

Text

Ryabinin et al. "Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets." Neural Information Processing Systems, 2021.

Markdown

[Ryabinin et al. "Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets." Neural Information Processing Systems, 2021.](https://mlanthology.org/neurips/2021/ryabinin2021neurips-scaling/)

BibTeX

@inproceedings{ryabinin2021neurips-scaling,
  title     = {{Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets}},
  author    = {Ryabinin, Max and Malinin, Andrey and Gales, Mark},
  booktitle = {Neural Information Processing Systems},
  year      = {2021},
  url       = {https://mlanthology.org/neurips/2021/ryabinin2021neurips-scaling/}
}