Search code examples
python-3.xmachine-learningneural-networkautoencoderchainer

Why does this VAE implementation sometimes add a sigmoid operation?


I'm building a Variational Autoencoder (VAE) in Python using the Chainer framework (link). I have found various working examples on github and am trying to adapt one of them. I have been succesful in getting it to run and it works just fine, but there's still something I don't understand.

In the following snippet, defining behavior for the decoder, there's an optional extra sigmoid function:

def decode(self, z, sigmoid=True):
    h = F.leaky_relu(self.ld1(z))
    h = F.leaky_relu(self.ld2(h))
    h = self.ld3(h)
    if sigmoid:
        return F.sigmoid(h)
    else:
        return h

This function is used during training with Sigmoid=False in the loss function:

def lf(x):
    mu, ln_var = self.encode(x)
    batchsize = len(mu)

    # reconstruction loss
    rec_loss = 0
    for l in six.moves.range(k):
        z = F.gaussian(mu, ln_var)
                                                       # ↓here↓
        rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize)
    self.rec_loss = rec_loss

    # adding latent loss
    self.latent_loss = beta * gaussian_kl_divergence(mu, ln_var) / batchsize
    self.loss = self.rec_loss + self.latent_loss
    chainer.report({'rec_loss': self.rec_loss, 'latent_loss': self.latent_loss, 'loss': self.loss}, observer=self)
    return self.loss

And is used with Sigmoid=True (implicitly) when generating examples after training:

z = C.Variable(np.random.normal(0, 1, (self._batchsize, args.dimz)).astype(np.float32))
    with C.using_config('train', False), C.no_backprop_mode():
        xrand = self._model.decode(z)  # ←here
    xrand = np.asarray(xrand.array).reshape(self._batchsize, 3, 18, 11)

Why this extra sigmoid function? What role does it fulfill? Why add it after training, but not during it?


Solution

  • Read the note of this documentation. F.bernoulli_nll's input argument should not be sigmoided, because the function internally contains sigmoid function. Therefore, when feeding the hidden variable to F.bernoulli_nll, sigmoid=False is specified. (I had the exactly same experience of this confusion.)