Search code examples
pythontensorflowkerasmutabletensor

How to insert certain values at certain indices of a tensor in tensorflow?


Let's say I have a tensor input of shape 100x1 and another tensor inplace of shape 20x1 and an index_tensor of shape 100x1. The index_tensor represents places of input where I want to insert the values from inplace. The index_tensor has only 20 True values and rest of its values are False. I try to explain the desired operation below. enter image description here How can this operation be achieved using tensorflow.

assign operation works only for tf.Variable while I want to apply it on the output of tf.nn.rnn.

I read one can use tf.scatter_nd but it requires inplace and index_tensor to be of the same shape.

The reason I want to use this is that I get an output from rnn, then I extract some values from and feed them to some dense layer and this output from dense layer, I want to insert back in the original tensor which I obtained from rnn operation. I do not want to apply dense layer operation on the whole output from rnn due to certain reasons and if I do not insert the result of dense layer back in output of rnn, then the dense layer is kind of useless.

Any suggestion will be really appreciated.


Solution

  • Because the tensor you have is immutable, you can't assign a new value to it nor change it in place. What you have to do is modify its value using standard operations. Below is how you can do it:

    input_array = np.array([2, 4, 7, 11, 3, 8, 9, 19, 11, 7])
    inplace_array = np.array([10, 20])
    indices_array = np.array([0, 0, 1, 0, 0, 0, 1, 0, 0, 0])
    # [[2], [6]] 
    indices = tf.cast(tf.where(tf.equal(indices_array, 1)), tf.int32)
    # [0, 0, 10, 0, 0, 0, 20, 0, 0, 0]
    scatter = tf.scatter_nd(indices, inplace_array, shape=tf.shape(input_array))
    # [1, 1, 0, 1, 1, 1, 0, 1, 1, 1]
    inverse_mask = tf.cast(tf.math.logical_not(indices_array), tf.int32)
    # [2, 4, 0, 11, 3, 8, 0, 19, 11, 7]
    input_array_zero_out = tf.multiply(inverse_mask, input_array)
    # [2, 4, 10, 11, 3, 8, 20, 19, 11, 7]
    output = tf.add(input_array_zero_out, tf.cast(scatter, tf.int32))