Search code examples
pythontensorflowserializationtensorflow-datasets

How do you save a Tensorflow dataset to a file?


There are at least two more questions like this on SO but not a single one has been answered.

I have a dataset of the form:

<TensorSliceDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

and another of the form:

<BatchDataset shapes: ((None, 512), (None, 512), (None, 512), (None,)), types: (tf.int32, tf.int32, tf.int32, tf.int32)>

I have looked and looked but I can't find the code to save these datasets to files that can be loaded later. The closest I got was this page in the TensorFlow docs, which suggests serializing the tensors using tf.io.serialize_tensor and then writing them to a file using tf.data.experimental.TFRecordWriter.

However, when I tried this using the code:

dataset.map(tf.io.serialize_tensor)
writer = tf.data.experimental.TFRecordWriter('mydata.tfrecord')
writer.write(dataset)

I get an error on the first line:

TypeError: serialize_tensor() takes from 1 to 2 positional arguments but 4 were given

How can I modify the above (or do something else) to accomplish my goal?


Solution

  • TFRecordWriter seems to be the most convenient option, but unfortunately it can only write datasets with a single tensor per element. Here are a couple of workarounds you can use. First, since all your tensors have the same type and similar shape, you can concatenate them all into one, and split them back later on load:

    import tensorflow as tf
    
    # Write
    a = tf.zeros((100, 512), tf.int32)
    ds = tf.data.Dataset.from_tensor_slices((a, a, a, a[:, 0]))
    print(ds)
    # <TensorSliceDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>
    def write_map_fn(x1, x2, x3, x4):
        return tf.io.serialize_tensor(tf.concat([x1, x2, x3, tf.expand_dims(x4, -1)], -1))
    ds = ds.map(write_map_fn)
    writer = tf.data.experimental.TFRecordWriter('mydata.tfrecord')
    writer.write(ds)
    
    # Read
    def read_map_fn(x):
        xp = tf.io.parse_tensor(x, tf.int32)
        # Optionally set shape
        xp.set_shape([1537])  # Do `xp.set_shape([None, 1537])` if using batches
        # Use `x[:, :512], ...` if using batches
        return xp[:512], xp[512:1024], xp[1024:1536], xp[-1]
    ds = tf.data.TFRecordDataset('mydata.tfrecord').map(read_map_fn)
    print(ds)
    # <MapDataset shapes: ((512,), (512,), (512,), ()), types: (tf.int32, tf.int32, tf.int32, tf.int32)>
    

    But, more generally, you can simply have a separate file per tensor and then read them all:

    import tensorflow as tf
    
    # Write
    a = tf.zeros((100, 512), tf.int32)
    ds = tf.data.Dataset.from_tensor_slices((a, a, a, a[:, 0]))
    for i, _ in enumerate(ds.element_spec):
        ds_i = ds.map(lambda *args: args[i]).map(tf.io.serialize_tensor)
        writer = tf.data.experimental.TFRecordWriter(f'mydata.{i}.tfrecord')
        writer.write(ds_i)
    
    # Read
    NUM_PARTS = 4
    parts = []
    def read_map_fn(x):
        return tf.io.parse_tensor(x, tf.int32)
    for i in range(NUM_PARTS):
        parts.append(tf.data.TFRecordDataset(f'mydata.{i}.tfrecord').map(read_map_fn))
    ds = tf.data.Dataset.zip(tuple(parts))
    print(ds)
    # <ZipDataset shapes: (<unknown>, <unknown>, <unknown>, <unknown>), types: (tf.int32, tf.int32, tf.int32, tf.int32)>
    

    It is possible to have the whole dataset in a single file with multiple separate tensors per element, namely as a file of TFRecords containing tf.train.Examples, but I don't know if there is a way to create those within TensorFlow, that is, without having to get the data out of the dataset into Python and then write it to the records file.