Search code examples
pythontensorflowkeras

Callback not working in TensorFlow to stop training


I have written a call back which stops training when accuracy becomes 99%. But the problem is I get this error. Sometimes if I resolve this error the call back not get called even though accuracy becomes 100%. The error:

'>' not supported between instances of 'NoneType' and 'float'
class myCallback(tf.keras.callbacks.Callback):
        
        def on_epoch_end(self, epoch, logs={}):
            
            if(logs.get('accuracy') > 0.99):
                
                
               
               self.model.stop_training = True


def train_mnist():
    # Please write your code only where you are indicated.
    # please do not remove # model fitting inline comments.

    # YOUR CODE SHOULD START HERE

    # YOUR CODE SHOULD END HERE
    call = myCallback()
    mnist = tf.keras.datasets.mnist

    (x_train, y_train),(x_test, y_test) = mnist.load_data(path=path)
    # YOUR CODE SHOULD START
    x_train = x_train/255
    y_train = y_train/255
    # YOUR CODE SHOULD END HERE
    model = tf.keras.models.Sequential([
        # YOUR CODE SHOULD START HERE
          keras.layers.Flatten(input_shape=(28,28)),
          keras.layers.Dense(128,activation='relu'),
          keras.layers.Dense(10,activation='softmax')
        # YOUR CODE SHOULD END HERE
    ])

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
    # model fitting
    history = model.fit(# YOUR CODE SHOULD START HERE
          x_train,y_train,epochs=9,callbacks=[call] )
    # model fitting
    return history.epoch, history.history['acc'][-1]

Solution

  • Two major problems with the above code:

    • Getting to 100% accuracy on training set almost always means that your model is overfitting. Thats BAD. What you want to do instead is specify the validation_split=.2 parameter in the .fit method, and look for a high accuracy on the validation set.
    • What you are trying to build in your custom callback is already done in keras.callbacks.EarlyStopping, it even has an option to restore to the best overall model over each epoch. And, by default, it is looking for a validation accuracy, not training accuracy, if you have a validation split.

    So, here's what you should do: Stop using custom callbacks, they take some mastery to get to work. Use EarlyStopping with restore_best instead. like this Always use validation_split and look for high accuracy in validation set. Like in this quick example.


    Did using built-in callbacks resolve your problem?