Search code examples
pythontensorflowtensorflow-datasets

Tensorflow Dataset API: map function producing multiple outputs per one input


TL;DR: My input_fn at one step produces 7 images from 1. I would like to use each of them individually (1x1000x1000x3 7 times) rather than as a single image (7x1000x1000x3), as to be able to shuffle them and mix them among multiple batches.

My data, images, is quite large: 5000x9000x3 features and 5000x9000x1 labels, so I store them as JPEG and feed the compressed image to input_fn, gen_fn uncompresses them and parser_fn outputs 7x1000x1000x3 & 7x1000x1000x1 (7 random crops, as a nested tuple always). Now the thing is, I don't want my input to be those 7 images, but rather to take that 7 as "the batch size". My attempt is as follows:

# Where `gen_fn` outputs JPEG encoded strings (bytes in python)
dataset = tf.dataset.Dataset.from_generator(gen_fn, (tf.string, 
tf.string))
print(dataset) # debug, shown below

# Produces ([7, 1000, 1000, 3], [7, 1000, 1000, 1])
dataset = dataset.map(lambda features, labels: parser_fn)
print(dataset) # debug, shown below

# Attempts to flatten out the 0th dimension
# Effectively produces ([1000, 1000, 3], [1000, 1000, 1])
dataset = dataset.flat_map(lambda x,y: tf.dataset.Dataset.from_tensor_slices((x, y))
print(dataset) # debug, shown below

# Shuffle all of them to avoid biasing the network
# dataset = dataset.shuffle(63) # 9*7

# Prefetch, repeat (does not have any effect, tested)
dataset = dataset.prefetch(1)
print(dataset) # debug, shown below
dataset = dataset.repeat(1)
print(dataset) # debug, shown below

# Batch
dataset = dataset.batch(1)
print(dataset) # debug, shown below

itr = dataset.make_one_shot_iterator()
features, labels = itr.get_next()
return features, labels

Which prints to console

<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.string, tf.string)>
<MapDataset shapes: ((7, 1000, 1000, 3), (7, 1000, 1000, 1)), types: (tf.float32, tf.float32)>
<FlatMapDataset shapes: ((1000, 1000, 3), (1000, 1000, 1)), types: (tf.float32, tf.float32)>
<PrefetchDataset shapes: ((1000, 1000, 3), (1000, 1000, 1)), types: (tf.float32, tf.float32)>
<RepeatDataset shapes: ((1000, 1000, 3), (1000, 1000, 1)), types: (tf.float32, tf.float32)>
<BatchDataset shapes: ((?, 1000, 1000, 3), (?, 1000, 1000, 1)), types: (tf.float32, tf.float32)>

It loads fine, but as soon as it starts training I get TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: <class 'list'>. Completely removing the batch call works fine and outputs 1000x1000x3 "batches".

As suggested in How do I create padded batches in Tensorflow for tf.train.SequenceExample data using the DataSet API? I have tried using padded_batch instead of batch as:

dataset = dataset.padded_batch(self.__batch_size, dataset.output_shapes)

Resulting in

<PaddedBatchDataset shapes: ((?, 1000, 1000, 3), (?, 1000, 1000, 1)), types: (tf.float32, tf.float32)>

But sadly the same result is obtained.

A GitHub issue suggests repeating the initial image multiple times. But that either means uncompressing multiple times the same image (saves memory, it's much slower) or repeating multiple times the full resolution image (each reptition would mean 400MB).

The image is a bottleneck in my architecture, I could preprocess all crops but that would mean losing some potential random crops and data augmentation. Repeating is not an option (time/memory constraints) and I can't get that code to work, any idea what might be wrong?


Solution

  • It turns out this code works perfectly and flattens the resulting 7 random crops from the initial one.

    The error, not shown/included in the above snipped, was rather from my generator, that was yielding a list instead of a tuple. For Tensorflow to correctly understand that the generator is yielding 2 values (features, labels), it must return a tuple, but mine was incorrectly yielding a list, to effectively telling Tensorflow that there was only 1 value.

    The major setback was the error being thrown at runtime, not while building the graph, thus debugging it turned out ot be rather hard and mostly consisted of trial and error.