Search code examples
pytorch

PyTorch - how to print the total number of trainable parameters for a single layer?


I'm trying to print the total number of trainable parameters for a single PyTorch layer. For example, given:

import torch.nn as nn

conv_layer = nn.Conv2d(in_channels=3,
                       out_channels=16,
                       kernel_size=3)

Based on others answers here is what I've tried so far:

print(conv_layer)

prints:

Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))

and:

print(conv_layer.parameters())

prints:

<generator object Module.parameters at 0x7fe99b7e1740>

What am I missing?


Solution

  • Derived from this article:

    import torch.nn as nn
    
    conv_layer = nn.Conv2d(in_channels=3,
                           out_channels=16,
                           kernel_size=3)
    
    num_trainable_params = sum([p.numel() for p in conv_layer.parameters()])
    
    print('\n' + 'num_trainable_params = ' + str(num_trainable_params) + '\n')
    

    Since it's a generator, you can iterate over it. In your case, the param_count_by_layer will be a list of length 1.

    Also, this posts cautions users if they use this approach while using a Tensorflow model;

    If you use torch_model.parameters(), the layers batchnorm in torch only show 2 values: weight and bias, while in tensorflow, 4 values of batchnorm are shown, which are gamma, beta and moving mean and moving var.

    The 2 later values are non-trainable parameters and they don’t show up in the torch_model.parameters(). If you compare torch_model.parameters() and tf_model.trainable_variables, they should be equal.