A Hessian-Aware Stochastic Differential Equation for Modelling SGD
Abstract
Continuous-time approximation of Stochastic Gradient Descent (SGD) is a crucial tool to study its escaping behaviors from stationary points. However, existing stochastic differential equation (SDE) models fail to fully capture these behaviors, even for simple quadratic objectives. Built on a novel stochastic backward error analysis framework, we derive the Hessian-Aware Stochastic Modified Equation (HA-SME), an SDE that incorporates Hessian information of the objective function into both its drift and diffusion terms. Our analysis shows that HA-SME matches the order-best approximation error guarantee among existing SDE models in the literature, while achieving a significantly reduced dependence on the smoothness parameter of the objective. Further, for quadratic objectives, under mild conditions, HA-SME is proved to be the first SDE model that recovers exactly the SGD dynamics in the distributional sense. Consequently, when the local landscape near a stationary point can be approximated by quadratics, HA-SME is expected to accurately predict the local escaping behaviors of SGD.
Cite
Text
Li et al. "A Hessian-Aware Stochastic Differential Equation for Modelling SGD." ICML 2024 Workshops: HiLD, 2024.Markdown
[Li et al. "A Hessian-Aware Stochastic Differential Equation for Modelling SGD." ICML 2024 Workshops: HiLD, 2024.](https://mlanthology.org/icmlw/2024/li2024icmlw-hessianaware/)BibTeX
@inproceedings{li2024icmlw-hessianaware,
title = {{A Hessian-Aware Stochastic Differential Equation for Modelling SGD}},
author = {Li, Xiang and Shen, Zebang and Zhang, Liang and He, Niao},
booktitle = {ICML 2024 Workshops: HiLD},
year = {2024},
url = {https://mlanthology.org/icmlw/2024/li2024icmlw-hessianaware/}
}