Search code examples
pythonpytorchfreezepruning

Freezing Individual Weights in Pytorch


The following question is not a duplicate of How to apply layer-wise learning rate in Pytorch? because this question aims at freezing a subset of a tensor from training rather than the entire layer.

I am trying out a PyTorch implementation of Lottery Ticket Hypothesis.

For that, I want to freeze the weights in a model that are zero. Is the following a correct way to implement it?

for name, p in model.named_parameters():
            if 'weight' in name:
                tensor = p.data.cpu().numpy()
                grad_tensor = p.grad.data.cpu().numpy()
                grad_tensor = np.where(tensor == 0, 0, grad_tensor)
                p.grad.data = torch.from_numpy(grad_tensor).to(device)

Solution

  • What you have seems like it would work provided you did it after loss.backward() and before optimizer.step() (referring to the common usage for these variable names). That said, it seems a bit convoluted. Also, if your weights are floating point values then comparing them to exactly zero is probably a bad idea, we could introduce an epsilon to account for this. IMO the following is a little cleaner than the solution you proposed:

    # locate zero-value weights before training loop
    EPS = 1e-6
    locked_masks = {n: torch.abs(w) < EPS for n, w in model.named_parameters() if n.endswith('weight')}
    
    ...
    
    for ... #training loop
    
        ...
    
        optimizer.zero_grad()
        loss.backward()
        # zero the gradients of interest
        for n, w in model.named_parameters():                                                                                                                                                                           
            if w.grad is not None and n in locked_masks:                                                                                                                                                                                   
                w.grad[locked_masks[n]] = 0 
        optimizer.step()