Search code examples
pythontensorflowmachine-learningkerasgenerator

keras.Model.fit does not work correctly with generator and sparse categorical crossentropy loss


tf.keras.Model.fit(x=generator) does not work correctly with SparseCategoricalCrossentropy/sparce_categorical_crossentropy loss function with a generator as training data. The same symptom reported in Accuracy killed when using ImageDataGenerator TensorFlow Keras.

Please advise if this behaviour is as expected or please point out if code is incorrect.

Code excerpt. Entire code at the bottom.

# --------------------------------------------------------------------------------
# CIFAR 10
# --------------------------------------------------------------------------------
USE_SPARCE_LABEL = True

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_validation, y_train, y_validation = train_test_split(
    x_train, y_train, test_size=0.2, random_state=42
)

# One Hot Encoding the labels when USE_SPARCE_LABEL is False
if not USE_SPARCE_LABEL:
    y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)
    y_validation = keras.utils.to_categorical(y_validation, NUM_CLASSES)
    y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)


# --------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------
model: Model = Model(
    inputs=inputs, outputs=outputs, name="cifar10"
)

# --------------------------------------------------------------------------------
# Compile
# --------------------------------------------------------------------------------
if USE_SPARCE_LABEL:
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)   # <--- cause incorrect behavior
else:
    loss_fn=tf.keras.losses.CategoricalCrossentropy(from_logits=False)

learning_rate = 1e-3
model.compile(
    optimizer=Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
    loss=loss_fn,     # <---- sparse categorical causes the incorrect behavior
    metrics=["accuracy"]
)

# --------------------------------------------------------------------------------
# Train 
# --------------------------------------------------------------------------------
batch_size = 16
number_of_epochs = 10

def data_label_generator(x, y):
    def _f():
        index = 0
        length = len(x)
        try: 
            while True:                
                yield x[index:index+batch_size], y[index:index+batch_size]
                index = (index + batch_size) % length
        except StopIteration:
            return
        
    return _f

earlystop_callback = tf.keras.callbacks.EarlyStopping(
    patience=5,
    restore_best_weights=True,
    monitor='val_accuracy'
)

steps_per_epoch = len(y_train) // batch_size
validation_steps = (len(y_validation) // batch_size) - 1  # To avoid run out of data for validation

history = model.fit(
    x=data_label_generator(x_train, y_train)(),  # <--- Generator
    batch_size=batch_size,
    epochs=number_of_epochs,
    verbose=1,
    validation_data=data_label_generator(x_validation, y_validation)(),
    shuffle=True,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    validation_batch_size=batch_size,
    callbacks=[
        earlystop_callback
    ]
)

Symptom

Using Sparse Index as the labels and SparseCategoricalCrossentropy as the loss function (USE_SPARSE_LABEL=True). The accuracy values got unstable and low, causing early stop.

2500/2500 [...] - 24s 8ms/step - loss: 1.4824 - accuracy: 0.0998 - val_loss: 1.1893 - val_accuracy: 0.1003
Epoch 2/10
2500/2500 [...] - 21s 8ms/step - loss: 1.0730 - accuracy: 0.1010 - val_loss: 0.8896 - val_accuracy: 0.0832
Epoch 3/10
2500/2500 [...] - 20s 8ms/step - loss: 0.9272 - accuracy: 0.1016 - val_loss: 0.9150 - val_accuracy: 0.0720
Epoch 4/10
2500/2500 [...] - 20s 8ms/step - loss: 0.7987 - accuracy: 0.1019 - val_loss: 0.8087 - val_accuracy: 0.0864
Epoch 5/10
2500/2500 [...] - 20s 8ms/step - loss: 0.7081 - accuracy: 0.1012 - val_loss: 0.8707 - val_accuracy: 0.0928
Epoch 6/10
2500/2500 [...] - 21s 8ms/step - loss: 0.6056 - accuracy: 0.1019 - val_loss: 0.7688 - val_accuracy: 0.0851

Using One Hot Encoding as the labels and CategoricalCrossentropy as the loss function (USE_SPARSE_LABEL=True). Work as expected.

2500/2500 [...] - 24s 8ms/step - loss: 1.4146 - accuracy: 0.4997 - val_loss: 1.0906 - val_accuracy: 0.6105
Epoch 2/10
2500/2500 [...] - 21s 9ms/step - loss: 1.0306 - accuracy: 0.6375 - val_loss: 0.9779 - val_accuracy: 0.6532
Epoch 3/10
2500/2500 [...] - 22s 9ms/step - loss: 0.8780 - accuracy: 0.6925 - val_loss: 0.8194 - val_accuracy: 0.7127
Epoch 4/10
2500/2500 [...] - 21s 8ms/step - loss: 0.7641 - accuracy: 0.7315 - val_loss: 0.9330 - val_accuracy: 0.7014
Epoch 5/10
2500/2500 [...] - 21s 8ms/step - loss: 0.6797 - accuracy: 0.7614 - val_loss: 0.7908 - val_accuracy: 0.7311
Epoch 6/10
2500/2500 [...] - 21s 9ms/step - loss: 0.6182 - accuracy: 0.7841 - val_loss: 0.7371 - val_accuracy: 0.7533
Epoch 7/10
2500/2500 [...] - 21s 9ms/step - loss: 0.4981 - accuracy: 0.8217 - val_loss: 0.8221 - val_accuracy: 0.7373
Epoch 8/10
2500/2500 [...] - 22s 9ms/step - loss: 0.4363 - accuracy: 0.8437 - val_loss: 0.7865 - val_accuracy: 0.7525
Epoch 9/10
2500/2500 [...] - 23s 9ms/step - loss: 0.3962 - accuracy: 0.8596 - val_loss: 0.8198 - val_accuracy: 0.7505
Epoch 10/10
2500/2500 [...] - 22s 9ms/step - loss: 0.3463 - accuracy: 0.8776 - val_loss: 0.8472 - val_accuracy: 0.7512

Code

import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras import (
    __version__
)


from keras.layers import (
    Layer,
    Normalization,
    Conv2D,
    MaxPooling2D,
    BatchNormalization,
    Dense,
    Flatten,
    Dropout,
    Reshape,
    Activation,
    ReLU,
    LeakyReLU,
)
from keras.models import (
    Model,
)
from keras.layers import (
    Layer
)
from keras.optimizers import (
    Adam
)
from sklearn.model_selection import train_test_split

print("TensorFlow version: {}".format(tf.__version__))
tf.keras.__version__ = __version__
print("Keras version: {}".format(tf.keras.__version__))

# --------------------------------------------------------------------------------
# CIFAR 10
# --------------------------------------------------------------------------------
NUM_CLASSES = 10
INPUT_SHAPE = (32, 32, 3)
USE_SPARCE_LABEL = False   # Setting False make it work as expected

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_validation, y_train, y_validation = train_test_split(
    x_train, y_train, test_size=0.2, random_state=42
)

# One Hot Encoding the labels
if not USE_SPARCE_LABEL:
    y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)
    y_validation = keras.utils.to_categorical(y_validation, NUM_CLASSES)
    y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)

