Search code examples
pytorchdimensionsgenerative-adversarial-network

RuntimeError: Expected 4-dimensional input for 4-dimensional weight 128 256, but got 2-dimensional input of size [32, 128] instead


I am working on creating an image generator using conditional GAN as the base model. I've run across an error that I don't understand how to debug, even after searching for solutions online. I'm not sure if I should change the settings for training or do some adjustment to my model, or something else. Any help on what to do would be appreciated.

The CGAN model I am using:

class Generator(nn.Module):
    def __init__(self, classes, channels, img_size, latent_dim):
        super(Generator, self).__init__()
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.classes, self.classes) # process label information, behave as a lookup table

        self.model = nn.Sequential(
            *self._create_layer_1(self.latent_dim + self.classes, 128, False),
            *self._create_layer_2(128, 256),
            *self._create_layer_2(256, 512),
            *self._create_layer_2(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
        )

    def _create_layer_1(self, size_in, size_out, normalize=True):
        layers = [nn.Linear(size_in, size_out)]
        if normalize:
            layers.append(nn.BatchNorm1d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def _create_layer_2(self, size_in, size_out, normalize=True):
        layers = [nn.ConvTranspose2d(size_in, size_out, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm1d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, noise, labels):
        z = torch.cat((self.label_embedding(labels), noise), -1)
        x = self.model(z)
        x = x.view(x.size(0), *self.img_shape)
        return x


class Discriminator(nn.Module):
    def __init__(self, classes, channels, img_size, latent_dim):
        super(Discriminator, self).__init__()
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.img_shape = (self.channels, self.img_size, self.img_size)
        self.label_embedding = nn.Embedding(self.classes, self.classes)
        self.adv_loss = torch.nn.BCELoss()

        self.model = nn.Sequential(
            *self._create_layer_1(self.classes + int(np.prod(self.img_shape)), 1024, False, True),
            *self._create_layer_2(1024, 512, True, True),
            *self._create_layer_2(512, 256, True, True),
            *self._create_layer_2(256, 128, False, False),
            *self._create_layer_1(128, 1, False, False),
            nn.Sigmoid()
        )

    def _create_layer_1(self, size_in, size_out, drop_out=True, act_func=True):
        layers = [nn.Linear(size_in, size_out)]
        if drop_out:
            layers.append(nn.Dropout(0.4))
        if act_func:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def _create_layer_2(self, size_in, size_out, drop_out=True, act_func=True):
        layers = [nn.Conv2d(size_in, size_out, 4, 2, 1, bias=False)]
        if drop_out:
            layers.append(nn.Dropout(0.4))
        if act_func:
            layers.append(nn.LeakyReLU(0.2, inplace=True))
        return layers

    def forward(self, image, labels):
        x = torch.cat((image.view(image.size(0), -1), self.label_embedding(labels)), -1)
        return self.model(x)

    def loss(self, output, label):
        return self.adv_loss(output, label)

Code for initializing the model:

class Model(object):
    def __init__(self,
                 name,
                 device,
                 data_loader,
                 classes,
                 channels,
                 img_size,
                 latent_dim,
                 style_dim=3):
        self.name = name
        self.device = device
        self.data_loader = data_loader
        self.classes = classes
        self.channels = channels
        self.img_size = img_size
        self.latent_dim = latent_dim
        self.style_dim = style_dim
        self.netG = cganG(self.classes, self.channels, self.img_size, self.latent_dim)
        self.netG.to(self.device)
        self.netD = cganD(self.classes, self.channels, self.img_size, self.latent_dim)
        self.netD.to(self.device)
        self.optim_G = None
        self.optim_D = None

    @property
    def generator(self):
        return self.netG

    @property
    def discriminator(self):
        return self.netD

    def create_optim(self, lr, alpha=0.5, beta=0.999):
        self.optim_G = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        self.netG.parameters()),
                                        lr=lr,
                                        betas=(alpha, beta))
        self.optim_D = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        self.netD.parameters()),
                                        lr=lr,
                                        betas=(alpha, beta))

    def _to_onehot(self, var, dim):
        res = torch.zeros((var.shape[0], dim), device=self.device)
        res[range(var.shape[0]), var] = 1.
        return res

    def train(self,
              epochs,
              log_interval=100,
              out_dir='',
              verbose=True):
        self.netG.train()
        self.netD.train()
        viz_z = torch.zeros((self.data_loader.batch_size, self.latent_dim), device=self.device)
        viz_noise = torch.randn(self.data_loader.batch_size, self.latent_dim, device=self.device)
        nrows = self.data_loader.batch_size // 8
        viz_label = torch.LongTensor(np.array([num for _ in range(nrows) for num in range(8)])).to(self.device)
        viz_onehot = self._to_onehot(viz_label, dim=self.classes)
        viz_style = torch.zeros((self.data_loader.batch_size, self.style_dim), device=self.device)
        total_time = time.time()
        for epoch in range(epochs):
            batch_time = time.time()
            for batch_idx, (data, target) in enumerate(self.data_loader):
                data, target = data.to(self.device), target.to(self.device)
                batch_size = data.size(0)
                real_label = torch.full((batch_size, 1), 1., device=self.device)
                fake_label = torch.full((batch_size, 1), 0., device=self.device)

                # Train G
                self.netG.zero_grad()
                z_noise = torch.randn(batch_size, self.latent_dim, device=self.device)
                x_fake_labels = torch.randint(0, self.classes, (batch_size,), device=self.device)
                x_fake = self.netG(z_noise, x_fake_labels)
                y_fake_g = self.netD(x_fake, x_fake_labels)
                g_loss = self.netD.loss(y_fake_g, real_label)
                g_loss.backward()
                self.optim_G.step()

                # Train D
                self.netD.zero_grad()
                y_real = self.netD(data, target)
                d_real_loss = self.netD.loss(y_real, real_label)
                y_fake_d = self.netD(x_fake.detach(), x_fake_labels)
                d_fake_loss = self.netD.loss(y_fake_d, fake_label)
                d_loss = (d_real_loss + d_fake_loss) / 2
                d_loss.backward()
                self.optim_D.step()

                if verbose and batch_idx % log_interval == 0 and batch_idx > 0:
                    print('Epoch {} [{}/{}] loss_D: {:.4f} loss_G: {:.4f} time: {:.2f}'.format(
                            epoch, batch_idx, len(self.data_loader),
                            d_loss.mean().item(),
                            g_loss.mean().item(),
                            time.time() - batch_time))
                    vutils.save_image(data, os.path.join(out_dir, 'real_samples.png'), normalize=True)
                    with torch.no_grad():
                        viz_sample = self.netG(viz_noise, viz_label)
                        vutils.save_image(viz_sample, os.path.join(out_dir, 'fake_samples_{}.png'.format(epoch)), nrow=8, normalize=True)
                    batch_time = time.time()   
            
            torch.save(self.netG.state_dict(), os.path.join(out_dir, 'netG_{}.pth'.format(epoch)))
            torch.save(self.netD.state_dict(), os.path.join(out_dir, 'netD_{}.pth'.format(epoch)))

            self.save_to(path=out_dir, name=self.name, verbose=False)
        if verbose:
            print('Total train time: {:.2f}'.format(time.time() - total_time))

