Search code examples
pythontensorflowkerasdeep-learninggradient-exploding

In Keras, using SGD, why model.fit() trains smoothly, but step wise training method gives exploding gradient and loss


Because this exploding gradients and exploding loss happens when the network is huge, so I don't bother post the entire network here. But I've tried my best, the past two weeks, I dig down into the very details of the source code to monitor some weights, hand code the update step to monitor loss, weights, updates, gradients and hyperparameters to compare to the internal status. I think I've done some homework before I ask here.

The question is there are two training methods using Keras API, the is model.fit(), 2nd is more customized one for more complex training and network, but while I kept nearly everything the same to each other, the model.fit() does not have exploding loss, but the custom method gives exploding. Funny enough, when I monitor many details under a much smaller size network, everything seems identical from two methods.

Environment:

# tensorflow 1.14
import tensorflow as tf
from tensorflow.keras import backend as K

For the model.fit() method:

# I skipped the details of the below two lines as I couldn't share the very details. but x is [10000, 32, 32, 3] image data, y is [10000, 10, 1] label. model is regular Keras model.

x_train, y_train, x_test, y_test = get_data()
model = get_keras_model()

loss_fn = tf.keras.losses.CategoricalCrossentropy()
sgd = tf.keras.optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)

model.compile(loss=loss_fn, optimizer=sgd, metrics=['accuracy'])
history = model.fit(x_train, y_train, batch_size=128, epochs=100, validation_data=(x_test, y_test))

Custom Method:

x_train, y_train, x_test, y_test = get_data()
model = get_keras_model()

input = model.inputs[0]
y_true = tf.placeholder(dtype = tf.int32, shape = [None, 10])
y_pred = model.outputs[0]

loss_fn = tf.keras.losses.CategoricalCrossentropy()
loss = loss_fn(y_true, y_pred)
weights = model.trainable_weights
sgd = tf.keras.optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)

training_updates = sgd.get_updates(loss, weights)
training_fn = K.function([y_true, input], [loss], training_updates)

num_train = 10000
steps_per_epoch = int(num_train / 128) # batch size 128
total_steps = steps_per_epoch * 100 # epoch 100

for step in total_steps:
    idx = np.random.randint(0, 10000, 128)
    input_img = x_train[idx]
    ground_true = y_train[idx]

    cur_loss = training_fn([ground_true, input_img])

So in short, same model, same loss function, same optimizer SGD, same image feed (I do control the image feed order altho the code here is random selection from training data). Is there anything in the internal process of model.fit() that would prevent the loss or gradient exploding?


Solution

  • After digging down to the source code, I found the cause of gradient exploding, the correct code (with minimum change as below):

    x_train, y_train, x_test, y_test = get_data()
    model = get_keras_model()
    
    input = model.inputs[0]
    y_true = tf.placeholder(dtype = tf.int32, shape = [None, 10])
    y_pred = model.outputs[0]
    
    loss_fn = tf.keras.losses.CategoricalCrossentropy()
    loss = loss_fn(y_true, y_pred)
    weights = model.trainable_weights
    sgd = tf.keras.optimizers.SGD(lr=.1, momentum=0.9, nesterov=True)
    
    training_updates = sgd.get_updates(loss, weights)
    
    # Correct:
    training_fn = K.function([y_true, input, K.symbolic_learning_phase()], [loss], training_updates)
    
    # Before:
    # training_fn = K.function([y_true, input], [loss], training_updates)
    
    num_train = 10000
    steps_per_epoch = int(num_train / 128) # batch size 128
    total_steps = steps_per_epoch * 100 # epoch 100
    
    for step in total_steps:
        idx = np.random.randint(0, 10000, 128)
        input_img = x_train[idx]
        ground_true = y_train[idx]
    
        # Correct:
        cur_loss = training_fn([ground_true, input_img, True])
    
        # Before:
        # cur_loss = training_fn([ground_true, input_img])
    

    My understanding for this particular Tensor K.symbolic_learning_phase() is that it has default value to be set False (if you check the source code when it's initialized), and BatchNormalization and Dropout layers etc. behaves differently under training phase and testing phase. And in this case, BatchNormalization layer is the cause for gradient exploding (now it makes sense some posts mentioned they got gradient exploding with BatchNormalization layer) This is because two of its trainable weights batch_normalization_1/gamma:0 and batch_normalization_1/beta:0 depends on this Tensor and with default value False they are not learning and their weights become nan pretty darn quickly during training.

    I notice not many Keras code using this training_updates methods really put K.symbolic_learning_phase() in their code, however, this is Keras' API does under the hood.