Search code examples
pytorchout-of-memorygpuconv-neural-networklstm

GPU memory increasing at each batch (PyTorch)


I am trying to build a convolutionnal network using ConvLSTM layer (LSTM cell but with convolutions instead of matrix multiplications), but the problem is that my GPU memory increases at each batch, even if I'm deleting variables, and getting the true value for the loss (and not the graph) for each iteration. I may be doing something wrong but that exact same script ran without issues with another model (with more parameters and also using ConvLSTM layer).

Each batch is composed of num_batch x 3 images (grayscale) and I'm trying to predict the difference |Im(t+1)-Im(t)| with the input Im(t)

def main():
    config = Config()

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, num_workers=0, shuffle=True, drop_last=True)
    
    nb_img = len(train_dataset)
    util.clear_progress_dir()

    step_tensorboard = 0
    ###################################
    #          Model Setup            #
    ###################################

    model = fully_convLSTM()
    if torch.cuda.is_available():   
        model = model.float().cuda()

    lr = 0.001
    optimizer = torch.optim.Adam(model.parameters(),lr=lr) 

    util.enumerate_params([model])

    ###################################
    #          Training Loop          #
    ###################################

    model.train() #Put model in training mode

    train_loss_recon = []
    train_loss_recon2 = []
    
    for epoch in tqdm(range(config.num_epochs)):
  
        running_loss1 = 0.0
        running_loss2 = 0.0

        for i, (inputs, outputs) in enumerate(train_dataloader, 0):
            print(i)
            torch.cuda.empty_cache()
            gc.collect()
           
           # if torch.cuda.is_available():
            inputs  = autograd.Variable(inputs.float()).cuda()
            outputs = autograd.Variable(outputs.float()).cuda()

            im1 =  inputs[:,0,:,:,:]
            im2 =  inputs[:,1,:,:,:]
            im3 =  inputs[:,2,:,:,:]
            
            diff1 = torch.abs(im2 - im1).cuda().float()
            diff2 = torch.abs(im3 - im2).cuda().float()

            model.initialize_hidden()
            
            optimizer.zero_grad()
            pred1 = model.forward(im1)  
            loss = reconstruction_loss(diff1, pred1)
            loss.backward()
            # optimizer.step()
           
            model.update_hidden()
            
            optimizer.zero_grad()
            pred2 = model.forward(im2)  
            loss2 = reconstruction_loss(diff2, pred2)   
            loss2.backward()   
            optimizer.step()

            model.update_hidden()

            ## print statistics
      
            running_loss1 += loss.detach().data
            running_loss2 += loss2.detach().data
            
            if i==0:

                with torch.no_grad():
                    img_grid_diff_true = (diff2).cpu()
                    img_grid_diff_pred = (pred2).cpu()
                    
                    f, axes = plt.subplots(2, 4, figsize=(48,48))
                    for l in range(4):
                        axes[0, l].imshow(img_grid_diff_true[l].squeeze(0).squeeze(0), cmap='gray')
                        axes[1, l].imshow(img_grid_diff_pred[l].squeeze(0).squeeze(0), cmap='gray')

                    plt.show()
                    plt.close()
           
                    writer_recon_loss.add_scalar('Reconstruction loss', running_loss1, step_tensorboard)
                    writer_recon_loss2.add_scalar('Reconstruction loss2', running_loss2, step_tensorboard)

                    step_tensorboard += 1
            
            del pred1
            del pred2
            del im1
            del im2
            del im3
            del diff1
            del diff2#, im1_noised, im2_noised
            del inputs
            del outputs
            del loss
            del loss2
            for obj in gc.get_objects():
                if torch.is_tensor(obj) :
                    del obj
        
            torch.cuda.empty_cache()
            gc.collect()
     
        epoch_loss = running_loss1 / len(train_dataloader.dataset)
        epoch_loss2 = running_loss2/ len(train_dataloader.dataset)
        print(f"Epoch {epoch} loss reconstruction1: {epoch_loss:.6f}")
        print(f"Epoch {epoch} loss reconstruction2: {epoch_loss2:.6f}")
        
        train_loss_recon.append(epoch_loss)
        train_loss_recon2.append(epoch_loss2)
        
        del running_loss1, running_loss2, epoch_loss, epoch_loss2

