Search code examples
pythontensorflowkerastensorflow-datasetsbatching

irregular/varying batch size in tensorflow?


I have a tensorflow dataset and would like to batch it such that batches do not have the same size - something like examples being grouped in batches whose sizes are defined by a vector of values rather than a fixed value.

Is there a way to do it within tensorflow?

And for a network without fixed batch size, is feeding irregular batches going to be a problem?

Thanks in advance!


Solution

  • The answer is yes. model.fit() method allows to pass to it a generator which will generate randomly-sized batches.

    x_train_batches = ... # some list of data batches of uneven length 
    y_train_batches = ... # some list of targets of SAME lengths as x_train_batches
    
    def train_gen(x_train_batches, y_train_batches):
        i = 0
        num_batches = len(x_train_batches)
        while True:
            yield (x_train_batches[i%num_batches], y_train_batches[i%num_batches])
            i += 1
    
    model.fit(train_gen(x_train_batches, y_train_batches), epochs=epochs, steps_per_epoch=NUM_BATCHES)
    

    Another, more elegant, way would be to subclass tf.keras.utils.Sequence and feed it to the model:

    class UnevenSequence(keras.utils.Sequence):
          def __init__(self, x_batches, y_batches):
              # x_batches, y_batches are lists of uneven batches
              self.x, self.y = x_batches, y_batches
          def __len__(self):
              return len(self.x)
          def __getitem__(self, idx):
              batch_x = self.x[idx]
              batch_y = self.y[idx]
              return (batch_x, batch_y)
    
    my_uneven_sequence = UnevenSequence(x_train_batches, y_train_batches)
    
    model.fit(my_uneven_sequence, epochs=10)