I can't fix the runtime error "one of the variables needed for gradient computation has been modified by an inplace operation.
I know, that if I comment out loss.backward()
the code will run, but I don't get in which order should I call the functions to avoid this error
When I call it my wrapper with Resnet50 I don't experience any problems, but with Unet the RuntimeError occurs
for i, (x, y) in batch_iter:
with torch.autograd.set_detect_anomaly(True):
input, target = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
if self.box_training:
out = self.model(input)
else:
out = self.model(input).clamp(0,1)
loss = self.criterion(out, target)
loss_value = loss.item()
train_losses.append(loss_value)
loss.backward()
self.optimizer.step()
batch_iter.set_description(f'Training: (loss {loss_value:.4f})')
self.training_loss.append(np.mean(train_losses))
self.learning_rate.append(self.optimizer.param_groups[0]['lr'])
As the comments pointed out, I should provide a model
And by looking at it, I actually found what was the problem:
model = UNet(in_channels=1,
num_encoding_blocks = 6,
out_classes = 1,
padding=1,
dimensions = 2,
out_channels_first_layer = 32,
normalization = None,
pooling_type = 'max',
upsampling_type = 'conv',
preactivation = False,
#residual = True,
padding_mode = 'zeros',
activation = 'ReLU',
initial_dilation = None,
dropout = 0,
monte_carlo_dropout = 0
)
It is residual = True
which I has commented out. I will look into the docs, what is going on. Maybe if you have an idea, you can enlighten me
Explanation:
It looks like the UNet library you are using includes a +=
(in-place tensor addition) in the residual branch of the encoder:
if self.residual:
connection = self.conv_residual(x)
x = self.conv1(x)
x = self.conv2(x)
x += connection # <------- !!!
In-place operations like +=
may overwrite information that is needed for gradient computation during loss.backward()
. PyTorch detects when this necessary information has been overwritten, and complains.
Fix:
If you want to train this network with residual
enabled, you would need to replace this +=
with a not-in-place add:
if self.residual:
connection = self.conv_residual(x)
x = self.conv1(x)
x = self.conv2(x)
x = x + connection # <-------
A similar edit is needed in the decoder. If you installed this unet
library via pip
, you would want to download it directly from github instead so you can make these edits (and uninstall the pip
version to avoid confusion).
For more information about why in-place operations can cause problems, see this blog post or this section of the PyTorch docs.