KL Divergence — in layman’s terms

If we are asked to look at the three animals below and say which one is more of a cat than a dog, most of us would agree that

  • the first one is “all cat and no dog”
  • second one is “more cat than dog”
  • third is “more dog than cat”

Images of animals are from this link Marvin’s review of The Illustrated Encyclopedia of Cat Breeds

If we want a neural net based model to do the same thing (we have gotten good at this in the last few years particularly with neural net based models)

  1. we need to first generate some training data ourselves labeling each picture with some probability assignments like the values shown above ( e.g. 90% [.9] cat ; 10% [.1] dog)
  2. then have the model predict these values for each image in our training set and let the model keep improving its predictions based on far off they are from the values humans assigned to them.
  3. Once the model does this successfully for a large number of pictures, then it is likely to make predictions even for images of cats and dogs never seen before, that would also agree with our estimates.

Focussing specifically on the last part of the second step above, “how does the model calculate how far off it is with its prediction of percentages of ‘catness/dogneess’ of an image”. This “how far off measure” could be done in many ways. One approach is described below

  • Simply adding up all the predicted value assignments and comparing with the human labeled assignments wont do, since the values assigned to any image add up to 100% always, (simply because if a picture is 90 % dog, it is inevitably 10% cat. Same argument applies if we are having more than two categories — cats, dogs, cows etc).
  • However if we do a weighted sum of scores, where
  • the weighting is the estimated prediction (could be human or model prediction) of “catness/dogness”
  • and the score is a function(referred to below as score_function) that we will choose as simply some function of the estimate itself,
  • then we can give a single numeric value for each picture, that captures the distribution of our predictions for cats and dogs.

So we can now generate two numbers for each each image

  • number based on human predictions for an image and
  • the number based on model generated predictions
  • and then compare these two numbers

For example in the case of second image that single numeric value for human prediction would be

  • .9*score_function(.9) + .1*score_function(.1)

So if the model came up with an estimate 20%, 80% for cats and dogs respectively for the second image above, then we can calculate how far off it is from human prediction by simply calculating

  • .9*score_function(.9) + .1*score_function(.1) (.2*score_function(.2) + .8*score_function(.8))

We can do better than this though. The weighted sum approach above doesn’t help when our model comes up with the estimates 10% cat and 90% dog for the second image, because those prediction values would add up to the same numeric values though the estimate of ‘catness’ and ‘dogness’ is not the same — the true estimate is 90% cat and 10% dog, whereas the model estimate is the reverse assignment — 10% cat and 90% dog.

A simple way to address this is calculate the score of the model’s estimate for cats and dogs and weight it with the values assigned to that image by humans. For example in above case if the model’s estimate for second image is 10% cat, 90% dog, its prediction error would would be the numeric value

  • .9*score_function(.1) + .1*score_function(.9) — (.9*score_function(.9) + .1*score_function(.1))

This difference can potentially turn out to be zero too if the value of the score_function is the same for the inputs .9 and .1 (i.e. score_function(.9) == score_function(.1)) . However this can simply be avoided by choosing a function that guarantees the values will not be the same for two distinct values of the input — any strictly monotonic function will do. [1]

This difference computation for calculating prediction error, in essence is what KL divergence is.

So each time our ‘catness/dogness’ prediction model makes a prediction (during training) of percentages/probabilities for an image, it is calculating prediction error by KL divergence, which is simply the calculation below. actual_cat_image_prob*score_function(predicted_cat_image_prob) + actual_dog_image_prob*score_function(predicted_dog_image_prob) — (actual_cat_image_prob*score_function(actual_cat_image_prob) + actual_dog_image_prob*score_function(actual_dog_image_prob))

  • The computed difference is then used to adjust the neural net weights to get the predicted distribution closer to actual distribution. The equation above clearly shows that as the predicted distribution (e.g. predicted_cat_image_prob) gets closer to human labeled distribution (e.g. actual_dog_image_prob) the difference approaches zero.
  • The score_function in KL divergence is the strictly monotonic function log(x) (we pass the inverse of the probability as input to this function, which is tantamount to computing -log(prob)). The base doesn’t matter — it just introduces a scale factor. Since we are trying to reduce the difference of two terms both of which has the same scale factor — it is irrelevant. The value of this function has an interpretation described in details section below.

Additional details

How come the explanation does not mention entropy and cross entropy ?

  • If the distribution that is used to calculate the scores in the weighted sum of scores that we did above, is the same distribution as the weights distribution, H(P) term in figure above — then that is the entropy of the probability distribution

