Search code examples

How to delete different rows of a 3D tensor

For example, there is an 3D tensor like this:

a = tf.constant([[[1,2,3],

I want to delete the different rows from the three elements with indices as:

idx = [[1], 

The result would be like this:

re = [[[1,2,3],

How to do it?


  • First approach: using tf.one_hot and tf.boolean_mask:

    # shape = (?,1,3)
    mask_idx = 1- tf.one_hot(idx,a.shape[1])
    # shape = (?,3)
    result = tf.boolean_mask(a,mask_idx[:,0,:])
    # shape = (?,2,3)
    result = tf.reshape(result,shape=(-1,a.shape[1]-1,a.shape[2]))

    Second approach: using tf.map_fn:

    result = tf.map_fn(lambda x: tf.boolean_mask(x[0],1 - tf.one_hot(tf.squeeze(x[1]),a.shape[1]))
                       , [a,idx]
                       , dtype=tf.int32)

    An example:

    import tensorflow as tf
    a = tf.constant([[[1,2,3],[4,5,6],[7,8,9]],
    idx = tf.constant([[1],[0],[2]],dtype=tf.int32)
    # First approach:
    # shape = (?,1,3)
    mask_idx = 1- tf.one_hot(idx,a.shape[1])
    # shape = (?,3)
    result = tf.boolean_mask(a,mask_idx[:,0,:])
    # shape = (?,2,3)
    result = tf.reshape(result,shape=(-1,a.shape[1]-1,a.shape[2]))
    # Second approach:
    result = tf.map_fn(lambda x: tf.boolean_mask(x[0],1 - tf.one_hot(tf.squeeze(x[1]),a.shape[1]))
                       , [a,idx]
                       , dtype=tf.int32)
    with tf.Session() as sess:
    # print
    [[[1 2 3]
      [7 8 9]]
     [[6 5 4]
      [3 2 1]]
     [[0 8 0]
      [1 5 4]]]