I am trying to fit distribution p
to distribution q
with KL divergence.
import torch
p = torch.Tensor([0.1, 0.2, 0.7])
q = torch.Tensor([0.333, 0.334, 0.333])
So I calculate kl divergence by myself:
def kl_div(P, Q):
return (P * (P / Q).log()).sum()
kl_div(p, q)
Result is tensor(0.2972)
Then I found that PyTorch has already implemented the torch.nn.functional.kl_div
function.
I think input
should be the network's output, and target
is a constant.
So I treated p
as the input
and q
as the target
.
But result of
torch.nn.functional.kl_div(p.log(), q, reduction='sum') # tensor(0.3245)
is different from mine.
And this one gets the same result as me.
torch.nn.functional.kl_div(q.log(), p, reduction='sum') # tensor(0.2972)
So what went wrong?
Is there a problem with my understanding of kl divergence?
Or I filled in the wrong parameters of torch.nn.functional.kl_div
?
And another question:
What if I have zero in the distribution?
Such as
p = torch.Tensor([0., 0.3, 0.7])
q = torch.Tensor([0.333, 0.334, 0.333])
I still need to calculate kl divergence in this situation.
KL Divergence is not symmetric. KL(P, Q) != KL(Q, P)
.
Pointwise KL Divergence is defined as y_{true} * log(y_{true}/y_{pred})
.
Following this, your function:
def kl_div(P, Q):
return (P * (P / Q).log()).sum()
Treats P
as the true distribution. From your question, it sounds like you are treating P
as the predicted distribution and Q
as the true distribution, so you have things flipped relative to your function.
You may be confused because the mathematical notation KL(P||Q)
defines P
as a distribution of observations and Q
as a "model" distribution, while the ML context uses P
to denote the output of the model you are training and Q
to denote ground truth observations from a dataset.
To your second question, KL Divergence is undefined when one of the values is zero. That is definitional to the metric. If P(i) = 0 and Q(i) > 0, this means P
says event i
is impossible while Q
says it is possible - there is no measure on that discrepancy.
You can fudge it by adding a small eps to your values, ie torch.nn.functional.kl_div((p+1e-8).log(), q, reduction='sum')
. However if 0 values come up often for your use case, you should consider a different metric.