Search code examples
pythontensorflowkeras-layertensorflow2.0

Multidimensional Tensor slicing


First things first: I'm relatively new to TensorFlow.

I'm trying to implement a custom layer in tensorflow.keras and I'm having relatively hard time when I try to achieve the following:

  1. I've got 3 Tensors (x,y,z) of shape (?,49,3,3,32) [where ? is the batch size]
  2. On each Tensor I compute the sum over the 3rd and 4th axes [thus I end up with 3 Tensors of shape (?,49,32)]
  3. By doing an argmax (A)on the above 3 Tensors (?,49,32) I get a single (?,49,32) Tensor

Now I want to use this tensor to select slices from the initial x,y,z Tensors in the following form:

  • Each element in the last dimension of A corresponds to the selected Tensor. (aka: 0 = X, 1 = Y, 2 = Z)
  • The index of the last dimension of A corresponds to the slice that I would like to extract from the Tensor last dimension.

I've tried to achieve the above using tf.gather but I had no luck. Then I tried using a series of tf.map_fn, which is ugly and computationally costly.

To simplify the above: let's say we've got an A array of shape (3,3,3,32). Then the numpy equivalent of what I try to achieve is this:

import numpy as np
x = np.random.rand(3,3,32)
y = np.random.rand(3,3,32)
z = np.random.rand(3,3,32)
x_sums = np.sum(np.sum(x,axis=0),0);
y_sums = np.sum(np.sum(y,axis=0),0);
z_sums = np.sum(np.sum(z,axis=0),0);
max_sums = np.argmax([x_sums,y_sums,z_sums],0)
A = np.array([x,y,z])
tmp = []
for i in range(0,len(max_sums)):
    tmp.append(A[max_sums[i],:,:,i) 
output = np.transpose(np.stack(tmp))

Any suggestions? ps: I tried tf.gather_nd but I had no luck


Solution

  • This is how you can do something like that with tf.gather_nd:

    import tensorflow as tf
    
    # Make example data
    tf.random.set_seed(0)
    b = 10  # Batch size
    x = tf.random.uniform((b, 49, 3, 3, 32))
    y = tf.random.uniform((b, 49, 3, 3, 32))
    z = tf.random.uniform((b, 49, 3, 3, 32))
    # Stack tensors together
    data = tf.stack([x, y, z], axis=2)
    # Put reduction axes last
    data_t = tf.transpose(data, (0, 1, 5, 2, 3, 4))
    # Reduce
    s = tf.reduce_sum(data_t, axis=(4, 5))
    # Find largest sums
    idx = tf.argmax(s, 3)
    # Make gather indices
    data_shape = tf.shape(data_t, idx.dtype)
    bb, ii, jj = tf.meshgrid(*(tf.range(data_shape[i]) for i in range(3)), indexing='ij')
    # Gather result
    output_t = tf.gather_nd(data_t, tf.stack([bb, ii, jj, idx], axis=-1))
    # Reorder axes
    output = tf.transpose(output_t, (0, 1, 3, 4, 2))
    print(output.shape)
    # TensorShape([10, 49, 3, 3, 32])