Rieoptax: Riemannian Optimization in JAX

Abstract

We present Rieoptax, an open source Python library for Riemannian optimization in JAX. We show that many differential geometric primitives, such as Riemannian exponential and logarithm maps, are usually faster in Rieoptax than existing frameworks in Python, both on CPU and GPU. We support a range of basic and advanced stochastic optimization solvers like Riemannian stochastic gradient, stochastic variance reduction, and adaptive gradient methods. A distinguishing feature of the proposed toolbox is that we also support differentially private optimization on Riemannian manifolds.

Cite

Text

Utpala et al. "Rieoptax: Riemannian Optimization in JAX." NeurIPS 2022 Workshops: OPT, 2022.

Markdown

[Utpala et al. "Rieoptax: Riemannian Optimization in JAX." NeurIPS 2022 Workshops: OPT, 2022.](https://mlanthology.org/neuripsw/2022/utpala2022neuripsw-rieoptax/)

BibTeX

@inproceedings{utpala2022neuripsw-rieoptax,
  title     = {{Rieoptax: Riemannian Optimization in JAX}},
  author    = {Utpala, Saiteja and Han, Andi and Jawanpuria, Pratik and Mishra, Bamdev},
  booktitle = {NeurIPS 2022 Workshops: OPT},
  year      = {2022},
  url       = {https://mlanthology.org/neuripsw/2022/utpala2022neuripsw-rieoptax/}
}