Search code examples
pythontensorflowtensorflow-datasets

Tensorflow 2.6.0: How do I Map One Element into Multiple Elements


I am trying to make a CNN for classifying medical images. These images are massive (~50k by ~30k). As part of my pipeline, I want to break the images into patches that are 256 by 256. I want to do this using a Dataset.map operator, so I can cache the data later for ease of training.

I have found this thread solving the problem in tensorflow 1, but I have not been able to convert it to tensorflow 2.

I apologize for asking this question, but could I get some help converting the code so I can get it working in tensorflow 2? I am a bit of a newbie so the help is appreciated


Solution

  • Feel free to use tf.stack in tf.data.Dataset.map, tf.data.Dataset.unbatch and the official documentation

    import tensorflow as tf
    
    some_image_dataset = tf.random.normal(shape=[10, 1024, 768]) 
    dataset = tf.data.Dataset.from_tensor_slices(some_image_dataset)
    
    def some_patches_map_func(image):
        return tf.stack([
            image[10 : 10 + 256, 20 : 20 + 256], 
            image[100 : 100 + 256, 100 : 100 + 256], 
            image[500 : 500 + 256, 200 : 200 + 256],
        ]) 
    
    dataset = dataset.map(some_patches_map_func)
    dataset = dataset.unbatch().shuffle(10)
    dataset = dataset.batch(2) 
        
    iterator = iter(dataset)
            
    print(next(iterator).shape) # (2, 256, 256)
    print(next(iterator).shape) # (2, 256, 256)
    print(next(iterator).shape) # (2, 256, 256)