Search code examples
python-3.xtensorflowtensorflow-datasets

How to save ParallelMapDataset?


I have an input dataset (let's name it ds), a function that passes in to encoder (model named embedder). I want to make a dataset of encodings and save it to file. What I tried to do:

Converter function:

def generate_embedding(image, label, embedder):
  return (embedder(image)[0], label)

Converting:

embedding_ds = ds.map(lambda image, label: generate_embedding(image, label, embedder), num_parallel_calls=tf.data.AUTOTUNE)

Saving:

embedding_ds.save(path)

But I have a problem with embedding_ds, it's not tf.data.Dataset (which I expected), but tf.raw_ops.ParallelMapDataset, which don't have save method. Can anybody give an advice?


Looks like this problem is present on my tensorflow version (2.9.2) and not present on 2.11


Solution

  • Maybe update? In 2.11.0, it works:

    import tensorflow as tf
    
    ds = tf.data.Dataset.range(5)
    
    tf.__version__ # 2.11.0
    
    ds = ds.map(lambda e : (e + 3) % 5, num_parallel_calls=3)
    
    ds.save('test') # works