Squared Wasserstein-2 Loss Functions for Efficient Learning of Stochastic Differential Equations
Abstract
We provide an analysis of the squared Wasserstein-2 ( $W_2$ ) distance between two probability distributions associated with two stochastic differential equations (SDEs). Based on this analysis, we propose using squared $W_2$ distance-based loss functions to train parametrized neural networks in order to reconstruct SDEs from noisy data. Specifically, we propose minimizing a time-decoupled squared $W_2$ distance loss function. To demonstrate the practicality of our Wasserstein distance-based loss functions, we performed numerical experiments that demonstrate the efficiency of our method in learning SDEs that arise across a number of applications.
Cite
Text
Xia et al. "Squared Wasserstein-2 Loss Functions for Efficient Learning of Stochastic Differential Equations." Machine Learning, 2025. doi:10.1007/S10994-025-06908-9Markdown
[Xia et al. "Squared Wasserstein-2 Loss Functions for Efficient Learning of Stochastic Differential Equations." Machine Learning, 2025.](https://mlanthology.org/mlj/2025/xia2025mlj-squared/) doi:10.1007/S10994-025-06908-9BibTeX
@article{xia2025mlj-squared,
title = {{Squared Wasserstein-2 Loss Functions for Efficient Learning of Stochastic Differential Equations}},
author = {Xia, Mingtao and Li, Xiangting and Shen, Qijing and Chou, Tom},
journal = {Machine Learning},
year = {2025},
pages = {255},
doi = {10.1007/S10994-025-06908-9},
volume = {114},
url = {https://mlanthology.org/mlj/2025/xia2025mlj-squared/}
}