I'm trying to have a layer in keras that takes a flat tensor x
(doesn't have zero value in it and shape = (batch_size, units)) multiplied by a mask
(of the same shape), and it will sort it in the way that masked values will be placed first in the output (the order of the elements value doesn't matter). For clarity here is an example (batch_size = 1, units = 8):
It seems simple but the problem is that I can't find a good solution. Any code or idea is appreciated.
My current code is as below, If you know a more efficient way please let me know.
class Sort(keras.layers.Layer):
def call(self, inputs):
x = inputs.numpy()
nonx, nony = x.nonzero() # idxs of nonzero elements
zero = [np.where(x == 0)[0][0], np.where(x == 0)[1][0]] # idx of first zero
x_shape = tf.shape(inputs)
result = np.zeros((x_shape[0], x_shape[1], 2), dtype = 'int') # mapping matrix
result[:, :, 0] += zero[0]
result[:, :, 1] += zero[1]
p = np.zeros((x_shape[0]), dtype = 'int')
for i, j in zip(nonx, nony):
result[i, p[i]] = [i, j]
p[i] += 1
y = tf.gather_nd(inputs, result)
return y