Search code examples
tensorflowkerasout-of-memorygpucross-validation

OOM in second round of cross-validation


What I need help with / What I was wondering I am performing cross-validation using the keras API, and have put all the code to perform one round of CV into a single function. The first round of CV works, but then upon the second round, I get an OOM error upon trying to build the next model.

  • Why is this happening?
  • How do I properly do this type of CV from a single python process?
  • Is there a way to completely flush the GPU/TPU memory to control things like memory fragmentation?
import tensorflow as tf

def run_fold_training(k_fold, num_folds, batch_size):
    #clear graph
    tf.keras.backend.clear_session()
    #try to get tpu or else gpu
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        print('Device:', tpu.master())
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
    except:
        strategy = tf.distribute.get_strategy()
    print('Number of replicas:', strategy.num_replicas_in_sync)
    with strategy.scope():
        # make k-fold dataset
        ds = build_dataset()
        train_ds = ds.enumerate().filter(
            lambda i, ds, num_folds=num_folds, k_fold=k_fold: i % num_folds != k_fold).map(
            lambda i, ds: ds).batch(batch_size)
        test_ds = ds.enumerate().filter(
            lambda i, ds, num_folds=num_folds, k_fold=k_fold: i % num_folds == k_fold).map(
            lambda i, ds: ds).batch(batch_size)
        # make, train, evaluate model
        model = MyModel(**model_kwargs)
        model.compile(**compile_kwargs)
        model.fit(train_ds, epochs=25)
        results = model.evaluate(test_ds, return_dict=True)

    return results["score"]

num_folds = 5
batch_size = 8
cv_loss = sum([run_fold_training(k, num_folds, batch_size) for k in range(num_folds)]) / num_folds
print(f"Final {num_folds}-fold cross validation score is: {cv_loss}")

What I've tried so far I'm clearing the keras backend at the start of the CV round and I'm also creating a new distribute strategy scope per round. I've already tried batch sizes of [1,2,4,8]. For all batchsizes it does one round fine, but gives OOM at the start of the next round.

It would be nice if... It would be great it there was access to lower level control over memory management. This could be in tiers of complexity. Like, simplest case would be a function that frees all device memory related to a certain graph. In TF1 I would have just made a new session per CV round, and this wouldn't be a problem.

Environment information (if applicable)

  • Operating System: ubuntu 18.04
  • Python version: 3.8
  • Docker: tensorflow/tensorflow:2.3.1-gpu

Solution

  • The answer was discovered by a friend. If there are references to graph ops/variables created outside the run_fold_training function then the clear_session will not completely work. The solution is to make sure that entire new graph is created after the clear_session. E.g. don't reuse optimizers, etc.