Search code examples
pythonpytorchautogradunet-neural-network

Can't fix torch autograd runtime error: UNet inplace operation


I can't fix the runtime error "one of the variables needed for gradient computation has been modified by an inplace operation.

I know, that if I comment out loss.backward() the code will run, but I don't get in which order should I call the functions to avoid this error

When I call it my wrapper with Resnet50 I don't experience any problems, but with Unet the RuntimeError occurs

    for i, (x, y) in batch_iter:
        with torch.autograd.set_detect_anomaly(True):
            input, target = x.to(self.device), y.to(self.device)

            self.optimizer.zero_grad()
            if self.box_training:
                out = self.model(input)
            else:
                out = self.model(input).clamp(0,1)
            
            loss = self.criterion(out, target)                
            loss_value = loss.item()
            train_losses.append(loss_value)
            loss.backward()
            self.optimizer.step()

            batch_iter.set_description(f'Training: (loss {loss_value:.4f})')

    self.training_loss.append(np.mean(train_losses))
    self.learning_rate.append(self.optimizer.param_groups[0]['lr'])

As the comments pointed out, I should provide a model

And by looking at it, I actually found what was the problem:

model = UNet(in_channels=1,
         num_encoding_blocks = 6,
         out_classes = 1,
         padding=1,
         dimensions = 2,
         out_channels_first_layer = 32,
         normalization = None,
         pooling_type = 'max',
         upsampling_type = 'conv',
         preactivation = False,
         #residual = True,
         padding_mode = 'zeros',
         activation = 'ReLU',
         initial_dilation = None,
         dropout = 0,
         monte_carlo_dropout = 0
        )

It is residual = True which I has commented out. I will look into the docs, what is going on. Maybe if you have an idea, you can enlighten me


Solution

  • Explanation:

    It looks like the UNet library you are using includes a += (in-place tensor addition) in the residual branch of the encoder:

    if self.residual:
        connection = self.conv_residual(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x += connection # <------- !!!
    

    In-place operations like += may overwrite information that is needed for gradient computation during loss.backward(). PyTorch detects when this necessary information has been overwritten, and complains.

    Fix:

    If you want to train this network with residual enabled, you would need to replace this += with a not-in-place add:

    if self.residual:
        connection = self.conv_residual(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = x + connection # <-------
    

    A similar edit is needed in the decoder. If you installed this unet library via pip, you would want to download it directly from github instead so you can make these edits (and uninstall the pip version to avoid confusion).

    For more information about why in-place operations can cause problems, see this blog post or this section of the PyTorch docs.