Search code examples
pythontensorflowmatrix-indexing

Fancy indexing in tensorflow


I have implemented a 3D CNN with a custom loss function (Ax' - y)^2 where x' is a flattened and cropped vector of the 3D output from the CNN, y is the ground truth and A is a linear operator that takes an x and outputs a y. So I need a way to flatten the 3D output and crop it using fancy indexing before computing the loss.

Here is what I have tried: This is the numpy code I am trying to replicate,

def flatten_crop(img_vol, indices, vol_shape, N):
    """
    :param img_vol: shape (145, 59, 82, N)
    :param indices: shape (396929,)
    """
    nVx, nVy, nVz = vol_shape
    voxels = np.reshape(img_vol, (nVx * nVy * nVz, N), order='F')
    voxels = voxels[indices, :]
    return voxels

I tried using tf.nd_gather to perform the same action but I am unable to generalize it for an arbitrary batch size. Here is my tensorflow code for batch size of 1 (or a single 3D output):

voxels = tf.transpose(tf.reshape(tf.transpose(y_pred), (1, 145 * 59 * 82)))   # to flatten and reshape using Fortran-like index order
voxels = tf.gather_nd(voxels, tf.stack([indices, tf.zeros(len(indices), dtype=tf.dtypes.int32)], axis=1))    # indexing
voxels = tf.reshape(voxels, (voxels.shape[0], 1))

Currently I have this piece of code in my custom loss function and I would like to be able to generalize to an arbitrary batch size. Also if you have an alternate suggestion to implement this (such as a custom layer instead of integrating with the loss function), I am all ears!

Thank you.


Solution

  • Try this code:

    import tensorflow as tf
    y_pred = tf.random.uniform((10, 145, 59, 82))
    indices = tf.random.uniform((396929,), 0, 145*59*82, dtype=tf.int32)
    voxels = tf.reshape(y_pred, (-1, 145 * 59 * 82))   # to flatten and reshape using Fortran-like index order
    voxels = tf.gather(voxels, indices, axis=-1)
    voxels = tf.transpose(voxels)