Search code examples
pythontensorflowtensorflow2.0tensorflow-lite

Converting saved_model to TFLite model using TF 2.0


currently I am working on converting custom object detection model (trained using SSD and inception network) to quantized TFLite model. I can able to convert custom object detection model from frozen graph to quantized TFLite model using the following code snippet (using Tensorflow 1.4):

converter = tf.lite.TFLiteConverter.from_frozen_graph(args["model"],input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])

converter.allow_custom_ops=True
converter.post_training_quantize=True 
tflite_model = converter.convert()
open(args["output"], "wb").write(tflite_model)

However tf.lite.TFLiteConverter.from_frozen_graph class method is not available for Tensorflow 2.0 (refer this link). So I tried to convert the model using tf.lite.TFLiteConverter.from_saved_model class method. The code snippet is shown below:

converter = tf.lite.TFLiteConverter.from_saved_model("/content/") # Path to saved_model directory
converter.optimizations =  [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

The above code snippet throws the following error:

ValueError: None is only supported in the 1st dimension. Tensor 'image_tensor' has invalid shape '[None, None, None, 3]'.

I tried to pass input_shapes as argument

converter = tf.lite.TFLiteConverter.from_saved_model("/content/",input_shapes={"image_tensor" : [1,300,300,3]})

but it throws the following error:

TypeError: from_saved_model() got an unexpected keyword argument 'input_shapes'

Am I missing something? Please feel free to correct me!


Solution

  • I got the solution using tf.compat.v1.lite.TFLiteConverter.from_frozen_graph. This compat.v1 brings the functionality of TF1.x into TF2.x. Following is the full code:

    converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph("/content/tflite_graph.pb",input_shapes = {'normalized_input_image_tensor':[1,300,300,3]},
        input_arrays = ['normalized_input_image_tensor'],output_arrays = ['TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1',
        'TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'])
    
    converter.allow_custom_ops=True
    
    # Convert the model to quantized TFLite model.
    converter.optimizations =  [tf.lite.Optimize.DEFAULT]
    tflite_model = converter.convert()
    
    
    # Write a model using the following line
    open("/content/uno_mobilenetV2.tflite", "wb").write(tflite_model)