Search code examples
pythonpytorchtorch

How can I extract the weight and bias of Linear layers in PyTorch?


In model.state_dict(), model.parameters() and model.named_parameters() weights and biases of nn.Linear() modules are contained separately, e.q. fc1.weight and fc1.bias. Is there a simple pythonic way to get both of them?

Expected example looks similar to this:

layer = model['fc1']
print(layer.weight)
print(layer.bias)

Solution

  • You can recover the named parameters for each linear layer in your model like so:

    from torch import nn
    
    for layer in model.children():
        if isinstance(layer, nn.Linear):
            print(layer.state_dict()['weight'])
            print(layer.state_dict()['bias'])