I've been trying to do something pretty simple, but with no success. I have a tensor (say X
of shape (None, 128)
containing some scores, in other words each batch has 128 scores. Now I apply Y = tf.math.top_k(X, k=a).indices
here a
indicates the top a
scores. Let us consider for simplicity, a = 95
. Then the shape of tensor Y
will be (None, 95)
. Till here it is fine.
Now my original data
tensor is of shape (None, 3969, 128)
. I wanted to do some operation on the datas having top_k scores. So I extracted the datas using:
ti = tf.reshape(Y, [Y.shape[-1], -1]) # Here ti is of shape (95, None)
fs = tf.gather(X, ti[:, 0], axis=-1) # Here fs is of shape (None, 3969, 95)
and then did my operation by say Z = fs * 0.7 # Here Z is of shape (None, 3969, 95)
. This was also fine.
Now I want to create a new tensor F
such that, firstly F
is of shape (None, 3969, 128)
, containing all the unchanged datas (datas whose scores do not fall in top_k) and modified datas (datas whose scores falls under top_k and have been modified in Z
) but, the order of these datas will be same as in original datas i.e., modified datas should still be in their original position. Here is where I am stuck.
I am relatively new with TensorFlow, so apologies if I'm missing anything simple or being unclear. Have been stuck with it for a few days now.
One way to do so is to use tf.tensor_scatter_nd_update
. You'll have to convert the indices from the topk function to something that works with your data though. For that, you can use a combination of tf.tile and tf.unravel_index to convert from the X
shape to your data
If we assume 3D data like you have, you could use something similar to this:
# getting the dimensions of the data tensor
# assuming that the shape of X is (B,N) and the shape of data is (B,D,N)
B, D, N = tf.unstack(tf.shape(data))
topk = tf.math.top_k(X, k=k)
# to get the absolute indices in the original tensor, we need to:
# - tile to get indices from (B,N) to (B,D,N)
# - do index arithmetics to get indices on the flattened tensor
topk_idx_tiled = tf.tile(topk.indices[:,None,:], [1,D,1])
flattened_indices = tf.reshape(tf.reshape(tf.range(B*D)*N,(B,D))[...,None] + topk_idx_tiled, -1)
# unraveling to get indices with batch dimensions so that we have compatibility with scatter_nd
sc_idx = tf.transpose(tf.unravel_index(flattened_indices, tf.shape(data)))
# scattering the updates to update the original data
updates = tf.reshape(tf.tile(topk.values[:,None,:],[1,D,1]),-1)*0.7
F = tf.tensor_scatter_nd_update(data, sc_idx, updates)