Search code examples
tensorflowgputensorflow2.0gradienttape

Use of tf.GradientTape() exhausts all the gpu memory, without it it doesn't matter


I'm working on Convolution Tasnet, model size I made is about 5.05 million variables.

I want to train this using custom training loops, and the problem is,

for i, (input_batch, target_batch) in enumerate(train_ds): # each shape is (64, 32000, 1)
    with tf.GradientTape() as tape:
        predicted_batch = cv_tasnet(input_batch, training=True) # model name
        loss = calculate_sisnr(predicted_batch, target_batch) # some custom loss
    trainable_vars = cv_tasnet.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
    cv_tasnet.optimizer.apply_gradients(zip(gradients, trainable_vars))

This part exhausts all the gpu memory (24GB available)..
When I tried without tf.GradientTape() as tape,

for i, (input_batch, target_batch) in enumerate(train_ds):
        predicted_batch = cv_tasnet(input_batch, training=True)
        loss = calculate_sisnr(predicted_batch, target_batch)

This uses a reasonable amount of gpu memory(about 5~6GB).

I tried the same format of tf.GradientTape() as tape for the basic mnist data, then it works without problem.
So would the size matter? but the same error arises when I lowered BATCH_SIZE to 32 or smaller.

Why the 1st code block exhausts all the gpu memory?

Of course, I put

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

this code at the very first cell.


Solution

  • Gradient tape triggers automatic differentiation which requires tracking gradients on all your weights and activations. Autodiff requires multiple more memory. This is normal. You'll have to manually tune your batch size until you find one that works, then tune your LR. Usually, the tune just means guess & check or grid search. (I am working on a product to do all of that for you but I'm not here to plug it).