Search code examples
tensorflowmachine-learningkerasdeep-learningloss

Tensorflow 2: Customized Loss Function works differently from the original Keras SparseCategoricalCrossentropy


I just started to work with tensorflow 2.0 and followed the simple example from its official website.

import tensorflow as tf
import tensorflow.keras.layers as layers

mnist = tf.keras.datasets.mnist
(t_x, t_y), (v_x, v_y) = mnist.load_data()

model = tf.keras.Sequential()
model.add(layers.Flatten())
model.add(layers.Dense(128, activation="relu"))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(10))

lossFunc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(optimizer='adam', loss=lossFunc,
              metrics=['accuracy'])
model.fit(t_x, t_y, epochs=5)

The output for the above code is:

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 4s 60us/sample - loss: 2.5368 - accuracy: 0.7455
Epoch 2/5
60000/60000 [==============================] - 3s 51us/sample - loss: 0.5846 - accuracy: 0.8446
Epoch 3/5
60000/60000 [==============================] - 3s 51us/sample - loss: 0.4751 - accuracy: 0.8757
Epoch 4/5
60000/60000 [==============================] - 3s 51us/sample - loss: 0.4112 - accuracy: 0.8915
Epoch 5/5
60000/60000 [==============================] - 3s 51us/sample - loss: 0.3732 - accuracy: 0.9018

However, if I change the lossFunc to the following:

def myfunc(y_true, y_pred):
    return lossFunc(y_true, y_pred)

which just simply wrap the previous function, it performs totally differently. The output is:

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 4s 60us/sample - loss: 2.4444 - accuracy: 0.0889
Epoch 2/5
60000/60000 [==============================] - 3s 51us/sample - loss: 0.5696 - accuracy: 0.0933
Epoch 3/5
60000/60000 [==============================] - 3s 51us/sample - loss: 0.4493 - accuracy: 0.0947
Epoch 4/5
60000/60000 [==============================] - 3s 51us/sample - loss: 0.4046 - accuracy: 0.0947
Epoch 5/5
60000/60000 [==============================] - 3s 51us/sample - loss: 0.3805 - accuracy: 0.0943

The loss values are very similar but the accuracy values are totally different. Anyone knows what is the magic in it, and what is the correct way to write your own loss function?


Solution

  • When you use built-in loss function, you can use 'accuracy' as metric . Under the hood, tensorflow will select appropriate accuracy function (in your case it is tf.keras.metrics.SparseCategoricalAccuracy()).

    When you define custom_loss function, then tensorflow doesn't know which accuracy function to use. In this case, you need to explicitly specify that it is tf.keras.metrics.SparseCategoricalAccuracy(). Please check the gist hub gist here.

    The code modification and the output is as follows

    model2 = tf.keras.Sequential()
    model2.add(layers.Flatten())
    model2.add(layers.Dense(128, activation="relu"))
    model2.add(layers.Dropout(0.2))
    model2.add(layers.Dense(10))
    
    lossFunc = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    
    model2.compile(optimizer='adam', loss=myfunc,
                  metrics=['accuracy',tf.keras.metrics.SparseCategoricalAccuracy()])
    model2.fit(t_x, t_y, epochs=5)
    

    output

    Train on 60000 samples
    Epoch 1/5
    60000/60000 [==============================] - 5s 81us/sample - loss: 2.2295 - accuracy: 0.0917 - sparse_categorical_accuracy: 0.7483
    Epoch 2/5
    60000/60000 [==============================] - 5s 76us/sample - loss: 0.5827 - accuracy: 0.0922 - sparse_categorical_accuracy: 0.8450
    Epoch 3/5
    60000/60000 [==============================] - 5s 76us/sample - loss: 0.4602 - accuracy: 0.0933 - sparse_categorical_accuracy: 0.8760
    Epoch 4/5
    60000/60000 [==============================] - 5s 76us/sample - loss: 0.4197 - accuracy: 0.0946 - sparse_categorical_accuracy: 0.8910
    Epoch 5/5
    60000/60000 [==============================] - 5s 76us/sample - loss: 0.3965 - accuracy: 0.0937 - sparse_categorical_accuracy: 0.8979
    <tensorflow.python.keras.callbacks.History at 0x7f5095286780>
    

    Hope this helps