Search code examples
pythontensorflow2.0tensorflow-datasets

How to unpack the data from tensorflow dataset?


This is my code about loading data from tfrecord:

def read_tfrecord(tfrecord, epochs, batch_size):

    dataset = tf.data.TFRecordDataset(tfrecord)

    def parse(record):
        features = {
            "image": tf.io.FixedLenFeature([], tf.string),
            "target": tf.io.FixedLenFeature([], tf.int64)
        }
        example = tf.io.parse_single_example(record, features)
        image = decode_image(example["image"])
        label = tf.cast(example["target"], tf.int32)
        return image, label

    dataset = dataset.map(parse)
    dataset = dataset.shuffle(buffer_size=10000)        
    dataset = dataset.prefetch(buffer_size=batch_size)  #
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.repeat(epochs)

    return dataset


x_train, y_train = read_tfrecord(tfrecord=train_files, epochs=EPOCHS, batch_size=BATCH_SIZE)

I got the following error:

ValueError: too many values to unpack (expected 2)

My question is:

How to unpack the data from dataset?


Solution

  • You can try this solution:

    dataset = read_tfrecord(tfrecord=train_files, epochs=EPOCHS, batch_size=BATCH_SIZE)
    
    iterator = iter(dataset)
    
    x, y = next(iterator)