Learning with a Wasserstein Loss

Charlie Frogner1, Chiyuan Zhang1, Hossein Mobahi2, Mauricio Araya-Polo3, Tomaso Poggio1
1 Center for Brains, Minds & Machines
McGovern Institute for Brain Research
Massachusetts Institute of Technology
2 Computer Science & Artificial Intelligence Lab
Massachusetts Institute of Technology
3 Shell Int'l Exploration & Production
Authors contributed equally.

Update: Caffe implementation is now available -- thanks to Prafulla Dhariwal and Jeevana Inala!

In the news

More-flexible machine learning - MIT News


Learning to predict multi-label outputs is challenging, but in many problems there is a natural metric on the outputs that can be used to improve predictions. In this paper we develop a loss function for multi-label learning, based on the Wasserstein distance. The Wasserstein distance provides a natural notion of dissimilarity for probability measures. Although optimizing with respect to the exact Wasserstein distance is costly, recent work has described a regularized approximation that is efficiently computed. We describe an efficient learning algorithm based on this regularization, as well as a novel extension of the Wasserstein distance from probability measures to unnormalized measures. We also describe a statistical learning bound for the loss. The Wasserstein loss can encourage smoothness of the predictions with respect to a chosen metric on the output space. We demonstrate this property on a real-data tag prediction problem, using the Yahoo Flickr Creative Commons dataset, outperforming a baseline that doesn't use the metric.


Charlie Frogner, Chiyuan Zhang, Hossein Mobahi, Mauricio Araya-Polo, Tomaso Poggio. Learning with a Wasserstein Loss. In Advances in Neural Information Processing Systems (NIPS) 28 (2015). (arXiv) (pdf) (supplement) (bibtex)

Code and data

Flickr tag prediction dataset
Bart Thomee, David A. Shamma, Gerald Friedland, Benjamin Elizalde, Karl Ni, Douglas Poland, Damian Borth, Li-Jia Li. The new data and new challenges in multimedia research. (arXiv:1503.01817)
Ken Chatfield, Karen Simonyan, Andrea Vedaldi, Andrew Zisserman. Return of the Devil in the Details: Delving Deep into Convolutional Nets. British Machine Vision Conference (2014). (website)

Wasserstein loss layer
A normalized Wasserstein loss layer has been added to Mocha.jl.
Caffe implementation is available. (Thanks to Prafulla Dhariwal and Jeevana Inala!)