Search code examples
pytorchautoencoder

Why VAE model in pytorch doesn’t use torch.nn.KLDivLoss?


  • I saw several examples of VAE implementations.
  • In VAE, we use loss with 2 parts: MSE and KLDivLoss
  • In all the examples I saw, they wrote the VAE loss (MSE+KL LOSS) in their own implementation and didn't use: torch.nn.KLDivLoss

One example can be found here: https://github.com/AntixK/PyTorch-VAE/blob/master/models/vanilla_vae.py

Why did they implement the part of the KLDivLoss and didn't use torch.nn.KLDivLoss?


Solution

  • torch.nn.KLDivLoss is KL divergence between two multinomial distributions and takes the distributions p, q as input. It computes the following:

    \sum_{i=0}^{C-1} q[i]\log p[i]/q[i]

    However, for VAE, you need KL div between two gaussian distributions. KLDivLoss won't compute this. Instead, this is computed with a closed-form formula.