Search code examples
pythonmachine-learningkerasdeep-learningimage-recognition

How to convert keras sequential API to functional API


I am new to deep learning, and am trying to convert this sequential API into a functional API to run on the CIFAR 10 dataset. Below is the sequential API:

model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu')

model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))

And here is my attempt at converting this into the functional API:

model_input = Input(shape=input_shape)

x = Conv2D(32, (3, 3), activation='relu',padding='valid')(model_input)
x = MaxPooling2D((2,2))(x)
x = Conv2D(32, (3, 3), activation='relu')(x)
x = MaxPooling2D((2,2))(x)
x = Conv2D(32, (3, 3))(x)

x = GlobalAveragePooling2D()(x)
x = Activation(activation='softmax')(x)

model = Model(model_input, x, name='nin_cnn')

x = layers.Flatten()
x = layers.Dense(64, activation='relu')
x = layers.Dense(10)

Here is the compile and train code:

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(train_images, train_labels, epochs=10, 
                    validation_data=(test_images, test_labels))

The original sequential API gets an accuracy of 0.7175999879837036, while the functional API gets an accuracy of 0.0502999983727932. Not sure where I have gone wrong when re-writing the code, any help would be appreciated. Thanks.


Solution

  • Your two models are not the same. The second and third convolutional layers are having 64 units and 32 units respectively for sequential and functional model in your sample code. And you did not include fully-connected layer in your functional model (you created those layer only after you constructed the model).

    If you doubt in the future, you can try to do

    model.summary()
    

    and compare to see if the models are the same.