I have implemented a 3D CNN with a custom loss function (Ax' - y)^2
where x' is a flattened and cropped vector of the 3D output from the CNN, y is the ground truth and A is a linear operator that takes an x and outputs a y. So I need a way to flatten the 3D output and crop it using fancy indexing before computing the loss.
Here is what I have tried: This is the numpy code I am trying to replicate,
def flatten_crop(img_vol, indices, vol_shape, N):
"""
:param img_vol: shape (145, 59, 82, N)
:param indices: shape (396929,)
"""
nVx, nVy, nVz = vol_shape
voxels = np.reshape(img_vol, (nVx * nVy * nVz, N), order='F')
voxels = voxels[indices, :]
return voxels
I tried using tf.nd_gather
to perform the same action but I am unable to generalize it for an arbitrary batch size. Here is my tensorflow code for batch size of 1 (or a single 3D output):
voxels = tf.transpose(tf.reshape(tf.transpose(y_pred), (1, 145 * 59 * 82))) # to flatten and reshape using Fortran-like index order
voxels = tf.gather_nd(voxels, tf.stack([indices, tf.zeros(len(indices), dtype=tf.dtypes.int32)], axis=1)) # indexing
voxels = tf.reshape(voxels, (voxels.shape[0], 1))
Currently I have this piece of code in my custom loss function and I would like to be able to generalize to an arbitrary batch size. Also if you have an alternate suggestion to implement this (such as a custom layer instead of integrating with the loss function), I am all ears!
Thank you.
Try this code:
import tensorflow as tf
y_pred = tf.random.uniform((10, 145, 59, 82))
indices = tf.random.uniform((396929,), 0, 145*59*82, dtype=tf.int32)
voxels = tf.reshape(y_pred, (-1, 145 * 59 * 82)) # to flatten and reshape using Fortran-like index order
voxels = tf.gather(voxels, indices, axis=-1)
voxels = tf.transpose(voxels)