Search code examples
pythontensorflowneural-networklstmrecurrent-neural-network

Strange memory usage of a custom LSTM layer in tensorflow and training gets killed


I'm trying to create a custom LSTMCell in TensorFlow. I have a CPU with 24GB of RAM (No GPU). Firstly I have created an LSTMCell as the default LSTMCell. The code is given below:

class LSTMCell(tf.keras.layers.AbstractRNNCell):
    def __init__(self, units, **kwargs):
        self.units = units
        super(LSTMCell, self).__init__(**kwargs)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(shape=(input_dim, self.units * 4),name='kernel',initializer='uniform')
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4),name='recurrent_kernel',initializer='uniform')
        self.bias = self.add_weight(shape=(self.units * 4,),name='bias',initializer='uniform')

    def _compute_carry_and_output_fused(self, z, c_tm1):
        z0, z1, z2, z3 = z
        i = K.sigmoid(z0)
        f = K.sigmoid(z1)
        c = f * c_tm1 + i * K.tanh(z2)
        o = K.sigmoid(z3)
        return c, o

    def call(self, inputs, states, training=None):
        h_tm1 = states[0] 
        c_tm1 = states[1]
        z = K.dot(inputs, self.kernel)
        z += K.dot(h_tm1, self.recurrent_kernel)
        z = K.bias_add(z, self.bias)
        z = tf.split(z, num_or_size_splits=4, axis=1)
        c, o = self._compute_carry_and_output_fused(z, c_tm1)
        h = o * K.sigmoid(c)
        self.h = h
        self.c = c
        return h, [h,c]

This cell is working fine. It's only consuming 8GB of RAM. Then I modified the cell to my need, where I have doubled the parameter. The code is below:

class LSTMCell(tf.keras.layers.AbstractRNNCell):
    def __init__(self, units, **kwargs):
        self.units = units
        super(LSTMCell, self).__init__(**kwargs)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(shape=(input_dim, self.units * 4),name='kernel',initializer='uniform')
        self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4),name='recurrent_kernel',initializer='uniform')
        self.bias = self.add_weight(shape=(self.units * 4,),name='bias',initializer='uniform')
        self.kernel_bits = self.add_weight(shape=(input_dim, self.units * 4),name='_diffq_k',initializer='uniform',trainable=True)
        self.recurrent_kernel_bits = self.add_weight(shape=(self.units, self.units * 4),name='_diffq_rk',initializer='uniform',trainable=True)

    def _compute_carry_and_output_fused(self, z, c_tm1):
        z0, z1, z2, z3 = z
        i = K.sigmoid(z0)
        f = K.sigmoid(z1)
        c = f * c_tm1 + i * K.tanh(z2)
        o = K.sigmoid(z3)
        return c, o

    def call(self, inputs, states, training=None):
        h_tm1 = states[0] 
        c_tm1 = states[1]
        z = K.dot(inputs, self.kernel + self.kernel_bits)
        z += K.dot(h_tm1, self.recurrent_kernel + self.recurrent_kernel_bits)
        z = K.bias_add(z, self.bias)
        z = tf.split(z, num_or_size_splits=4, axis=1)
        c, o = self._compute_carry_and_output_fused(z, c_tm1)
        h = o * K.sigmoid(c)
        self.h = h
        self.c = c
        return h, [h,c]

Now when I try to train with this cell, it consumes all of my RAM within a few seconds and gets killed. The model that I'm using is given below:

input_shape = (1874, 1024)
input = tf.keras.layers.Input(shape=input_shape, name = "input_layer")
x = input
lstm = tf.keras.layers.RNN(LSTMCell(units=input_shape[1]), return_sequences = True)
x = lstm(x)
model = tf.keras.models.Model(input, x, name='my_model')

For the same dataset, the RAM consumption for both cells is a lot different. I have tried reducing input dimensions, and I can only train an lstm of 128 units within my capacity. If I go above that, the RAM gets full, and training gets killed. I have done the same thing in PyTorch, and there was no issue. Can anyone point out the cause of the problem that I'm having?


Solution

  • I found the reason. LSTMCell is being called for every element of the sequence. For the given input shape of (1874, 1024), in every forward call, the calculations on call are being done 1874 times, and it's keeping these intermediate data on memory to calculate the gradients. It was not my intention. I only wanted to do the calculation only one time on every forward call.