Search code examples
pythonpytorchreinforcement-learningdeep-copypruning

How to solve deepcopy error of a pruned model in pytorch


I am trying to build a RL model, where my actor network has some pruned connections. When using the data collector SyncDataCollector from torchrl, the deepcopy fails (see error below).

This seems to be due to the pruned connections, which sets the pruned layers with gradfn (and not requires_grad=True) as suggested in this post.

Here is an example of code I would like to run, where SyncDataCollector attempts a deepcopy of the model,

device = torch.device("cpu")

model = nn.Sequential(
    nn.Linear(1,5),
    nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)


policy_module = TensorDictModule(
    model, in_keys=["in"], out_keys=["out"]
)

env = FlyEnv()

collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=1,
    total_frames=2,
    split_trajs=False,
    device=device,
)

And here is a minimal example producing the error

import torch
from torch import nn
from copy import deepcopy

import torch.nn.utils.prune as prune

device = torch.device("cpu")

model = nn.Sequential(
    nn.Linear(1,5),
    nn.Linear(5,1)
)
mask = torch.tensor([1,0,0,1,0]).reshape(-1,1)
prune.custom_from_mask(model[0], name='weight', mask=mask)

new_model = deepcopy(model)

where the error is

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment.  If you were attempting to deepcopy a module, this may be because of a torch.nn.utils.weight_norm usage, see https://github.com/pytorch/pytorch/pull/103001

I tried to remove the pruning with prune.remove(model[0], 'weight') and then setting model[0].requires_grad_(), which fixes the result but then all the weights are trained...

I think it might work to mask the pruned weights "manually", by masking them before each forward pass, but it does not seem efficient (nor elegant).


Solution

  • The error is caused because the parameter is moved to <param>_orig and the masked value is stored alongside it. When the SyncDataCollector takes the params and buffers out and puts them on "meta" device to create a stateless policy, these additional values are ignored because they're not parameters anymore (and hence not caught by the call to "to").

    What you can do as a fix is to call

    policy_module.module[0].weight = policy_module.module[0].weight.detach()
    

    before creating the collector. That should be ok because the weight attribute will be recomputed during the next forward call anyway.

    TorchRL should maybe handle better the deepcopy, although in this case the error is caused by a tensor requiring gradients at a place where it shouldn't. IMO the pruning methods should compute the "weight" during forward call (as they do) but then prune