Scalify: Scale Propagation for Efficient Low-Precision LLM Training
Abstract
Low-precision formats such as float8 have been introduced in machine learning accelerated hardware to improve computational efficiency for large language models training and inference. Nevertheless, adoption by the ML community has been slowed down by the complex, and sometimes brittle, techniques required to match higher precision training accuracy. In this work, we present Scalify, a end-to-end scale propagation paradigm for computational graphs, generalizing and formalizing existing tensor scaling methods. Experiment results show that Scalify supports out-of-the-box float8 matrix multiplication and gradients representation, as well as float16 optimizer state storage. Our JAX implementation of Scalify is open-sourced at [github.com/graphcore-research/jax-scalify](https://github.com/graphcore-research/jax-scalify).
Cite
Text
Balanca et al. "Scalify: Scale Propagation for Efficient Low-Precision LLM Training." ICML 2024 Workshops: WANT, 2024.Markdown
[Balanca et al. "Scalify: Scale Propagation for Efficient Low-Precision LLM Training." ICML 2024 Workshops: WANT, 2024.](https://mlanthology.org/icmlw/2024/balanca2024icmlw-scalify/)BibTeX
@inproceedings{balanca2024icmlw-scalify,
title = {{Scalify: Scale Propagation for Efficient Low-Precision LLM Training}},
author = {Balanca, Paul and Hosegood, Samuel and Luschi, Carlo and Fitzgibbon, Andrew W},
booktitle = {ICML 2024 Workshops: WANT},
year = {2024},
url = {https://mlanthology.org/icmlw/2024/balanca2024icmlw-scalify/}
}