Search code examples
pythontensorflowtensorflow-datasets

Give a list of image path pairs to load using tf.data.Dataset.map but it seems to only read 2


I am learning how to use Tensorflow with a piece of old code provided at a workshop a couple of years back (i.e. it should've been tested and works). However it didn't work, and my investigations led me back to step 1 of loading the data and making sure it has loaded properly.

The data reading pipeline is as follows:

  1. The load function
@tf.function
def load(path_pair):
    image_path = path_pair[0]
    masks_path = path_pair[1]

    image_raw = tf.io.read_file(image_path)
    image = tf.io.decode_image(
        image_raw, channels=1, dtype=tf.uint8
    )

    masks_raw = tf.io.read_file(masks_path)
    masks = tf.io.decode_image(
        masks_raw, channels=NUM_CONTOURS, dtype=tf.uint8
    )

    return image / 255, masks / 255```
2. The function used to create the dataset
```def create_datasets(dataset_type):
    path_pairs = get_path_pairs(dataset_type) # this just gives a list of 2 x 2 tuples containing the image/mask path to load
    dataset = tf.data.Dataset.from_tensor_slices(path_pairs)
    dataset = dataset.shuffle(
        len(path_pairs),
        reshuffle_each_iteration=True,
    )
    dataset = dataset.map(load)

    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    return dataset```



When I use the create_datasets function on a dataset that contains 818 data pairs and check the size of the loaded dataset using len(dataset)it tells me there is only 2 items loaded. 

Solution

  • The problem is, that you are batching your dataset, thus when you use len(dataset), you get the number of batches, not the number of elements in your dataset. To get them you can, for instance, iterate over your batches:

    num_samples = 0
    for batch in dataset:
        num_samples += len(batch[0])
    print(num_samples)