Search code examples
tensorflowkerasinputefficientnetragged-tensors

Correct way to iterate over Keras ragged tensor


I have an input Tensorflow ragged tensor structured like this [batch num_images width height channels] and I need to iterate over the dimension num_images to extract some features relevant for downstream applications. Example code is the following:

from tensorflow.keras.applications.efficientnet import EfficientNetB7
from tensorflow.keras.layers import Input
import tensorflow as tf


eff_net = EfficientNetB7(weights='imagenet', include_top=False)
input_claim = Input(shape=(None, 600, 600, 3), name='input_1', ragged=True)
eff_out = tf.map_fn(fn=eff_net, 
                    elems=input_claim, fn_output_signature=tf.float32)

The first Input dimension is set to None as it can differ across data points, and for this reason the input receives instances of tf.RaggedTensor.

This code breaks with a TypeError in this way TypeError: Could not build a TypeSpec for KerasTensor(type_spec=RaggedTensorSpec(TensorShape([None, None, 600, 600, 3]), tf.float32, 1, tf.int64), name='input_1', description="created by layer 'input_1'") of unsupported type <class 'keras.engine.keras_tensor.RaggedKerasTensor'>. I suspect there is a better way to perform this type of preprocessing though

Update: num_images is needed because (although not described here) I am doing some following reduce operation on this dimension


Solution

  • You can use tf.ragged.map_flat_values to achieve the same

    Create a model like:

    def eff_net(x): #dummy eff_net for testing that returns [batch, dim]
        return tf.random.normal(shape=tf.shape(x)[:2])
    
    input_claim = keras.Input(shape=(None, 600, 600, 3), name='input_1', ragged=True)
    
    class RaggedMapLayer(layers.Layer):
        def call(self, x):
            return tf.ragged.map_flat_values(eff_net, x)
    
    outputs = RaggedMapLayer()(input_claim)
    
    model = keras.Model(inputs=input_claim, outputs=outputs)
    

    testing,

    inputs = tf.RaggedTensor.from_row_splits( tf.random.normal(shape=(10, 600, 600, 3)), row_splits=[0, 2, 5,10])
    #shape [3, None, 600, 600, 3]
    
    model(inputs).shape
    #[3, None, 600]