Search code examples
pythontensorflowmachine-learningkerasimagenet

How can I use a tensorflow data set (TDFS) as an input for a tensorflow model?


I am currently working with the ImageNet data set, and as you may know it is extremely large.

I have preprocessed it from .tar files into tfrecord files.

I am currently loading it using:

train, val = tfds.load(*)

So I have two tfds: train and val.

I am then adjusting them using:

def resize_with_crop(image, label):
    i = image
    i = tf.cast(i, tf.float32)
    i = tf.image.resize_with_crop_or_pad(i, 224, 224)
    i = tf.keras.applications.mobilenet_v2.preprocess_input(i)
    return (i, label)

# Preprocess the images
train = train.map(resize_with_crop)
val = val.map(resize_with_crop)

which I am following from here.

After I try to fit my model,d = model.fit(train, validation_data=val,...) where the first layer has shape (None, 224, 224, 3), I receive the error: ValueError: Input 0 of layer conv2d is incompatible with the layer: expected ndim=4, found ndim=3

This issue (I believe) is because the model is being given one image at a time (so it doesn't have a 4d shape. I cannot hold the dataset in memory to restructure it as (None, 224, 224, 3) as I would for a cifar-10 dataset).

My question is, now that the images are of form (224, 224, 3) how can I use them with a tensorflow model that expects a 4d shape but I cant reshape the dataset in memory?

Or is there a way to adjust the tfds shape so that it works as an input for the model?

I am not sure I fully understand tfds, which is why I am having this issue. Additionally, I am sure that the labels will cause an issue (since theyre ints), so how can I restructure the tfds labels to be one hot encoded for the model?


Solution

  • tfds.load returns a tf.data.Dataset object. So, you can do with the returned value(s) whatever is possible with a tensorflow dataset.

    The 4D data of input is mostly expected as (batch_size, Hight, Width, Channel). So, if your images are in shape (224,224,3), you need to batch them in order to add batch dimension to be compatible with what model expects.

    For batching a dataset, simply use .batch(batch_size):

    train = train.batch(32)
    val = val.batch(32)