I have been trying to implement a trasnformer-UNet hybrid model for image segmentation. Whenever I try to train the model, it keeps running out of memory. Initially, I thought this was due to the size of the model and tried to decrease parameters like batch size, number of attention heads, number of transformer layers etc.
All these steps only delayed the inevitable, running out of memory. I even tried to use cloud GPUs, but still no luck. Here's a screenshot of pytorch's memory snapshot tool:
(I do empty the cache by calling torch.cuda.empty_cache
)
I suspect that the some tensors are being retained because I use a list to implement the skip connection (the necessary tensors are appended to the list).
class convolutionalEncoder(torch.nn.Module):
class conv_block(torch.nn.Module):
def __init__(self, out_channels, device) -> None:
super().__init__()
self.conv = LazyConv2d(out_channels= out_channels, kernel_size= 2,
stride= (2, 2), padding= 'valid', device= device)
def forward(self, X):
global relu
X = self.conv(X)
X = relu(X)
return X
def __init__(self, device) -> None:
super(convolutionalEncoder, self).__init__()
self.conv_block_list = []
self.skip_conn = []
for i in range(num_skip_conn):
self.conv_block_list.append(self.conv_block(filters[i], device))
def forward(self, X):
for i in range(num_skip_conn):
X = self.conv_block_list[i](X)
self.skip_conn.append(X)
return self.skip_conn
It looks like your code is building a list of every skip connection that isn't reset between batches
class convolutionalEncoder(torch.nn.Module):
class conv_block(torch.nn.Module):
...
def __init__(self, device) -> None:
super(convolutionalEncoder, self).__init__()
self.conv_block_list = []
self.skip_conn = []
for i in range(num_skip_conn):
self.conv_block_list.append(self.conv_block(filters[i], device))
def forward(self, X):
for i in range(num_skip_conn):
X = self.conv_block_list[i](X)
self.skip_conn.append(X)
return self.skip_conn
self.skip_conn
should just be a normal list, not an attribute that is retained between batches. What is happening is the tensors for every single batch are added to self.skip_conn
, essentially storing the entire dataset in that list until you oom.
Just replace it with a new normal list every time
def forward(self, X):
skip_conn = []
for i in range(num_skip_conn):
X = self.conv_block_list[i](X)
skip_conn.append(X)
return skip_conn