Search code examples
tensorflow

Mixing multiple tf.data.Dataset?


I have three datasets D1, D2, D3, that are outputting the same type of data. What I'm trying to do is to output randomdly D1 or D2 or D3 from one unique pipeline. I tried to use tf.data.Dataset.zip((D1, D2, D3)) but then I don't know how to flatten its output in order to shuffle it and then have an output like D1_element, D3_element,D1_element , D2_element ... Here is a small example:

import tensorflow as tf

D1 = tf.data.Dataset.range(1,5)
D2 = tf.data.Dataset.range(5,10)
D3 = tf.data.Dataset.range(10,15)

zip = tf.data.Dataset.zip((D1,D2,D2))
...

Solution

  • I found the following solution if there is anybody interested:

    import tensorflow as tf
    
    def stack(*inputs):
        return tf.stack(inputs)
    
    D1 = tf.data.Dataset.range(1,5)
    D2 = tf.data.Dataset.range(5,10)
    D3 = tf.data.Dataset.range(10,15)
    
    D = tf.data.Dataset.zip((D1,D2,D3))
    D = D.map(stack)
    D = D.apply(tf.contrib.data.unbatch())
    D = D.shuffle(10, seed=0)
    D = D.batch(3)
    D = D.prefetch(1)
    
    it = D.make_one_shot_iterator()
    next_element = it.get_next()
    
    with tf.Session() as sess:
        print sess.run(next_element)