Jensen-Shannon Divergence

An important way of comparing two probability distributions $P$ and $Q$ is the Kullback-Leibler divergence: $$ D(P||Q) = \int p(x) \log {p(x) \over q(x)} \; dx.$$ It has the nice properties that it’s non-negative, and equals zero iff $P = Q$.

One drawback of this quantity is that it’s not symmetric. We can see it as comparing a candidate distribution $Q$, to a true distribution $P$. What marks $P$ as the true distribution is that we’re taking expectations relative to it, implying that it determines how $x$ appears in the world. In fact, one way to think of the KL divergence is the additional number of bits we’d need to code symbols that actually came from a distribution $P$, but that we encoded thinking they came from the distribution $Q$. Sometimes, our problem may not justify this asymmetry between the distributions, and we might prefer a symmetric measure.

A simple fix is the symmetrized quantity $$ {1 \over 2} (D(P||Q) + D(Q||P)).$$ However, this does not address a second, more serious problem: the KL divergence blows up if there are values of $x$ for which $q(x) = 0 \neq p(x)$. This blow-up is signalling that the distributions are very different, because there are values of $x$ which can occur under the true distribution but which the candidate distribution thinks should never occur. We might want to signal this difference in a less drastic way.

To fix this more serious problem, we need to guarantee that the support of the candidate distribution contains that of the true distribution. The only way to guarantee this is to mix in the true distribution with the candidate distribution. This ensures that the denominator of the logarithm always contains a $p(x)$, and so will never be 0 if the numerator isn’t. The most natural way to mix the distributions is by equal parts, giving us the mixture distribution $$ M \triangleq {1 \over 2} (P + Q).$$

We now have three distributions: the two original ones, $P$ and $Q$, and their mixture, $M$. The only way, up to constant factors, to form a symmetric quantity out of these by adding KL divergences, that also doesn’t blow up, is to compare the third distribution to the first two in equal parts. This gives us the Jensen-Shannon divergence: $$D_{JS} \triangleq {1 \over 2} \left(D(P||M) + D(Q||M)\right).$$

It retains some of the important properties of the KL-divergence: it’s non-negative, since its components are. It equals 0 when the two original distributions are equal, since in that case $P =Q = M$. And when the two distributions are maximally different, having disjoint supports, each component reduces to ${1 \over 2} \log 2$, yielding an overall divergence of $\log 2$, or 1 bit if we work in base-2.

Note that I haven’t shown that the divergence increases monotonically from 0 to $\log 2$ as the distributions diverge, though, intuitively, that should be the case.

The maximum value of 1 bit hints at an interesting interpretation of the JS-divergence. Consider how we would generate samples from the mixture distribution $M$: we can first flip an unbiased coin, and then sample from $P$ or $Q$ based on whether the coin landed heads or tails. We can then ask whether we can tell how the coin landed from the sample we drew. If the distributions are very different, then we should be able to do this. If they’re identical, then the sample won’t tell us anything about the coin-toss.

We can measure this informativeness as the mutual information between our samples, $X$, and the results of the coin-toss, $Z$: $$I(X;Z) = H(Z) – H(Z|X) \le H(Z) = \log 2.$$ Computing this mutual information the other way yields the JS-divergence: \begin{align} I(X;Z) &= H(X) – H(X|Z)\\ H(X) &= – \int p(m) \log p(m) + \int p(m, x) \log p(m|z) dm \\ &= -\int {1 \over 2} (P(m) + Q(m)) \log \left({1\over 2} (P(m) + Q(m)) \right) \\ -H(X|Z) &= \sum_{z \in \{0,1\}} \int p(m|z) p(z) \log p(m|z) \; dm\\ &= {1 \over 2} \int P(m) \log P(m) \; dm + {1 \over 2} \int Q(m) \log Q(m) \; dm \\ H(X) – H(X|Z) &= {1 \over 2} \int P(m) \log \left( { P(m) \over {1 \over 2} \left(P(m) + Q(m)\right) } \right) \; dm\\ &+ {1 \over 2} \int Q(m) \log \left( { Q(m) \over {1 \over 2} \left(P(m) + Q(m)\right) } \right) \; dm\\ &= {1 \over 2} \left(D_{KL}(P || M) + D_{KL}(Q || M) \right). \end{align}

So the JS-divergence measures how different two distributions are by whether we can distinguish them in a mixture.

Connection to Binary Classification

In fact, we can relate the JS-divergence to the accuracy of a binary classifier trained to discriminate the two distributions. This is shown in detail in e.g. “Generative Adversarial Networks” by Goodfellow et al. I’ll sketch the basic idea here.

Let our binary classifier $D(x)$ output a probability that the input $x$ is sampled from $P$, and the binary variable $z_i$ indicate whether $x_i$ truly was drawn from $P$ or not. Then to train our classifier, we compute the likelihood of our (independent) observations as $$ p(x_1,\dots,x_N | D) = \prod_i p(x_i|D) = \prod_i D(x_i)^{z_i} (1 – D(x_i))^{1 – z_i}.$$ Taking negative-logarithms gives the cross-entropy loss \begin{align} \mathcal{L}_{CE} &= -\sum_i z_i \log D(x_i) + (1 – z_i) \log (1 – D(x_i))\\ &\to -{1 \over 2} \sum_x p(x) \log D(x) + q(x) \log(1 – D(x)),\end{align} where in the second line we’ve taken the limit (and normalized) of infinite data points.

To determine the optimal value our discriminator should output, we minimize this loss with respect to $D(x)$. The derivative is $$ {\partial \mathcal{L} \over \partial D(x)} = -\half {p(x) \over D(x)} D'(x) + \half {q(x) \over 1 – D(x)} D'(x).$$ Setting this to zero, we get $$ D(x) = {p(x) \over p(x) + q(x)}.$$

Plugging this back into our express for the loss, we get

\begin{align} \mathcal{L}_{CE} &= -\half \sum_x p(x) \log {p(x) \over p(x) +q(x)} + q(x) \log {q(x) \over p(x) +q(x)}\\&= -\half \sum_x p(x) \log {p(x) \over 2 \cdot \half(p(x) +q(x))} + q(x) \log {q(x) \over 2\cdot \half(p(x) +q(x))} \\&=-\half \sum_x p(x) \log {p(x) \over \half(p(x) +q(x))} + q(x) \log {q(x) \over \half(p(x) +q(x))} + \half \sum_x p(x) \log 2 + q(x) \log 2\\ &= -\half D_{KL}(P||M) -\half D_{KL}(Q||M) + \log(2) \\ &=\log(2) -D_{JS}(P||Q).\end{align}

So the loss of the optimum binary classifier is the (negative) JS divergence, up to additive constant factor.

When the two distributions are disjoint, $D_{JS}(P||Q) = \log 2$ as we saw above, and the loss is minimized at 0. This makes sense: the classifier has no trouble discriminating the distributions and incurs no loss.

When the two distributions are identical, as GANs try to make them, the JS divergence is 0, and the loss is maximized at $\log 2$. This also makes sense: the classifier can’t discriminate the distributions at all, so it’s answer is as good as guessing.

$$ \blacksquare$$


Posted

in

,

by

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *