I'm trying to load the weights of a Pytorch model but getting this error: _pickle.UnpicklingError: invalid load key, '\x1f'.
Here is the weights loading code:
import os
import torch
import numpy as np
# from data_loader import VideoDataset
import timm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device being used:', device)
mname = os.path.join('./CDF2_0.pth')
checkpoints = torch.load(mname, map_location=device)
print("Checkpoint loaded successfully.")
model = timm.create_model('legacy_xception', pretrained=True, num_classes=2).to(device)
model.load_state_dict(checkpoints['state_dict'])
model.eval()
I have tried with different Pytorch versions. I have tried to inspect the weights by changing the extension to .zip and opening with archive manager but can't fix the issue. Here is a public link to the weights .pth file, I'm trying to load. Any help is highly appreciated as I have around 40 trained models that took around one month for training!
The error is typical when trying to open a gzip file as if it was a pickle or pytorch file, because gzips start with a 1f
byte. But this is not a valid gzip: it looks like a corrupted pytorch file.
Indeed, looking at hexdump -C file.pt | head
(shown below), most of it looks like a pytorch file (which should be a ZIP archive, not gzip, containing a python pickle file named data.pkl). But the first few bytes are wrong: instead of starting like a ZIP file as it should (bytes 50 4B
or ASCII PK
), it starts like a gzip
file (1f 8b 08 08
). In fact it's exactly as if the first 31 bytes were replaced with a valid, empty gzip file (with a timestamp ff 35 29 67
pointing to November 4, 2024 9:00:47 PM GMT).
Your file:
00000000 1f 8b 08 08 ff 35 29 67 02 ff 43 44 46 32 5f 30 |.....5)g..CDF2_0|
00000010 2e 70 74 68 00 03 00 00 00 00 00 00 00 00 00 44 |.pth...........D|
00000020 46 32 5f 30 2f 64 61 74 61 2e 70 6b 6c 46 42 0f |F2_0/data.pklFB.|
00000030 00 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a |.ZZZZZZZZZZZZZZZ|
00000040 80 02 7d 71 00 28 58 08 00 00 00 62 65 73 74 5f |..}q.(X....best_|
00000050 61 63 63 71 01 63 6e 75 6d 70 79 2e 63 6f 72 65 |accq.cnumpy.core|
...
(inspecting the pickle data we can see a dictionary {"best_acc": ..., "state_dict": ...})
with the typical contents of a checkpoint of a pytorch model).
A valid zipped pickle produced by torch.save({"best_acc": np.array([1]), "state_dict": ...}, "CDF2_0.pth")
:
00000000 50 4b 03 04 00 00 08 08 00 00 00 00 00 00 00 00 |PK..............|
00000010 00 00 00 00 00 00 00 00 00 00 0f 00 13 00 43 44 |..............CD|
00000020 46 32 5f 30 2f 64 61 74 61 2e 70 6b 6c 46 42 0f |F2_0/data.pklFB.|
00000030 00 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a 5a |.ZZZZZZZZZZZZZZZ|
00000040 80 02 7d 71 00 28 58 08 00 00 00 62 65 73 74 5f |..}q.(X....best_|
00000050 61 63 63 71 01 63 6e 75 6d 70 79 2e 63 6f 72 65 |accq.cnumpy.core|
...
A gzip containing an empty file with the same name and timestamp (with gzip --best
) has 31 bytes, the same as your file's prefix (except for the two 'Operating System' bytes):
00000000 1f 8b 08 08 ff 35 29 67 02 03 43 44 46 32 5f 30 |.....5)g..CDF2_0|
00000010 2e 70 74 68 00 03 00 00 00 00 00 00 00 00 00 |.pth...........|
Edit: Here's a script that might fix such files in general:
#!/usr/bin/env python3
import os
import sys
from pathlib import Path
from shutil import copy2
from tempfile import TemporaryDirectory
import numpy as np
import torch
CHUNK_SIZE = 4
def main(orig_path: Path) -> None:
fixed_path = orig_path.with_suffix(".fixed.pth")
copy2(orig_path, fixed_path)
with TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir) / orig_path.name
torch.save({"best_acc": np.array([1]), "state_dict": {}}, temp_path)
with open(temp_path, "rb") as f_temp:
with open(fixed_path, "rb+") as f_fixed:
while True:
content = f_fixed.read(CHUNK_SIZE)
replacement = f_temp.read(CHUNK_SIZE)
if content == replacement:
break
print(f"Replacing {content!r} with {replacement!r}")
f_fixed.seek(-CHUNK_SIZE, os.SEEK_CUR)
f_fixed.write(replacement)
if __name__ == "__main__":
assert len(sys.argv) == 2, "Expected exactly one argument (the path to the broken .pth file)."
main(Path(sys.argv[1]))