Search code examples
tensorflowpytorch

Is there an equivalent function of pytorch named "index_select" in tensorflow


I tried to translate pytorch code to tensorflow. So I wanna know is there an equivalent function of pytorch named index_select in tensorflow


Solution

  • I haven't found a similar api can directly achieve it, but we can use tf.slice to implement it.

    
    def tf_index_select(input_, dim, indices):
        """
        input_(tensor): input tensor
        dim(int): dimension
        indices(list): selected indices list
        """
        shape = input_.get_shape().as_list()
        if dim == -1:
            dim = len(shape)-1
        shape[dim] = 1
        
        tmp = []
        for idx in indices:
            begin = [0]*len(shape)
            begin[dim] = idx
            tmp.append(tf.slice(input_, begin, shape))
        res = tf.concat(tmp, axis=dim)
        
        return res
    

    Here is an example to show the equivalence.

    
    import tensorflow as tf
    import torch
    import numpy as np
    
    a = np.arange(2*3*4).reshape(2,3,4)
    dim = 1
    indices = [0,2]
    # array([[[ 0,  1,  2,  3],
    #         [ 4,  5,  6,  7],
    #         [ 8,  9, 10, 11]],
    
    #        [[12, 13, 14, 15],
    #         [16, 17, 18, 19],
    #         [20, 21, 22, 23]]])
    
    # pytorch
    res = torch.tensor(a).index_select(dim, torch.tensor(indices))
    # tensor([[[ 0,  1,  2,  3],
    #          [ 8,  9, 10, 11]],
    
    #         [[12, 13, 14, 15],
    #          [20, 21, 22, 23]]])
    
    # tensorflow
    res = tf_index_select(tf.constant(a), dim, indices)
    # tensor([[[ 0,  1,  2,  3],
    #          [ 8,  9, 10, 11]],
    
    #         [[12, 13, 14, 15],
    #          [20, 21, 22, 23]]])