Search code examples
pythondeep-learningpytorchpickletorch

Error loading Pytorch model checkpoint: _pickle.UnpicklingError: invalid load key, '\x1f'


I'm trying to load the weights of a Pytorch model but getting this error: _pickle.UnpicklingError: invalid load key, '\x1f'.

Here is the weights loading code:

import os
import torch
import numpy as np
# from data_loader import VideoDataset
import timm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device being used:', device)

mname = os.path.join('./CDF2_0.pth')
checkpoints = torch.load(mname, map_location=device)
print("Checkpoint loaded successfully.")
model = timm.create_model('legacy_xception', pretrained=True, num_classes=2).to(device)
model.load_state_dict(checkpoints['state_dict'])
model.eval()

I have tried with different Pytorch versions. I have tried to inspect the weights by changing the extension to .zip and opening with archive manager but can't fix the issue. Here is a public link to the weights .pth file, I'm trying to load. Any help is highly appreciated as I have around 40 trained models that took around one month for training!


Solution

  • The error is typical when trying to open a gzip file as if it was a pickle or pytorch file, because gzips start with a 1f byte. But this is not a valid gzip: it looks like a corrupted pytorch file.

    Indeed, looking at hexdump -C file.pt | head (shown below), most of it looks like a pytorch file (which should be a ZIP archive, not gzip, containing a python pickle file named data.pkl). But the first few bytes are wrong: instead of starting like a ZIP file as it should (bytes 50 4B or ASCII PK), it starts like a gzip file (1f 8b 08 08). In fact it's exactly as if the first 31 bytes were replaced with a valid, empty gzip file (with a timestamp ff 35 29 67 pointing to November 4, 2024 9:00:47 PM GMT).

    Your file:

    00000000  1f 8b 08 08 ff 35 29 67  02 ff 43 44 46 32 5f 30  |.....5)g..CDF2_0|
    00000010  2e 70 74 68 00 03 00 00  00 00 00 00 00 00 00 44  |.pth...........D|
    00000020  46 32 5f 30 2f 64 61 74  61 2e 70 6b 6c 46 42 0f  |F2_0/data.pklFB.|
    00000030  00 5a 5a 5a 5a 5a 5a 5a  5a 5a 5a 5a 5a 5a 5a 5a  |.ZZZZZZZZZZZZZZZ|
    00000040  80 02 7d 71 00 28 58 08  00 00 00 62 65 73 74 5f  |..}q.(X....best_|
    00000050  61 63 63 71 01 63 6e 75  6d 70 79 2e 63 6f 72 65  |accq.cnumpy.core|
    ...
    

    (inspecting the pickle data we can see a dictionary {"best_acc": ..., "state_dict": ...}) with the typical contents of a checkpoint of a pytorch model).

    A valid zipped pickle produced by torch.save({"best_acc": np.array([1]), "state_dict": ...}, "CDF2_0.pth"):

    00000000  50 4b 03 04 00 00 08 08  00 00 00 00 00 00 00 00  |PK..............|
    00000010  00 00 00 00 00 00 00 00  00 00 0f 00 13 00 43 44  |..............CD|
    00000020  46 32 5f 30 2f 64 61 74  61 2e 70 6b 6c 46 42 0f  |F2_0/data.pklFB.|
    00000030  00 5a 5a 5a 5a 5a 5a 5a  5a 5a 5a 5a 5a 5a 5a 5a  |.ZZZZZZZZZZZZZZZ|
    00000040  80 02 7d 71 00 28 58 08  00 00 00 62 65 73 74 5f  |..}q.(X....best_|
    00000050  61 63 63 71 01 63 6e 75  6d 70 79 2e 63 6f 72 65  |accq.cnumpy.core|
    ...
    

    A gzip containing an empty file with the same name and timestamp (with gzip --best) has 31 bytes, the same as your file's prefix (except for the two 'Operating System' bytes):

    00000000  1f 8b 08 08 ff 35 29 67  02 03 43 44 46 32 5f 30  |.....5)g..CDF2_0|
    00000010  2e 70 74 68 00 03 00 00  00 00 00 00 00 00 00     |.pth...........|
    

    Edit: Here's a script that might fix such files in general:

    #!/usr/bin/env python3
    import os
    import sys
    from pathlib import Path
    from shutil import copy2
    from tempfile import TemporaryDirectory
    
    import numpy as np
    import torch
    
    CHUNK_SIZE = 4
    
    def main(orig_path: Path) -> None:
        fixed_path = orig_path.with_suffix(".fixed.pth")
        copy2(orig_path, fixed_path)
    
        with TemporaryDirectory() as temp_dir:
            temp_path = Path(temp_dir) / orig_path.name
            torch.save({"best_acc": np.array([1]), "state_dict": {}}, temp_path)
    
            with open(temp_path, "rb") as f_temp:
                with open(fixed_path, "rb+") as f_fixed:
                    while True:
                        content = f_fixed.read(CHUNK_SIZE)
                        replacement = f_temp.read(CHUNK_SIZE)
                        if content == replacement:
                            break
                        print(f"Replacing {content!r} with {replacement!r}")
                        f_fixed.seek(-CHUNK_SIZE, os.SEEK_CUR)
                        f_fixed.write(replacement)
    
    
    if __name__ == "__main__":
        assert len(sys.argv) == 2, "Expected exactly one argument (the path to the broken .pth file)."
        main(Path(sys.argv[1]))