Search code examples
python-3.xtensorflowdatasettensorflow2.0tensorflow2.x

Get length of a dataset in Tensorflow


source_dataset = tf.data.TextLineDataset('primary.csv')
target_dataset = tf.data.TextLineDataset('secondary.csv')
dataset = tf.data.Dataset.zip((source_dataset, target_dataset))
dataset = dataset.shard(10000, 0)
dataset = dataset.map(lambda source, target: (tf.string_to_number(tf.string_split([source], delimiter=',').values, tf.int32),
                                              tf.string_to_number(tf.string_split([target], delimiter=',').values, tf.int32)))
dataset = dataset.map(lambda source, target: (source, tf.concat(([start_token], target), axis=0), tf.concat((target, [end_token]), axis=0)))
dataset = dataset.map(lambda source, target_in, target_out: (source, tf.size(source), target_in, target_out, tf.size(target_in)))

dataset = dataset.shuffle(NUM_SAMPLES)  #This is the important line of code

I would like to shuffle my entire dataset fully, but shuffle() requires a number of samples to pull, and tf.Size() does not work with tf.data.Dataset.

How can I shuffle properly?


Solution

  • I was working with tf.data.FixedLengthRecordDataset() and ran into a similar problem. In my case, I was trying to only take a certain percentage of the raw data. Since I knew all the records have a fixed length, a workaround for me was:

    totalBytes = sum([os.path.getsize(os.path.join(filepath, filename)) for filename in os.listdir(filepath)])
    numRecordsToTake = tf.cast(0.01 * percentage * totalBytes / bytesPerRecord, tf.int64)
    dataset = tf.data.FixedLengthRecordDataset(filenames, recordBytes).take(numRecordsToTake)
    

    In your case, my suggestion would be to count directly in python the number of records in 'primary.csv' and 'secondary.csv'. Alternatively, I think for your purpose, to set the buffer_size argument doesn't really require counting the files. According to the accepted answer about the meaning of buffer_size, a number that's greater than the number of elements in the dataset will ensure a uniform shuffle across the whole dataset. So just putting in a really big number (that you think will surpass the dataset size) should work.