Fast Convex Optimization for Two-Layer ReLU Networks: Equivalent Model Classes and Cone Decompositions

Abstract

We develop fast algorithms and robust software for convex optimization of two-layer neural networks with ReLU activation functions. Our work leverages a convex re-formulation of the standard weight-decay penalized training problem as a set of group-l1-regularized data-local models, where locality is enforced by polyhedral cone constraints. In the special case of zero-regularization, we show that this problem is exactly equivalent to unconstrained optimization of a convex "gated ReLU" network. For problems with non-zero regularization, we show that convex gated ReLU models obtain data-dependent approximation bounds for the ReLU training problem. To optimize the convex re-formulations, we develop an accelerated proximal gradient method and a practical augmented Lagrangian solver. We show that these approaches are faster than standard training heuristics for the non-convex problem, such as SGD, and outperform commercial interior-point solvers. Experimentally, we verify our theoretical results, explore the group-l1 regularization path, and scale convex optimization for neural networks to image classification on MNIST and CIFAR-10.

Cite

Text

Mishkin et al. "Fast Convex Optimization for Two-Layer ReLU Networks: Equivalent Model Classes and Cone Decompositions." International Conference on Machine Learning, 2022.

Markdown

[Mishkin et al. "Fast Convex Optimization for Two-Layer ReLU Networks: Equivalent Model Classes and Cone Decompositions." International Conference on Machine Learning, 2022.](https://mlanthology.org/icml/2022/mishkin2022icml-fast/)

BibTeX

@inproceedings{mishkin2022icml-fast,
  title     = {{Fast Convex Optimization for Two-Layer ReLU Networks: Equivalent Model Classes and Cone Decompositions}},
  author    = {Mishkin, Aaron and Sahiner, Arda and Pilanci, Mert},
  booktitle = {International Conference on Machine Learning},
  year      = {2022},
  pages     = {15770-15816},
  volume    = {162},
  url       = {https://mlanthology.org/icml/2022/mishkin2022icml-fast/}
}