Search code examples
pythontensorflowkerastfrecord

Keras model with TensorFlow TFRecord Dataset error -- rank is undefined


I'm using a fairly standard TFRecord dataset. The records are Example protobufs. The "image" feature is a 28 by 28 tensor serialised by tf.io.serialize_tensor.

feature_description = {
    "image": tf.io.FixedLenFeature((), tf.string),
    "label": tf.io.FixedLenFeature((), tf.int64)}

image_shape = (28, 28)

def preprocess(example):
    example = tf.io.parse_single_example(example, feature_description)
    image, label = example["image"], example["label"]
    image = tf.io.parse_tensor(image, out_type=tf.float64)
    return image, label

batch_size = 32
dataset = tf.data.TFRecordDataset("data/train.tfrecord")\
                 .map(preprocess).batch(batch_size).prefetch(1)

However, I have the following simple Keras model:

model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten(input_shape=image_shape))
model.add(tf.keras.layers.Dense(10, activation="softmax"))
model.compile(loss="sparse_categorical_crossentropy", optimizer="sgd", metrics=["accuracy"])

and whenever I try to fit or predict this model with the dataset

model.fit(dataset)
model.predict(dataset)

I get the following error:

ValueError: Input 0 of layer sequential is incompatible with the layer: its rank is undefined, but the layer requires a defined rank.

Strangely, if I instead create an equivalent dataset via tf.data.Dataset.from_tensor_slices(images), although it yields exactly the same items, the error does not occur.


Solution

  • The model needs to infer a single input shape. But preprocess parses serialised image tensors of any shape, and this is done on the fly as records are streamed, so it is not possible to infer an input shape for all of the data.

    This is easily fixed by adding a TF function which asserts the tensor shape, tf.ensure_shape:

    def preprocess(example):
        example = tf.io.parse_single_example(example, feature_description)
        image, label = example["image"], example["label"]
        image = tf.io.parse_tensor(image, out_type=tf.float64)
        image = tf.ensure_shape(image, image_shape)    # THE FIX
        return image, label