Search code examples
pythontensorflowtensorflow2.0tensorflow-datasetsmnist

Have the same result in two different runs tensorflow


I am trying to run the example in Training a neural network on MNIST with Keras. I want to do it twice giving the weights and no shuffle so I get the SAME result in both runs. Here the full code:

import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import numpy as np

tf.enable_v2_behavior()
tfds.disable_progress_bar()

def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.cast(image, tf.float32) / 255., label

def get_dataset():
    (ds_train, ds_test), ds_info = tfds.load(
        'mnist',
        split=['train', 'test'],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )

    ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_train = ds_train.cache()
    # ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
    ds_train = ds_train.batch(128)
    ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

    ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    ds_test = ds_test.batch(128)
    ds_test = ds_test.cache()
    ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

    return ds_train, ds_test

def keras_fit(ds_train, ds_test, verbose=True, init1='glorot_uniform', init2='glorot_uniform'):
    # https://www.tensorflow.org/datasets/keras_example
    model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
      tf.keras.layers.Dense(128, activation='relu', kernel_initializer=init1),
      tf.keras.layers.Dense(10, activation='softmax', kernel_initializer=init2)
    ])
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=tf.keras.optimizers.Adam(0.001),
        metrics=['accuracy'],
    )
    model.fit(
        ds_train,
        epochs=6,
        validation_data=ds_test,
        verbose=verbose, shuffle=False
    )
    return model.evaluate(ds_test, verbose=verbose)

def test_mnist():
    init = tf.keras.initializers.GlorotUniform()
    init1 = tf.constant_initializer(init((784, 128)).numpy())
    init2 = tf.constant_initializer(init((128, 10)).numpy())
    ds_train, ds_test = get_dataset()
    keras1 = keras_fit(ds_train, ds_test, init1=init1, init2=init2)
    keras2 = keras_fit(ds_train, ds_test, init1=init1, init2=init2)
    print(keras1)
    print(keras2)

if __name__ == "__main__":
    test_mnist()

From my understanding, both layers will be initialized with the exact same values and I do not random shuffle the data (the fit function has shuffle=False). Shouldn't I get the same exact result? What am I doing wrong?

I get pretty similar results and sometimes equal accuracy, but this is not 100% sure.

PS: I am getting the following message after each epoch:

[[{{node IteratorGetNext}}]]
         [[Shape/_6]]
2020-12-20 13:52:10.002034: W tensorflow/core/common_runtime/base_collective_executor.cc:217] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
         [[{{node IteratorGetNext}}]]

I tried basically the same script but with fashion MNIST. And it worked! I get the exact same result.

The difference is that Fashion MNIST uses a numpy array and not a tf.dataset.

The problem must be either a shuffle done by tf.dataset class or the warning message I get, maybe I don't finish to iterate over the full dataset before I stop and start in the next at the same point?


Solution

  • Ok, I got the answer on why it worked some times and not some other times. I came to realize it worked when using CPU and not when using GPU.

    So if you have a GPU this code will not work but can make it work using os.environ['CUDA_VISIBLE_DEVICES'] = '-1' at the start.

    My theory of why this happens is the following: Let's say I have 2 batches, on CPU, I will train on batch 1, update weights and then train on batch 2. On a GPU, I will send both batches to train on parallel and therefore batch 2 will work on other weights (previous to being updated by batch 1).