Search code examples
python-3.xtensorflowkerastensorflow2.0

3D cropping inside a TensorFlow/Keras model


I'm trying to create a model that performs 3D cropping on an input tensor based on given bounding box coordinates, however, I keep receiving the following error message:

TypeError: Exception encountered when calling layer "tf.__operators__.getitem_6" (type SlicingOpLambda).

Only integers, slices (`:`), ellipsis (`...`), tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid indices, got <tf.Tensor 'Shape:0' shape=(6,) dtype=int32>

Call arguments received by layer "tf.__operators__.getitem_6" (type SlicingOpLambda):
  • tensor=tf.Tensor(shape=(None, 100, 100, 100, 3), dtype=float32)
  • slice_spec=({'start': 'None', 'stop': 'None', 'step': 'None'}, {'start': 'tf.Tensor(shape=(6,), dtype=int32)', 'stop': 'tf.Tensor(shape=(6,), dtype=int32)', 'step': 'None'}, {'start': 'tf.Tensor(shape=(6,), dtype=int32)', 'stop': 'tf.Tensor(shape=(6,), dtype=int32)', 'step': 'None'}, {'start': 'tf.Tensor(shape=(6,), dtype=int32)', 'stop': 'tf.Tensor(shape=(6,), dtype=int32)', 'step': 'None'}, {'start': 'None', 'stop': 'None', 'step': 'None'})
  • var=None

Here's a simplified version of my code:

import tensorflow as tf

def crop_3d(x: tf.Tensor, bounding_box: tf.Tensor):
    bounding_box = tf.cast(bounding_box, dtype=tf.int32)
    z_offset = bounding_box[0]
    y_offset = bounding_box[1]
    x_offset = bounding_box[2]
    z_size = bounding_box[3]
    y_size = bounding_box[4]
    x_size = bounding_box[5]
    cropped = x[:, z_offset:z_offset + z_size, y_offset:y_offset + y_size, x_offset:x_offset + x_size, :]
    return cropped

input_image = tf.keras.layers.Input(shape=(100, 100, 100, 3), name="inputs")
bounding_box = tf.keras.layers.Input(shape=(6,), name="bounding_box", dtype=tf.int32)
x = crop_3d(input_image, bounding_box)
model = tf.keras.Model(inputs=[input_image, bounding_box], outputs=x)

If it makes any difference, the extracted bounding box's size is constant, I only want specify the offset values for each batch item.


Solution

  • The problem was that the batch size was not set in the Input layers and not considered in the crop function. Here is the corrected code:

    import numpy as np
    import tensorflow as tf
    
    def crop_3d(x: tf.Tensor, bounding_box: tf.Tensor):
        bounding_box = tf.cast(bounding_box, dtype=tf.int32)
        cropped_outputs = []
        for i in range(x.shape[0]):
            z_offset = bounding_box[i][0]
            y_offset = bounding_box[i][1]
            x_offset = bounding_box[i][2]
            z_size = bounding_box[i][3]
            y_size = bounding_box[i][4]
            x_size = bounding_box[i][5]
            cropped = x[i, z_offset:z_offset + z_size, y_offset:y_offset + y_size, x_offset:x_offset + x_size, :]
            cropped_outputs.append(cropped)
    
        return tf.stack(cropped_outputs)
    
    
    input_image = tf.keras.layers.Input(shape=(100, 100, 100, 3), batch_size=1, name="image")
    bounding_box = tf.keras.layers.Input(shape=(6,), batch_size=1, name="bounding_box", dtype=tf.int32)
    x = crop_3d(input_image, bounding_box)
    model = tf.keras.Model(inputs=[input_image, bounding_box], outputs=x)
    
    image_np = np.empty(shape=(1, 100, 100, 100, 3))
    image = tf.convert_to_tensor(image_np, dtype=tf.float32)
    bounding_box = np.reshape(np.array([5, 7, 1, 15, 25, 65]), newshape=(1, 6))
    bounding_box = tf.convert_to_tensor(bounding_box, dtype=tf.int32)
    res = model([image, bounding_box]).numpy()
    assert np.all(image_np[:, 5:20, 7:32, 1:66] == res)