Search code examples
tensorflowkerastensorflow-datasetstensorflow-lite

TFLiteConverter representative_dataset from keras.preprocessing.image_dataset_from_directory dataset


I've got a dataset coming in via

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=validation_split,
  subset="training",
  seed=seed,
  image_size=(img_height, img_width),
  batch_size=batch_size)

(Based around code from https://www.tensorflow.org/tutorials/load_data/images with very minor changes to configuration)

I'm converting the eventual model to a TFLite model, which is working, but I think the model's too large for the end device so I'm trying to run post training quantization by supplying a representative_dataset (like https://www.tensorflow.org/lite/performance/post_training_quantization)

However I can't work out how to turn the dataset generated from image_dataset_from_directory into the format expected by representative_dataset

The example provided has

def representative_dataset():
  for data in tf.data.Dataset.from_tensor_slices((images)).batch(1).take(100):
    yield [data.astype(tf.float32)]

I've tried things like

def representative_dataset():
  for data in train_ds.batch(1).take(100):
    yield [data.astype(tf.float32)]

but that wasn't it


Solution

  • Looks like

    def representative_dataset():
      for image_batch, labels_batch in train_ds:
        yield [image_batch]
    

    Was what I was looking for, image_batch is already tf.float32