Search code examples
pythontensorflowtensorflow2.0tensorflow-datasets

How to Save a Tensorflow Dataset


As the title says I'm trying to save a TensorSliceDataset object to file. Viewing tensorflow's website it seems that the tf.data.Dataset class has a save function but it is not implemented for TensorSliceDataset objects. Pickling also did not work for me.

Example code

import tensorflow as tf
t = tf.range(10)
ds = tf.data.Dataset.from_tensor_slices(t)
ds.save()

returns error: AttributeError: 'TensorSliceDataset' object has no attribute 'save'


Solution

  • With Tensorflow 2.10.0, you can use tf.data.Dataset.save:

    import tensorflow as tf
    
    print(tf.__version__)
    # 2.10.0
    
    path = '/content/'
    t = tf.range(10)
    ds = tf.data.Dataset.from_tensor_slices(t)
    
    tf.data.Dataset.save(ds, path)
    new_ds = tf.data.Dataset.load(path)
    

    Otherwise, use tf.data.experimental.save for older versions:

    import tensorflow as tf
    
    path = '/content/'
    t = tf.range(10)
    ds = tf.data.Dataset.from_tensor_slices(t)
    tf.data.experimental.save(ds, path)
    new_ds = tf.data.experimental.load(path)