Video card: gtx1070ti 8Gb, batchsize 64, input image size 128*128. I had such UNET with resnet152 as encoder which worked pretty fine:
class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152:
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool) #from that pool layer I would like to get rid off
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
self.dec5 = DecoderBlockV(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
is_deconv)
self.dec1 = DecoderBlockV(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
center = self.center(conv5)
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))
# blocks
class DecoderBlockV(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlockV2, self).__init__()
self.in_channels = in_channels
if is_deconv:
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels),
)
def forward(self, x):
return self.block(x)
class DecoderCenter(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderCenter, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels)
)
def forward(self, x):
return self.block(x)
Then I edited my class looks to make it work without pooling layer:
class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152:
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of Resnet are implemented')
self.relu = nn.ReLU(inplace=True)
self.input_adjust = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu)
self.conv1 = self.encoder.layer1
self.conv2 = self.encoder.layer2
self.conv3 = self.encoder.layer3
self.conv4 = self.encoder.layer4
self.dec4 = DecoderBlockV(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec2 = DecoderBlockV(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec1 = DecoderBlockV(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,is_deconv)
self.final = nn.Conv2d(num_filters * 2 * 2, num_classes, kernel_size=1)
def forward(self, x):
input_adjust = self.input_adjust(x)
conv1 = self.conv1(input_adjust)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
center = self.conv4(conv3)
dec4 = self.dec4(center) #now without centblock
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = F.dropout2d(self.dec1(torch.cat([dec2, conv1], 1)), p=self.dropout_2d)
return self.final(dec1)
is_deconv - in both cases True. After changing it stop to work with batchsize 64, only with with size of 16 or with batchsize 64 but with resnet16 only - otherwise out of cuda memory. What am I doing wrong?
Full stack of error:
~/Desktop/ml/salt/open-solution-salt-identification-master/common_blocks/unet_models.py in forward(self, x)
418 conv1 = self.conv1(input_adjust)
419 conv2 = self.conv2(conv1)
--> 420 conv3 = self.conv3(conv2)
421 center = self.conv4(conv3)
422 dec4 = self.dec4(center)
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
355 result = self._slow_forward(*input, **kwargs)
356 else:
--> 357 result = self.forward(*input, **kwargs)
358 for hook in self._forward_hooks.values():
359 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
65 def forward(self, input):
66 for module in self._modules.values():
---> 67 input = module(input)
68 return input
69
~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
355 result = self._slow_forward(*input, **kwargs)
356 else:
--> 357 result = self.forward(*input, **kwargs)
358 for hook in self._forward_hooks.values():
359 hook_result = hook(self, input, result)
~/anaconda3/lib/python3.6/site-packages/torchvision-0.2.0-py3.6.egg/torchvision/models/resnet.py in forward(self, x)
79
80 out = self.conv2(out)
---> 81 out = self.bn2(out)
82 out = self.relu(out)
The problem is that you do not have enough memory, as already mentioned in the comments.
To be more specific, the problem lies in the increased size due to the removal of the max pooling, as you already correctly narrowed it down. The point of max pooling - aside from the increased invariance of the setting - is to reduce the image size, such that it costs you less memory because you need to store less activations for the backpropagation part.
A good answer of what max pooling does might be helpful. My go-to one is the one on Quora. As you also correctly figured out already is the fact that the batch size also plays a huge role in terms of memory consumption. Generally, it is also preferred to use smaller batch sizes anyways, as long as you don't have a skyrocketing processing time after.