Executing the code below sometimes leads to the loss going up during training, and then staying there. Why is that?
import tensorflow as tf
from tensorflow.keras import layers, losses, models
FEATURE_COUNT = 2
TRAINING_SET_SIZE = 128
def patch_nans(t: tf.Tensor) -> tf.Tensor:
""":return t with nans replaced by zeros"""
nan_mask = tf.math.is_nan(t)
return tf.where(nan_mask, tf.zeros_like(t), t)
def check_numerics(t: tf.Tensor) -> tf.Tensor:
"""Throw an exception if t contains nans."""
return tf.debugging.check_numerics(t, "t")
def get_model() -> models.Model:
inp = layers.Input(shape=[FEATURE_COUNT])
mid = layers.Dense(units=64)(inp)
mid = layers.ReLU()(mid)
mid = layers.Dense(units=1)(mid)
mid = layers.Lambda(patch_nans)(mid)
out = layers.Lambda(check_numerics)(mid)
return models.Model(inp, out)
model = get_model()
model.compile(
optimizer=tf.optimizers.SGD(),
loss=losses.mean_squared_error
)
model.summary()
features = tf.random.normal(shape=[TRAINING_SET_SIZE, FEATURE_COUNT])
features_with_nans = tf.maximum(tf.math.log(features + 1), tf.zeros_like(features))
labels = tf.random.normal(shape=[TRAINING_SET_SIZE, 1])
# Evaluate the model before training
model.evaluate(features_with_nans, labels, batch_size=8)
# Evaluate the model while training
model.fit(features_with_nans, labels, batch_size=8, epochs=4)
The model is a simple sequential model with two layers, the loss is MSE, and the training set doesn't have any extreme values (except for NaNs).
Excerpt of a run where the loss goes up:
8/128 [>.............................] - ETA: 0s - loss: 0.4720
128/128 [==============================] - 0s 593us/sample - loss: 1.1050
Train on 128 samples
Epoch 1/4
8/128 [>.............................] - ETA: 3s - loss: 2.3937
128/128 [==============================] - 0s 2ms/sample - loss: 1.1096
Epoch 2/4
8/128 [>.............................] - ETA: 0s - loss: 1.1668
128/128 [==============================] - 0s 141us/sample - loss: 1.1202
Epoch 3/4
8/128 [>.............................] - ETA: 0s - loss: 1.0059
128/128 [==============================] - 0s 141us/sample - loss: 1.1202
Epoch 4/4
8/128 [>.............................] - ETA: 0s - loss: 1.6480
128/128 [==============================] - 0s 156us/sample - loss: 1.1202
Once you have nan
in your model, you will have nan
in the gradients, that's inevitable.
And once you have nan
in the gradients, which are summed up, you will have nan
in all the model's weights.
Once you have nan
in the model's weights, you can't do anything with that model.
Check it for yourself with print(model.get_weights())
after training.
The loss goes up because the model suddenly starts outputting only zeros (because the weights are all nan
), and in the second pass it doesn't change anymore.
Yes, I know this sounds strange, as you replaced the nans before calculating the loss, but some internal behavior in tensorflow will still see these nans - very probably it's still applying a chain rule, it doesn't understand that when there is a zero, it should simply skip all previous layers - it's a computer after all, and zero * nan = nan
.
If you really really want to use the nans (doesn't sound like a good idea, though), you must remove them at the very beginning.
Here is a proposal where you remove the nans at the beginning, then you use the same nan mask to make the final results zero for nans, and also transform the labels to zero where there are nans. This way your loss gets well behaved:
import tensorflow.keras.backend as K
#uses a given nan mask to zero the outputs at specified places
def removeNan(x):
t, nan_mask = x
return tf.where(nan_mask, tf.zeros_like(t), t)
#a changed model that removes the nans at the very beginning
#later this model uses the same nan mask to zero the outputs
def get_model2() -> models.Model:
inp = layers.Input(shape=[FEATURE_COUNT])
#remove the nans before anything!!!! Keep the mask for applying to the outputs
nanMask = layers.Lambda(lambda x: tf.math.is_nan(x))(inp)
mid = layers.Lambda(removeNan)([inp, nanMask])
mid = layers.Dense(units=64)(mid)
mid = layers.ReLU()(mid)
mid = layers.Dense(units=1)(mid)
#apply the mask again, just to have consistent results
out = layers.Lambda(removeNan)([mid, nanMask])
return models.Model(inp, out)
#your features and labels
features = tf.random.normal(shape=[TRAINING_SET_SIZE, FEATURE_COUNT])
features_with_nans = tf.maximum(tf.math.log(features + 1), tf.zeros_like(features))
labels = tf.random.normal(shape=[TRAINING_SET_SIZE, 1])
#remember to make the labels have zero too, so you get a more trustable loss value:
feature_nans = 0*K.sum(features_with_nans, axis=-1, keepdims=True)
labels_with_nans = labels + feature_nans
labels_with_nans = K.switch(tf.math.is_nan(labels_with_nans),
K.zeros_like(labels_with_nans),
labels_with_nans)
#build new model
model = get_model2()
model.compile(
optimizer=tf.optimizers.SGD(),
loss=losses.mean_squared_error
)
model.summary()
#fit and check weights
model.fit(features_with_nans, labels_with_nans, batch_size=10, epochs=5)
print(model.get_weights())
Caution (must check): I read somewhere that with GPU or TPU the nans would be internally replaced with zeros to make it possible to use the hardware.
If this is true, you should definitely be using something else instead of nan
, such as maybe a -10000
value that you use as mask in the method I proposed.