Search code examples
pythonkerasdeep-learningneural-network

My autoencoder was not learning to predict value [Updated]


I am trying to work on building a variational autoencoder in Keras, with an input shape of X= (1,50) and Y= (1,20).

I have uploaded the DataSet, you can download it from here.

I made 1 input, and I want to make relation between the input and output. ( the data is 1 dimension of binary cases). but always I found these results:

enter image description here

I tried changing activation and loss and no positive results.

from keras.layers import Lambda, Input, Dense, Dropout
from keras.models import Model
from keras import backend as K, optimizers
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras.optimizers

# Function for reparameterization trick
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

# Load your data
# Note: Replace this with your actual data loading
# training_feature = X
# ground_truth_r = Y

original_dim = 32  # Adjust according to your data shape
latent_dim = 32

# Encoder network
inputs_x = Input(shape=(original_dim, ), name='encoder_input')
inputs_x_dropout = Dropout(0.25)(inputs_x)
inter_x1 = Dense(128, activation='tanh')(inputs_x_dropout)
inter_x2 = Dense(64, activation='tanh')(inter_x1)
z_mean = Dense(latent_dim, name='z_mean')(inter_x2)
z_log_var = Dense(latent_dim, name='z_log_var')(inter_x2)
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
encoder = Model(inputs_x, [z_mean, z_log_var, z], name='encoder')


# Decoder network for reconstruction
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
inter_y1 = Dense(64, activation='tanh')(latent_inputs)
inter_y2 = Dense(128, activation='tanh')(inter_y1)

outputs_reconstruction = Dense(original_dim)(inter_y2)  # original_dim should be 32
decoder = Model(latent_inputs, outputs_reconstruction, name='decoder')
decoder.compile(optimizer='adam', loss='mean_squared_error')


from keras.models import Model, Sequential
from keras.layers import  BatchNormalization

# Predictor network
# Start of the predictor model
latent_input_for_predictor = Input(shape=(latent_dim,))

# Building the predictor model using the functional API
x = Dense(1024, activation='relu')(latent_input_for_predictor)
x = Dense(512, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(32, activation='relu')(x)
x = BatchNormalization()(x)
predictor_output = Dense(Y.shape[1], activation='linear')(x)  # Adjust the output dimension as per your requirement




if ( 1 == 1): 
  # Create the model
  predictor = Model(inputs=latent_input_for_predictor, outputs=predictor_output)

  # Compile the model
  optimizer = optimizers.Adam(learning_rate=0.001)
  predictor.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])




  # Train the reconstruction model
  history_reconstruction = decoder.fit(X, X, epochs=100, batch_size=100, shuffle=True, validation_data=(XX, XX))


  latent_representations = encoder.predict(X)[2]
  #vae.fit([training_feature_sk,training_score],            epochs=epochs,             batch_size=batch_size, verbose = 0)
  # Train the prediction model
  history_prediction = predictor.fit(latent_representations, Y, epochs=100, batch_size=100, shuffle=True, validation_data=(encoder.predict(XX)[2], YY))




  # Save models and plot training/validation loss
  encoder.save("BrmEnco_Updated.h5", overwrite=True)
  decoder.save("BrmDeco_Updated.h5", overwrite=True)
  predictor.save("BrmPred_Updated.h5", overwrite=True)

  plt.figure(figsize=(12, 4))

  plt.subplot(1, 2, 1)
  plt.plot(history_reconstruction.history['loss'], label='Decoder Training Loss')
  plt.plot(history_reconstruction.history['val_loss'], label='Decoder Validation Loss')
  plt.title('Decoder Loss')
  plt.ylabel('Loss')
  plt.xlabel('Epoch')
  plt.legend()

  plt.subplot(1, 2, 2)
  plt.plot(history_prediction.history['loss'], label='Predictor Training Loss')
  plt.plot(history_prediction.history['val_loss'], label='Predictor Validation Loss')
  plt.title('Predictor Loss')
  plt.ylabel('Loss')
  plt.xlabel('Epoch')
  plt.legend()

  plt.show()

