Search code examples
python-3.xpytorchcomputer-visionpruning

Access weights, biases in model - PyTorch pruning


To implement L1-norm, unstructured, layer-wise pruning with torch.nn.utils.prune l1_unstructured, radom_unstructured methods, I have a toy LeNet-300-100 dense neural network as-

class LeNet300(nn.Module):
    def __init__(self):
        super(LeNet300, self).__init__()
        
        # Define layers-
        self.fc1 = nn.Linear(in_features = 28 * 28 * 1, out_features = 300)
        self.fc2 = nn.Linear(in_features = 300, out_features = 100)
        self.output_layer = nn.Linear(in_features = 100, out_features = 10)
        
        self.weights_initialization()
    
    
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        x = self.output_layer(x)
        return x
    
    
    def weights_initialization(self):
        '''
        When we define all the modules such as the layers in '__init__()'
        method above, these are all stored in 'self.modules()'.
        We go through each module one by one. This is the entire network,
        basically.
        '''
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)
                nn.init.constant_(m.bias, 1)
    
    
    def shape_computation(self, x):
        print(f"Input shape: {x.shape}")
        x = self.fc1(x)
        print(f"dense1 output shape: {x.shape}")
        x = self.fc2(x)
        print(f"dense2 output shape: {x.shape}")
        x = self.output_layer(x)
        print(f"output shape: {x.shape}")
        del x
        return None

# Initialize architecture-
model = LeNet300().to(device)

This has 266610 trainable parameters. To prune this with 20% for the first two dense layers and 10% for the output layer until 99.5% sparsity, you need 25 pruning rounds. The pruning is done with- l1_unstructured(module = fc, name = 'weight', amount = 0.2)

Iterating through the layers is done with-

for name, module in model.named_modules():
    if name == '':
        continue
    else:
        print(f"layer: {name}, module: {module}")

However, for a particular module, how to access its weights and biases besides using- module.weight, module.bias ?

The idea is to use a layer-wise (for now) pruning function as-

# Prune multiple parameters/layers in a given model-
for name, module in model.named_modules():
  
    # prune 20% of weights/connections in for all hidden layaers-
    if isinstance(module, torch.nn.Linear) and name != 'output_layer':
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.2)
    
    # prune 10% of weights/connections for output layer-
    elif isinstance(module, torch.nn.Linear) and name == 'output_layer':
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.1)

Bias pruning will also be included.


Solution

  • The name parameter is the attribute name of the parameter, within the module, on which the pruning will be applied (see documentation page). As such, you can provide either 'weight' or 'bias' in your case since you are focusing on nn.Linear exclusively.

    Additionally, you will read that prune.l1_unstructured will:

    Modifies module in place (and also return the modified module) by:

    • adding a named buffer called name+'_mask' corresponding to the binary mask applied to the parameter name by the pruning method.

    • replacing the parameter name by its pruned version, while the original (unpruned) parameter is stored in a new parameter named name+'_orig'.

    You can access the pruned weights via the weight and bias attributes and the original parameters with weight_orig and bias_orig:

    m = nn.Linear(2, 3)
    m = prune.l1_unstructured(m, 'weight', amount=.2)
    m = prune.l1_unstructured(m, 'bias', amount=.2)
    
    >>> m.state_dict().keys()
    odict_keys(['weight_orig', 'bias_orig', 'weight_mask', 'bias_mask'])