Search code examples
pytorchgenerative-adversarial-network

Debugging the discriminator implementation for a GAN


I'm having issues while including the discriminator to an implementation of an SRGAN. While training on the Flickr dataset, I see the discriminator fails to learn anything early on (with the BCELoss showing a value of 100) and never to recover. I played around it a bit and removed the sigmoid in the hopes of using BCEWithLogits as the loss. This led to the loss varying wildly in the beginning and getting to zero.

What is a good method to debug the discriminator implementation? I suspect way I'm calling the discriminator in training to have an issue.

class DiscriminatorConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super(DiscriminatorConvBlock, self).__init__()
        num_groups = 8
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 
                                   nn.GroupNorm(num_groups, out_channels),
                                   nn.LeakyReLU(0.2, False),
                                 )
    def forward(self, x):
        out = self.conv1(x)
        return out

class Discriminator(nn.Module):
    def __init__(self, low_res_dim):
        super(Discriminator, self).__init__()
        img_d = int(low_res_dim / 4)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), 
                                  nn.LeakyReLU(0.2, False),
                                 )
        self.conv2 = DiscriminatorConvBlock(64, 64, 2)
        self.conv3 = DiscriminatorConvBlock(64, 128, 1)
        self.conv4 = DiscriminatorConvBlock(128, 128, 2)
        self.conv5 = DiscriminatorConvBlock(128, 256, 1)
        self.conv6 = DiscriminatorConvBlock(256, 256, 2)
        self.conv7 = DiscriminatorConvBlock(256, 512, 1)
        self.conv8 = DiscriminatorConvBlock(512, 512, 2)

        self.dense1 = nn.Linear(512 * img_d * img_d , 1024)
        self.leakyRelu = nn.LeakyReLU(0.2, False)
        self.dense2 = nn.Linear(1024 , 1)
        self.drop = nn.Dropout(0.3)
        
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        out = self.conv7(out)
        out = self.conv8(out)
        out = out.view(-1, out.size(1) * out.size(2) * out.size(3))
        out = self.leakyRelu(self.dense1(out))
        out = self.dense2(out)
        out = torch.clamp_(out, 0.0, 1.0)
        return out


gen_model = Generator().to(device)
disc_model = Discriminator(low_res).to(device)

# VGG terms
vgg = models.vgg19(pretrained=True).to(device)
feature_nodes = ["features.35"]
feature_extractor = create_feature_extractor(vgg, feature_nodes)
feature_extractor_nodes = feature_nodes
normalizeT = transforms.Normalize([ 0.5, 0.5, 0.5 ], [ 0.5, 0.5, 0.5 ])

for model_parameters in feature_extractor.parameters():
    model_parameters.requires_grad = False
feature_extractor.eval()

gen_optimizer = optim.Adam(gen_model.parameters(),lr=1e-4)
disc_optimizer = optim.Adam(disc_model.parameters(),lr=1e-5)
gen_scheduler = CosineAnnealingWarmRestarts(gen_optimizer, 
                                        T_0 = 8,# Number of iterations for the first restart
                                        T_mult = 1, # A factor increases TiTi​ after a restart
                                        eta_min = 1e-5) # Minimum learning rate
disc_scheduler = CosineAnnealingWarmRestarts(disc_optimizer, 
                                        T_0 = 8,# Number of iterations for the first restart
                                        T_mult = 1, # A factor increases TiTi​ after a restart
                                        eta_min = 1e-6) # Minimum learning rate
mse_loss = nn.MSELoss()
vgg_loss = nn.MSELoss()
disc_loss = nn.BCEWithLogitsLoss()
disc_loss_generator = nn.BCEWithLogitsLoss()
gen_optimizer.zero_grad()
for epoch in range(num_epochs):
    gen_scheduler.step()
    disc_scheduler.step()
    for i, data in enumerate(tqdm.tqdm(dataloader)):
        input_images, labels = data
        # forward pass
        input_images = input_images.to(device)
            
        lowres_images = transforms.Resize(low_res)(input_images)
        gen_highres_images = gen_model(lowres_images.to(device))

        for model_parameters in disc_model.parameters():
            model_parameters.requires_grad = True
        # Discriminator
        disc_model.zero_grad()
        actual_label = disc_model(input_images.to(device))
        # Adversarial loss
        d2_loss = (disc_loss(actual_label, torch.ones_like(actual_label,dtype=torch.float)))
        d2_loss.backward()

        generated_label = disc_model(gen_highres_images.to(device))
        d1_loss = (disc_loss(generated_label, torch.zeros_like(generated_label,dtype=torch.float)))
        d1_loss.backward(retain_graph=True)

        errD = d2_loss + d1_loss
        disc_optimizer.step()

        gen_model.zero_grad() 
        # Perceptual loss
        mse = mse_loss(normalizeT(gen_highres_images), normalizeT(input_images))

        vgg_losses = []
        sr_feature = feature_extractor(normalizeT(input_images))
        gt_feature = feature_extractor(normalizeT(gen_highres_images))
        for i in range(len(feature_extractor_nodes)):
            vgg_losses.append(vgg_loss(sr_feature[feature_extractor_nodes[i]],
                                           gt_feature[feature_extractor_nodes[i]]))

        for model_parameters in disc_model.parameters():
            model_parameters.requires_grad = False
        actual_generated_label = disc_model(gen_highres_images.to(device))
        gen_disc_loss = disc_loss_generator(actual_generated_label, torch.ones_like(actual_label,dtype=torch.float))
        generator_loss = vgg_losses[0] +  mse  + gen_disc_loss
        generator_loss.backward()
        gen_optimizer.step()
        gen_optimizer.zero_grad()
        torch.cuda.empty_cache()

Solution

  • The problem is out = torch.clamp_(out, 0.0, 1.0). This doesn't make sense with what you want the model to do and the loss you are using.

    BCEWithLogitsLoss applies a sigmoid. The sigmoid function returns [0.5, 1] on the range [0, 1] you are clamping the output to. You are essentially forcing the model to predict class 1 with >=50% confidence for every example.