Search code examples
pythondeep-learningpytorch

Pytorch accumulates tensors from the previous batch(es)


I have been trying to implement a trasnformer-UNet hybrid model for image segmentation. Whenever I try to train the model, it keeps running out of memory. Initially, I thought this was due to the size of the model and tried to decrease parameters like batch size, number of attention heads, number of transformer layers etc.

All these steps only delayed the inevitable, running out of memory. I even tried to use cloud GPUs, but still no luck. Here's a screenshot of pytorch's memory snapshot tool:

enter image description here

(I do empty the cache by calling torch.cuda.empty_cache)

I suspect that the some tensors are being retained because I use a list to implement the skip connection (the necessary tensors are appended to the list).

class convolutionalEncoder(torch.nn.Module):
    
    class conv_block(torch.nn.Module):
 
        def __init__(self, out_channels, device) -> None:
            super().__init__()
            self.conv = LazyConv2d(out_channels= out_channels, kernel_size= 2,
                                   stride= (2, 2), padding= 'valid', device= device)
            
 
        def forward(self, X):
            global relu
 
            X = self.conv(X)
            X = relu(X)
 
            return X
 
    def __init__(self, device) -> None:
        super(convolutionalEncoder, self).__init__()
        self.conv_block_list = []
        self.skip_conn = []
        for i in range(num_skip_conn):
            self.conv_block_list.append(self.conv_block(filters[i], device))
 
    def forward(self, X):
 
        for i in range(num_skip_conn):
            X = self.conv_block_list[i](X)
            self.skip_conn.append(X)
 
        return self.skip_conn

Here's a link to the code

Here's a link to the pickle file (memory snapshot dump)


Solution

  • It looks like your code is building a list of every skip connection that isn't reset between batches

    class convolutionalEncoder(torch.nn.Module):
        
        class conv_block(torch.nn.Module):
     
            ...
     
        def __init__(self, device) -> None:
            super(convolutionalEncoder, self).__init__()
            self.conv_block_list = []
            self.skip_conn = []
            for i in range(num_skip_conn):
                self.conv_block_list.append(self.conv_block(filters[i], device))
     
        def forward(self, X):
     
            for i in range(num_skip_conn):
                X = self.conv_block_list[i](X)
                self.skip_conn.append(X)
     
            return self.skip_conn
    

    self.skip_conn should just be a normal list, not an attribute that is retained between batches. What is happening is the tensors for every single batch are added to self.skip_conn, essentially storing the entire dataset in that list until you oom.

    Just replace it with a new normal list every time

        def forward(self, X):
     
            skip_conn = []
            for i in range(num_skip_conn):
                X = self.conv_block_list[i](X)
               skip_conn.append(X)
     
            return skip_conn