Search code examples
tensorflowkerasconv-neural-networkmnistsparsecategoricalcrossentropy

How does tensorflow.keras.model.fit() implicitly know to associate an int value with a probability distribution?


I'm learning tensorflow and following along with the MIT intro to deep learning 2023 course, and I encountered something in the second lab that seems too convenient to my mind. The lab is about building a convolutional NN to recognize handwritten digits, off of the MNIST database. The training labels are a 1-d tensor of integer values, all the values are 0-9. But the model itself outputs a tensor of the probability of a given digit being correct.

The model is defined as:

    def build_cnn_model():
        cnn_model = tf.keras.Sequential([
            #Use parameters as shown in the diagram
            #First two params of Conv2D are filter shape
            tf.keras.layers.Conv2D(3,3, input_shape=(28, 28, 1), activation='relu'),
        
            tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(1,1)),
        
            tf.keras.layers.Conv2D(3,3, input_shape=(26, 26, 24)),
            tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=(1,1)),
        
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(128, activation='relu'),
        
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        return cnn_model

    cnn = build_cnn_model()

and then compiled with:

    cnn.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=1e-1),
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

and then I call fit as:

    cnn.fit(train_images, train_labels, batch_size=BATCH_SIZE, epochs=EPOCHS)

My thinking is that this is why we're using sparse_categorical_crossentropy as the loss function, since it doesn't use 1-hot encoding, and I saw something in the documentation (that I didn't fully understand) about how it assumes a probability range. However this seems like an incomplete explanation. What if I'd defined my model with an output layer of 20 values instead of 10? or if my vocabulary of training labels had more than 10 unique values? It all seems a little to convenient and I think I'm missing something. Where exactly does the conversion occur between a range of probabilities and an integer value?

Thank you in advance!


Solution

  • Nothing needs to be converted, actually. For one-hot targets, the cross-entropy reduced to -log(p), where p is the output probability of the correct class.

    If there are more outputs than classes (e.g. 20 output units but labels only go from 0-9), the model will simply learn to assign a very low probability to the "extra" class outputs, since they are never correct. If there are more labels than outputs (e.g. 10 output units but labels go from 0-19), the program will generally either crash, or the loss will end up as NaN, depending on the implementation.