Search code examples
deep-learningpytorchramray

Small PyTorch networks take almost 3 GB of RAM to train on MNIST


I am running into problems using PyTorch. I have to run some experiments on PyTorch custom models, and given that I have to train a lot of them I tried to run them in parallel using ray[tune]. I have written the code, and it runs "fine", but the problem is that I can't parallelize too much because the system runs out of RAM. I discovered that each PyTorch model in training uses something like 2/2.5 GB of RAM. The problem is that I am running really tiny networks, like 3 layers in total, and I am running the tests on MNIST, so everything should be really light weight, why is the training of such tiny networks using so much RAM? This becomes of course a real problem when I try to run the code in parallel, since just 4 trainings in parallel end up taking more than 8 GB of RAM. Is this normal?

I of course thought there was something wrong with my code, so I created a self-contained script in PyTorch, with a tiny network on MNIST, to test how much memory it ends up using; here is the code, it is by design really standard:

# FILENAME: selfcontained_pytorch_tester.py
# Importing the necessary modules for the training and testing of a simple network in pytorch with MNIST

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Checking if the GPU is available and setting the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device} device")

# Loading the data
training_data = datasets.MNIST(
    root="Data_SelfContained_Pytorch_Testing/data_train_MNIST",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="Data_SelfContained_Pytorch_Testing/data_test_MNIST",
    train=False,
    download=True,
    transform=ToTensor()
)


# Creating the dataloaders
dataloader_train = DataLoader(training_data, batch_size=64)
dataloader_test = DataLoader(test_data, batch_size=64)


# Defining the network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


# Creating the model
model = SimpleNN().to(device)

# Defining the loss function
loss_fn = nn.CrossEntropyLoss()

# Defining the optimizer Adam
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


# Defining the training function
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


# Defining the testing function
def test(dataloader, model, loss_fn):
    size = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


# Training and testing the model
epochs = 5
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(dataloader_train, model, loss_fn, optimizer)
    test(dataloader_test, model, loss_fn)
print("Done!")

Then I ran the script, simply monitoring the RAM usage with htop, and even this standard, really simple, pytorch script ends up using approximately 2.5 GB of ram (even something more!).

Here is a screen of htop:

enter image description here

As you can see we have huge VIRT and RES is approx. 2.9 GB.

What is going on? Is this normal for PyTorch?

I heard from friends that ran networks on TensorFlow that those are absurd memory requirements for the trainings of networks this small. Is this true?


UPDATE: I modified the code to see if perhaps the cuda device was to blame (so the GPU), and indeed if I run the code using device = 'cpu' the amount of RAM used by the script drops to approximately 500 MB; so in essence more than 2 GB of extra RAM memory is used when using the Nvidia GPU (cuda). Again: what is going on? Is this normal? I am using the pytorch version I installed with:

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

And the output of nvidia-smi is:

enter image description here

So there are two different cuda versions, but I have read that they should be compatible. Now of course I am wondering if this could be the problem, but I don't think so, I have run the code under other configurations obtaining the same result..

Is there a problem? Is this RAM allocation when using cuda necessary and intended behaviour? If so is there a way to 'pay this RAM price' only once when running more than one training at the same time with ray?


Solution

  • I think I have found the solution!

    Before all the relevant calls to torch you can put

    # At the start of the script
    import os
    os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
    

    This should tell torch not to load any unnecessary components in RAM. But, and here is the catch, this will only work if you have cuda version 11.8 or superior. I had to update my drivers to make it work. Doing this my memory requirements using cuda dropped to 700 MB, from the original 2.5 GB.

    That's it! I think this is the solution; everything is working fine now.