Search code examples
pythonmachine-learningmnisttensorflow-federated

Why the eminst data is converted from (28*28) to [-1, 784] instead of [0,784] in image classification problem?


This is code snippet from https://www.tensorflow.org/federated/tutorials/federated_learning_for_image_classification

The example is of image classification problem using federated learning. Below function is pre-processing function of emnist data (which is of size 28*28). Can anyone help to understand why the data was reshaped to -1 to 784? as far as I understand, we convert it from two dimensional to one dimensional array because it is easier to process. But I am not sure why -1 was included. Isn't it 0 o 784 would have been enough?

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER=10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], **[-1, 784]**),
        y=tf.reshape(element['label'], **[-1, 1]**))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Solution

  • The -1 here indicates that the size of this dimension should be inferred, and should be considered to be a batch dimension. Since the MNIST data is 28 x 28 pixels, if we have N examples of this data, we will have N x 28 x 28 = N x 784 total pixels here. The -1 here allows this map function to be agnostic to batch size.

    If we were to apply this map function before batching, we would be able to hardcode the -1 as a 1 instead--but this would be an antipattern for writing tf.data.Dataset pipelines generally, see the vectorized mapping section in the guidance on writing performant tf.data.Dataset pipelines.

    We would not be able to use a 0 here, for this would only work if there were exactly 0 examples in the element here; as the equation above indicates, this would hardcode an assumption that there are 0 pixels in element.