Search code examples
pythondeep-learningpytorchlstm

Variational Autoencoder (VAE) returns consistent output


I'm working on the signal compression and reconstruction with VAE. I've trained 1600 fragments but the values of 1600 reconstructed signals are very similar. Moreover, results from same batch are almost consistent. As using the VAE, loss function of the model contains binary cross entropy (BCE) and the output of the train model should be located between 0 to 1 (The input data also normalized to 0~1).

VAE model(LSTM) :

class LSTM_VAE(nn.Module):
def __init__(self,
             input_size=3000,
             hidden=[1024, 512, 256, 128, 64],
             latent_size=64,
             num_layers=8,
             bidirectional=True):
    super().__init__()

    self.input_size = input_size
    self.hidden = hidden
    self.latent_size = latent_size
    self.num_layers = num_layers
    self.bidirectional = bidirectional

    self.actv = nn.LeakyReLU()

    self.encode = nn.LSTM(input_size=self.input_size,
                          hidden_size=self.hidden[0],
                          num_layers=self.num_layers,
                          batch_first=True,
                          bidirectional=True)
    self.bn_encode = nn.BatchNorm1d(1)

    self.decode = nn.LSTM(input_size=self.latent_size,
                          hidden_size=self.hidden[2],
                          num_layers=self.num_layers,
                          batch_first=True,
                          bidirectional=True)
    self.bn_decode = nn.BatchNorm1d(1)

    self.fc1 = nn.Linear(self.hidden[0]*2, self.hidden[1])
    self.fc2 = nn.Linear(self.hidden[1], self.hidden[2])
    self.fc31 = nn.Linear(self.hidden[2], self.latent_size)
    self.fc32 = nn.Linear(self.hidden[2], self.latent_size)
    self.bn1 = nn.BatchNorm1d(1)
    self.bn2 = nn.BatchNorm1d(1)
    self.bn3 = nn.BatchNorm1d(1)

    self.fc4 = nn.Linear(self.hidden[2]*2, self.hidden[1])
    self.fc5 = nn.Linear(self.hidden[1], self.hidden[0])
    self.fc6 = nn.Linear(self.hidden[0], self.input_size)
    self.bn4 = nn.BatchNorm1d(1)
    self.bn5 = nn.BatchNorm1d(1)
    self.bn6 = nn.BatchNorm1d(1)

def encoder(self, x):
    x = torch.unsqueeze(x, 1)
    x, _ = self.encode(x)
    x = self.actv(x)
    x = self.fc1(x)
    x = self.actv(x)
    x = self.fc2(x)
    x = self.actv(x)

    mu = self.fc31(x)
    log_var = self.fc32(x)

    return mu, log_var

def decoder(self, z):
    z, _ = self.decode(z)
    z = self.bn_decode(z)
    z = self.actv(z)
    z = self.fc4(z)
    z = self.bn4(z)
    z = self.fc5(z)
    z = self.bn5(z)
    z = self.fc6(z)
    z = self.bn6(z)
    z = torch.sigmoid(z)

    return torch.squeeze(z)

def sampling(self, mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)

    return mu + eps * std

def forward(self, x):
    mu, log_var = self.encoder(x.view(-1, self.input_size))
    z = self.sampling(mu, log_var)
    z = self.decoder(z)

    return z, mu, log_var

Loss function and Train code :

def lossF(recon_x, x, mu, logvar, input_size):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, input_size), reduction='sum')
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

return BCE + KLD

optim = torch.optim.Adam(model.parameters(), lr=opt.lr)

for epoch in range(opt.epoch):
    for batch_idx, data in enumerate(train_set):
        data = data.to(device)
        optim.zero_grad()
        recon_x, mu, logvar = model(data)
        loss = lossF(recon_x, data, mu, logvar, opt.input_size)
        loss.backward()
        train_loss += loss.item()
        optim.step()

I built the code by refer the example codes of others and only changed very few parameters. I rebuilt the code, change the dataset, update parameters but nothing worked. If you have any suggestion to solve this problem, PLEASE let me know.


Solution

  • I've find out the reason of the issue. It turns out that the decoder model derives output value in the range of 0.4 to 0.6 to stabilize the BCE loss. BCE loss can't be 0 even if the prediction is correct to answer. Also the loss value is non-linear to the range of the output. The easiest way to lower the loss is give 0.5 for the output, and my model did. To avoid this error, I standardize my data and added some outlier data to avoid BCE issue. VAE is such complicated network for sure.