Search code examples
tensorflowobject-detection

tensorflow remember the index after calculating getting the maximum box


Assume that I have two arrays of boxes, each of which has the shape (?, b1, 4) and (?, b2, 4) respectively (treat ? as a unknown batch size):

box1: [[[1,2,3,4], [2,3,4,5], [3,4,5,6]...]...]
box2: [[[4,3,2,1], [3,2,5,4], [4,3,5,6]...]...]

(the number above are set arbitarily)

I want to:

  1. in each batch, for each box A in box1, find in box2 the box B which has the maximum IOU (intersection over union) with A (in the same batch, of course), and then append the tuple (A, B) to a list list_max.

  2. append to list_nonmax all the boxes in box2 that does not have maximum IOU with any box in box1 (separated by batch, of course)

You can assume that:

  1. b1 and b2 are both python variables, not tensorflow tensor.

  2. methods for calculating IOU between single box or between batch of boxes already exists and can be used literally:

    iou_single_box(box1, box2) : both box1 and box2 are of shape (4,).

    iou_multiple_boxes(bbox1, bbox2) : both bbox1 and bbox2 are of shape (b1, 4) and (b2, 4) respectively.

    iou_batch_boxes(bbbox1, bbbox2) : both bbbox1 and bbbox2 are of shape (?, b1, 4) and (?, b2, 4) respectively (treat ? as a unknown batch size).

I found these particularly hard in tensorflow, especially for the list_nonmax case, because, whereas it is easy to use padding and then tf.reduce_max() to get box tuples with maximum iou, it is impossible to remember their index and then extract out boxes for list_nonmax.


Solution

  • You need tf.nn.top_k() for this. It returns both the maximum value and the index it's at at the last dimension.

    val, idx = tf.nn.top_k( iou_batch_boxes( bbbox1, bbbox2 ), k = 1 )
    

    will give you the box2 index with the max iou for each box1 and batch.

    To get your list_max you need to tf.stack() box1 with box2's entries by idx with tf.gather_nd() along axis 1. Here's a working code with a dummy iou function:

    import tensorflow as tf
    
    box1 = tf.reshape( tf.constant( range( 16 ), dtype = tf.float32 ), ( 2, 2, 4 ) )
    box2 = tf.reshape( tf.constant( range( 2, 26 ), dtype = tf.float32 ), ( 2, 3, 4 ) )
    batch_size = box1.get_shape().as_list()[ 0 ]
    
    def dummy_iou_batch_boxes( box1, box2 ):
        b1s, b2s = box1.get_shape().as_list(), box2.get_shape().as_list()
        return tf.constant( [ [ [9.0,8,7], [1,2,3],
                                [0  ,1,2], [0,5,0] ] ] )
    
    iou = dummy_iou_batch_boxes( box1, box2 )
    val, idx = tf.nn.top_k( iou, k = 1 )
    idx = tf.reshape( idx, ( batch_size, box1.get_shape().as_list()[ 1 ] ) )
    one_hot_idx = tf.one_hot( idx, depth = box2.get_shape().as_list()[ 1 ] )
    full_idx = tf.where( tf.equal( 1.0, one_hot_idx ) )
    box1_idx = full_idx[ :, 0 : 2 ]
    box2_idx = full_idx[ :, 0 : 3 : 2 ]
    box12 = tf.gather_nd( box1, box1_idx )
    box22 = tf.gather_nd( box2, box2_idx )
    list_max = tf.stack( [ box12, box22 ], axis = 1 )
    
    with tf.Session() as sess:
        res = sess.run( [ list_max ] )
        for v in res:
            print( v )
            print( "-----------------------------")
    

    will output:

    [[[ 0. 1. 2. 3.]
    [ 2. 3. 4. 5.]]

    [[ 4. 5. 6. 7.]
    [10. 11. 12. 13.]]

    [[ 8. 9. 10. 11.]
    [22. 23. 24. 25.]]

    [[12. 13. 14. 15.]
    [18. 19. 20. 21.]]]

    If you want this as a list or tuple, you can use tf.unstack() on the above list_max.

    To get list_nonmax the thing you need is merging the indices into a mask, which I believe I've already answered in another one of your questions, but the important part is:

    mask = tf.reduce_max( tf.one_hot( idx, depth = num_bbbox2 ), axis = -2 )
    

    This will give you a mask with shape ( batch, num_box2 ) telling you for each batch and each of box2 if that box2 is the max iou one for any box1.

    From here, you can either use the mask or get the list of indices with tf.where() like this:

    was_never_max_idx = tf.where( tf.equal( 0, mask ) )