Search code examples
pythonmemoryparameterspytorchneural-network

How to get a flattened view of PyTorch model parameters?


I want to have a flattened view of the parameters belonging to my Pytorch model. Note it should be a view and not a copy of the parameters. In other words, when I modify the parameters in the view it should also modify the model parameter. I can get the model parameters as follows:

import torch

model = torch.nn.Sequential(
    torch.nn.Linear(1, 10), 
    torch.nn.Tanh(), 
    torch.nn.Linear(10, 1)
)

params = list(model.parameters())

for p in params:
    print(p)

Here params is a list of tensors. I need it to be a 1D tensor of all the parameter instead. It is trivial to do

params = torch.cat([p.flatten() for p in model.parameters()])
print(params.shape) # torch.Size([31])

However, now modifying a parameter in params then does not change the actual model parameter (since torch.cat() copies the memory). Is it possible to get a 1D tensor view of model parameters?


Solution

  • You should be able to achieve this by constructing the parameter tensor in 1d first and then copying views into the model:

    import torch
    
    model = torch.nn.Sequential(
        torch.nn.Linear(1, 10), 
        torch.nn.Tanh(), 
        torch.nn.Linear(10, 1)
    )
    
    def fuse_parameters(model):
        """Move model parameters to a contiguous tensor, and return that tensor."""
        n = sum(p.numel() for p in model.parameters())
        params = torch.zeros(n)
        i = 0
        for p in model.parameters():
            params_slice = params[i:i + p.numel()]
            params_slice.copy_(p.flatten())
            p.data = params_slice.view(p.shape)
            i += p.numel()
        return params
    
    print("before fusing parameters")
    with torch.no_grad(): print(model(torch.ones(3, 1)).flatten())
    params = fuse_parameters(model)
    print("after fusing parameters")
    with torch.no_grad(): print(model(torch.ones(3, 1)).flatten());
    params.mul_(2)
    print("after modifying fused parameters")
    with torch.no_grad(): print(model(torch.ones(3, 1)).flatten())
    

    This prints:

    before fusing parameters
    tensor([-0.3356, -0.3356, -0.3356])
    after fusing parameters
    tensor([-0.3356, -0.3356, -0.3356])
    after modifying fused parameters
    tensor([-0.7728, -0.7728, -0.7728])
    

    (Doing this kind of thing post-hoc - creating a single tensor view out of multiple different tensors - does not appear to be supported in PyTorch as of 2022-12. There's preliminary support for nested tensors, but the nested tensors still copy data.)


    If you also want to fuse grads

    You can do the same with gradients, and store the fused model gradients as a 1d tensor in params.grad

    import torch
    
    model = torch.nn.Sequential(
        torch.nn.Linear(1, 10), 
        torch.nn.Tanh(), 
        torch.nn.Linear(10, 1)
    )
    
    def fuse_parameters_and_gradients(model):
        """Move model parameters and gradients to a contiguous tensor, and return that tensor."""
        n = sum(p.numel() for p in model.parameters())
        params = torch.zeros(n, requires_grad=True)
        params.grad = torch.zeros(n)
        i = 0
        for p in model.parameters():
            params_slice = params[i:i + p.numel()]
            with torch.no_grad(): params_slice.copy_(p.flatten())
            p.data = params_slice.view(p.shape)
            p.grad = params.grad[i:i + p.numel()].view(p.shape)
            i += p.numel()
        return params
    
    params = fuse_parameters_and_gradients(model)
    
    print("params and grad before optimizer step")
    with torch.no_grad(): print(params[:3], params.grad[:3])
    
    opt = torch.optim.SGD(model.parameters(), 0.1)
    opt.zero_grad(); model(torch.ones(1, 1)).backward(); opt.step()
    print("\n(optimizer step)\n")
    
    print("params and grad after optimizer step")
    with torch.no_grad(): print(params[:3], params.grad[:3])
    

    This prints:

    params and grad before optimizer step
    tensor([0.8219, 0.6089, 0.8708], requires_grad=True) tensor([0., 0., 0.])
    
    (optimizer step)
    
    params and grad after optimizer step
    tensor([0.7974, 0.6136, 0.8743], requires_grad=True) tensor([ 0.2457, -0.0474, -0.0341])