Search code examples
pythonneural-networkartificial-intelligence

What is the number of nodes in input layer of the following artificial neural network?


This is for MNIST datset

model = Sequential()    
model.add(Flatten(input_shape=((28,28))))             

Doesnt flatten create the input layer to have 728 nodes, so that each each pixel value is inputted in the neurons/nodes?


Solution

  • Yes, flatten take a multi-dimensional tensor of shape and return a vector tensor so the output shape will be the product of each dimension. In your example the output will be of size 28 * 28 = 784.