Search code examples
pythonpytorchlstmrecurrent-neural-network

How to Relsove the derivative for '_cudnn_rnn_backward' is not implemented for WGAN-LSTM with gradient penalty


I try to train an WGAN with LSTMs as critic and generator for immage generation on the MNIST dataset. Unfortunately I keep on running into the error message:

NotImplementedError: the derivative for '_cudnn_rnn_backward' is not implemented. Double backwards is not supported for CuDNN RNNs due to limitations in the CuDNN API. To run double backwards, please disable the CuDNN backend temporarily while running the forward pass of your RNN. For example: 
with torch.backends.cudnn.flags(enabled=False):
   output = model(inputs)

I have trouble to understand this error message, since don't think I perform Double backwards operations. Could you help me understand where this error message comes from, and how to resolve it?

Here are the relevant parts of my Implementation:

Critic

class LSTM_Critic(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTM_Critic, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(IMG_SIZE, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)
    
    def forward(self, x, labels):
        # Set initial hidden and cell states 
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_().to(device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_().to(device)

        # Passing in the input and hidden state into the model and  obtaining outputs
        x = x.reshape(BATCH_SIZE, IMG_SIZE, IMG_SIZE)
        out, hidden = self.lstm(x, (h0.detach(), c0.detach()))  # out: tensor of shape (batch_size, seq_length, hidden_size)
        #Reshaping the outputs such that it can be fit into the fully connected layer
        
        out = self.fc(out[:, -1, :])
        return out

Initialisation

gen = LSTM_Generator(200, 100, num_layers, num_classes).to(device)
critic = LSTM_Critic(input_size, hidden_size, num_layers, num_classes).to(device)

initialize_weights(gen)
initialize_weights(critic)

# initializate optimizer
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.9))

gen.train()
critic.train()

Training

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, labels) in enumerate(tqdm(loader)):
        real = real.to(device)
        cur_batch_size = real.shape[0]
        labels = labels.to(device)

        # Train Critic: max E[critic(real)] - E[critic(fake)]
        # equivalent to minimizing the negative of that
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(cur_batch_size, Z_DIM, 1, 1).to(device)
            fake = gen(noise, labels)
            critic_real = critic(real, labels).reshape(-1)
            critic_fake = critic(fake, labels).reshape(-1)

            gp = gradient_penalty(critic, labels, real, fake, device=device)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

Gradient Penalty

def gradient_penalty(critic, labels, real, fake, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake * (1 - alpha)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, labels)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]



    gradient = gradient.reshape(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

The code fails at the line loss_critic.backward(retain_graph=True)


Solution

  • ASAIK, WGAN GP cannot work with RNNs because of cudnn limitations.

    One workaround is to write JIT fused RNNs. See https://github.com/pytorch/pytorch/issues/5261#issuecomment-687330144

    Another solution is to disable cudnn when forward passing as prompted but it will be on cpu and very slow.

    with torch.backends.cudnn.flags(enabled=False):
        output = model(inputs)