Search code examples
pythonkerasdeep-learningresnettransfer-learning

Why is my resnet50 model in Keras not converging?


I am currently trying to classify integrated circuits in defect and non defect images. I already tried VGG16 and InceptionV3 and got really good results for both (95% validation accuracy and low val loss). Now I wanted to try resnet50 but my model is not converging. Its accuracy is at 95 % too but the validation loss keeps increasing while the val acc gets stuck at 50 %.

This is my script so far:

from keras.applications.resnet50 import ResNet50
from keras.optimizers import Adam
from keras.preprocessing import image
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D, Dropout
from keras import backend as K
from keras_preprocessing.image import ImageDataGenerator
import tensorflow as tf

class ResNet:
    def __init__(self):
        self.img_width, self.img_height = 224, 224  # Dimensions of cropped image
        self.classes_num = 2  # Number of classifications

        # Training configurations
        self.epochs = 32
        self.batch_size = 16  # Play with this to determine number of images to train on per epoch
        self.lr = 0.0001

    def build_model(self, train_path):
        train_data_path = train_path
        train_datagen = ImageDataGenerator(rescale=1. / 255, validation_split=0.25)

        train_generator = train_datagen.flow_from_directory(
            train_data_path,
            target_size=(self.img_height, self.img_width),
            color_mode="rgb",
            batch_size=self.batch_size,
            class_mode='categorical',
            subset='training')

        validation_generator = train_datagen.flow_from_directory(
            train_data_path,
            target_size=(self.img_height, self.img_width),
            color_mode="rgb",
            batch_size=self.batch_size,
            class_mode='categorical',
            subset='validation')

        # create the base pre-trained model
        base_model = ResNet50(weights='imagenet', include_top=False, input_shape=    (self.img_height, self.img_width, 3))

        # add a global spatial average pooling layer
        x = base_model.output
        x = GlobalAveragePooling2D()(x)
        # let's add a fully-connected layer
        x = Dense(1024, activation='relu')(x)
        #x = Dropout(0.3)(x)
        # and a logistic layer -- let's say we have 200 classes
        predictions = Dense(2, activation='softmax')(x)

        # this is the model we will train
        model = Model(inputs=base_model.input, outputs=predictions)

        # first: train only the top layers (which were randomly initialized)
        # i.e. freeze all convolutional InceptionV3 layers
        for layer in base_model.layers:
            layer.trainable = True

        # compile the model (should be done *after* setting layers to non-trainable)
        opt = Adam(self.lr)  # , decay=self.INIT_LR / self.NUM_EPOCHS)
        model.compile(opt, loss='binary_crossentropy', metrics=["accuracy"])

        # train the model on the new data for a few epochs
        from keras.callbacks import ModelCheckpoint, EarlyStopping
        import matplotlib.pyplot as plt

        checkpoint = ModelCheckpoint('resnetModel.h5', monitor='val_accuracy', verbose=1, save_best_only=True,
                                 save_weights_only=False, mode='auto', period=1)

        early = EarlyStopping(monitor='val_accuracy', min_delta=0, patience=16, verbose=1, mode='auto')
        hist = model.fit_generator(steps_per_epoch=self.batch_size, generator=train_generator,
                               validation_data=validation_generator, validation_steps=self.batch_size, epochs=self.epochs,
                               callbacks=[checkpoint, early])

        plt.plot(hist.history['accuracy'])
        plt.plot(hist.history['val_accuracy'])
        plt.plot(hist.history['loss'])
        plt.plot(hist.history['val_loss'])
        plt.title("model accuracy")
        plt.ylabel("Accuracy")
        plt.xlabel("Epoch")
        plt.legend(["Accuracy", "Validation Accuracy", "loss", "Validation Loss"])
        plt.show()

        plt.figure(1)

import tensorflow as tf

if __name__ == '__main__':
    x = ResNet()
    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.compat.v1.Session(config=config)
    x.build_model("C:/Users/but/Desktop/dataScratch/Train")

And this is the training of the model

enter image description here

What could be the reason for resnet to fail but for vgg and inception to work? Do I have any mistakes in my script?


Solution

  • At least for the code, I don't see any mistakes that might affect the training process.

    # and a logistic layer -- let's say we have 200 classes
    predictions = Dense(2, activation='softmax')(x)
    

    Those lines are a bit suspicious. But it seems that the typo is on the comment, so it should be okay.

    # first: train only the top layers (which were randomly initialized)
    # i.e. freeze all convolutional InceptionV3 layers
    for layer in base_model.layers:
        layer.trainable = True
    

    These are suspicious too. If you want to freeze the ResNet-50's layers, what you need to do is

    ...
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(self.img_height, self.img_width, 3))
    for layer in base_model.layers:
        layer.trainable = False
    ...
    

    But it turned out that layer.trainable = True was actually your intention, so it wouldn't matter either.

    First of all, if you are using the same code which you used for training VGG16 and Inception V3, it is unlikely that the code is the problem.

    Why don't you check following susceptible reasons?

    • The model may be too small/big that it underfits/overfits. (Number of Parameters)
    • The model may need more time to converge. (Training for more epochs)
    • The ResNet may not be suited for this classification.
    • Pretrained weights that you used may not be suited for this classification.
    • The learning rate may be too small/big.
    • etc...