Fast Trainable Projection for Robust Fine-Tuning
Abstract
Robust fine-tuning aims to achieve competitive in-distribution (ID) performance while maintaining the out-of-distribution (OOD) robustness of a pre-trained model when transferring it to a downstream task. Recently, projected gradient descent has been successfully used in robust fine-tuning by constraining the deviation from the initialization of the fine-tuned model explicitly through projection. However, algorithmically, two limitations prevent this method from being adopted more widely, scalability and efficiency. In this paper, we propose a new projection-based fine-tuning algorithm, Fast Trainable Projection (FTP) for computationally efficient learning of per-layer projection constraints, resulting in an average 35% speedup on our benchmarks compared to prior works. FTP can be combined with existing optimizers such as AdamW, and be used in a plug-and-play fashion. Finally, we show that FTP is a special instance of hyper-optimizers that tune the hyper-parameters of optimizers in a learnable manner through nested differentiation. Empirically, we show superior robustness on OOD datasets, including domain shifts and natural corruptions, across four different vision tasks with five different pre-trained models. Additionally, we demonstrate that FTP is broadly applicable and beneficial to other learning scenarios such as low-label and continual learning settings thanks to its easy adaptability. The code will be available at https://github.com/GT-RIPL/FTP.git.
Cite
Text
Tian et al. "Fast Trainable Projection for Robust Fine-Tuning." Neural Information Processing Systems, 2023.Markdown
[Tian et al. "Fast Trainable Projection for Robust Fine-Tuning." Neural Information Processing Systems, 2023.](https://mlanthology.org/neurips/2023/tian2023neurips-fast/)BibTeX
@inproceedings{tian2023neurips-fast,
title = {{Fast Trainable Projection for Robust Fine-Tuning}},
author = {Tian, Junjiao and Liu, Yen-Cheng and Smith, James S and Kira, Zsolt},
booktitle = {Neural Information Processing Systems},
year = {2023},
url = {https://mlanthology.org/neurips/2023/tian2023neurips-fast/}
}