I want use mean_squared_error instead of F.bernoulli_nll as Reconstruct Loss function in my VAE using chainer5.0.0.
I am a Chainer5.0.0 user. I have implemented VAE(Variational Autoencoder). I used below Japanese articles for reference.
class VAE(chainer.Chain):
def __init__(self, n_in, n_latent, n_h, act_func=F.tanh):
super(VAE, self).__init__()
self.act_func = act_func
with self.init_scope():
# encoder
self.le1 = L.Linear(n_in, n_h)
self.le2 = L.Linear(n_h, n_h)
self.le3_mu = L.Linear(n_h, n_latent)
self.le3_ln_var = L.Linear(n_h, n_latent)
# decoder
self.ld1 = L.Linear(n_latent, n_h)
self.ld2 = L.Linear(n_h, n_h)
self.ld3 = L.Linear(n_h, n_in)
def __call__(self, x, sigmoid=True):
return self.decode(self.encode(x)[0], sigmoid)
def encode(self, x):
h1 = self.act_func(self.le1(x))
h2 = self.act_func(self.le2(h1))
mu = self.le3_mu(h2)
ln_var = self.le3_ln_var(h2)
return mu, ln_var
def decode(self, z, sigmoid=True):
h1 = self.act_func(self.ld1(z))
h2 = self.act_func(self.ld2(h1))
h3 = self.ld3(h2)
if sigmoid:
return F.sigmoid(h3)
else:
return h3
def get_loss_func(self, C=1.0, k=1):
def lf(x):
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction error
rec_loss = 0
for l in six.moves.range(k):
z = F.gaussian(mu, ln_var)
z.name = "z"
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize)
self.rec_loss = rec_loss
self.rec_loss.name = "reconstruction error"
self.latent_loss = C * gaussian_kl_divergence(mu, ln_var) / batchsize
self..name = "latent loss"
self.loss = self.rec_loss + self.latent_loss
self.loss.name = "loss"
return self.loss
return lf
I used this code and my VAE has been trained by MNIST and Fashion-MNIST datasets. I have checked my VAE outputs similar images to input images after training.
The rec_loss is Reconstruct Loss, which means how far decoded images from input image. I think we can use mean_squared_error instead of F.bernoulli_nll.
So I have changed my code like below.
rec_loss += F.mean_squared_error(x, self.decode(z)) / k
But after changing my code, the training result acts weird. Output images are same, which means output images do not depend on input images.
What is problem?
I asked this question in Japanese(https://ja.stackoverflow.com/questions/55477/chainer%E3%81%A7vae%E3%82%92%E4%BD%9C%E3%82%8B%E3%81%A8%E3%81%8D%E3%81%ABloss%E9%96%A2%E6%95%B0%E3%82%92bernoulli-nll%E3%81%A7%E3%81%AF%E3%81%AA%E3%81%8Fmse%E3%82%92%E4%BD%BF%E3%81%86%E3%81%A8%E5%AD%A6%E7%BF%92%E3%81%8C%E9%80%B2%E3%81%BE%E3%81%AA%E3%81%84). But nobody has responsed it, so I submit this question here.
When I replace
rec_loss += F.mean_squared_error(x, self.decode(z)) / k
by
rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))
, the problem has been solved.
But why?
They should be identical except for the fact that the latter code using F.mean(F.sum....
only averages along the minibatch axis (since it's already summed over the input data dimension, 784 in case of flattened MNIST), while the former averages over the minibatch axis and the input data dimension. This means that the the latter loss, in case of flattened MNIST, is 784 times larger? I'm assuming k
is 1
.