# --------------------------------------------------------------------------------
# Model
# --------------------------------------------------------------------------------
inputs = tf.keras.Input(
    name='image',
    shape=INPUT_SHAPE,
    dtype=tf.float32
) 

x = Conv2D(                                           
    filters=32, 
    kernel_size=(3, 3), 
    strides=(1, 1), 
    padding="same",
    activation='relu', 
    input_shape=INPUT_SHAPE
)(inputs)
x = BatchNormalization()(x)
x = Conv2D(                                           
    filters=64, 
    kernel_size=(3, 3), 
    strides=(1, 1), 
    padding="same",
    activation='relu'
)(x)
x = MaxPooling2D(                                     
    pool_size=(2, 2)
)(x)
x = Dropout(0.20)(x)

x = Conv2D(                                           
    filters=128, 
    kernel_size=(3, 3), 
    strides=(1, 1), 
    padding="same",
    activation='relu'
)(x)
x = BatchNormalization()(x)
x = MaxPooling2D(                                     
    pool_size=(2, 2)
)(x)
x = Dropout(0.20)(x)

x = Flatten()(x)
x = Dense(300, activation="relu")(x)
x = BatchNormalization()(x)
x = Dropout(0.20)(x)
x = Dense(200, activation="relu")(x)
outputs = Dense(NUM_CLASSES, activation="softmax")(x)

model: Model = Model(
    inputs=inputs, outputs=outputs, name="cifar10"
)

# --------------------------------------------------------------------------------
# Compile
# --------------------------------------------------------------------------------
learning_rate = 1e-3

if USE_SPARCE_LABEL:
    loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
else:
    loss_fn=tf.keras.losses.CategoricalCrossentropy(from_logits=False)

model.compile(
    optimizer=Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
    loss=loss_fn,
    metrics=["accuracy"]
)
model.summary()


# --------------------------------------------------------------------------------
# Train
# --------------------------------------------------------------------------------
batch_size = 16
number_of_epochs = 10

def data_label_generator(x, y):
    def _f():
        index = 0
        length = len(x)
        try: 
            while True:                
                yield x[index:index+batch_size], y[index:index+batch_size]
                index = (index + batch_size) % length
        except StopIteration:
            return
        
    return _f

earlystop_callback = tf.keras.callbacks.EarlyStopping(
    patience=5,
    restore_best_weights=True,
    monitor='val_accuracy'
)

steps_per_epoch = len(y_train) // batch_size
validation_steps = (len(y_validation) // batch_size) - 1  # -1 to avoid run out of data for validation

history = model.fit(
    x=data_label_generator(x_train, y_train)(),
    batch_size=batch_size,
    epochs=number_of_epochs,
    verbose=1,
    validation_data=data_label_generator(x_validation, y_validation)(),
    shuffle=True,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    validation_batch_size=batch_size,
    callbacks=[
        earlystop_callback
    ]
)

Environment

TensorFlow version: 2.14.1
Keras version: 2.14.0
Python 3.10.12
Ubuntu 22.04LTS

Workaround

The answer by innat worked.

model.compile(
    optimizer=Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=1e-08),
    #metrics=["accuracy"]
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')])
model.summary()

Solution

  • The behaviour that is found with metrics=["accuracy"] for using sparse target vectors seems like a potential bug in the API. According to the doc, the string identifier accuracy should be converted to appropriate loss instance.

    When you pass the strings accuracy or acc, we convert this to one of tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.CategoricalAccuracy, tf.keras.metrics.SparseCategoricalAccuracy based on the shapes of the targets and of the model output

    In you case, you need to use keras.metrics.SparseCategoricalAccuracy(name='accuracy') specifically to make it work.