Search code examples
pythontensorflowconcatenationtensor

Tensorflow1 concat 2D tensors with all row wise permutations


Let say we have two 2D tensors of shapes (n, k) and (n, k). We want to concat the two tensors with all row wise permutations such that the resulting tensor is of shape (n, n, 2*k).

Example,

A = [[a, b], [c, d]]; B = [[e, f], [g, h]]

The resulting tensor should be:

[[[a, b, e, f], [a, b, g, h]], [[c, d, e, f], [c, d, g, h]]]

Assume that the input tensors A and B have non-static shapes so we can not use for loop over tf.shape() indices value.

Any help is appreciated. Thank you very much.


Solution

  • Use tf.concat with tf.repeat and tf.tile

    import tensorflow as tf
    import numpy as np
    
    # Input
    A = tf.convert_to_tensor(np.array([['a', 'b'], ['c', 'd']]))
    B = tf.convert_to_tensor(np.array([['e', 'f'], ['g', 'h']]))
    
    # Repeat, tile and concat
    C = tf.concat([tf.repeat(A, repeats=tf.shape(A)[-1], axis=0), 
                   tf.tile(B, multiples=[tf.shape(A)[0], 1])], 
                  axis=-1)
    
    # Reshape to requested shape
    C = tf.reshape(C, [tf.shape(A)[0], tf.shape(A)[0], -1])
    
    print(C)
    
    >>> <tf.Tensor: shape=(2, 2, 4), dtype=string, numpy=
    >>> array([[[b'a', b'b', b'e', b'f'],
    >>>         [b'a', b'b', b'g', b'h']],
    >>>        [[b'c', b'd', b'e', b'f'],
    >>>         [b'c', b'd', b'g', b'h']]], dtype=object)>