SIMPLE: A Gradient Estimator for K-Subset Sampling

Abstract

$k$-subset sampling is ubiquitous in machine learning, enabling regularization and interpretability through sparsity. The challenge lies in rendering $k$-subset sampling amenable to end-to-end learning. This has typically involved relaxing the reparameterized samples to allow for backpropagation, but introduces both bias and variance. In this work, we fall back to discrete $k$-subset sampling on the forward pass. This is coupled with using the gradient with respect to the exact marginals, computed efficiently, as a proxy for the true gradient. We show that our gradient estimator exhibits lower bias and variance compared to state-of-the-art estimators. Empirical results show improved performance on learning to explain and sparse models benchmarks. We provide an algorithm for computing the exact ELBO for the $k$-subset distribution, obtaining significantly lower loss compared to state-of-the-art discrete sparse VAEs. All of our algorithms are exact and efficient.

Cite

Text

Ahmed et al. "SIMPLE: A Gradient Estimator for K-Subset Sampling." International Conference on Learning Representations, 2023.

Markdown

[Ahmed et al. "SIMPLE: A Gradient Estimator for K-Subset Sampling." International Conference on Learning Representations, 2023.](https://mlanthology.org/iclr/2023/ahmed2023iclr-simple/)

BibTeX

@inproceedings{ahmed2023iclr-simple,
  title     = {{SIMPLE: A Gradient Estimator for K-Subset Sampling}},
  author    = {Ahmed, Kareem and Zeng, Zhe and Niepert, Mathias and Van den Broeck, Guy},
  booktitle = {International Conference on Learning Representations},
  year      = {2023},
  url       = {https://mlanthology.org/iclr/2023/ahmed2023iclr-simple/}
}