Here is the model used :

class ConvLSTMCell(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size):
        super(ConvLSTMCell, self).__init__()

        # assert hidden_channels % 2 == 0

        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.kernel_size = kernel_size
        # self.num_features = 4

        self.padding = 1

        self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)
        self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True)
        self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False)

        self.Wci = None
        self.Wcf = None
        self.Wco = None

    def forward(self, x, h, c): ## Equation (3) dans Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting
        ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci)
        cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf)
        cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h)) ###gt= tanh(cc)
        co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc * self.Wco) ##channel out = hidden channel
        ch = co * torch.tanh(cc)
        return ch, cc #short memory, long memory

    def init_hidden(self, batch_size, hidden, shape):
        if self.Wci is None:
            self.Wci = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
            self.Wcf = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
            self.Wco = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda()
        else:
            assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!'
            assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!'
        return (autograd.Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda(),
                autograd.Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda())


class fully_convLSTM(nn.Module):
    def __init__(self):
        super(fully_convLSTM, self).__init__()
        layers = []
        self.hidden_list = [1,32,32,1]#,32,64,32,
        for k in range(len(self.hidden_list)-1):   # Define blocks of [ConvLSTM,BatchNorm,Relu]
            name_conv = "self.convLSTM" +str(k)
            cell_conv = ConvLSTMCell(self.hidden_list[k],self.hidden_list[k+1],3)
            setattr(self, name_conv, cell_conv)
            name_batchnorm = "self.batchnorm"+str(k)
            batchnorm=nn.BatchNorm2d(self.hidden_list[k+1])
            setattr(self, name_batchnorm, batchnorm)
            name_relu =" self.relu"+str(k)
            relu=nn.ReLU()
            setattr(self, name_relu, relu)
        self.sigmoid = nn.Sigmoid()
    
        self.internal_state=[]
        
    def initialize_hidden(self):  
        for k in range(len(self.hidden_list)-1):   
            name_conv = "self.convLSTM" +str(k)             
            (h,c) = getattr(self,name_conv).init_hidden(config.batch_size, self.hidden_list[k+1],(256,256))
            self.internal_state.append((h,c))          
        self.internal_state_new=[]
    def update_hidden(self):
        for i, hidden in enumerate(self.internal_state_new):
            self.internal_state[i] = (hidden[0].detach(), hidden[1].detach())
        self.internal_state_new = []        
    def forward(self, input):
        x = input
        for k in range(len(self.hidden_list)-1):
            name_conv = "self.convLSTM" +str(k)
            name_batchnorm = "self.batchnorm"+str(k)
            name_relu =" self.relu"+str(k)
            x, c = getattr(self,name_conv)(x, self.internal_state[k][1], self.internal_state[k][0]) 
            self.internal_state_new.append((x.detach(),c.detach()))
            x = getattr(self,name_batchnorm)(x)
            if k!= len(self.hidden_list)-2:
                x = getattr(self,name_relu)(x)
            else :
                x = self.sigmoid(x)
        return x

So my question is, what in my code is causing memory to accumulate during the training phase?


Solution

  • A few quick notes about training code:

    • torch.Variable is deprecated since at least 8 minor versions (see here), don't use it
    • gc.collect() has no point, PyTorch does the garbage collector on it's own
    • Don't use torch.cuda.empty_cache() for each batch, as PyTorch reserves some GPU memory (doesn't give it back to OS) so it doesn't have to allocate it for each batch once again. It will make your code slow, don't use this function at all tbh, PyTorch handles this.
    • Don't spam random memory cleaning, that's most probably not where the error is

    Model

    Yes, this is probably the case (although it's hard to read this model's code).

    Take notice of self.internal_state list and self.internal_state_new list also.

    • Each time you call model.initialize_hidden() a new set of tensor is added to this list (and never cleaned as far as I can tell)
    • self.internal_state_new seems to be cleaned in update_hidden, maybe self.internal_state should be also?

    In essence, check out this self.internal_state property of your model, the list grows indefinitely from what I see. Initializing with zeros everywhere is quite strange, there is probably no need to do that (e.g. PyTorch's RNN is initialized with zeros by default, this is probably similar).