Search code examples
tensorflowtile

Tensorflow: How to tile elements of a Tensor with a "multiples" tensor?


Suppose, I have an input tensor as follow:

a = tf.constant ([[0, 0], [1, 1], [2, 2]])

and a "multiple" tensor:

mul= tf.constant([1, 3, 2])

I want to have the result tensor like that:

res =
 [[0, 0],
  [1, 1],
  [1, 1],
  [1, 1],
  [2, 2],
  [2, 2]]

The number of rows (n) in tensor a equals to the no. elements in tensor mul. If n is fixed, I can use the following code:

res = tf.tile ([a[0]], [mul[0], 1])

for i in range (1, 3):
    res = tf.concat ((res, tf.tile ([a[i]], [mul[i], 1])), 0)

But if I don't know n (it varies), how can I get the result?

I really appreciate if you have any ideas!


Solution

  • Your easiest bet might be to use a py_func to turn [1, 3, 2] into something like [0, 1, 1, 1, 2, 2] which can be used as indices to tf.gather.

    If the maximum number of rows in the a tensor is known statically I think you can use tf.dynamic_partition to get this but it might lead to a large graph.

    Alternatively, a tf.while_loop coupled with a tf.TensorArray to store the outputs could work.