Search code examples
tensorflowkerasdeep-learninglstmautoencoder

How to correctly ignore padded or missing timesteps at decoding time in multi-feature sequences with LSTM autonecoder


I am trying to learn a latent representation for text sequence (multiple features (3)) by doing reconstruction USING AUTOENCODER. As some of the sequences are shorter than the maximum pad length or a number of time steps I am considering (seq_length=15), I am not sure if reconstruction will learn to ignore the timesteps or not for calculating loss or accuracies.

I followed suggestions from this answer to crop the outputs but my losses are nan and several of accuracies as well.

input1 = keras.Input(shape=(seq_length,),name='input_1')
input2 = keras.Input(shape=(seq_length,),name='input_2')
input3 = keras.Input(shape=(seq_length,),name='input_3')
input1_emb = layers.Embedding(70,32,input_length=seq_length,mask_zero=True)(input1)
input2_emb = layers.Embedding(462,192,input_length=seq_length,mask_zero=True)(input2)
input3_emb = layers.Embedding(84,36,input_length=seq_length,mask_zero=True)(input3)
merged = layers.Concatenate()([input1_emb, input2_emb,input3_emb])
activ_func = 'tanh'
encoded = layers.LSTM(120,activation=activ_func,input_shape=(seq_length,),return_sequences=True)(merged) #
encoded = layers.LSTM(60,activation=activ_func,return_sequences=True)(encoded)
encoded = layers.LSTM(15,activation=activ_func)(encoded)

# Decoder reconstruct inputs
decoded1 = layers.RepeatVector(seq_length)(encoded)
decoded1 = layers.LSTM(60, activation= activ_func , return_sequences=True)(decoded1)
decoded1 = layers.LSTM(120, activation= activ_func , return_sequences=True,name='decoder1_last')(decoded1)

Decoder one has an output shape of (None, 15, 120).

input_copy_1 = layers.TimeDistributed(layers.Dense(70, activation='softmax'))(decoded1)
input_copy_2 = layers.TimeDistributed(layers.Dense(462, activation='softmax'))(decoded1)
input_copy_3 = layers.TimeDistributed(layers.Dense(84, activation='softmax'))(decoded1)

For each output, I am trying to crop the O padded timesteps as suggested by this answer. padding has 0 where actual input was missing (had zero due to padding) and 1 otherwise

@tf.function
def cropOutputs(x):
    #x[0] is softmax of respective feature (time distributed) on top of decoder
    #x[1] is the actual input feature
    padding =  tf.cast( tf.not_equal(x[1][1],0), dtype=tf.keras.backend.floatx())
    print(padding)
    return x[0]*tf.tile(tf.expand_dims(padding, axis=-1),tf.constant([1,x[0].shape[2]], tf.int32))

Applying crop function to all three outputs.

input_copy_1 = layers.Lambda(cropOutputs, name='input_copy_1', output_shape=(None, 15, 70))([input_copy_1,input1])
input_copy_2 = layers.Lambda(cropOutputs, name='input_copy_2', output_shape=(None, 15, 462))([input_copy_2,input2])
input_copy_3 = layers.Lambda(cropOutputs, name='input_copy_3', output_shape=(None, 15, 84))([input_copy_3,input3])

My logic is to crop timesteps of each feature (all 3 features for sequence have the same length, meaning they miss timesteps together). But for timestep, they have been applied softmax as per their feature size (70,462,84) so I have to zero out timestep by making a multi-dimensional mask array of zeros or ones equal to this feature size with help of mask padding, and multiply by respective softmax representation using this using multi-dimensional mask array.

I am not sure I am doing this right or not as I have Nan losses for these inputs as well as other accuracies have that I am learning jointly with this task (it happens only with this cropping thing).


Solution

  • If it helps someone, I end up cropping the padded entries from the loss directly (taking some keras code pointer from these answers).

    @tf.function
    def masked_cc_loss(y_true, y_pred):
    
            mask = tf.keras.backend.all(tf.equal(y_true, masked_val_hotencoded), axis=-1)
            mask = 1 - tf.cast(mask, tf.keras.backend.floatx())    
     
            loss = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred) * mask 
            
            return tf.keras.backend.sum(loss) / tf.keras.backend.sum(mask) #  averaging by the number of unmasked entries