I try to train an WGAN with LSTMs as critic and generator for immage generation on the MNIST dataset. Unfortunately I keep on running into the error message:
NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented. Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example:
with torch.backends.cudnn.flags(enabled=False):
output = model(inputs)
I have trouble to understand this error message, since don't think I perform Double backwards
operations.
Could you help me understand where this error message comes from, and how to resolve it?
Here are the relevant parts of my Implementation:
Critic
class LSTM_Critic(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(LSTM_Critic, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(IMG_SIZE, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x, labels):
# Set initial hidden and cell states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_().to(device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_().to(device)
# Passing in the input and hidden state into the model and obtaining outputs
x = x.reshape(BATCH_SIZE, IMG_SIZE, IMG_SIZE)
out, hidden = self.lstm(x, (h0.detach(), c0.detach())) # out: tensor of shape (batch_size, seq_length, hidden_size)
#Reshaping the outputs such that it can be fit into the fully connected layer
out = self.fc(out[:, -1, :])
return out
Initialisation
gen = LSTM_Generator(200, 100, num_layers, num_classes).to(device)
critic = LSTM_Critic(input_size, hidden_size, num_layers, num_classes).to(device)
initialize_weights(gen)
initialize_weights(critic)
# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
gen.train()
critic.train()
Training
for epoch in range(NUM_EPOCHS):
for batch_idx, (real, labels) in enumerate(tqdm(loader)):
real = real.to(device)
cur_batch_size = real.shape[0]
labels = labels.to(device)
# Train Critic: max E[critic(real)] - E[critic(fake)]
# equivalent to minimizing the negative of that
for _ in range(CRITIC_ITERATIONS):
noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
fake = gen(noise, labels)
critic_real = critic(real, labels).reshape(-1)
critic_fake = critic(fake, labels).reshape(-1)
gp = gradient_penalty(critic, labels, real, fake, device=device)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
)
critic.zero_grad()
loss_critic.backward(retain_graph=True)
opt_critic.step()
Gradient Penalty
def gradient_penalty(critic, labels, real, fake, device="cpu"):
BATCH_SIZE, C, H, W = real.shape
alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * alpha + fake * (1 - alpha)
# Calculate critic scores
mixed_scores = critic(interpolated_images, labels)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.reshape(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
The code fails at the line loss_critic.backward(retain_graph=True)
ASAIK, WGAN GP cannot work with RNNs because of cudnn limitations.
One workaround is to write JIT fused RNNs. See https://github.com/pytorch/pytorch/issues/5261#issuecomment-687330144
Another solution is to disable cudnn when forward passing as prompted but it will be on cpu and very slow.
with torch.backends.cudnn.flags(enabled=False):
output = model(inputs)