I trained a QAT
(Quantization Aware Training) based model in Pytorch
, the training went on smoothly. However when I tried to load the weights into the fused model and run a test on widerface dataset I faced lots of errors:
(base) marian@u04-2:/mnt/s3user/Pytorch_Retinaface_quantized# python test_widerface.py --trained_model ./weights/mobilenet0.25_Final_quantized.pth --network mobile0.25layers:
Loading pretrained model from ./weights/mobilenet0.25_Final_quantized.pth
remove prefix 'module.'
Missing keys:235
Unused checkpoint keys:171
Used keys:65
Traceback (most recent call last):
File "/root/.vscode/extensions/ms-python.python-2020.1.58038/pythonFiles/ptvsd_launcher.py", line 43, in <module>
main(ptvsdArgs)
File "/root/.vscode/extensions/ms-python.python-2020.1.58038/pythonFiles/lib/python/old_ptvsd/ptvsd/__main__.py", line 432, in main
run()
File "/root/.vscode/extensions/ms-python.python-2020.1.58038/pythonFiles/lib/python/old_ptvsd/ptvsd/__main__.py", line 316, in run_file
runpy.run_path(target, run_name='__main__')
File "/root/anaconda3/lib/python3.7/runpy.py", line 263, in run_path
pkg_name=pkg_name, script_name=fname)
File "/root/anaconda3/lib/python3.7/runpy.py", line 96, in _run_module_code
mod_name, mod_spec, pkg_name, script_name)
File "/root/anaconda3/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/mnt/f3user/Pytorch_Retinaface_quantized/test_widerface.py", line 114, in <module>
net = load_model(net, args.trained_model, args.cpu)
File "/mnt/f3user/Pytorch_Retinaface_quantized/test_widerface.py", line 95, in load_model
model.load_state_dict(pretrained_dict, strict=False)
File "/root/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 830, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for RetinaFace:
While copying the parameter named "ssh1.conv3X3.0.weight", whose dimensions in the model are torch.Size([32, 64, 3, 3]) and whose dimensions in the checkpoint are torch.Size([32, 64, 3, 3]).
While copying the parameter named "ssh1.conv5X5_2.0.weight", whose dimensions in the model are torch.Size([16, 16, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 16, 3, 3]).
While copying the parameter named "ssh1.conv7x7_3.0.weight", whose dimensions in the model are torch.Size([16, 16, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 16, 3, 3]).
While copying the parameter named "ssh2.conv3X3.0.weight", whose dimensions in the model are torch.Size([32, 64, 3, 3]) and whose dimensions in the checkpoint are torch.Size([32, 64, 3, 3]).
While copying the parameter named "ssh2.conv5X5_2.0.weight", whose dimensions in the model are torch.Size([16, 16, 3, 3]) and whose dimensions in the checkpoint are torch.Size([16, 16, 3, 3]).
.....
The full list can be found here.
basically the weights cant be found. plus the scale and zero_point which are missing from the fused model.
in case it matters, the following snippet is the actual training loop which was used to train and save the model :
if __name__ == '__main__':
# train()
...
net = RetinaFace(cfg=cfg)
print("Printing net...")
print(net)
net.fuse_model()
...
net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(net, inplace=True)
print(f'quantization preparation done.')
...
quantized_model = net
for i in range(max_epoch):
net = net.to(device)
train_one_epoch(net, data_loader, optimizer, criterion, cfg, gamma, i, step_index, device)
if i in stepvalues:
step_index += 1
if i > 3 :
net.apply(torch.quantization.disable_observer)
if i > 2 :
net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
net=net.cpu()
quantized_model = torch.quantization.convert(net.eval(), inplace=False)
quantized_model.eval()
# evaluate on test set ?!
torch.save(net.state_dict(), save_folder + cfg['name'] + '_Final.pth')
torch.save(quantized_model.state_dict(), save_folder + cfg['name'] + '_Final_quantized.pth')
#torch.jit.save(torch.jit.script(quantized_model), save_folder + cfg['name'] + '_Final_quantized_jit.pth')
for testing the test_widerface.py
is used which can be accessed here
You can view the keys here
Why has this happened? How should this be taken care of?
I checked the name, and created a new state_dict dictionary and inserted the 112 keys that were in both checkpoint and model using the snippet below :
new_state_dict = {}
checkpoint_state_dict = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
for (ck, cp) in checkpoint_state_dict.items():
for (mk, mp) in model.state_dict().items():
kname,kext = os.path.splitext(ck)
mname,mext = os.path.splitext(mk)
# check the two parameter and see if they are the same
# then use models key naming scheme and use checkpoints weights
if kname+kext == mname+mext or kname+'.0'+kext == mname+mext:
new_state_dict[mname+mext] = cp
else:
if kext in ('.scale','.zero_point'):
new_state_dict[ck] = cp
and then use this new state_dict! yet I'm getting the ver same exact errors! meaning errors like this :
RuntimeError: Error(s) in loading state_dict for RetinaFace:
While copying the parameter named "ssh1.conv3X3.0.weight", whose dimensions in the model are torch.Size([32, 64, 3, 3]) and whose dimensions in the checkpoint are torch.Size([32, 64, 3, 3]).
This is really frustrating and there is no documentation concerning this! I'm completely clueless here.
I finally found out the cause. The error messages with the form of :
While copying the parameter named "xxx.weight", whose dimensions in the model are torch.Size([yyy]) and whose dimensions in the checkpoint are torch.Size([yyy]).
are actually generic messages, only returned when an exception has occured while copying the parameters in question.
Pytorch developers could easily add the actual exception args into this spurious yet unhelpful message, so it could actually help better debug the issue at hand. Anyway, looking at the exception which was by the way :
"copy_" not implemented for \'QInt8'
You'll now know what the actual issue is/was!