I'm trying to create a custom LSTMCell in TensorFlow. I have a CPU with 24GB of RAM (No GPU). Firstly I have created an LSTMCell as the default LSTMCell. The code is given below:
class LSTMCell(tf.keras.layers.AbstractRNNCell):
def __init__(self, units, **kwargs):
self.units = units
super(LSTMCell, self).__init__(**kwargs)
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(shape=(input_dim, self.units * 4),name='kernel',initializer='uniform')
self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4),name='recurrent_kernel',initializer='uniform')
self.bias = self.add_weight(shape=(self.units * 4,),name='bias',initializer='uniform')
def _compute_carry_and_output_fused(self, z, c_tm1):
z0, z1, z2, z3 = z
i = K.sigmoid(z0)
f = K.sigmoid(z1)
c = f * c_tm1 + i * K.tanh(z2)
o = K.sigmoid(z3)
return c, o
def call(self, inputs, states, training=None):
h_tm1 = states[0]
c_tm1 = states[1]
z = K.dot(inputs, self.kernel)
z += K.dot(h_tm1, self.recurrent_kernel)
z = K.bias_add(z, self.bias)
z = tf.split(z, num_or_size_splits=4, axis=1)
c, o = self._compute_carry_and_output_fused(z, c_tm1)
h = o * K.sigmoid(c)
self.h = h
self.c = c
return h, [h,c]
This cell is working fine. It's only consuming 8GB of RAM. Then I modified the cell to my need, where I have doubled the parameter. The code is below:
class LSTMCell(tf.keras.layers.AbstractRNNCell):
def __init__(self, units, **kwargs):
self.units = units
super(LSTMCell, self).__init__(**kwargs)
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(shape=(input_dim, self.units * 4),name='kernel',initializer='uniform')
self.recurrent_kernel = self.add_weight(shape=(self.units, self.units * 4),name='recurrent_kernel',initializer='uniform')
self.bias = self.add_weight(shape=(self.units * 4,),name='bias',initializer='uniform')
self.kernel_bits = self.add_weight(shape=(input_dim, self.units * 4),name='_diffq_k',initializer='uniform',trainable=True)
self.recurrent_kernel_bits = self.add_weight(shape=(self.units, self.units * 4),name='_diffq_rk',initializer='uniform',trainable=True)
def _compute_carry_and_output_fused(self, z, c_tm1):
z0, z1, z2, z3 = z
i = K.sigmoid(z0)
f = K.sigmoid(z1)
c = f * c_tm1 + i * K.tanh(z2)
o = K.sigmoid(z3)
return c, o
def call(self, inputs, states, training=None):
h_tm1 = states[0]
c_tm1 = states[1]
z = K.dot(inputs, self.kernel + self.kernel_bits)
z += K.dot(h_tm1, self.recurrent_kernel + self.recurrent_kernel_bits)
z = K.bias_add(z, self.bias)
z = tf.split(z, num_or_size_splits=4, axis=1)
c, o = self._compute_carry_and_output_fused(z, c_tm1)
h = o * K.sigmoid(c)
self.h = h
self.c = c
return h, [h,c]
Now when I try to train with this cell, it consumes all of my RAM within a few seconds and gets killed. The model that I'm using is given below:
input_shape = (1874, 1024)
input = tf.keras.layers.Input(shape=input_shape, name = "input_layer")
x = input
lstm = tf.keras.layers.RNN(LSTMCell(units=input_shape[1]), return_sequences = True)
x = lstm(x)
model = tf.keras.models.Model(input, x, name='my_model')
For the same dataset, the RAM consumption for both cells is a lot different. I have tried reducing input dimensions, and I can only train an lstm of 128 units within my capacity. If I go above that, the RAM gets full, and training gets killed. I have done the same thing in PyTorch, and there was no issue. Can anyone point out the cause of the problem that I'm having?
I found the reason. LSTMCell is being called for every element of the sequence. For the given input shape of (1874, 1024)
, in every forward call, the calculations on call are being done 1874 times, and it's keeping these
intermediate data on memory to calculate the gradients. It was not my intention. I only wanted to do the calculation only one time on every forward call.