Search code examples
pythontensorflowmachine-learningtfrecord

use numpy iterator for validation data


I am new to machine learning and i am trying to solve a problem with my code.
The training dataset I am using is saved in a tf.record file and since it is too large to be stored in memory I am using an iterator for the training set. The problem is that even the validation set is too big to put into memory (at least for my laptop with only 2GB of memory on the video card), so it is saved on disk as a tf.record too, I just don't think I can use the same trick as the iterator. What can I use then?

Code

#reading training and validation dataset
def read_tfrecord(example):
  tfrecord_format = (
      {
          "x": tf.io.FixedLenFeature([], tf.string),
          "y": tf.io.FixedLenFeature([], tf.string),
      }
  )
  example = tf.io.parse_single_example(example, tfrecord_format)
  x = tf.io.parse_tensor(example['x'], out_type=tf.float32)
  y = tf.io.parse_tensor(example['y'], out_type=tf.double)
    

  return x,y

filename = "train.tfrecord"
training_dataset = (tf.data.TFRecordDataset(filename).map(read_tfrecord))
iterator = training_dataset.repeat().prefetch(10).as_numpy_iterator()

filename = "validation.tfrecord"
validation_dataset = (tf.data.TFRecordDataset(filename).map(read_tfrecord))
val_iterator = validation_dataset.repeat().prefetch(10).as_numpy_iterator()

I call then fit method in this way

model.fit(iterator,
          validation_data=(val_iterator),
          epochs=35,
          verbose=1)

but the program fails to finish the first epoch, it stucks and never ends


Solution

  • found a solution using generator, I'll post the code

    #generator
    def generator(self, dataset, batch_size):
          ds = dataset.repeat().prefetch(tf.data.AUTOTUNE)
          iterator = iter(ds)
          x, y = iterator.get_next()
          
          while True:
            yield x, y
    
    #reading training and validation dataset
    def read_tfrecord(example):
      tfrecord_format = (
          {
              "x": tf.io.FixedLenFeature([], tf.string),
              "y": tf.io.FixedLenFeature([], tf.string),
          }
      )
      example = tf.io.parse_single_example(example, tfrecord_format)
      x = tf.io.parse_tensor(example['x'], out_type=tf.float32)
      y = tf.io.parse_tensor(example['y'], out_type=tf.double)
        
    
      return x,y
    
    filename = "train.tfrecord"
    training_dataset = tf.data.TFRecordDataset(filename).map(read_tfrecord)
    train_ds = generator(training_dataset, batch_size)
    
    filename = "validation.tfrecord"
    validation_dataset = (tf.data.TFRecordDataset(filename).map(read_tfrecord))
    valid_ds = generator(validation_dataset, batch_size)
    
    kwargs['validation_data'] = (valid_ds)
    
    #get your training step with something like this
    training_steps = x.shape[0]//batch_size
    validation_steps = x_val.shape[0]//batch_size
    
    model.fit(train_ds, steps_per_epoch = training_steps, validation_steps=validation_steps, **kwargs)