Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX
Abstract
One way to achieve higher-order automatic differentiation (AD) is to implement first-order AD and apply it repeatedly. This nested approach works, but can result in combinatorial amounts of redundant work. This paper describes a more efficient method, already known but with a new presentation, and its implementation in JAX. We also study its application to neural ordinary differential equations, and in particular discuss some additional algorithmic improvements for higher-order AD of differential equations.
Cite
Text
Bettencourt et al. "Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX." NeurIPS 2019 Workshops: Program_Transformations, 2019.Markdown
[Bettencourt et al. "Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX." NeurIPS 2019 Workshops: Program_Transformations, 2019.](https://mlanthology.org/neuripsw/2019/bettencourt2019neuripsw-taylormode/)BibTeX
@inproceedings{bettencourt2019neuripsw-taylormode,
title = {{Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX}},
author = {Bettencourt, Jesse and Johnson, Matthew J. and Duvenaud, David},
booktitle = {NeurIPS 2019 Workshops: Program_Transformations},
year = {2019},
url = {https://mlanthology.org/neuripsw/2019/bettencourt2019neuripsw-taylormode/}
}