A Trace-Restricted Kronecker-Factored Approximation to Natural Gradient

Abstract

Second-order optimization methods have the ability to accelerate convergence by modifying the gradient through the curvature matrix. There have been many attempts to use second-order optimization methods for training deep neural networks. In this work, inspired by diagonal approximations and factored approximations such as Kronecker-factored Approximate Curvature (KFAC), we propose a new approximation to the Fisher information matrix (FIM) called Trace-restricted Kronecker-factored Approximate Curvature (TKFAC), which can hold the certain trace relationship between the exact and the approximate FIM. In TKFAC, we decompose each block of the approximate FIM as a Kronecker product of two smaller matrices and scaled by a coefficient related to trace. We theoretically analyze TKFAC's approximation error and give an upper bound of it. We also propose a new damping technique for TKFAC on convolutional neural networks to maintain the superiority of second-order optimization methods during training. Experiments show that our method has better performance compared with several state-of-the-art algorithms on some deep network architectures.

Cite

Text

Gao et al. "A Trace-Restricted Kronecker-Factored Approximation to Natural Gradient." AAAI Conference on Artificial Intelligence, 2021. doi:10.1609/AAAI.V35I9.16921

Markdown

[Gao et al. "A Trace-Restricted Kronecker-Factored Approximation to Natural Gradient." AAAI Conference on Artificial Intelligence, 2021.](https://mlanthology.org/aaai/2021/gao2021aaai-trace/) doi:10.1609/AAAI.V35I9.16921

BibTeX

@inproceedings{gao2021aaai-trace,
  title     = {{A Trace-Restricted Kronecker-Factored Approximation to Natural Gradient}},
  author    = {Gao, Kai-Xin and Liu, Xiao-Lei and Huang, Zheng-Hai and Wang, Min and Wang, Zidong and Xu, Dachuan and Yu, Fan},
  booktitle = {AAAI Conference on Artificial Intelligence},
  year      = {2021},
  pages     = {7519-7527},
  doi       = {10.1609/AAAI.V35I9.16921},
  url       = {https://mlanthology.org/aaai/2021/gao2021aaai-trace/}
}