Search code examples
deep-learningpytorch

Which way to use KL divergence in PyTorch is correct? And what if zero is in distribution?


I am trying to fit distribution p to distribution q with KL divergence.

import torch

p = torch.Tensor([0.1, 0.2, 0.7])
q = torch.Tensor([0.333, 0.334, 0.333])

So I calculate kl divergence by myself:

def kl_div(P, Q):
    return (P * (P / Q).log()).sum()

kl_div(p, q)

Result is tensor(0.2972)

Then I found that PyTorch has already implemented the torch.nn.functional.kl_div function.

I think input should be the network's output, and target is a constant.
So I treated p as the input and q as the target.

But result of

torch.nn.functional.kl_div(p.log(), q, reduction='sum')  # tensor(0.3245)

is different from mine.

And this one gets the same result as me.

torch.nn.functional.kl_div(q.log(), p, reduction='sum')   # tensor(0.2972)

So what went wrong?
Is there a problem with my understanding of kl divergence?
Or I filled in the wrong parameters of torch.nn.functional.kl_div?


And another question:

What if I have zero in the distribution?

Such as

p = torch.Tensor([0., 0.3, 0.7])
q = torch.Tensor([0.333, 0.334, 0.333])

I still need to calculate kl divergence in this situation.


Solution

  • KL Divergence is not symmetric. KL(P, Q) != KL(Q, P).

    Pointwise KL Divergence is defined as y_{true} * log(y_{true}/y_{pred}).

    Following this, your function:

    def kl_div(P, Q):
        return (P * (P / Q).log()).sum()
    

    Treats P as the true distribution. From your question, it sounds like you are treating P as the predicted distribution and Q as the true distribution, so you have things flipped relative to your function.

    You may be confused because the mathematical notation KL(P||Q) defines P as a distribution of observations and Q as a "model" distribution, while the ML context uses P to denote the output of the model you are training and Q to denote ground truth observations from a dataset.

    To your second question, KL Divergence is undefined when one of the values is zero. That is definitional to the metric. If P(i) = 0 and Q(i) > 0, this means P says event i is impossible while Q says it is possible - there is no measure on that discrepancy.

    You can fudge it by adding a small eps to your values, ie torch.nn.functional.kl_div((p+1e-8).log(), q, reduction='sum'). However if 0 values come up often for your use case, you should consider a different metric.