Search code examples
pythonpytorchtorchscript

How to convert torchscript model in PyTorch to ordinary nn.Module?


I am loading the torchscript model in the following way:

model = torch.jit.load("model.pt").to(device)

The children modules of this model are identified as RecursiveScriptModule. I would like to finetune the uploaded weights and in order to make it simplier and cast them to torch.float32 It is preferable to convert all this stuff to ordinary PyTorch nn.Module.

In the official docs https://pytorch.org/docs/stable/jit.html it is told how to convert nn.Module to torchscript, but I have not found any examples in doing this in the opposite direction. Is there a way to do this?

P.S the example of loading model pretrained model is given here: https://github.com/openai/CLIP/blob/main/notebooks/Interacting_with_CLIP.ipynb


Solution

  • You may try to load it as it e.g. state_dict = torch.load(src).state_dict(). Then manually convert every key and value new_v = state_dict[k].cpu().float().