Search code examples
imagetensorflowimage-processingimage-preprocessing

Extract random non-overlapping patches from image tensor in Tensorflow


I want to extract 3 random, non-overlapping sub-images of 80 x 80 with Tensorflow. How could I do it? The picture below should give an idea.

Illustration of the problem


Solution

  • I think I found the solution, if you have any suggestion please go ahead.

      @tf.function
      def sample_img(img,frame_dim=(80,80),seed=42,n=3,padding='VALID'):
       if n > (img.shape[0] * img.shape[1]) // (frame_dim[0] * frame_dim[1]):
         padding = 'SAME' 
    
       patches = tf.image.extract_patches(tf.reshape(img,shape=(-1,*img.shape)),
                             [1,*frame_dim,1],
                             [1,*frame_dim,1],
                             [1,1,1,1],padding=padding)
    
       patches_res = tf.reshape(patches,shape=(-1,*frame_dim,img.shape[2]))
    
       ixs = tf.reshape(tf.range(patches_res.shape[0],dtype=tf.int64),shape=(1,-1))
       ixs_sampled = tf.random.uniform_candidate_sampler(ixs,
                                                      patches_res.shape[0],n,
                                               unique=True,range_max=patches_res.shape[0])
    
       ixs_sampled_res = tf.reshape(ixs_sampled.sampled_candidates,shape=(n,1))