Search code examples
machine-learningpytorchdcgan

Need help understanding the label input in a CGAN


I am trying to implement a CGAN. I understand that in convolutional generator and discriminator models, you add volume to the inputs by adding depth that represents the label. So if you have 10 classes in your data, your generator and discriminator would both have the base depth + 10 as its input volume.

However, I am reading various implementations online and I can't seem to find where they are actually acquiring this label. Surely CGANs can't be unsupervised because you need to obtain the label to input. e.g. in cifar10 if you are training the discriminator on a real image of a frog, you would need the 'frog' annotation.

Here is one of the pieces of code I am studying:

class CGAN(object):
def __init__(self, args):
    # parameters
    self.epoch = args.epoch
    self.batch_size = args.batch_size
    self.save_dir = args.save_dir
    self.result_dir = args.result_dir
    self.dataset = args.dataset
    self.log_dir = args.log_dir
    self.gpu_mode = args.gpu_mode
    self.model_name = args.gan_type
    self.input_size = args.input_size
    self.z_dim = 62
    self.class_num = 10
    self.sample_num = self.class_num ** 2

    # load dataset
    self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
    data = self.data_loader.__iter__().__next__()[0]

    # networks init
    self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size, class_num=self.class_num)
    self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size, class_num=self.class_num)
    self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
    self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))

    if self.gpu_mode:
        self.G.cuda()
        self.D.cuda()
        self.BCE_loss = nn.BCELoss().cuda()
    else:
        self.BCE_loss = nn.BCELoss()

    print('---------- Networks architecture -------------')
    utils.print_network(self.G)
    utils.print_network(self.D)
    print('-----------------------------------------------')

    # fixed noise & condition
    self.sample_z_ = torch.zeros((self.sample_num, self.z_dim))
    for i in range(self.class_num):
        self.sample_z_[i*self.class_num] = torch.rand(1, self.z_dim)
        for j in range(1, self.class_num):
            self.sample_z_[i*self.class_num + j] = self.sample_z_[i*self.class_num]

    temp = torch.zeros((self.class_num, 1))
    for i in range(self.class_num):
        temp[i, 0] = i

    temp_y = torch.zeros((self.sample_num, 1))
    for i in range(self.class_num):
        temp_y[i*self.class_num: (i+1)*self.class_num] = temp

    self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
    if self.gpu_mode:
        self.sample_z_, self.sample_y_ = self.sample_z_.cuda(), self.sample_y_.cuda()

def train(self):
    self.train_hist = {}
    self.train_hist['D_loss'] = []
    self.train_hist['G_loss'] = []
    self.train_hist['per_epoch_time'] = []
    self.train_hist['total_time'] = []

    self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
    if self.gpu_mode:
        self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()

    self.D.train()
    print('training start!!')
    start_time = time.time()
    for epoch in range(self.epoch):
        self.G.train()
        epoch_start_time = time.time()
        for iter, (x_, y_) in enumerate(self.data_loader):
            if iter == self.data_loader.dataset.__len__() // self.batch_size:
                break

            z_ = torch.rand((self.batch_size, self.z_dim))
            y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
            y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size)
            if self.gpu_mode:
                x_, z_, y_vec_, y_fill_ = x_.cuda(), z_.cuda(), y_vec_.cuda(), y_fill_.cuda()

            # update D network
            self.D_optimizer.zero_grad()

            D_real = self.D(x_, y_fill_)
            D_real_loss = self.BCE_loss(D_real, self.y_real_)

            G_ = self.G(z_, y_vec_)
            D_fake = self.D(G_, y_fill_)
            D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)

            D_loss = D_real_loss + D_fake_loss
            self.train_hist['D_loss'].append(D_loss.item())

            D_loss.backward()
            self.D_optimizer.step()

            # update G network
            self.G_optimizer.zero_grad()

            G_ = self.G(z_, y_vec_)
            D_fake = self.D(G_, y_fill_)
            G_loss = self.BCE_loss(D_fake, self.y_real_)
            self.train_hist['G_loss'].append(G_loss.item())

            G_loss.backward()
            self.G_optimizer.step()

It seems as if y_vec_ and y_fill_ are the labels for the images, but in the instance of y_fill_ which is used to label real images for the discriminator, it equals y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size)

It doesn't seem like its getting any information on the label from the dataset? How is it giving the discriminator the correct label?

Thanks!


Solution

  • y_fill_ is based on y_vec_ which is based on y_, so they are reading label info from mini batches which is correct. you might be confused by the scatter operation, basically what the code's doing is transferring the label into one-hot encoding