Weighted Neural Tangent Kernel: A Generalized and Improved Network-Induced Kernel

Abstract

The neural tangent kernel (NTK) has recently attracted intense study, as it describes the evolution of an over-parameterized neural network (NN) trained by gradient descent. However, it is now well-known that gradient descent is not always a good optimizer for NNs, which can partially explain the unsatisfactory practical performance of the NTK regression estimator. In this paper, we introduce the weighted neural tangent kernel (WNTK), a generalized and improved tool, which can capture an over-parameterized NN’s training dynamics under adjusted gradient descent direction. Theoretically, in the infinite-width limit, we prove: (1) the stability of the WNTK at initialization and during training, and (2) the equivalence between the WNTK regression estimator and the corresponding NN estimator with different learning rates on different parameters. With the proposed weight update algorithm, weight terms, or equivalently NN descent directions, can be trained through multiple-kernel optimization. Both empirical and analytical WNTKs outperform the corresponding NTKs in numerical experiments, coinciding with the fact that adjusted gradient descent could outperform original gradient descent in NNs’ training.

Cite

Text

Tan et al. "Weighted Neural Tangent Kernel: A Generalized and Improved Network-Induced Kernel." Machine Learning, 2023. doi:10.1007/S10994-023-06356-3

Markdown

[Tan et al. "Weighted Neural Tangent Kernel: A Generalized and Improved Network-Induced Kernel." Machine Learning, 2023.](https://mlanthology.org/mlj/2023/tan2023mlj-weighted/) doi:10.1007/S10994-023-06356-3

BibTeX

@article{tan2023mlj-weighted,
  title     = {{Weighted Neural Tangent Kernel: A Generalized and Improved Network-Induced Kernel}},
  author    = {Tan, Lei and Wu, Shutong and Zhou, Wenxing and Huang, Xiaolin},
  journal   = {Machine Learning},
  year      = {2023},
  pages     = {2871-2901},
  doi       = {10.1007/S10994-023-06356-3},
  volume    = {112},
  url       = {https://mlanthology.org/mlj/2023/tan2023mlj-weighted/}
}