Jaxpruner: A Concise Library for Sparsity Research

Abstract

This paper introduces JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. JaxPruner aims to accelerate research on sparse neural networks by providing concise implementations of popular pruning and sparse training algorithms with minimal memory and latency overhead. Algorithms implemented in JaxPruner use a common API and work seamlessly with the popular optimization library Optax, which, in turn, enables easy integration with existing JAX based libraries. We demonstrate this ease of integration by providing examples in four different codebases: Scenic, t5x, Dopamine and FedJAX and provide baseline experiments on popular benchmarks. Jaxpruner is hosted at github.com/google-research/jaxpruner

Cite

Text

Lee et al. "Jaxpruner: A Concise Library for Sparsity Research." Conference on Parsimony and Learning, 2024.

Markdown

[Lee et al. "Jaxpruner: A Concise Library for Sparsity Research." Conference on Parsimony and Learning, 2024.](https://mlanthology.org/cpal/2024/lee2024cpal-jaxpruner/)

BibTeX

@inproceedings{lee2024cpal-jaxpruner,
  title     = {{Jaxpruner: A Concise Library for Sparsity Research}},
  author    = {Lee, Joo Hyung and Park, Wonpyo and Mitchell, Nicole Elyse and Pilault, Jonathan and Ceron, Johan Samir Obando and Kim, Han-Byul and Lee, Namhoon and Frantar, Elias and Long, Yun and Yazdanbakhsh, Amir and Han, Woohyun and Agrawal, Shivani and Subramanian, Suvinay and Wang, Xin and Kao, Sheng-Chun and Zhang, Xingyao and Gale, Trevor and Bik, Aart J.C. and Ferev, Milen and Han, Zhonglin and Kim, Hong-Seok and Dauphin, Yann and Dziugaite, Gintare Karolina and Castro, Pablo Samuel and Evci, Utku},
  booktitle = {Conference on Parsimony and Learning},
  year      = {2024},
  pages     = {515-528},
  volume    = {234},
  url       = {https://mlanthology.org/cpal/2024/lee2024cpal-jaxpruner/}
}