I want to use WGAN-GP, and when I run the code, it gives me an error:
def calculate_gradient_penalty(real_images, fake_images):
t = torch.rand(real_images.size(0), 1, 1, 1).to(real_images.device)
t = t.expand(real_images.size())
interpolates = t * real_images + (1 - t) * fake_images
interpolates.requires_grad_(True)
disc_interpolates = D(interpolates)
grad = torch.autograd.grad(
outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True, retain_graph=True, allow_unused=True)[0]
grad_norm = torch.norm(torch.flatten(grad, start_dim=1), dim=1)
loss_gp = torch.mean((grad_norm - 1) ** 2) * lambda_term
return loss_gp
RuntimeError Traceback (most recent call last) in
/opt/conda/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph, inputs) 243 create_graph=create_graph, 244 inputs=inputs) --> 245 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs) 246 247 def register_hook(self, hook):
/opt/conda/lib/python3.8/site-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 143 retain_graph = create_graph 144 --> 145 Variable.execution_engine.run_backward( 146 tensors, grad_tensors, retain_graph, create_graph, inputs, 147 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 2; 15.75 GiB total capacity; 13.76 GiB already allocated; 2.75 MiB free; 14.50 GiB reserved in total by PyTorch)
The train code:
%%time
d_progress = []
d_fake_progress = []
d_real_progress = []
penalty = []
g_progress = []
data = get_infinite_batches(benign_data_loader)
one = torch.FloatTensor([1]).to(device)
mone = (one * -1).to(device)
for g_iter in range(generator_iters):
print('----------G Iter-{}----------'.format(g_iter+1))
for p in D.parameters():
p.requires_grad = True # This is by Default
d_loss_real = 0
d_loss_fake = 0
Wasserstein_D = 0
for d_iter in range(critic_iter):
D.zero_grad()
images = data.__next__()
if images.size()[0] != batch_size:
continue
# Train Discriminator
# Real Images
images = images.to(device)
z = torch.randn(batch_size, 100, 1, 1).to(device)
d_loss_real = D(images)
d_loss_real = d_loss_real.mean(0).view(1)
d_loss_real.backward(mone)
# Fake Images
fake_images = G(z)
d_loss_fake = D(fake_images)
d_loss_fake = d_loss_fake.mean(0).view(1)
d_loss_fake.backward(one)
# Calculate Penalty
gradient_penalty = calculate_gradient_penalty(images.data, fake_images.data)
gradient_penalty.backward()
# Total Loss
d_loss = d_loss_fake - d_loss_real + gradient_penalty
Wasserstein_D = d_loss_real - d_loss_fake
d_optimizer.step()
print(f'D Iter:{d_iter+1}/{critic_iter} Loss:{d_loss.detach().cpu().numpy()}')
time.sleep(0.1)
d_progress.append(d_loss) # Store Loss
d_fake_progress.append(d_loss_fake)
d_real_progress.append(d_loss_real)
penalty.append(gradient_penalty)
# Generator Updata
for p in D.parameters():
p.requires_grad = False # Avoid Computation
# Train Generator
# Compute with Fake
G.zero_grad()
z = torch.randn(batch_size, 100, 1, 1).to(device)
fake_images = G(z)
g_loss = D(fake_images)
g_loss = g_loss.mean().mean(0).view(1)
g_loss.backward(one)
# g_cost = -g_loss
g_optimizer.step()
print(f'G Iter:{g_iter+1}/{generator_iters} Loss:{g_loss.detach().cpu().numpy()}')
g_progress.append(g_loss) # Store Loss
Does anyone know how to solve this problem?
All loss tensors which are saved outside of the optimization cycle (i.e. outside the for g_iter in range(generator_iters)
loop) need to be detached from the graph. Otherwise, you are keeping all previous computation graphs in memory.
As such, you should detach anything that gets appended to d_progress
, d_fake_progress
, d_real_progress
, penalty
, and g_progress
.
You can do so by converting the tensor to a scalar value with torch.Tensor.item
, the graph will free itself on the following iteration. Change the following lines:
d_progress.append(d_loss) # Store Loss
d_fake_progress.append(d_loss_fake)
d_real_progress.append(d_loss_real)
penalty.append(gradient_penalty)
#######
g_progress.append(g_loss) # Store Loss
to:
d_progress.append(d_loss.item()) # Store Loss
d_fake_progress.append(d_loss_fake.item())
d_real_progress.append(d_loss_real.item())
penalty.append(gradient_penalty.item())
#######
g_progress.append(g_loss.item()) # Store Loss