Search code examples
pythontensorflowtensorflow-datasetstensorflow2.0

How to print one example of a dataset from tf.data?


I have a dataset in tf.data. How can I easily print (or grab) one element in my dataset?

Similar to:

print(dataset[0])

Solution

  • In TF 1.x you can use the following. There are different iterators provided (some might be deprecated in future versions).

    import tensorflow as tf
    
    d = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])
    diter = d.make_one_shot_iterator()
    e1 = diter.get_next()
    
    with tf.Session() as sess:
      print(sess.run(e1))
    

    Or in TF 2.x

    import tensorflow as tf
    
    d = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4])
    print(next(iter(d)).numpy())
    
    ## You can also use loops as follows to traverse the full set one item at a time
    for elem in d:
        print(elem)