Search code examples
pythontensorflowiteratortensorflow-datasetstensorflow2.0

retrieving the next element from tf.data.Dataset in tensorflow 2.0 beta


Before tensorflow 2.0-beta, to retrieve the first element from tf.data.Dataset, we may use a iterator as shown below:

#!/usr/bin/python

import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
iterator = train_dataset.make_one_shot_iterator()
with tf.Session() as sess:
    # 1.0 will be printed.
    print (sess.run(iterator.get_next()))

In tensorflow 2.0-beta, it seems that the above one-shot-iterator is now deprecated. To print out the entire elements we may use the following for approach.

#!/usr/bin/python

import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])

for data in train_dataset:
    # 1.0, 2.0, 3.0, and 4.0 will be printed.
    print (data.numpy())

However, if we only want to retrieve exactly one element from tf.data.Dataset, then how can we do with tensorflow 2.0 beta? It seems that next(train_dataset) is not supported. It could be done easily with the old one shot iterator as shown above, but it's not very obvious with the new for based approach.

Any suggestion is welcomed.


Solution

  • You can .take(1) from the dataset:

    for elem in train_dataset.take(1):
      print (elem.numpy())