Search code examples
matrixtensorflowimage-segmentationsampling

How to sample n pixels from different object classes in tensorflow?


Problem

I want to randomly sample n pixels from every instance class in an image.

Lets say my image is given as I with width w and height h. I also have an image with the labels L describing the instance classes of the same shape as I.

Current Approach

My current idea is to first reshape the labels to one large vector of shape (N_p, 1). Then I repeat them N_c times to have shape (N_p, N_c). Now I repeat a vector l consisting of all unique labels with shape (1, N_c) to shape (N_p, N_c). Equaling those two gets me a matrix with a one in column y and row x, where pixel corresponding to row x is of class corresponding to column y.

The next step is to concatenate a matrix with increasing index positions with the previous matrix. Now I can random shuffle that matrix across the rows.

The only missing step is to extract n*N_c rows of that matrix, which first have a one for each of the classes. Then using the indices in the right part of the matrix, I can use

tf.gather_nd

to get the pixels out of the original image I.

Questions

  1. How can I achieve the missing operation in tensorflow? That is: Get the k*n rows, such that they contain each the first n rows having a one in the left part of the matrix for each column of the matrix.

  2. Are these operations efficient?

  3. Is there some simpler method?


Solution

  • Solution

    For anybody interested, here is the solution to my problem with corresponding tensorflow code. I was on the right track, the missing function is

    tf.nn.top_k
    

    Here is some example code to sample k pixels from each of an image's instance classes.

    import tensorflow as tf
    
    seed = 42
    
    width = 10
    height = 6
    embedding_dim = 3
    
    sample_size = 2
    
    image = tf.random_normal([height, width, embedding_dim], mean=0, stddev=4, seed=seed)
    labels = tf.constant([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                          [0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
                          [0, 0, 1, 1, 0, 2, 2, 2, 0, 0],
                          [0, 0, 1, 1, 0, 2, 2, 2, 0, 0],
                          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                          [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=tf.uint8)
    
    
    labels = tf.cast(labels, tf.int32)
    
    # First reshape to one vector
    image_v = tf.reshape(image, [-1, embedding_dim])
    labels_v = tf.reshape(labels, [-1])
    
    # Get classes
    classes, indices = tf.unique(labels_v)
    
    # Dimensions
    N_c = tf.shape(classes)[0]
    N_p = tf.shape(labels_v)[0]
    
    # Helper matrices
    I = tf.tile(tf.expand_dims(indices, [-1]), [1, N_c])
    C = tf.tile(tf.transpose(tf.expand_dims(tf.range(N_c), [-1])), [N_p, 1])
    E = tf.cast(tf.equal(I, C), tf.int32)
    P = tf.expand_dims(tf.range(N_p) + 1, [-1])
    R = tf.concat([E, P], axis=1)
    R_rand = tf.random_shuffle(R, seed = seed)
    E_rand, P_rand = tf.split(R_rand, [N_c, 1], axis = 1)
    M = tf.transpose(E_rand)
    _, topInidices = tf.nn.top_k(M, k = sample_size)
    topInidicesFlat = tf.expand_dims(tf.reshape(topInidices, [-1]), [-1])
    sampleIndices = tf.gather_nd(P_rand, topInidicesFlat)
    samples = tf.gather_nd(image_v, sampleIndices)
    
    sess = tf.Session()
    list = [image,
            labels,
            image_v,
            labels_v,
            classes,
            indices,
            N_c,
            N_p,
            I,
            C,
            E,
            P,
            R,
            R_rand,
            E_rand,
            P_rand,
            M,
            topInidices,
            topInidicesFlat,
            sampleIndices,
            samples
            ]
    list_ = sess.run(list)
    print(list_)