Search code examples
pythonmachine-learningdeep-learningchainer

VAE does not learn when I change reconstruction loss functions F.bernoulli_nll to F.mean_squared_error in Chainer


I want use mean_squared_error instead of F.bernoulli_nll as Reconstruct Loss function in my VAE using chainer5.0.0.

I am a Chainer5.0.0 user. I have implemented VAE(Variational Autoencoder). I used below Japanese articles for reference.

class VAE(chainer.Chain):

    def __init__(self, n_in, n_latent, n_h, act_func=F.tanh):
        super(VAE, self).__init__()
        self.act_func = act_func
        with self.init_scope():
            # encoder
            self.le1        = L.Linear(n_in, n_h)
            self.le2        = L.Linear(n_h,  n_h)
            self.le3_mu     = L.Linear(n_h,  n_latent)
            self.le3_ln_var = L.Linear(n_h,  n_latent)

            # decoder
            self.ld1 = L.Linear(n_latent, n_h)
            self.ld2 = L.Linear(n_h,      n_h)
            self.ld3 = L.Linear(n_h,      n_in)

    def __call__(self, x, sigmoid=True):
        return self.decode(self.encode(x)[0], sigmoid)

    def encode(self, x):
        h1 = self.act_func(self.le1(x))
        h2 = self.act_func(self.le2(h1))
        mu = self.le3_mu(h2)
        ln_var = self.le3_ln_var(h2) 
        return mu, ln_var

    def decode(self, z, sigmoid=True):
        h1 = self.act_func(self.ld1(z))
        h2 = self.act_func(self.ld2(h1))
        h3 = self.ld3(h2)
        if sigmoid:
            return F.sigmoid(h3)
        else:
            return h3

    def get_loss_func(self, C=1.0, k=1):
        def lf(x):
            mu, ln_var = self.encode(x)
            batchsize = len(mu.data)
            # reconstruction error
            rec_loss = 0
            for l in six.moves.range(k):
                z = F.gaussian(mu, ln_var)
                z.name = "z"
                rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize)
            self.rec_loss = rec_loss
            self.rec_loss.name = "reconstruction error"
            self.latent_loss = C * gaussian_kl_divergence(mu, ln_var) / batchsize
            self..name = "latent loss"
            self.loss = self.rec_loss + self.latent_loss
            self.loss.name = "loss"
            return self.loss
        return lf

I used this code and my VAE has been trained by MNIST and Fashion-MNIST datasets. I have checked my VAE outputs similar images to input images after training.

The rec_loss is Reconstruct Loss, which means how far decoded images from input image. I think we can use mean_squared_error instead of F.bernoulli_nll.

So I have changed my code like below.

rec_loss += F.mean_squared_error(x, self.decode(z)) / k

But after changing my code, the training result acts weird. Output images are same, which means output images do not depend on input images.

What is problem?

I asked this question in Japanese(https://ja.stackoverflow.com/questions/55477/chainer%E3%81%A7vae%E3%82%92%E4%BD%9C%E3%82%8B%E3%81%A8%E3%81%8D%E3%81%ABloss%E9%96%A2%E6%95%B0%E3%82%92bernoulli-nll%E3%81%A7%E3%81%AF%E3%81%AA%E3%81%8Fmse%E3%82%92%E4%BD%BF%E3%81%86%E3%81%A8%E5%AD%A6%E7%BF%92%E3%81%8C%E9%80%B2%E3%81%BE%E3%81%AA%E3%81%84). But nobody has responsed it, so I submit this question here.

Solution?

When I replace

rec_loss += F.mean_squared_error(x, self.decode(z)) / k 

by

rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))

, the problem has been solved.

But why?


Solution

  • They should be identical except for the fact that the latter code using F.mean(F.sum.... only averages along the minibatch axis (since it's already summed over the input data dimension, 784 in case of flattened MNIST), while the former averages over the minibatch axis and the input data dimension. This means that the the latter loss, in case of flattened MNIST, is 784 times larger? I'm assuming k is 1.