Search code examples
tensorflowkerasneural-networkclassificationmnist

Simple classification neural network outputting choosing same class for all


I am learning neural networks and am creating a simple classification neural network. I am inexperienced so I apologize if it is a dumb mistake. In the code below, I import my dataset, format it into a one-hot vector, and then use the simple network on Tensorflow's tutorial. I am using categorical cross entropy because my output is a rating, and if I am not mistaken, categorical cross entropy punishes less for numbers that are close to the correct classification. Anyways, my accuracy is always 2-12%, which is obviously no good. Classifications are between 1-20 (for ratings of 0.5-10 in 0.5 increments) When I test my model out on my test_data, it seems to choose a number and classify all images as the same number/category. Funny enough, instead of giving different probabilities, it gives back a one-hot vector with the model being 100% confident that every test image is the same class. My dataset is very small, I know, but I don't think even bad data is supposed to classify all as the same and at 100% confidence. The code:

from __future__ import absolute_import, division, print_function

import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = np.load("dataset.npy", allow_pickle=True)

train_labels = list(map(float, train_labels))
test_labels = list(map(float, test_labels))
train_labels = [int(i * 2) for i in train_labels]
test_labels = [int(i * 2) for i in test_labels]

train_zeros = np.zeros((307, 20))
test_zeros = np.zeros((103, 20))

for i in range(len(train_zeros)):
    train_zeros[i][train_labels[i] - 1] = 1
for i in range(len(test_zeros)):
    test_zeros[i][test_labels[i] - 1] = 1

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(128, 128)),
    keras.layers.Dense(512, activation=tf.nn.relu),
    keras.layers.Dense(20, activation=tf.nn.softmax)
])

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images, train_zeros, epochs=10)

predictions = model.predict(test_images)

print(predictions[0])

def plot_image(i, predictions_array, true_label, img):
    predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])

    plt.imshow(img, cmap=plt.cm.binary)

    predicted_label = np.argmax(predictions_array) / 2
    if predicted_label == true_label:
        color = 'blue'
    else:
        color = 'red'

    plt.xlabel("{} {:2.0f}% ({})".format(predicted_label,
                                         100 * np.max(predictions_array),
                                         true_label),
               color=color)


def plot_value_array(i, predictions_array, true_label):
    predictions_array, true_label = predictions_array[i], true_label[i]
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    thisplot = plt.bar(range(20), predictions_array, color="#777777")
    plt.ylim([0, 1])
    predicted_label = np.argmax(predictions_array)

    thisplot[predicted_label].set_color('red')
    thisplot[true_label].set_color('blue')

num_rows = 5
num_cols = 3
num_images = num_rows*num_cols
plt.figure(figsize=(2*2*num_cols, 2*num_rows))
for i in range(num_images):
  plt.subplot(num_rows, 2*num_cols, 2*i+1)
  plot_image(i, predictions, test_labels, test_images)
  plt.subplot(num_rows, 2*num_cols, 2*i+2)
  plot_value_array(i, predictions, test_labels)
plt.show()

Solution

  • You should definitely treat this as a regression problem rather than a classification problem.

    if I am not mistaken, categorical cross entropy punishes less for numbers that are close to the correct classification

    I am afraid this is not correct. Your model & loss will treat a mislabelling between 4 and 4.5 in exactly the same way as it would between 0.5 and 20. This is obviously incorrect.

    I'd strongly recommend you consider this a regression problem and switch to something like mean squared error for a loss function. CHeck out this tutorial for a complete worked example.