Search code examples
pythontensorflowdelete-rowtensor

How to delete different rows of a 3D tensor


For example, there is an 3D tensor like this:

a = tf.constant([[[1,2,3],
                  [4,5,6],
                  [7,8,9]],
                 [[9,8,7],
                  [6,5,4],
                  [3,2,1]],
                 [[0,8,0],
                  [1,5,4],
                  [3,1,1]]])

I want to delete the different rows from the three elements with indices as:

idx = [[1], 
       [0], 
       [2]]

The result would be like this:

re = [[[1,2,3],
       [7,8,9]],
      [[6,5,4],
       [3,2,1]],
      [[0,8,0],
       [1,5,4]]]

How to do it?


Solution

  • First approach: using tf.one_hot and tf.boolean_mask:

    # shape = (?,1,3)
    mask_idx = 1- tf.one_hot(idx,a.shape[1])
    # shape = (?,3)
    result = tf.boolean_mask(a,mask_idx[:,0,:])
    # shape = (?,2,3)
    result = tf.reshape(result,shape=(-1,a.shape[1]-1,a.shape[2]))
    

    Second approach: using tf.map_fn:

    result = tf.map_fn(lambda x: tf.boolean_mask(x[0],1 - tf.one_hot(tf.squeeze(x[1]),a.shape[1]))
                       , [a,idx]
                       , dtype=tf.int32)
    

    An example:

    import tensorflow as tf
    
    a = tf.constant([[[1,2,3],[4,5,6],[7,8,9]],
                        [[9,8,7],[6,5,4],[3,2,1]],
                        [[0,8,0],[1,5,4],[3,1,1]]],dtype=tf.int32)
    idx = tf.constant([[1],[0],[2]],dtype=tf.int32)
    
    # First approach:
    # shape = (?,1,3)
    mask_idx = 1- tf.one_hot(idx,a.shape[1])
    # shape = (?,3)
    result = tf.boolean_mask(a,mask_idx[:,0,:])
    # shape = (?,2,3)
    result = tf.reshape(result,shape=(-1,a.shape[1]-1,a.shape[2]))
    
    # Second approach:
    result = tf.map_fn(lambda x: tf.boolean_mask(x[0],1 - tf.one_hot(tf.squeeze(x[1]),a.shape[1]))
                       , [a,idx]
                       , dtype=tf.int32)
    
    with tf.Session() as sess:
        print(sess.run(result))
    
    # print
    [[[1 2 3]
      [7 8 9]]
    
     [[6 5 4]
      [3 2 1]]
    
     [[0 8 0]
      [1 5 4]]]