Search code examples
tensorflowrandomsampling

How would I randomly sample pixels in Tensorflow?


Assume I have a tensor of shape (b,h,w,d) with

b: batch_size
h: height of image
w: width of image
d: feature dimension

How would I sample a set of 100 randomly chosen pixels out of it, such that I have a tensor of shape (b, 100, d)?

Also, how would I do it if do not know h, w which is the case at test time?


Solution

  • The generic strategy is to create all necessary pixel coordinates using TF and then apply tf.gather_nd. This works even when the shape is unkown.

    import tensorflow as tf
    import numpy as np
    
    b, h, w, d = 2, 4, 6, 4
    data = np.arange(b * h * w * d).reshape(b, h, w, d)
    # forget shape
    b, h, w, d = None, None, None, None
    # make TF has no idea what the original dimensions were
    data_pldhr = tf.placeholder(tf.float32, [None, None, None, None])
    data_shape = tf.shape(data)
    B_op, H_op, W_op, D_op = [data_shape[i] for i in range(4)]
    
    # add chose same data for each batch (other case is even more trivial)
    REPEAT = 10
    pixel_h = tf.random_uniform([REPEAT], minval=0, maxval=H_op, dtype=tf.int32)
    pixel_w = tf.random_uniform([REPEAT], minval=0, maxval=H_op, dtype=tf.int32)
    pixel_h = tf.expand_dims(pixel_h, axis=0)
    pixel_w = tf.expand_dims(pixel_w, axis=0)
    
    pixel_h = tf.tile(pixel_h, [B_op, 1])
    pixel_w = tf.tile(pixel_w, [B_op, 1])
    
    # add batch-dimension
    b_idx = tf.tile(tf.expand_dims(tf.range(0, B_op), axis=-1), [1, REPEAT])
    
    # combine everything
    pixel_pos = tf.stack([b_idx, pixel_h, pixel_w], axis=-1)
    
    selected_pixels = tf.gather_nd(data_pldhr, pixel_pos)
    
    with tf.Session() as sess:
        ret = sess.run([tf.shape(selected_pixels), selected_pixels, pixel_pos], {data_pldhr: data})
        print ret[0]  # shape of output: [ 2 10  4]
        print ret[1]  # content with the previous shape
        print ret[2]  # selected pixel positions
    

    Note instead of tiling along the batch dimension you can always directly produce different coordinates for each batch entry using

    count_op = tf.mul(B_op, REPEAT)
    tf.random_uniform([count_op], minval=0, maxval=H_op, dtype=tf.int32)