Algorithms for Mean-Field Variational Inference via Polyhedral Optimization in the Wasserstein Space

Abstract

We develop a theory of finite-dimensional polyhedral subsets over the Wasserstein space and optimization of functionals over them via first-order methods. Our main application is to the problem of mean-field variational inference, which seeks to approximate a distribution $\pi$ over $\mathbb{R}^d$ by a product measure $\pi^\star$. When $\pi$ is strongly log-concave and log-smooth, we provide (1) approximation rates certifying that $\pi^\star$ is close to the minimizer $\pi^\star_\diamond$ of the KL divergence over a \emph{polyhedral} set $\mathcal{P}_\diamond$, and (2) an algorithm for minimizing $\text{KL}(\cdot\|\pi)$ over $\mathcal{P}_\diamond$ with accelerated complexity $O(\sqrt \kappa \log(\kappa d/\varepsilon^2))$, where $\kappa$ is the condition number of $\pi$.

Cite

Text

Jiang et al. "Algorithms for Mean-Field Variational Inference via Polyhedral Optimization in the Wasserstein Space." Conference on Learning Theory, 2024.

Markdown

[Jiang et al. "Algorithms for Mean-Field Variational Inference via Polyhedral Optimization in the Wasserstein Space." Conference on Learning Theory, 2024.](https://mlanthology.org/colt/2024/jiang2024colt-algorithms/)

BibTeX

@inproceedings{jiang2024colt-algorithms,
  title     = {{Algorithms for Mean-Field Variational Inference via Polyhedral Optimization in the Wasserstein Space}},
  author    = {Jiang, Yiheng and Chewi, Sinho and Pooladian, Aram-Alexandre},
  booktitle = {Conference on Learning Theory},
  year      = {2024},
  pages     = {2720-2721},
  volume    = {247},
  url       = {https://mlanthology.org/colt/2024/jiang2024colt-algorithms/}
}