Code for setting up the training:

def main():
    device = torch.device("cuda:0" if FLAGS.cuda else "cpu")
    if FLAGS.train:
        dataloader = torch.utils.data.DataLoader(
            dset.ImageFolder(FLAGS.data_dir, transforms.Compose([
                transforms.Resize(FLAGS.img_size),
                transforms.CenterCrop(FLAGS.img_size),
                transforms.ToTensor()
                ])),
                batch_size=FLAGS.batch_size,
                shuffle=True, 
                num_workers=4, 
                pin_memory=True
                )
        model = Model(FLAGS.model, device, dataloader, FLAGS.classes, FLAGS.channels, FLAGS.img_size, FLAGS.latent_dim)
        model.create_optim(FLAGS.lr)

        # Train
        print("Start training...\n")
        model.train(FLAGS.epochs, FLAGS.log_interval, FLAGS.out_dir, True)

if __name__ == '__main__':
    from utils import boolean_string
    parser.add_argument('--cuda', type=boolean_string, default=True, help='enable CUDA.')
    parser.add_argument('--train', type=boolean_string, default=True, help='train mode or eval mode.')
    parser.add_argument('--data_dir', type=str, default='../datasets', help='Directory for dataset.')
    parser.add_argument('--out_dir', type=str, default='output', help='Directory for output.')
    parser.add_argument('--epochs', type=int, default=800, help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=32, help='size of batches')
    parser.add_argument('--lr', type=float, default=0.0002, help='learning rate')
    parser.add_argument('--latent_dim', type=int, default=62, help='latent space dimension')
    parser.add_argument('--classes', type=int, default=25, help='number of classes')
    parser.add_argument('--img_size', type=int, default=128, help='size of images')
    parser.add_argument('--channels', type=int, default=3, help='number of image channels')

