Squared Wasserstein-2 Loss Functions for Efficient Learning of Stochastic Differential Equations

Abstract

We provide an analysis of the squared Wasserstein-2 ( $W_2$ ) distance between two probability distributions associated with two stochastic differential equations (SDEs). Based on this analysis, we propose using squared $W_2$ distance-based loss functions to train parametrized neural networks in order to reconstruct SDEs from noisy data. Specifically, we propose minimizing a time-decoupled squared $W_2$ distance loss function. To demonstrate the practicality of our Wasserstein distance-based loss functions, we performed numerical experiments that demonstrate the efficiency of our method in learning SDEs that arise across a number of applications.

Cite

Text

Xia et al. "Squared Wasserstein-2 Loss Functions for Efficient Learning of Stochastic Differential Equations." Machine Learning, 2025. doi:10.1007/S10994-025-06908-9

Markdown

[Xia et al. "Squared Wasserstein-2 Loss Functions for Efficient Learning of Stochastic Differential Equations." Machine Learning, 2025.](https://mlanthology.org/mlj/2025/xia2025mlj-squared/) doi:10.1007/S10994-025-06908-9

BibTeX

@article{xia2025mlj-squared,
  title     = {{Squared Wasserstein-2 Loss Functions for Efficient Learning of Stochastic Differential Equations}},
  author    = {Xia, Mingtao and Li, Xiangting and Shen, Qijing and Chou, Tom},
  journal   = {Machine Learning},
  year      = {2025},
  pages     = {255},
  doi       = {10.1007/S10994-025-06908-9},
  volume    = {114},
  url       = {https://mlanthology.org/mlj/2025/xia2025mlj-squared/}
}