Search code examples
pythonpython-3.xmachine-learningpytorchtorch

Torch is not saving my freezed and optimized model


when I start my script it runs fine until it hits the traced_model.save(args.save_path) statement after that the script just stop running. Could someone please help me out with this?

import argparse
import torch
from model import SpeechRecognition
from collections import OrderedDict


def trace(model):
    model.eval()
    x = torch.rand(1, 81, 300)
    hidden = model._init_hidden(1)
    traced = torch.jit.trace(model, (x, hidden))
    return traced

def main(args):
    print("loading model from", args.model_checkpoint)
    checkpoint = torch.load(args.model_checkpoint, map_location=torch.device('cpu'))
    h_params = SpeechRecognition.hyper_parameters
    model = SpeechRecognition(**h_params)

    model_state_dict = checkpoint['state_dict']
    new_state_dict = OrderedDict()
    for k, v in model_state_dict.items():
        name = k.replace("model.", "") # remove `model.`
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict)

    print("tracing model...")
    traced_model = trace(model)
    print("saving to", args.save_path)
    traced_model.save(args.save_path)
    print("Done!")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="testing the wakeword engine")
    parser.add_argument('--model_checkpoint', type=str, default='your/checkpoint_file', required=False,
                        help='Checkpoint of model to optimize')
    parser.add_argument('--save_path', type=str, default='path/where/you/want/to/save/the/model', required=False,
                        help='path to save optmized model')

    args = parser.parse_args()
    main(args)

If you start the script you can even see where it stops working because print("Done!") is not executed. Here is what it looks in the terminal when I run the script:

loading model from C:/Users/supre/Documents/Python Programs/epoch=0-step=11999.ckpt
tracing model...
saving to C:/Users/supre/Documents/Python Programs

Solution

  • According to the PyTorch documentation, a common PyTorch convention is to save models using either a .pt or .pth file extension.

    To save model checkpoints or multiple components, organize them in a dictionary and use torch.save() to serialize the dictionary. For example,

    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                ...
                }, PATH)
    

    A common PyTorch convention is to save these checkpoints using the .tar file extension.

    Hope this answers your question.