Search code examples
pythontensorflowtfrecord

How to release memory from TF2 graphs when using TFRecordDataset


Sorry if there are any mistakes in this question. I come from a PyTorch background but I need to use TFRecordDataset in order to read from TFRecord's. Currently, this looks like the following:

class TFRecordReader:
    def __iter__(self):
        dataset = tf.data.TFRecordDataset(
            [self.tfrecord_path], compression_type="GZIP"
        )
        dataset = dataset.map(self._parse_example)
        self._tfr_iter = iter(dataset)

    def __next__(self):
        return next(self._tfr_iter)

However, I need to create multiple TFRecordReader's per PyTorch worker to do batch balancing. This results in me having 4 TFRecordDataset (4 buckets to balance with) per worker per GPU, so I end up with 4 * 4 * 4 = 64 TFRecordDataset in memory. I have enough CPU memory to do this, but the issue is that the memory is not being released from the TFRecordDataset as the memory is increasing consistently over the course of training. I believe the issue is that the computation graph keeps being grown (every time a new TFRecord is read a new TFRecordDataset is created for it), but is never released.

How can I make sure that the memory used by the TFRecordDataset is released after I finish iterating through a single TFRecord?

I tried:

def __iter__(self)
    with tf.Graph().as_default() as g:
        dataset = tf.data.TFRecordDataset(
            [self.tfrecord_path], compression_type="GZIP"
        )
        dataset = dataset.map(self._parse_example)

        tf.compat.v1.enable_eager_execution()
        self._tfr_iter = iter(dataset)

        while True:
            try:
                example_dict = next(
                    self._tfr_iter
                )
   # ...

However, I get an error that:

RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.

I would really appreciate any advice on how to make sure the memory does not keep growing. I am using Tensorflow 2.5 for reference.


Solution

  • The issue turned out to be using the PyTorch profiler with PyTorch Lightning. The issue was not with Tensorflow.

    See relevant issue here