Search code examples
pythontensorflowmachine-learningkerastensorflow-datasets

Validation set with TensorFlow Dataset


From Train and evaluate with Keras:

The argument validation_split (generating a holdout set from the training data) is not supported when training from Dataset objects, since this features requires the ability to index the samples of the datasets, which is not possible in general with the Dataset API.

Is there a workaround? How can I still use a validation set with TF datasets?


Solution

  • No, you can't use use validation_split (as described clearly by documentation), but you can create validation_data instead and create Dataset "manually".

    You can see an example in the same tensorflow tutorial:

    # Prepare the training dataset
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)
    
    # Prepare the validation dataset
    val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
    val_dataset = val_dataset.batch(64)
    
    model.fit(train_dataset, epochs=3, validation_data=val_dataset)
    

    You could create those two datasets from numpy arrays ((x_train, y_train) and (x_val, y_val)) using simple slicing as shown there:

    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    x_val = x_train[-10000:]
    y_val = y_train[-10000:]
    x_train = x_train[:-10000]
    y_train = y_train[:-10000]
    

    There are also other ways to create tf.data.Dataset objects, see tf.data.Dataset documentation and related tutorials/notebooks.