Search code examples
tensorflowindexingneural-networkmri

how to extract 1 slice of a 3d image tensor?


lets say I have a 64x64x64 3D image. I also have a vector x which is length 64.

I want to take the 'argmax(x)' layer like so:

2d_image = 3d_image[:,argmax(x),:]

more precise (for tensorsflow):

def extract_slice(x,3d_image):
     slice_index = tf.math.argmax(x,axis=1,output_type=tf.dtype.int32) #it has to be int for indexing
     return 3d_image[:,slice_index,:]

The error is:

Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got <tf.Tensor 'ArgMax_50:0' shape=(None,) dtype=int32>

The np.shape of the paramters are:

3d_image shape is (None, 64, 64, 64, 1)

x shape is (None, 64)

slice_index shape is (None,)

->the ,1 dimension at 3d_image shape is because that its a sample from array.. I don't think its matter

I know that None shape is the batch size, which is unknown, but the others looks excellent.. so what is the problem?

From what I understand, it looks like the index is not int32, but I actually DID cast it to tf.int so what could be the problem?? maybe int32 is differ from tf.int32? or maybe the index method I used is not valid in tensorflow? maybe it should be a function something like that: tf.index(image,[:,slice_index,:])..?

Thanks!


Solution

  • Argmax returns 1D tensor. Convert it to scalar:

     slice_index = tf.reshape(slice_index, ())