What I need help with / What I was wondering I am performing cross-validation using the keras API, and have put all the code to perform one round of CV into a single function. The first round of CV works, but then upon the second round, I get an OOM error upon trying to build the next model.
import tensorflow as tf
def run_fold_training(k_fold, num_folds, batch_size):
#clear graph
tf.keras.backend.clear_session()
#try to get tpu or else gpu
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
print('Device:', tpu.master())
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)
with strategy.scope():
# make k-fold dataset
ds = build_dataset()
train_ds = ds.enumerate().filter(
lambda i, ds, num_folds=num_folds, k_fold=k_fold: i % num_folds != k_fold).map(
lambda i, ds: ds).batch(batch_size)
test_ds = ds.enumerate().filter(
lambda i, ds, num_folds=num_folds, k_fold=k_fold: i % num_folds == k_fold).map(
lambda i, ds: ds).batch(batch_size)
# make, train, evaluate model
model = MyModel(**model_kwargs)
model.compile(**compile_kwargs)
model.fit(train_ds, epochs=25)
results = model.evaluate(test_ds, return_dict=True)
return results["score"]
num_folds = 5
batch_size = 8
cv_loss = sum([run_fold_training(k, num_folds, batch_size) for k in range(num_folds)]) / num_folds
print(f"Final {num_folds}-fold cross validation score is: {cv_loss}")
What I've tried so far I'm clearing the keras backend at the start of the CV round and I'm also creating a new distribute strategy scope per round. I've already tried batch sizes of [1,2,4,8]. For all batchsizes it does one round fine, but gives OOM at the start of the next round.
It would be nice if... It would be great it there was access to lower level control over memory management. This could be in tiers of complexity. Like, simplest case would be a function that frees all device memory related to a certain graph. In TF1 I would have just made a new session per CV round, and this wouldn't be a problem.
Environment information (if applicable)
tensorflow/tensorflow:2.3.1-gpu
The answer was discovered by a friend. If there are references to graph ops/variables created outside the run_fold_training
function then the clear_session
will not completely work. The solution is to make sure that entire new graph is created after the clear_session
. E.g. don't reuse optimizers, etc.