Search code examples
tensorflowkerasconv-neural-networkmnist

Extreme Overfitting CNN for MNIST


Im coding a simple CNN for classify mnist digits, something fairly simple, but the model is overfitting very fast, by a wide margin

i implement counter_overfitting techniqques like dropout, batch norm, data augmentation but the model simple never improves

import tensorflow as tf
import tensorflow
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
from PIL import Image

class ConvBlock(tf.keras.layers.Layer):
    """Convolutional Block featuring Conv2D + Pooling"""

    def __init__(self, conv_deep=1, kernels=32, kernel_size=3, pool_size=2, dropout_rate=0.4):
        super(ConvBlock, self).__init__(self)

        self.conv_layers = []
        self.pooling_layers = []
        self.bnorm_layers = []
        self.dropout_layers = []
        for index in range(0, conv_deep):
            self.conv_layers.append(tf.keras.layers.Conv2D(filters=kernels, kernel_size=kernel_size, padding="same", activation="relu"))
            self.pooling_layers.append(tf.keras.layers.MaxPool2D(pool_size=pool_size))
            self.bnorm_layers.append(tf.keras.layers.BatchNormalization())
            self.dropout_layers.append(tf.keras.layers.Dropout(dropout_rate))

    def call(self, inputs, training=False):
        output = inputs

        for (conv, pooling, bnorm, dropout) in zip(self.conv_layers, self.pooling_layers, self.bnorm_layers, self.dropout_layers):
            output = conv(output)
            output = pooling(output)
            output = bnorm(output)

            if training:
                output = dropout(output)

        return output

class DigitsClassifier(tf.keras.Model):
    """MNIST Digit Classifier"""

    def __init__(self):
        super(DigitsClassifier, self).__init__(self)

        self.conv_input = ConvBlock(conv_deep=2, kernels=32)
        self.conv_hiden = ConvBlock(conv_deep=1, kernels=16)

        self.flatten = tf.keras.layers.Flatten()
        self.hiden = tf.keras.layers.Dense(50, "relu")
        self.bnorm = tf.keras.layers.BatchNormalization()
        self.softmax = tf.keras.layers.Dense(10, "softmax")

    def call(self, inputs):

        output = self.conv_input(inputs)
        output = self.conv_hiden(output)

        output = self.flatten(output)
        output = self.hiden(output)
        output = self.bnorm(output)
        output = self.softmax(output)

        return output

#Load Train Data
(train_digits, train_labels), (eval_digits, eval_labels) = tf.keras.datasets.mnist.load_data("./Resources")
kaggle_digits = pd.read_csv("./Resources/test.csv").values

#Preprocess
train_digits = np.reshape(train_digits, [np.shape(train_digits)[0], 28, 28, 1])/255.0
eval_digits = np.reshape(eval_digits, [np.shape(eval_digits)[0], 28, 28, 1])/255.0
kaggle_digits = np.reshape(kaggle_digits, [np.shape(kaggle_digits)[0], 28, 28, 1])/255.0

#Generator
def get_sample(digits, return_labels=False, labels=None):
    if(return_labels):
        if(np.shape(digits)[0] == np.shape(labels)[0]):
            for index in range(0, np.shape(digits)[0]):
                yield (digits[index], labels[index])
        else:
            raise ValueError("Digits and Labels dont have the same numberof samples")
    else:
        for index in range(0, np.shape(digits)[0]):
            yield (digits[index])

def transform_sample(digit, label):
    rot = random.randint(-1, 2)
    t_digit = digit
    t_digit = tf.compat.v2.image.rot90(t_digit, rot)

    return t_digit, label

#Define datasets
train_ds = tf.data.Dataset.from_generator(get_sample, (tf.float32, tf.int32), args=[train_digits, True, train_labels]).map(transform_sample, 100).batch(1000).prefetch(2)
eval_ds = tf.data.Dataset.from_generator(get_sample, (tf.float32, tf.int32), args=[eval_digits, True, eval_labels]).batch(1000).prefetch(2)
kaggle_ds = tf.data.Dataset.from_generator(get_sample, (tf.float32), args=[kaggle_digits]).batch(1000).prefetch(2)


for digits, label in train_ds.take(1):
    print(label)
    sns.regplot(data=digits)
    plt.show()


#Define model and load weights (Pretrained on google colab notebook)
model = DigitsClassifier()
model.compile(tf.keras.optimizers.Adadelta(7.0), tf.keras.losses.SparseCategoricalCrossentropy())
model.fit(train_ds, epochs=50, verbose=2, validation_data=eval_ds)

at this point i really dont know what to try, i will reduce model complexity but i dont think that will help

PD: counter-intuitive, stop using data augmentation techniques let the model improves, my data augmentation simple consist in a map function transform_sample that perform a random 90 degree rotation to every image, or not rotation at all


Solution

  • Since the Model is Overfitting, you can

    1. Shuffle the Data, by using shuffle=True in cnn_model.fit. Code is shown below:

      model.fit(train_ds, epochs=50, verbose=2, shuffle = True, validation_data=eval_ds)

    2. Use Early Stopping. Code is shown below

      callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=15)

      model.fit(train_ds, epochs=50, verbose=2, callbacks=[callback], validation_data=eval_ds)

    3. Use Regularization. Code for Regularization is shown below (You can try l1 Regularization or l1_l2 Regularization as well). Code for the same is shown below.

    from tensorflow.keras.regularizers import l2

    Regularizer = l2(0.001)

    self.conv_layers.append(tf.keras.layers.Conv2D(filters=kernels, kernel_size=kernel_size, padding="same", activation="relu", activity_regularizer=Regularizer, kernel_regularizer=Regularizer))

    self.hiden = tf.keras.layers.Dense(50, "relu", activity_regularizer=Regularizer, kernel_regularizer=Regularizer))

    self.softmax = tf.keras.layers.Dense(10, "softmax", activity_regularizer=Regularizer, kernel_regularizer=Regularizer)

    1. Finally, if there still no change, you can try using Pre-Trained Models like ResNet or VGG Net, etc..