Search code examples
pythontensorflowtensorflow2.0tensor

Selective meshgrid in Tensorflow


Given the following code:

import tensorflow as tf

def combine(x, y):
    xx, yy = tf.meshgrid(x, y, indexing='ij')
    combo = tf.stack([tf.reshape(xx, [-1]), tf.reshape(yy, [-1])], axis=1)
    print(combo)
    
x = tf.constant([11, 0, 7, 1])
combine(x, x)

I want to clean combo vector in order to obtain the following tf vector [(11, 0), (11, 7), (11, 1), (0, 7), (0, 1), (7, 1)]. Is it possible to do this in Tensorflow?


Solution

  • You can introduce a mask, to do get the desired result-

    def combine(x, y):
        xx, yy = tf.meshgrid(x, y, indexing='ij')
    
        #create a mask to take the strictly upper triangular matrix
        ones = tf.ones_like(xx)
        mask = tf.cast(tf.linalg.band_part(ones, 0, -1)  - tf.linalg.band_part(ones, 0, 0) , dtype=tf.bool) 
        x = tf.boolean_mask(xx, mask)
        y = tf.boolean_mask(yy, mask)
    
        combo = tf.stack([x, y], axis=1)
        print(combo)
    
    x = tf.constant([11, 0, 7, 1])
    a = combine(x, x)
    
    #output
    [[11  0]
     [11  7]
     [11  1]
     [ 0  7]
     [ 0  1]
     [ 7  1]],