Search code examples
pythontensorflowkeraspytorchonnx

Unable to load .pb while converting PyTorch model to tf.keras


Context

I'm using tf.keras for a personal project and I need to retrieve a pretrained Alexnet model. Unfortunately, this model is not directly accessible using tf.keras only, so I downloaded the pretrained model using PyTorch, converted it into an onnx file and then exported it as a .pb file with the following code :

torch_pretrained = torchvision.models.alexnet()
torch_pretrained.load_state_dict(torch.load("alexnet.pth"))

dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(torch_pretrained, dummy_input, "alexnet_pretrained.onnx")

onnx_pretrained = onnx.load("alexnet_pretrained.onnx")
onnx_pretrained = prepare(onnx_pretrained)
onnx_pretrained.export_graph('alexnet')

Issue

I'm now trying to retrieve the .pb file using keras as explained here with the following code :

model = tf.keras.models.load_model("alexnet")
model.summary()

And i get an error :

AttributeError: '_UserObject' object has no attribute 'summary'

I also get a warning while loading the model, but I don't think it's relevant :

WARNING:tensorflow:SavedModel saved prior to TF 2.5 detected when loading Keras model. Please ensure that you are saving the model with model.save() or tf.keras.models.save_model(), NOT tf.saved_model.save(). To confirm, there should be a file named "keras_metadata.pb" in the SavedModel directory.

The loaded model has a very obscure type as you can see :

<tensorflow.python.saved_model.load.Loader._recreate_base_user_object.._UserObject object at 0x0000023137981BB0>

While doing my researches, I found this which means I'm not the only one to face this issue.

Question

The easiest way would be to solve this specific issue, but if anyone knows of another way to load a pretrained Alexnet model to tf.keras, this would also solve my actual problem.

Specs

Windows 10
python 3.9.7
tensorflow 2.6.0
torch 1.10.2
torchvision 0.11.3
onnx 1.10.2
onnx-tf 1.9.0

Solution

  • Solution

    I followed the suggestion of Jakub : I installed "pytorch2keras" (see this). I just ran the function to convert directly the pytorch model into a keras model, and it actually worked.

    I only had to modify the code of the module as there were some depencies issues (they are using onnx.optimizer, which is now called onnxoptimizer) so I just changed the import line in :

    from

    from onnx import optimizer
    

    to

    import onnxoptimizer as optimizer