Tensorflow.keras: RNN to classify Mnist

I am trying to understand the tensorflow.keras.layers.SimpleRNN by building a simple digits classifier. The digits of Mnist dataset are of size 28X28. So the main idea is to present each line of the image in a time t. I have seem this idea in some blogs, for instance, this one, where it presents this image:

So my RNN is like this:

self.model = Sequential()        
self.model.add(layers.SimpleRNN(128, input_shape=(28,28)))
self.model.add(Dense(self.output_size, activation='softmax'))

I know that RNN is defined using the following equations:


W={w_{hh},w_{xh}} and V={v}.

input vector: x_t.

Update equations:

h_t=f(w_{hh} h_{t-1}+w_{xh} x_t).

y = v h_t.


  1. What is exacly "units=128" defining? Is the number of neurons of W_hh, w_xh? Is there anyplace where I can find this information?

  2. If I run self.model.summary()

I get

Layer (type)                 Output Shape              Param #   
simple_rnn (SimpleRNN)       (None, 128)               20096     
dense_35 (Dense)             (None, 10)                1290      
Total params: 21,386
Trainable params: 21,386
Non-trainable params: 0

How do I go from the number of units to these numbers of parameters "20096" and "1290"?

  1. In the case of this example the sequence has always the same size. However, it I am dealing with text, the sequence has variable size. So, what exacly input_shape=(28,28) means? I could not find this information anywhere.


    1. Units is the number of neurons, which is the dimensionality of the output for that layer. This information can be found at the documentation.

    2. The number of parameters are dependent on the layer input and the number of units. For the SimpleRNN layer this is 128 * 128 + 128 * 28 + 128 = 20096 (see this answer). For the dense layer this is 128 * 10 + 10 = 1290. These 10 and 128 that are added are because of the bias weights in the layer, which is turned on by default.

    3. input_shape = (28, 28) means that your network will handle inputs of size 28x28 data points. Since the first dimension is the batch dimension, it will handle 28 vectors of length 28 (as depicted in your image). Inputs of a variable length are usually split up to fit in the given input_shape. If an input is smaller than the input_shape, padding can be applied to make it fit.