Search code examples
pythondeep-learningneural-networkpytorchgenerative-adversarial-network

How to initialise (and check sanity) weights efficiently of layers within complex (nested) modules in PyTorch?


Looking for an efficient way to access nested Modules and Layers to set the weights

I am replicating the DCGAN Paper and my code works as expected. I found out that in the paper, the authors said that:

All weights were initialized from a zero-centered Normal distribution with standard deviation 0.02

This awesome answer explains that it can be done using torch.nn.init.normal_(nn.Conv2d(1,1,1, 1,1 ).weight.data, 0.0, 0.02) but I have complex structure using ModuleList and others. What is the most efficient way of doing this?

By Complex, please look at the code below for my implementation:

'''
Implement the Deep Convolution Gan AKA DCGAN in Pytorch: Paper at https://arxiv.org/pdf/1511.06434v2.pdf
'''
import torch
import torch.nn as nn


class GeneratorBlock(nn.Module):
    '''
    Generator Block uses TransposedConv2D -> Batch Norm (except LAST block) -> Relu
    Note: kernel_size = 4, stride = 2, padding = 1 is used in the paper. When BatchNorm is used, Bias is not used for Conv2D
    '''
    def __init__(self, in_channels, out_channels, kernel_size = 4, stride = 2, padding = 1, use_batchnorm:bool = True):
        super().__init__()
        self.use_batchnorm = use_batchnorm
        self.transpose_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = kernel_size, stride=stride, padding=padding, bias = not self.use_batchnorm)
        self.batch_norm = nn.BatchNorm2d(out_channels) if self.use_batchnorm else None
        self.activation = nn.ReLU() # Paper uses Relu in Generator Network
    
    def forward(self, x):
        x = self.transpose_conv(x)
        return self.activation(self.batch_norm(x)) if self.use_batchnorm else self.activation(x)


class Generator(nn.Module):
    '''
    Generate Images using Transposed Convolution. Input is a random noise of [Batch, 100, 1,1] Dimension and then upsampled
    '''
    def __init__(self, input_features = 100, base_feature = 128, final_channels:int = 1):
        '''
        We use nn.Sequantial here just to show the workings. If you want to make the layers dynamically using a loop, find nn.ModuleList() in the Descriminator block. Both works same
        So we'll use 'base_feature = 64' as a base for input and output channels
        args:
            input_features: The shape of Random Noise from which an image will be generated
            base_feature: The shape of feature map or number or channels which will act as out base. Other inputs and outputs will be calculated based on this
            final_channels: The channels / features which will be sent to the Discriminator as an input
        '''
        super(Generator, self).__init__()

        # in Descriminator, we do the same work using ModuleList(). Uses 4 blocks
        self.blocks = nn.Sequential(
            GeneratorBlock(in_channels = input_features, out_channels = base_feature * 8, stride = 1, padding = 0), # from Random Noise, Generate 1024 features
            GeneratorBlock(in_channels = base_feature * 8, out_channels = base_feature * 4), # 1024 -> 512 features
            GeneratorBlock(in_channels = base_feature * 4, out_channels = base_feature * 2), # 512 -> 256 features
            GeneratorBlock(in_channels = base_feature * 2, out_channels = base_feature), # 256 -> 128 features
            nn.ConvTranspose2d(base_feature, final_channels, kernel_size = 4, stride = 2, padding = 1)# 128 -> final feature. It is just GeneratorBlock without ReLu and BatchNorm ;)
        )
        self.activation = nn.Tanh() # To make the outputs between [-1,1]
    
    def forward(self, x):
        '''
        Takes Random Noise as input and Generte features from that
        '''
        return self.activation(self.blocks(x))
    

class DiscriminatorBlock(nn.Module):
    '''
    Discriminator Block uses Conv2D -> Batch Norm (except FIRST block) -> LeakyRelu
    Note: kernel_size = 4, stride = 2, padding = 1 is used in the paper. When BatchNorm is used, Bias is not used for Conv2D
    '''
    def __init__(self, in_channels, out_channels, kernel_size = 4, stride = 2, padding = 1, use_batchnorm:bool = True):
        super().__init__()
        self.use_batchnorm = use_batchnorm
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias = not self.use_batchnorm)
        self.batch_norm = nn.BatchNorm2d(out_channels) if self.use_batchnorm else None
        self.activation = nn.LeakyReLU(0.2)
    
    def forward(self, x):
        x = self.conv(x)
        return self.activation(self.batch_norm(x)) if self.use_batchnorm else self.activation(x)
    

class Discriminator(nn.Module):
    '''
    CNNs to classify whether the image generated by the Generator are as good as the real ones
    Feature Changes as :: 1 -> 64 -> 128 -> 256 -> 512 -> 1
    '''
    def __init__(self, input_features = 1, output_features = 1,  middle_features = [64,128,256]):
        '''
        In the paper, they take in a feature of [Batch, 1, 64, 64] from the Generator and then output a single number per sample in the batch
        '''
        super().__init__()
        self.layers = nn.ModuleList() # Just a fancy method of stacking layers using loop

        # in the paper, the first layer does not use BatchNorm
        self.layers.append(DiscriminatorBlock(input_features, middle_features[0], use_batchnorm = False)) #  1 -> 64 Because the input has 1 channel

        for i, channel in enumerate(middle_features): # total 4 blocks are used in paper. 1 has already been used in the line above. 3 blocks are these
            self.layers.append(DiscriminatorBlock(channel, channel*2)) # 64 -> 128 --- 128 -> 256 --- 256 -> 512

        self.final_conv = nn.Conv2d(in_channels = middle_features[-1]*2,  out_channels = output_features, kernel_size = 4, stride = 2,  padding = 0) # Input from previous layer 512 -> 1
        self.sigmoid_layer = nn.Sigmoid() # gives whether an image is real or fake or more precisely, how CLOSE is it to the real image

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        
        return self.sigmoid_layer(self.final_conv(x))


def test_DCGAN_code():
    noise = torch.rand(10,100,1,1)
    image = Generator()(noise)
    result = Discriminator()(image)
    print('Model Built Successfully!!! Generating 10 random samples and their end results')
    print(f"'Z' random Noise shape: {noise.shape} || Generator output shape: {image.shape} || Discriminator shape: {result.shape}")


Solution

  • You can simply iterate over all submodules, at the end of your __init__ method:

    class Generator(nn.Module):
      def __init__(self, ....):
        # all code here
        # ...
        # init weights, at the very bottom of __init__
       for sm in self.modules():
         if isinstance(sm, nn.Conv2d):
           # only conv2d will be initialized in this way
           torch.nn.init.normal_(sm.weight.data, 0.0, 0.02)
    

    done.