Search code examples
pythontensorflowsliceindices

how to get indices of a tensor values based on a column condition in tensorflow


I have a tensor like this:

sim_topics = [[0.65 0.   0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
              [0.   0.51 0.   0.   0.52  0.   0.   0.   0.53 0.42 0.]
              [0.   0.32 0.   0.50 0.34  0.   0.   0.39 0.32 0.52 0.]
              [0.   0.23 0.37 0.   0.    0.37 0.37 0.   0.47 0.39 0.3 ]]

I want to get indices in this tensor based on a tensor condition:

masked_t = [True  False  True  False True True False True False True False]

So the output should be like this:

[[0.65 0. 0.   0.   0.42  0.   0.   0.51 0.   0.34 0.]
 [0.   0. 0.   0.   0.52  0.   0.   0.   0.   0.42 0.]
 [0.   0. 0.   0.   0.34  0.   0.   0.39 0.   0.52 0.]
 [0.   0. 0.37 0.   0.    0.37 0.   0.   0.   0.39 0.]]

So the condition is working on the columns of the initial tensor. Actually I need the indices of the elements which they are True in the maske_t.

So the indices should be:

[[0, 0],
 [1,0],
 [2, 0],
 [3,0],
 [0,2],
 [1,2],
 [2,2],
 [3,2],
 ....]]

Actually this approach works when Im doing row wise, but here I want to select specific columns based on a condition so it raise s incompatibility error:

out = tf.cast(tf.zeros(shape=tf.shape(sim_topics), dtype=tf.float64), tf.float64)
indices = tf.where(tf.where(masked_t, out, sim_topics))

Solution

  • You can directly obtain your required tensor like this:

    result = tf.multiply(sim_topics, tf.cast(masked_t, dtype=tf.float64))
    

    Let the broadcasting do the work for masked_t to be same size as sim_topics