Settings:

PyTorch version: 1.1.0
CUDA version: 9.0.176

         Args         |    Type    |    Value
--------------------------------------------------
  cuda                |  bool      |  True
  train               |  bool      |  True
  resume              |  bool      |  False
  data_dir            |  str       |  ../datasets
  out_dir             |  str       |  output
  epochs              |  int       |  800
  batch_size          |  int       |  32
  lr                  |  float     |  0.0002
  latent_dim          |  int       |  62
  classes             |  int       |  25
  img_size            |  int       |  128
  channels            |  int       |  3

Image size as input:

torch.Size([32, 3, 128, 128])

Model structure:

Generator(
  (label_embedding): Embedding(25, 25)
  (model): Sequential(
    (0): Linear(in_features=87, out_features=128, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): ConvTranspose2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): ConvTranspose2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): ConvTranspose2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace)
    (11): Linear(in_features=1024, out_features=49152, bias=True)
    (12): Tanh()
  )
)

Discriminator(
  (label_embedding): Embedding(25, 25)
  (adv_loss): BCELoss()
  (model): Sequential(
    (0): Linear(in_features=49177, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Conv2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): Dropout(p=0.4)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Conv2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): Dropout(p=0.4)
    (7): LeakyReLU(negative_slope=0.2, inplace)
    (8): Conv2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): Linear(in_features=128, out_features=1, bias=True)
    (10): Sigmoid()
  )
)

The error I got:

File "main.py", line 121, in <module>
    main()
  File "main.py", line 56, in main
    model.train(FLAGS.epochs, FLAGS.log_interval, FLAGS.out_dir, True)
  File "build_gan.py", line 123, in train
    x_fake = self.netG(z_noise, x_fake_labels)
  File "anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "cgan.py", line 42, in forward
    x = self.model(z)
  File "anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "anaconda3/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "anaconda3/lib/python3.6/site-packages/torch/nn/modules/conv.py", line 796, in forward
    output_padding, self.groups, self.dilation)
RuntimeError: Expected 4-dimensional input for 4-dimensional weight 128 256, but got 2-dimensional input of size [32, 128] instead

I am using my own image dataset with 3 channels and 25 classes. I have tried to change the image size and kernel size but still got the same error. Any help on what should I do to debug would be highly appreciated.


Solution

  • The issue is actually with your model architecture. You are trying to place a conv2d layer just after a linear fully connected layer.The _create_layer_1 produces a 1d output. You are trying to feed this 1d output to a conv2d layer which expects a multidimensional input.

    From your code the best thing I feel to make it work in a single go would be to remove "_create_layer_2" function completely from generator class and use _create_layer_1 function to define all your layers(so that all layers are fully connected layers). Also do this for your discriminator

    If you still need to use conv2d. You should reshape the input to conv2d to a 2d tensor. Also you have to flatten the 2d tensor to 1d before your final linear layer. Or you could ditch the first linear nn.Linear layer and start with conv2d altogether.

    To summarise As you are designing GANs you might have experience developing CNNs. The point is you dont simply mixup conv2d/conv layers with linear layers without using proper flatten/reshape.

    Cheers