Let say we have two 2D tensors of shapes (n, k
) and (n, k
). We want to concat the two tensors with all row wise permutations such that the resulting tensor is of shape (n, n, 2*k
).
Example,
A = [[a, b], [c, d]]; B = [[e, f], [g, h]]
The resulting tensor should be:
[[[a, b, e, f], [a, b, g, h]], [[c, d, e, f], [c, d, g, h]]]
Assume that the input tensors A and B have non-static shapes so we can not use for loop over tf.shape()
indices value.
Any help is appreciated. Thank you very much.
Use tf.concat
with tf.repeat
and tf.tile
import tensorflow as tf
import numpy as np
# Input
A = tf.convert_to_tensor(np.array([['a', 'b'], ['c', 'd']]))
B = tf.convert_to_tensor(np.array([['e', 'f'], ['g', 'h']]))
# Repeat, tile and concat
C = tf.concat([tf.repeat(A, repeats=tf.shape(A)[-1], axis=0),
tf.tile(B, multiples=[tf.shape(A)[0], 1])],
axis=-1)
# Reshape to requested shape
C = tf.reshape(C, [tf.shape(A)[0], tf.shape(A)[0], -1])
print(C)
>>> <tf.Tensor: shape=(2, 2, 4), dtype=string, numpy=
>>> array([[[b'a', b'b', b'e', b'f'],
>>> [b'a', b'b', b'g', b'h']],
>>> [[b'c', b'd', b'e', b'f'],
>>> [b'c', b'd', b'g', b'h']]], dtype=object)>