Search code examples
pythontensorflowkerastensorassign

Assigning values to a 2D tensor using indices in Tensorflow


I have a 2D tensor A, I wish to replace it's non-zero entries with another tensor B as follows.

A = tf.constant([[1.0,0,1.0],[0,1.0,0],[1.0,0,1.0]],dtype=tf.float32)
B = tf.constant([1.0,2.0,3.0,4,0,5.0],dtype=tf.float32)

So I would like to have the final A as

 A = tf.constant([[1.0,0.0,2.0],[0,3.0,0.0],[4.0,0.0,5.0]],dtype=tf.float32)

And I get the indices of non-zero elements of A as follows

where_nonzero = tf.not_equal(A, tf.constant(0, dtype=tf.float32))
indices = tf.where(where_nonzero)

indices = <tf.Tensor: shape=(5, 2), dtype=int64, numpy=
array([[0, 0],
   [0, 2],
   [1, 1],
   [2, 0],
   [2, 2]])>

Can someone please help with this?


Solution

  • IIUC, you should be able to use tf.tensor_scatter_nd_update:

    import tensorflow as tf
    
    A = tf.constant([[1.0,0,1.0],[0,1.0,0],[1.0,0,1.0]],dtype=tf.float32)
    B = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0],dtype=tf.float32)
    
    where_nonzero = tf.not_equal(A, tf.constant(0, dtype=tf.float32))
    indices = tf.where(where_nonzero)
    A = tf.tensor_scatter_nd_update(A, indices, B)
    print(A)
    
    tf.Tensor(
    [[1. 0. 2.]
     [0. 3. 0.]
     [4. 0. 5.]], shape=(3, 3), dtype=float32)