Search code examples
pythontensorflowmachine-learningneural-networkdigits

why do we need a y variable in keras model.fit()?


I am working with the hand written digits dataset. The data is loaded as follows:

(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()

This is the code for a neural network created to classify the digits:

model = keras.Sequential([
    keras.layers.Dense(10, input_shape=(784,), activation='sigmoid')
])

model.compile(
    optimizer='adam', 
    loss = 'sparse_categorical_crossentropy',
    metrics = ['accuracy']
)
model.fit(X_train_flattened, y_train, epochs=5)

The question is, what is the function of y_train in model.fit(). this appears to be a classification problem and the network just needs the input(x_train_flattened) for it to be trained.


Solution

  • X_train_flattened provides the images as input, y_train (the label telling the model which digit it is 0-9) tells the model what it should aim to predict for each image.

    This is necessary in supervised machine learning (Supervised machine learning tutorial) for the model to learn what classification each image belongs to.

    The loss function (sparse_categorical_crossentropy here) computes how far off the model's predictions are from the true labels (y_train). Without y_train, the model wouldn't have any basis for calculating this error and wouldn’t know how to improve.

    During training, the model uses the error (or loss) calculated from comparing its predictions to y_train to update its parameters through backpropagation. Here is the original paper on backpropagation if it helps (George Hinton backpropagation) .