Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX

Abstract

One way to achieve higher-order automatic differentiation (AD) is to implement first-order AD and apply it repeatedly. This nested approach works, but can result in combinatorial amounts of redundant work. This paper describes a more efficient method, already known but with a new presentation, and its implementation in JAX. We also study its application to neural ordinary differential equations, and in particular discuss some additional algorithmic improvements for higher-order AD of differential equations.

Cite

Text

Bettencourt et al. "Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX." NeurIPS 2019 Workshops: Program_Transformations, 2019.

Markdown

[Bettencourt et al. "Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX." NeurIPS 2019 Workshops: Program_Transformations, 2019.](https://mlanthology.org/neuripsw/2019/bettencourt2019neuripsw-taylormode/)

BibTeX

@inproceedings{bettencourt2019neuripsw-taylormode,
  title     = {{Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX}},
  author    = {Bettencourt, Jesse and Johnson, Matthew J. and Duvenaud, David},
  booktitle = {NeurIPS 2019 Workshops: Program_Transformations},
  year      = {2019},
  url       = {https://mlanthology.org/neuripsw/2019/bettencourt2019neuripsw-taylormode/}
}