Search code examples
deep-learningpytorchgenerative-adversarial-network

Runtime error on WGan-gp algorithm when running on GPU


I am a newbie in pytorch and running the WGan-gp algorithm on google colab using GPU runtime. I encountered the error below. The algorithm works fine when at None runtime i.e cpu.

Error generated during training

0%|          | 0/3 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-18-7e1d4849a60a> in <module>
     19             # Calculate gradient penalty on real and fake images
     20             # generated by generator
---> 21             gp = gradient_penalty(netCritic, real_image, fake, device)
     22             critic_loss = -(torch.mean(critic_real_pred)
     23                             - torch.mean(critic_fake_pred)) + LAMBDA_GP * gp

<ipython-input-15-f84354d74f37> in gradient_penalty(netCritic, real_image, fake_image, device)
      8     # image
      9     # interpolated image ← alpha *real image  + (1 − alpha) fake image
---> 10     interpolated_image = (alpha*real_image) + (1-alpha) * fake_image
     11 
     12     # calculate the critic score on the interpolated image

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Snippet of my WGan-gp code

def gradient_penalty(netCritic, real_image, fake_image, device=device):

    batch_size, channel, height, width = real_image.shape

    # alpha is selected randomly between 0 and 1
    alpha = torch.rand(batch_size, 1, 1, 1).repeat(1, channel, height, width)
    # interpolated image=randomly weighted average between a real and fake
    # image
    # interpolated image ← alpha *real image  + (1 − alpha) fake image
    interpolated_image = (alpha*real_image) + (1-alpha) * fake_image
    
    # calculate the critic score on the interpolated image
    interpolated_score = netCritic(interpolated_image)

    # take the gradient of the score wrt to the interpolated image
    gradient = torch.autograd.grad(inputs=interpolated_image,
                                   outputs=interpolated_score,
                                   retain_graph=True,
                                   create_graph=True,
                                   grad_outputs=torch.ones_like
                                   (interpolated_score)
                                   )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1)**2)
    return gradient_penalty


n_epochs = 2000
cur_step = 0
LAMBDA_GP = 10
display_step = 50
CRITIC_ITERATIONS = 5
nz = 100

for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real_image, _ in tqdm(dataloader):
        cur_batch_size = real_image.shape[0]
        real_image = real_image.to(device)
        for _ in range(CRITIC_ITERATIONS):
            fake_noise = get_noise(cur_batch_size, nz, device=device)
            fake = netG(fake_noise)
            critic_fake_pred = netCritic(fake).reshape(-1)
            critic_real_pred = netCritic(real_image).reshape(-1)

            # Calculate gradient penalty on real and fake images
            # generated by generator
            gp = gradient_penalty(netCritic, real_image, fake, device)
            critic_loss = -(torch.mean(critic_real_pred)
                            - torch.mean(critic_fake_pred)) + LAMBDA_GP * gp
            netCritic.zero_grad()
            # To make a backward pass and retain the intermediary results
            critic_loss.backward(retain_graph=True)
            optimizerCritic.step()

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = netCritic(fake).reshape(-1)
        gen_loss = -torch.mean(gen_fake)
        netG.zero_grad()
        gen_loss.backward()
        # Update optimizer
        optimizerG.step()

        # Visualization code ##
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Step{cur_step}: GenLoss: {gen_loss}: CLoss: {critic_loss}")
            display_images(fake)
            display_images(real_image)
            gen_loss = 0
            critic_loss = 0
        cur_step += 1


        

I tried to introduce cuda() at the lines 10 and 21 indicated in the error output.But not working.


Solution

  • Here is one approach to solve this kind of error:

    1. Read the error message and locate the exact line where it occured:

      ... in gradient_penalty(netCritic, real_image, fake_image, device)
            8     # image
            9     # interpolated image ← alpha *real image  + (1 − alpha) fake image
      ---> 10     interpolated_image = (alpha*real_image) + (1-alpha) * fake_image
           11 
           12     # calculate the critic score on the interpolated image
      
      RuntimeError: Expected all tensors to be on the same device, 
                    but found at least two devices, cuda:0 and cpu!
      
    2. Look for input tensors that have not been properly transferred to the correct device. Then look for intermediate tensors that have not been transferred.
      Here alpha is assigned to a random tensor but no transfer is done!

      >>> alpha = torch.rand(batch_size, 1, 1, 1) \
                       .repeat(1, channel, height, width)
      
    3. Fix the issue and test:

      >>> alpha = torch.rand(batch_size, 1, 1, 1, device=fake_image.device) \
                       .repeat(1, channel, height, width)