Search code examples
pythonfilepytorchsavecheckpointing

Append model checkpoints to existing file in PyTorch


In PyTorch, it is possible to save model checkpoints as follows:

import torch

# Create a model
model = torch.nn.Sequential(
    torch.nn.Linear(1, 50),
    torch.nn.Tanh(),
    torch.nn.Linear(50, 1)
)

# ... some training here
# Save checkpoint
torch.save(network.state_dict(), 'checkpoint.pt')

During my training procedure, I save a checkpoint every 100 epochs or so. Currently this results in a folder with many files, e.g.

checkpoint0.pt
checkpoint100.pt
checkpoint200.pt

I was wondering if it was possible to append checkpoints to an existing file, so I don't clutter my disk with small files but instead have only a single file called checkpoints.pt. I currently have implemented this as follows:

import torch

# Create a model
model = torch.nn.Sequential(
    torch.nn.Linear(1, 50),
    torch.nn.Tanh(),
    torch.nn.Linear(50, 1)
)

# ... some training here
# Save 1st checkpoint
data = {'0': model.state_dict()}
torch.save(data, 'checkpoints.pt')

# ... some training here
# Save 2nd checkpoint
data = torch.load('checkpoints.pt')
data['100'] = model.state_dict()
torch.save(data, 'checkpoints.pt')

print(torch.load('checkpoints.pt'))

But the problem is it requires loading the existing file in memory before appending a new checkpoint, which is memory intensive especially considering that I have 100s of checkpoints. Is there a way to do this (or something similar) without having to load the existing checkpoints back into memory?


Solution

  • See this post on multiple pickled objects in the same file. The short of it is that pytorch checkpointing is backended by pickle, so if you use a trivial pickle wrapper rather than the default torch.save you can easily accomplish this:

    import _pickle as pickle # _pickle is the newer updated version (cpickle) I believe, with improved C-backend
    
    def append_save(network,path):
        with open(path,"ab") as f:
            pickle.dump(network.state_dict(),f)
    

    Now, you'll have to read each model state-dict serially from the file.

    def read_checkpoints(path):
      checkpoints = []
    
      with open(path,"rb") as f:
          while True:
              try:
                  checkpoints.append(pickle.load(f))
              except EOFError:
                  break