A Trace-Restricted Kronecker-Factored Approximation to Natural Gradient
Abstract
Second-order optimization methods have the ability to accelerate convergence by modifying the gradient through the curvature matrix. There have been many attempts to use second-order optimization methods for training deep neural networks. In this work, inspired by diagonal approximations and factored approximations such as Kronecker-factored Approximate Curvature (KFAC), we propose a new approximation to the Fisher information matrix (FIM) called Trace-restricted Kronecker-factored Approximate Curvature (TKFAC), which can hold the certain trace relationship between the exact and the approximate FIM. In TKFAC, we decompose each block of the approximate FIM as a Kronecker product of two smaller matrices and scaled by a coefficient related to trace. We theoretically analyze TKFAC's approximation error and give an upper bound of it. We also propose a new damping technique for TKFAC on convolutional neural networks to maintain the superiority of second-order optimization methods during training. Experiments show that our method has better performance compared with several state-of-the-art algorithms on some deep network architectures.
Cite
Text
Gao et al. "A Trace-Restricted Kronecker-Factored Approximation to Natural Gradient." AAAI Conference on Artificial Intelligence, 2021. doi:10.1609/AAAI.V35I9.16921Markdown
[Gao et al. "A Trace-Restricted Kronecker-Factored Approximation to Natural Gradient." AAAI Conference on Artificial Intelligence, 2021.](https://mlanthology.org/aaai/2021/gao2021aaai-trace/) doi:10.1609/AAAI.V35I9.16921BibTeX
@inproceedings{gao2021aaai-trace,
title = {{A Trace-Restricted Kronecker-Factored Approximation to Natural Gradient}},
author = {Gao, Kai-Xin and Liu, Xiao-Lei and Huang, Zheng-Hai and Wang, Min and Wang, Zidong and Xu, Dachuan and Yu, Fan},
booktitle = {AAAI Conference on Artificial Intelligence},
year = {2021},
pages = {7519-7527},
doi = {10.1609/AAAI.V35I9.16921},
url = {https://mlanthology.org/aaai/2021/gao2021aaai-trace/}
}