Why is it called entropy?

  • Entropy is a measure of uncertainty (or equivalently a measure of information). In the figure of cats above, there are no features of a dog in the first picture — it is all cats. The probability distribution in this case is 1,0 for cats and dogs respectively and the weighted sum of scores (H(P) in figure above) is 0 — that is zero uncertainty (or zero information — telling us something we already know is certain is not useful) . The second and third images in contrast are a mixture of cat and dog features. So there is a degree of uncertainty to label them exclusively as just a cat or a dog (e.g. entropy of image 2 is .9*log(1/.9) + .1*log(1/.1)).
  • An extreme case of uncertainty is a fair dice — every face of the dice is equally likely in a throw — there is maximum uncertainty in the probability distribution of a fair dice — all outcomes are equally likely. The entropy in this case, given our score function being the log function is log 6 (six faces of dice). So, the entropy of a discrete distribution can span from a value of 0 (total certainty of an outcome) to log N (all outcomes equally likely).
  • So we can associate and uncertainty measure to an individual outcome (which is just log (1/p) — lower the probability — higher the entropy/uncertainty) and to the entire distribution the outcome is part of.
  • In summary, the entropy of a probability distribution is a measure of its uncertainty(information) computed as a weighted sum of the uncertainty(information content) of its outcomes(in the continuous case it is an integral as opposed to a weighted sum). There is a paper which clearly explains what kind of distribution can we conservatively yet optimally assume given certain known details of the distribution. [5] The video in reference below gives a nice interpretation of entropy as the average number of useful bits that is transmitted representing the state of the system given an underlying distribution over all states of the system.[6]

If the two distributions in the weighted sum of scores calculation are different, then that is the cross entropy -H(P,Q) term in figure above (the reverse case of q(x) being the weighting distribution is H(Q,P))

Are there any other loss functions like KL divergence?

  • Yes, there are other [2] measures one of which is based on KL divergence but lacks one of the useful properties of KL divergence, which is asymmetry.
  • That is, if we are comparing two probability distributions P(X) and Q(X), computing KL divergence using P(X) as the weighting/reference distribution is not the same as computing KL divergence using Q(X) as the weighting/reference distribution. The Quora link below explains this well. [3]. So the choice of which distribution to use as the weighting/reference distribution is determined by the problem.

If our task was to classify the images above exclusively as either a cat or a dog, but not a distribution over both classes, then second term on the right in KL divergence equation becomes 0 since the labeled probability distributions in training for our problem above will be one hot vectors of the form [0 1] or [1 0] and log 1 is 0. So KL divergence in this case becomes just the cross entropy error. Also the summation of the first term on the right will also become 0 for all terms except one term because of the one hot vector, leaving just -log q(x) — the negative log likelihood. As our predictions for the right label cat/dog gets close to 1 — the true label value the divergence goes to 0.

  • One subtle point to note here, is we can use KL divergence for both these tasks, (1) classifying an image as just one class or (2) classifying an image into multiple classes (10% cat, 90% dogs).
  • However, our training data needs to be different for both cases. For the first task our training data just has one-hot vectors as labels. For our second task our training data needs to have entity distributions for the images (90% cats, 10% dogs). We cant train our model to get an entity distribution of 90% cats and 10% dogs just by using KL divergence if our training data does not contain such instances.
  • In reality, we are unlikely to want a model to learn 90% cats and dogs, but if we are doing an NLP task like generating an entity distribution for words where the word “cell” could be multiple entity types “phone”, “biological cell”, “prison cell”, we would need to have labeled data during training that would capture entity distributions for the training set — one hot vectors as probability distributions wont suffice since KL divergence is going to fit that distribution which is a “one-zero” vector distribution. This means we will get for the word “cell” one entity type that is say 99% and all the other entity types as nearly zeroes, which is not what we want.

References

  1. Logarithm is Strictly Increasing
  2. Why isn’t the Jensen-Shannon divergence used more often than the Kullback-Leibler (since JS is symmetric, thus possibly a better indicator of distance)?
  3. KL Divergence: Forward vs Reverse?
  4. Kullback–Leibler divergence — Wikipedia
  5. Probability distributions and maximum entropy
  6. This video (Feb 2018) very clearly and crisply explains entropy, cross entropy and KL divergence (unlike this verbose answer!)

Originally published at www.quora.com.

Machine learning practitioner

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store