Search code examples
tensorflowtensorflow-datasetstensorflow-estimatortensorflow1.15

tf.data, construct a batch with different data?


I want to construct a batch of data with batchsize 16, using tf.data, where [:8] is a kind of data A, [8:16] is a kind of data B.

It is easy to do without tf.data. If use tf.data, the code could be:

def _decode_record(record, name_to_features):
    example = tf.parse_single_example(record, name_to_features)
    return example

dataA = tf.data.TFRecordDataset(input_files)
dataA = dataA.apply(
            tf.contrib.data.map_and_batch(
                lambda record: _decode_record(record, name_to_features),
                batch_size=batch_size)
           )

How to do it next? I try:

dataB = tf.data.TFRecordDataset(input_files2)
dataB = dataB.apply(
            tf.contrib.data.map_and_batch(
                lambda record: _decode_record(record, name_to_features),
                batch_size=batch_size)
           )
dataC = dataA.concatenate(dataB)

But concatenate is: Append the whole dataset dataB to the end of dataA.

For concatenate, note that name_to_features should be same for dataA and dataB, which means I should pad a lot dummy data.

I don't want to use tf.cond or tf.where to judge different data inside the model_fn of tf.estimator, where it is also very hard to debug.


Solution

  • You can zip the datasets together, then construct batches from the (dataA, dataB) pairs:

    import tensorflow as tf
    
    dataset_1 = tf.data.Dataset.from_tensors(1).repeat(100)
    dataset_2 = tf.data.Dataset.from_tensors(2).repeat(100)
    
    dataset = tf.data.Dataset.zip((dataset_1, dataset_2))
    dataset = dataset.batch(8)
    dataset = dataset.map(lambda a, b: tf.concat([a, b], 0))
    

    Produces

    tf.Tensor([1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2], shape=(16,), dtype=int32)
    tf.Tensor([1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2], shape=(16,), dtype=int32)
    tf.Tensor([1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2], shape=(16,), dtype=int32)
    tf.Tensor([1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2], shape=(16,), dtype=int32)
    ...