Search code examples
pythonmachine-learningdeep-learningpytorchtorchscript

what is the right usage of _extra_files arg in torch.jit.save


one option I tried is pickling vocab and saving with extrafiles arg

import torch
import pickle

class Vocab(object):
    pass

vocab = Vocab()
pickle.dump(open('path/to/vocab.pkl','w'))

m = torch.jit.ScriptModule()

## I am not sure about the usage of this arg, the docs didn't help me
extra_files = torch._C.ExtraFilesMap()
extra_files['vocab.pkl'] = 'path/to/vocab.pkl'
# I also tried  pickle.dumps(vocab), and directly vocab

torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

## Load with extra files.
files = {'vocab.pkl': ''}
torch.jit.load('scriptmodule.pt', _extra_files = files)

this gives

TypeError: import_ir_module(): incompatible function arguments. The following argument types are supported:
    1. (arg0: Callable[[List[str]], torch._C.ScriptModule], arg1: str, arg2: object, arg3: torch._C.ExtraFilesMap) -> None

other option is obviously to load the pickle separately, but I was looking for single file option.

it would be nice if one could just add vocab to to the torchscript ... it would also be nice to know if there is some reason for not doing this that I am obviously not aware of.


Solution

  • I believe that the documentation for torch.jit.load is incorrect. You need to create an ExtraFilesmap() object to load the saved files.

    The following is an example of how I got things to work: Step 1: Save model

    extra_files = torch._C.ExtraFilesMap()
    extra_files['foo.txt'] = 'bar'
    traced_script_module.save(serialized_model_path, _extra_files=extra_files)
    

    Step 2: Load model

    files = torch._C.ExtraFilesMap()
    files['foo.txt'] = ''
    loaded_model = torch.jit.load(serialized_model_path, _extra_files=files)
    print(files)