Search code examples
tensorflowduplicatesuniquetensorindices

How does tensorflow get indices of unique value in Tensorflow Tensor?


Suppose I have one input 1D tensor, I want to get indices for unique elements in 1D tensor.

input 1D tensor

[ 1  3  0  0  0  3  5  6  8  9 12  2  5  7  0 11  6  7  0  0]

expected output

Values:  [1, 3, 0, 5, 6, 8, 9, 12,  2,  7, 11]
indices: [0, 1, 2, 6, 7, 8, 9, 10, 11, 13, 15]

Here is my strategy now.

input = [ 1,  3,  0,  0,  0,  3,  5,  6,  8,  9, 12,  2,  5,  7,  0, 11,  6,  7,  0,  0,]
unique_value_in_input, _ = tf.unique(input) # [1 3 0 5 6 8 9 12 2 7 11]
number_of_unique_value = tf.shape(unique_value_in_input)[0] #11
y = tf.reshape(y, (number_of_unique_value, 1)) #[[1], [3], [0], [5], [6], [8], [9], ..]

input_matrix = tf.tile(input, [number_of_unique_value]) # repeat the tensor for tf.equal()
input_matrix = tf.reshape(input, [number_of_unique_value,-1]) 

cols = tf.where(tf.equal(input_matrix, y))[:,-1] #[[ 0  0] [ 1  1] [ 1  5] [ 2  6] [ 2 12] ...]

Since I will have repeat value in tf.where() step, which means I have duplicated True in result. Is there any function I can use in this issue?


Solution

  • You should be able to do the following and get the desired output. We do the following. For each value in unique values, you get a boolean tensor and get the maximum index (i.e only the first maximum index) through tf.argmax.

    import tensorflow as tf
    
    input = tf.constant([ 1,  3,  0,  0,  0,  3,  5,  6,  8,  9, 12,  2,  5,  7,  0, 11,  6,  7,  0,  0,], tf.int64)
    
    unique_vals, _ = tf.unique(input) 
    res = tf.map_fn(
        lambda x: tf.argmax(tf.cast(tf.equal(input, x), tf.int64)), 
        unique_vals)
    
    with tf.Session() as sess:
      print(sess.run(res))