Search code examples
tensorflowkerasdeep-learningneural-networkkeras-layer

Tensorflow, how to implement sorting layer


I'm trying to have a layer in keras that takes a flat tensor x (doesn't have zero value in it and shape = (batch_size, units)) multiplied by a mask (of the same shape), and it will sort it in the way that masked values will be placed first in the output (the order of the elements value doesn't matter). For clarity here is an example (batch_size = 1, units = 8):

enter image description here

It seems simple but the problem is that I can't find a good solution. Any code or idea is appreciated.


Solution

  • My current code is as below, If you know a more efficient way please let me know.

    class Sort(keras.layers.Layer):
      
      def call(self, inputs):
        x = inputs.numpy()
        nonx, nony = x.nonzero()    # idxs of nonzero elements
        zero = [np.where(x == 0)[0][0], np.where(x == 0)[1][0]]   # idx of first zero
        x_shape = tf.shape(inputs)
        result = np.zeros((x_shape[0], x_shape[1], 2), dtype = 'int')   # mapping matrix
    
        result[:, :, 0] += zero[0]
        result[:, :, 1] += zero[1]
        p = np.zeros((x_shape[0]), dtype = 'int')
        for i, j in zip(nonx, nony):
          result[i, p[i]] = [i, j]
          p[i] += 1
    
        y = tf.gather_nd(inputs, result)
        return y