Search code examples
pythontensorflowtfrecord

Tensorflow computing an aggregate function over lots of data in tfrecord format


I have a very large dataset stored in chunks of 5000 across several tfrecord files. Together all of these records are way larger than my RAM. What I would like to do is sample from N = 0.05 * TOTAL_SIZE random indexes in the dataset and compute the mean and std deviation to normalize my data.

If it wasn't for the dataset size this would be easy, but I run out of memory even when I try and compute a sum of all the tensors I'm interested in.

# NOTE: count is computed ahead of time by looping over all the tfrecord entries

with tf.device('/cpu:0'):
    sample_size = int(count * 0.05)
    random_indexes = set(np.random.randint(low=0, high=count, size=sample_size))
    stat_graph = tf.Graph()
    with tf.Session(graph=stat_graph) as sess:
        val_sum = np.zeros(shape=(180, 2050))
        for file in files:
            print("Reading from file: %s" % file)
            for record in tf.python_io.tf_record_iterator(file):
                features = tf.parse_single_example(
                    record,
                    features={
                        "val": tf.FixedLenFeature((180, 2050), tf.float32),
                    })
                if index in random_indexes:
                    val_sum += features["val"].eval(session=sess)
                index += 1
        val_mean = val_sum / sample_size

What is the right way to compute some aggregate function, i.e. mean and/or standard deviation, over a tfrecord dataset?


Solution

  • I think that tf.parse_single_example adds a new tensor to the graph every time it's called. Instead of the above, you should feed the string with a placeholder:

    ...
    record_placeholder = tf.placeholder(tf.string)
    features = tf.parse_single_example(
        record_placeholder,
        features={
            "val": tf.FixedLenFeature((180, 2050), tf.float32),
        })
    for record in tf.python_io.tf_record_iterator(file):
    ...
    val_sum += features["val"].eval(feed_dict={record_placeholder: record}, session=sess)
    

    Let me know if this works since I have no way of testing it.