Search code examples

Change number of classes of MNIST Tensorflow

Hi, Im trying to adapt the beginners tutorial of Tensorflow with MNIST and softmax. In the tutorial you have 10 clases (for digits 0-9). Now, with a different dataset (EMNIST) I have 62 classes for digits and letters. What I have in the model of the orginal example is:

x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b`

Where 784 stands for the total pixels of a 28x28 image and 10 is the number of classes. What I want is:

x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 62]))
b = tf.Variable(tf.zeros([62]))
y = tf.matmul(x, W) + b`

For 62 classes. But when I reach this part of the code, where the next batch is called for execution:

  for _ in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100), feed_dict={x: batch_xs, y_: batch_ys}) 

I get this error... Traceback (most recent call last):

File "", line 77, in <module>, argv=[sys.argv[0]] + unparsed)
  File "C:\Users\Willy Barales\Anaconda3\lib\site-packages\tensorflow\python\platform\", line 48, in run
    _sys.exit(main(_sys.argv[:1] + flags_passthrough))
  File "", line 64, in main, feed_dict={x: batch_xs, y_: batch_ys})
  File "C:\Users\Willy Barales\Anaconda3\lib\site-packages\tensorflow\python\client\", line 789, in run
  File "C:\Users\Willy Barales\Anaconda3\lib\site-packages\tensorflow\python\client\", line 975, in _run
    % (np_val.shape,, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (100, 10) for Tensor 'Placeholder_1:0', which has shape '(?, 62)'

Any ideas on how to change the dataset for this example? Do I have to change something in the file where .next_batch() is implemented?

As far as I know, EMNIST has the exact same format as MNIST. Thanks in advance.

Info on the new dataset:


  • All I have to do was to edit in the file the part where the one hot vectors were created from labels, since those are the ones corresponding to batch_ys, thanks to the enlightment of Neijla.

    def extract_labels(f, one_hot=False, num_classes=62)

    Besides of course, changing the number of the classes in the model as I stated first in my question.