Search code examples
tensorflowz-index

tensorflow assert element tf.where


I have 2 matrix with shape:

pob.shape = (2,49,20)  
rob.shape = np.zeros((2,49,20))  

and I want to get the index of pob's elements which has value !=0. So in numpy I can do this:

x,y,z = np.where(pob!=0)

eg:

x = [2,4,7]  
y = [3,5,5]  
z = [3,5,6]

I want to change value of rob:

rob[x1,y1,:] = np.ones((20))

How can i do this with tensorflow objects? I tried to use tf.where but I can't get the index value out of tensor obj


Solution

  • You could use tf.range() and tf.meshgrid() to create index matrices, then use tf.where() with your condition on them to obtain the indices which meet it. However, the tricky part would come next: you can't easily assign values to a tensor based on indices in TF (my_tensor[my_indices] = my_values).

    A workaround for your problem ("for all (i,j,k), if pob[i,j,k] != 0 then rob[i,j] = 1") could be as follows:

    import tensorflow as tf
    
    # Example values for demonstration:
    pob_val = [[[0, 0, 0], [1, 0, 0], [1, 0, 1]], [[1, 1, 1], [0, 0, 0], [0, 0, 0]]]
    pob = tf.constant(pob_val)
    pob_shape = tf.shape(pob)
    rob = tf.zeros(pob_shape)
    
    # Get the mask:
    mask = tf.cast(tf.not_equal(pob, 0), tf.uint8)
    
    # If there's at least one "True" in mask[i, j, :], make all mask[i, j, :] = True:
    mask = tf.cast(tf.reduce_max(mask, axis=-1, keepdims=True), tf.bool)
    mask = tf.tile(mask, [1, 1, pob_shape[-1]])
    
    # Apply mask:
    rob = tf.where(mask, tf.ones(pob_shape), rob)
    
    with tf.Session() as sess:
        rob_eval = sess.run(rob)
        print(rob_eval)
        # [[[0. 0. 0.]
        #   [1. 1. 1.]
        #   [1. 1. 1.]]
        #
        #  [[1. 1. 1.]
        #   [0. 0. 0.]
        #   [0. 0. 0.]]]