Solution

  • I ran it on some digits images in sklearn, so I increased the input shape from 32 to 36, and set the predictor output shape to 10 classes. I normalised the input data and used a smaller batch size.

    Input data X is a 36-dimensional binary vector (e.g. sample 0 = [1 0 1 0 ...]) and y is a 9-dimensional multilabel indicator target. The multilabel information is y is: [is odd?, is prime?, is multiple of 3?, ...].

    I used binary cross-entropy loss for both X and Y. This is to strongly penalise wrong digits. Both the decoder and prediction outputs are sigmoid. I removed the dropout layer as it seemed to hurt performance, and threw in some batch norm layers. I used the NAdam optimizer.

    I run .fit for 1 epoch, and then manually calculate the train and validation accuracies, before running the next epoch. Initially I was using Keras' accuracy values, but found them to be incorrect for some outputs.

    enter image description here

    Loss and accuracy curves:

    enter image description here

    The solid lines are loss, and show that the model is converging. The model memorises the train Y (dotted orange - 100%), and on the validation Y it gets about 70% right (dashed orange). So its performance with Y is quite good. It doesn't perfectly reconstruct X - it gets X exactly right about 50% of the time (dotted blue), and on the validation X it only gets it exactly right <5% of the time (dashed blue).

    Although its reconstructions seem to have relatively low accuracy, this is because if only a single digit is wrong in the 36-dim reconstruction, it is considered incorrect (according to the metric specified in the comments). Its actual reconstructions are close to the originals and look quite good. Even though it doesn't get X exactly right, it is often very close, differing by a pixel here or there:

    left: original X train, right: reconstruction enter image description here

    Since the reconstructions are close to the originals, to me that suggests the latent space is not bad, and is perhaps usable. I think something to consider is whether X needs to be exact if one is interpolating (i.e. making up new values) in latent space.

    from keras.layers import Lambda, Input, Dense, Dropout, BatchNormalization
    from keras.models import Model
    from keras import backend as K
    
    import numpy as np
    import matplotlib.pyplot as plt
    import tensorflow as tf
    
    from sklearn import set_config
    set_config(transform_output='default')
    
    #Set the random seed for consistent results
    import random
    random.seed(0)
    tf.random.set_seed(0)
    np.random.seed(0)
    #clear session for each run
    K.clear_session()
    
    #
    #Load digits data
    #
    from sklearn.datasets import load_digits
    from sklearn.preprocessing import StandardScaler
    
    digits = load_digits()
    X, Y_original = digits['data'], digits['target']
    
    #Create a multilabel-indicator Y: each y is a 9-dim *binary* vector
    Y_original = Y_original.reshape(-1, 1).astype(int)
    Y_multilabel = np.empty((Y_original.shape[0], 9))
    for idx, y in enumerate(Y_original):
        is_odd = bool(y % 2)
        is_prime = True if y in [1, 3, 5, 7] else False
        is_multiple_3 = True if y in [3, 6, 9] else False
        is_large = True if y >= 5 else False
        is_extreme = True if y in [0, 9] else False
        binary_digits = [int(d) for d in bin(y[0] ** 2 + 10)[-4:]] #last 4 digits
        
        Y_multilabel[idx, :] = np.array([
            is_odd, is_prime, is_multiple_3, is_large, is_extreme] + binary_digits
        ).astype(int)
    
    Y = Y_multilabel
    
    #Create a 32-dim binary X
    X = X.reshape(-1, 8, 8)[:, 1:-1, 1:-1].reshape(-1, 36)
    X = StandardScaler().fit_transform(X)
    X = np.where(X < 0, 0, 1) #make X binary
    
    input_dim = X.shape[1]
    multilabel_size = Y_multilabel.shape[1]
    
    #View some samples
    f, axs = plt.subplots(5, 5, figsize=(4, 4), layout='tight')
    for i, ax in enumerate(axs.flatten()):
        ax.imshow(X[i, :].reshape(6, 6), cmap='binary')
        ax.axis('off')
        y_label = str(Y[i].astype(int)).replace(' ', '')[1:-1]
        ax.set_title(y_label, fontsize=8)
    f.suptitle('Samples from normalised digits data', fontsize=10)
    plt.show()
    
    # reparameterization trick
    # instead of sampling from Q(z|X), sample eps = N(0,I)
    # z = z_mean + sqrt(var)*eps
    def sampling(args):
        z_mean, z_log_var = args
        batch = K.shape(z_mean)[0]
        dim = K.int_shape(z_mean)[1]
        # by default, random_normal has mean=0 and std=1.0
        epsilon = K.random_normal(shape=(batch, dim))
        thre = K.random_uniform(shape=(batch,1))
        return z_mean + K.exp(0.5 * z_log_var) * epsilon
    
    # Define VAE model components
    intermediate_dim = 32 // 1
    latent_dim = 32 // 1
    
    # Encoder network
    inputs_x = Input(shape=input_dim, name='encoder_input')
    # inputs_x_dropout = Dropout(0.25)(inputs_x)
    inputs_x_dropout = Dense(1024, activation='relu')(inputs_x)
    inputs_x_dropout = BatchNormalization()(inputs_x_dropout)
    inputs_x_dropout = Dense(512, activation='relu')(inputs_x_dropout)
    inputs_x_dropout = BatchNormalization()(inputs_x_dropout)
    inputs_x_dropout = Dense(224, activation='relu')(inputs_x_dropout)
    
    inputs_x_dropout = BatchNormalization()(inputs_x_dropout)
    inter_x1 = Dense(128, activation='relu')(inputs_x_dropout)
    inter_x2 = Dense(intermediate_dim, activation='relu')(inter_x1)
    
    z_mean = Dense(latent_dim, name='z_mean')(inter_x2)
    z_log_var = Dense(latent_dim, name='z_log_var')(inter_x2)
    z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
    encoder = Model(inputs_x, [z_mean, z_log_var, z], name='encoder')
    
    # Decoder network for reconstruction
    latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
    inter_y1 = Dense(intermediate_dim, activation='relu')(latent_inputs)
    
    inter_y1 = Dense(224, activation='relu')(inter_y1)
    inter_y1 = BatchNormalization()(inter_y1)
    inter_y1 = Dense(512, activation='relu')(inter_y1)
    inter_y1 = BatchNormalization()(inter_y1)
    inter_y1 = Dense(1024, activation='relu')(inter_y1)
    inter_y1 = BatchNormalization()(inter_y1)
    
    inter_y2 = Dense(128, activation='relu')(inter_y1)
    outputs_reconstruction = Dense(input_dim, activation='sigmoid')(inter_y2)
    decoder = Model(latent_inputs, outputs_reconstruction, name='decoder')
    
    # Separate network for multilabel indicator prediction from inter_y2
    outputs_prediction = Dense(multilabel_size, activation='sigmoid')(inter_y2)
    predictor = Model(latent_inputs, outputs_prediction, name='predictor')
    
    # Instantiate VAE model with two outputs
    outputs_vae = [decoder(z), predictor(z)]
    vae = Model(inputs_x, outputs_vae, name='vae_mlp')
    vae.compile(optimizer='nadam', loss='binary_crossentropy')
    
    # Train the model
    val_size = 360 #20% val size
    
    X_trn = X[:val_size]
    Y_trn = Y[:val_size]
    
    X_val = X[-val_size:]
    Y_val = Y[-val_size:]
    
    from collections import defaultdict
    metrics = defaultdict(list)
    
    for epoch in range(70):
        history = vae.fit(X_trn, [X_trn, Y_trn], batch_size=32, shuffle=True)
    
        h = history.history
        metrics['trn_predictor_loss'].extend(h['predictor_loss'])
        metrics['trn_decoder_loss'].extend(h['decoder_loss'])
        metrics['trn_loss'].extend(h['loss'])
        
        #Manually calculate accuracy for trn and val
        for mode in ['trn', 'val']:
            XY = [X_trn, Y_trn] if mode == 'trn' else [X_val, Y_val]
            n_samples = len(XY[0])
            
            soft_recon, soft_pred = vae.predict(XY[0])
            
            hard_recon = (soft_recon > 0.5).astype(int)
            hard_pred = (soft_pred > 0.5).astype(int)
            
            recon_acc = sum(
                [np.array_equal(xhat, x) for xhat, x in zip(hard_recon, XY[0])]
            ) / n_samples * 100
            
            pred_acc = sum(
                [np.array_equal(yhat, y) for yhat, y in zip(hard_pred, XY[1])]
            ) / n_samples * 100
            
            metrics[mode + '_decoder_acc'].append(recon_acc)
            metrics[mode + '_predictor_acc'].append(pred_acc)
    
    plt.plot(metrics['trn_loss'], 'C3', lw=2, label='loss')
    plt.plot(metrics['trn_decoder_loss'], 'C0', lw=2, label='loss | decoder')
    plt.plot(metrics['trn_predictor_loss'], 'C1', lw=2, label='loss | predictor')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    
    ax2 = plt.gca().twinx()
    ax2.plot(metrics['trn_decoder_acc'], 'C0', ls=':', label='trn acc | decoder')
    ax2.plot(metrics['trn_predictor_acc'], 'C1', ls=':', label='trn acc | predictor')
    
    ax2.plot(metrics['val_decoder_acc'], 'C0', ls='--', label='val acc | decoder')
    ax2.plot(metrics['val_predictor_acc'], 'C1', ls='--', label='val acc | predictor')
    ax2.set_ylabel('accuracy (%)')
    plt.gcf().legend(bbox_to_anchor=(0.7, 1.1), ncol=2)
    plt.gcf().set_size_inches(7, 4)
    
    soft_recon, soft_pred = vae.predict(X)
    
    #Convert soft predictions (probabilities) to hard binary 0/1
    recon_binary = soft_recon > 0.5
    pred_binary = soft_pred > 0.5
    
    f, axs = plt.subplots(nrows=25, ncols=2, figsize=(3, 35))
    axs = axs.flatten()
    for i, ax in zip(range(1, len(axs), 2), (axs[1::2])):
        ax.imshow(recon_binary[i, :].reshape(6, 6), cmap='binary')
        axs[i - 1].imshow(X[i, :].reshape(6, 6), cmap='binary')
        
        y_multilabel = str(Y[i].astype(int)).replace(' ', '')[1:-1]
        yhat = str(pred_binary[i].astype(int)).replace(' ', '')[1:-1]
        ax.set_title('$\hat{y}$:' + yhat + '\n$y$:' + y_multilabel,
                     fontsize=8, fontproperties={'family': 'monospace'})
    # f.suptitle('Digit reconstructions and predictions', fontsize=10)
    [ax.axis('off') for ax in axs]
    plt.tight_layout()
    plt.show()