When trying to compare probability distributions, common approaches (especially in machine learning) are cross entropy, KL Divergence, Kolmogorov Smirnov score, and so on. However, if we think about a probability distribution, there is a very useful piece of **geometry** that the just mentioned list of measures don’t take into consideration: *their **domain***. For instance, below we see our univariate distributions ($f, g, h$) and they all sit on top of the real number line, $\mathbb{R}$. So it can prove very powerful to make use of that information, not simply comparing the heights (function values) of our distributions. That is what the **Wasserstein Distance** does. Our traditional functions (CE, KL Div) measure *overlap* between distributions, not *displacement*. ![](Screen%20Shot%202022-11-17%20at%206.57.27%20AM.png) ![](Screen%20Shot%202022-11-17%20at%207.20.45%20AM.png) A **transport plan** is really a joint probability distribution whose marginals are given by the two measures (red above). We can then define the **Wasserstein Distance** as: > The transport plan (across all valid transport plans) that minimizes the the cost of transport. ### Probability as Geometry ![](Screen%20Shot%202022-11-17%20at%207.37.37%20AM.png) ![](Screen%20Shot%202022-11-17%20at%207.38.19%20AM.png) ![](Screen%20Shot%202022-11-17%20at%207.38.47%20AM.png) ![](Screen%20Shot%202022-11-17%20at%207.39.50%20AM.png) TODO: add commonality described [here](https://youtu.be/MSbvkhAR0VY?t=723). ![](Screen%20Shot%202022-11-17%20at%207.45.54%20AM.png) ![](Screen%20Shot%202022-11-17%20at%207.48.23%20AM.png) ![](Screen%20Shot%202022-11-18%20at%207.42.13%20AM.png) ### In a context where *true* distribution is a single point ![](Pasted%20image%2020221208072530.png) ## Python Implementation Here is a working implementation that could be used to train a tensorflow model: ```python import tensorflow as tf from sklearn.metrics import pairwise_distances def create_D(bins): """Create distance matrix between bins""" # Currently using index instead of middle of bin. Investigate. bin_mid = [i for i, x in enumerate(bins)] # bin_mid = [x.mid for x in bins] bin_mid_array = np.array(bin_mid) D = pd.DataFrame( pairwise_distances(bin_mid_array.reshape(-1, 1)) ) D = D.values D = tf.convert_to_tensor(D) return tf.cast(D, tf.float32) def ws_loss(y_true, y_pred, D=None): """Working implementation of WS loss""" p = y_true q = y_pred p = tf.cast(p, tf.float32) q = tf.cast(q, tf.float32) idx_of_true_class = tf.where(p == 1)[:, 1] distance_of_classes_to_true_class = tf.gather( D, indices=idx_of_true_class, axis=0 ) return tf.math.reduce_sum( q * distance_of_classes_to_true_class, axis=1 ) D = create_D() wsl = partial(ws_loss, D=D) model.compile(loss=wsl) ``` --- Date: 20221117 Links to: Tags: #review References: * [Introduction to the Wasserstein distance - YouTube](https://www.youtube.com/watch?v=CDiol4LG2Ao) * [Shape Analysis (Lecture 19): Optimal transport - YouTube](https://www.youtube.com/watch?v=MSbvkhAR0VY) * [CS 182: Lecture 19: Part 3: GANs - YouTube](https://youtu.be/RdC4XeExDeY?t=505) * [Approximating Wasserstein distances with PyTorch - Daniel Daza](https://dfdazac.github.io/sinkhorn.html) * [Geometric Loss functions between sampled measures, images and volumes — GeomLoss](http://www.kernel-operations.io/geomloss/index.html)